Visual Computing Library
Loading...
Searching...
No Matches
kd_tree.h
1/*****************************************************************************
2 * VCLib *
3 * Visual Computing Library *
4 * *
5 * Copyright(C) 2021-2025 *
6 * Visual Computing Lab *
7 * ISTI - Italian National Research Council *
8 * *
9 * All rights reserved. *
10 * *
11 * This program is free software; you can redistribute it and/or modify *
12 * it under the terms of the Mozilla Public License Version 2.0 as published *
13 * by the Mozilla Foundation; either version 2 of the License, or *
14 * (at your option) any later version. *
15 * *
16 * This program is distributed in the hope that it will be useful, *
17 * but WITHOUT ANY WARRANTY; without even the implied warranty of *
18 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
19 * Mozilla Public License Version 2.0 *
20 * (https://www.mozilla.org/en-US/MPL/2.0/) for more details. *
21 ****************************************************************************/
22
23#ifndef VCL_SPACE_COMPLEX_KD_TREE_H
24#define VCL_SPACE_COMPLEX_KD_TREE_H
25
26#include <vclib/concepts/mesh.h>
27#include <vclib/space/core/box.h>
28
29#include <numeric>
30#include <queue>
31#include <vector>
32
33namespace vcl {
34
35template<PointConcept PointType>
36class KDTree
37{
38 using Scalar = PointType::ScalarType;
39
40 struct Node
41 {
42 union
43 {
44 // standard node
45 struct
46 {
47 Scalar splitValue;
48 uint firstChildId : 24;
49 uint dim : 2;
50 uint leaf : 1;
51 };
52
53 // leaf
54 struct
55 {
56 uint start;
57 unsigned short size;
58 };
59 };
60 };
61
62 struct QueryNode
63 {
64 QueryNode() {}
65
66 QueryNode(uint id) : nodeId(id) {}
67
68 uint nodeId; // id of the next node
69 Scalar sq; // squared distance to the next node
70 };
71
72 // dummy values
73 inline static Scalar dummyScalar;
74 inline static std::vector<Scalar> dummyScalars;
75
76 std::vector<PointType> mPoints;
77 std::vector<uint> mIndices;
78 std::vector<Node> mNodes;
79
80 uint mPointsPerCell = 16; // min number of point in a leaf
81 uint mMaxDepth = 64; // max tree depth
82 uint mDepth = 0; // actual tree depth
83
84public:
85 KDTree() {}
86
87 KDTree(
88 const std::vector<PointType>& points,
89 uint pointsPerCell = 16,
90 uint maxDepth = 64,
91 bool balanced = false) :
92 mPoints(points), mIndices(points.size()),
93 mPointsPerCell(pointsPerCell), mMaxDepth(maxDepth)
94 {
95 std::iota(std::begin(mIndices), std::end(mIndices), 0);
96 mNodes.resize(1);
97 mNodes.back().leaf = 0;
98
99 mDepth = createTree(0, 0, points.size(), 1, balanced);
100 }
101
114 template<MeshConcept MeshType>
116 const MeshType& m,
117 uint pointsPerCell = 16,
118 uint maxDepth = 64,
119 bool balanced = false)
120 requires (std::is_same_v<
121 typename MeshType::VertexType::CoordType,
122 PointType>)
123 :
124 mPoints(m.vertexNumber()), mIndices(m.vertexNumber()),
125 mPointsPerCell(pointsPerCell), mMaxDepth(maxDepth)
126 {
127 using VertexType = MeshType::VertexType;
128
129 uint i = 0;
130 for (const VertexType& v : m.vertices()) {
131 mPoints[i] = v.coord();
132 mIndices[i] = m.index(v);
133 i++;
134 }
135 mNodes.resize(1);
136 mNodes.back().leaf = 0;
137
138 mDepth = createTree(0, 0, mPoints.size(), 1, balanced);
139 }
140
148 const PointType& queryPoint,
149 Scalar& dist = dummyScalar) const
150 {
151 std::vector<QueryNode> mNodeStack(mDepth + 1);
152 mNodeStack[0].nodeId = 0;
153 mNodeStack[0].sq = 0.;
154 unsigned int count = 1;
155
156 int minIndex = mIndices.size() / 2;
157 Scalar minDist = queryPoint.squaredDist(mPoints[minIndex]);
158 minIndex = mIndices[minIndex];
159
160 while (count) {
162 const Node& node = mNodes[qnode.nodeId];
163
164 if (qnode.sq < minDist) {
165 if (node.leaf) {
166 --count; // pop
167 uint end = node.start + node.size;
168 for (uint i = node.start; i < end; ++i) {
169 Scalar pointSquareDist =
170 queryPoint.squaredDist(mPoints[i]);
171 if (pointSquareDist < minDist) {
172 minDist = pointSquareDist;
173 minIndex = mIndices[i];
174 }
175 }
176 }
177 else {
178 // replace the stack top by the farthest and push the
179 // closest
180 Scalar new_off = queryPoint[node.dim] - node.splitValue;
181 if (new_off < 0.) {
182 mNodeStack[count].nodeId = node.firstChildId;
183 qnode.nodeId = node.firstChildId + 1;
184 }
185 else {
186 mNodeStack[count].nodeId = node.firstChildId + 1;
187 qnode.nodeId = node.firstChildId;
188 }
189 mNodeStack[count].sq = qnode.sq;
190 qnode.sq = new_off * new_off;
191 ++count;
192 }
193 }
194 else {
195 // pop
196 --count;
197 }
198 }
199 dist = std::sqrt(minDist);
200 return minIndex;
201 }
202
203 PointType nearestNeighbor(
204 const PointType& queryPoint,
205 Scalar& dist = dummyScalar) const
206 {
207 return mPoints[nearestNeighborIndex(queryPoint, dist)];
208 }
209
227 std::vector<uint> kNearestNeighborsIndices(
228 const PointType& queryPoint,
229 uint k,
230 std::vector<Scalar>& distances = dummyScalars) const
231 {
232 struct P
233 {
234 P(const PointType& p, int i) : p(p), i(i) {}
235
236 PointType p;
237 uint i;
238 };
239
240 struct Comparator
241 {
242 const PointType qp;
243
244 Comparator(const PointType& qp) : qp(qp) {}
245
246 bool operator()(const P& p1, const P& p2)
247 {
248 if (qp.squaredDist(p1.p) > qp.squaredDist(p2.p))
249 return true;
250 return false;
251 }
252 };
253
255
256 std::priority_queue<P, std::vector<P>, Comparator> neighborQueue(cmp);
257
258 std::vector<QueryNode> mNodeStack(mDepth + 1);
259 mNodeStack[0].nodeId = 0;
260 mNodeStack[0].sq = 0.;
261 unsigned int count = 1;
262
263 while (count) {
264 // we select the last node (AABB) inserted in the stack
266
267 // while going down the tree qnode.nodeId is the nearest sub-tree,
268 // otherwise, in backtracking, qnode.nodeId is the other sub-tree
269 // that will be visited iff the actual nearest node is further than
270 // the split distance.
271 const Node& node = mNodes[qnode.nodeId];
272
273 // if the distance is less than the top of the max-heap, it could be
274 // one of the k-nearest neighbours
275 if (neighborQueue.size() < k ||
276 qnode.sq < queryPoint.squaredDist(neighborQueue.top().p)) {
277 // when we arrive to a leaf
278 if (node.leaf) {
279 --count; // pop of the leaf
280
281 // end is the index of the last element of the leaf in
282 // mPoints
283 unsigned int end = node.start + node.size;
284 // adding the element of the leaf to the heap
285 for (unsigned int i = node.start; i < end; ++i)
286 neighborQueue.push(P(mPoints[i], mIndices[i]));
287 }
288 // otherwise, if we're not on a leaf
289 else {
290 // the new offset is the distance between the searched point
291 // and the actual split coordinate
292 Scalar new_off = queryPoint[node.dim] - node.splitValue;
293
294 // left sub-tree
295 if (new_off < 0.) {
296 mNodeStack[count].nodeId = node.firstChildId;
297 // in the father's nodeId we save the index of the other
298 // sub-tree (for backtracking)
299 qnode.nodeId = node.firstChildId + 1;
300 }
301 // right sub-tree (same as above)
302 else {
303 mNodeStack[count].nodeId = node.firstChildId + 1;
304 qnode.nodeId = node.firstChildId;
305 }
306 // distance is inherited from the father (while descending
307 // the tree it's equal to 0)
308 mNodeStack[count].sq = qnode.sq;
309 // distance of the father is the squared distance from the
310 // split plane
311 qnode.sq = new_off * new_off;
312 ++count;
313 }
314 }
315 else {
316 // pop
317 --count;
318 }
319 }
320 distances.clear();
321 std::vector<uint> res;
322
323 uint i = 0;
324 while (!neighborQueue.empty() && i < k) {
325 res.push_back(neighborQueue.top().i);
326 distances.push_back(queryPoint.dist(neighborQueue.top().p));
327 neighborQueue.pop();
328 i++;
329 }
330 return res;
331 }
332
333 std::vector<PointType> kNearestNeighbors(
334 const PointType& queryPoint,
335 uint k,
336 std::vector<Scalar>& distances = dummyScalars) const
337 {
338 std::vector<uint> dists =
340 std::vector<PointType> res;
341 res.reserve(dists.size());
342 for (uint k : dists) {
343 res.push_back(mPoints[k]);
344 }
345 return res;
346 }
347
355 std::vector<uint> neighborsIndicesInDistance(
356 const PointType& queryPoint,
357 Scalar dist,
358 std::vector<Scalar>& distances = dummyScalars) const
359 {
360 std::vector<uint> queryPoints;
361 distances.clear();
362 std::vector<QueryNode> mNodeStack(mDepth + 1);
363 mNodeStack[0].nodeId = 0;
364 mNodeStack[0].sq = 0.;
365 unsigned int count = 1;
366
367 Scalar squareDist = dist * dist;
368 while (count) {
370 const Node& node = mNodes[qnode.nodeId];
371
372 if (qnode.sq < squareDist) {
373 if (node.leaf) {
374 --count; // pop
375 unsigned int end = node.start + node.size;
376 for (unsigned int i = node.start; i < end; ++i) {
377 Scalar pointSquareDist =
378 queryPoint.squareDist(mPoints[i]);
380 queryPoints.push_back(mIndices[i]);
381 distances.push_back(queryPoint.dist(mPoints[i]));
382 }
383 }
384 }
385 else {
386 // replace the stack top by the farthest and push the
387 // closest
388 Scalar new_off = queryPoint[node.dim] - node.splitValue;
389 if (new_off < 0.) {
390 mNodeStack[count].nodeId = node.firstChildId;
391 qnode.nodeId = node.firstChildId + 1;
392 }
393 else {
394 mNodeStack[count].nodeId = node.firstChildId + 1;
395 qnode.nodeId = node.firstChildId;
396 }
397 mNodeStack[count].sq = qnode.sq;
398 qnode.sq = new_off * new_off;
399 ++count;
400 }
401 }
402 else {
403 // pop
404 --count;
405 }
406 }
407 return queryPoints;
408 }
409
410 std::vector<PointType> neighborsInDistance(
411 const PointType& queryPoint,
412 Scalar dist,
413 std::vector<Scalar>& distances = dummyScalars) const
414 {
415 std::vector<uint> dists =
417 std::vector<PointType> res;
418 res.reserve(dists.size());
419 for (uint k : dists) {
420 res.push_back(mPoints[k]);
421 }
422 return res;
423 }
424
425private:
448 uint nodeId,
449 uint start,
450 uint end,
451 uint level,
452 bool balanced)
453 {
454 // select the first node
455 Node& node = mNodes[nodeId];
457
458 // putting all the points in the bounding box
459 aabb.add(mPoints[start]);
460 for (uint i = start + 1; i < end; ++i)
461 aabb.add(mPoints[i]);
462
463 // bounding box diagonal
464 PointType diag = aabb.max() - aabb.min();
465
466 // the split "dim" is the dimension of the box with the biggest value
467 uint dim = 0;
468 Scalar max = std::numeric_limits<double>::lowest();
469 for (uint i = 0; i < PointType::DIM; ++i) {
470 if (diag[i] > max) {
471 max = diag[i];
472 dim = i;
473 }
474 }
475
476 node.dim = dim;
477 // we divide the points using the median value along the "dim" dimension
478 if (balanced) {
479 std::vector<Scalar> tempVector;
480 for (uint i = start + 1; i < end; ++i)
481 tempVector.push_back(mPoints[i][dim]);
482 std::sort(tempVector.begin(), tempVector.end());
483 node.splitValue = (tempVector[tempVector.size() / 2.0] +
484 tempVector[tempVector.size() / 2.0 + 1]) /
485 2.0;
486 }
487 // we divide the bounding box in 2 partitions, considering the average
488 // of the "dim" dimension
489 else {
490 node.splitValue = Scalar(0.5 * (aabb.max()[dim] + aabb.min()[dim]));
491 }
492
493 // midId is the index of the first element in the second partition
494 unsigned int midId = split(start, end, dim, node.splitValue);
495
496 node.firstChildId = mNodes.size();
497 mNodes.resize(mNodes.size() + 2);
498 bool flag = (midId == start) || (midId == end);
500
501 // left child
502 unsigned int childId = mNodes[nodeId].firstChildId;
503 Node& childL = mNodes[childId];
504 if (flag || (midId - start) <= mPointsPerCell || level >= mMaxDepth) {
505 childL.leaf = 1;
506 childL.start = start;
507 childL.size = midId - start;
509 }
510 else {
511 childL.leaf = 0;
513 }
514
515 // right child
516 childId = mNodes[nodeId].firstChildId + 1;
517 Node& childR = mNodes[childId];
518 if (flag || (end - midId) <= mPointsPerCell || level >= mMaxDepth) {
519 childR.leaf = 1;
520 childR.start = midId;
521 childR.size = end - midId;
523 }
524 else {
525 childR.leaf = 0;
527 }
528
529 if (leftLevel > rightLevel)
530 return leftLevel;
531 return rightLevel;
532 }
533
540 uint split(uint start, uint end, uint dim, Scalar splitValue)
541 {
542 uint l, r;
543 for (l = start, r = end - 1; l < r; ++l, --r) {
544 while (l < end && mPoints[l][dim] < splitValue)
545 l++;
546 while (r >= start && mPoints[r][dim] >= splitValue)
547 r--;
548 if (l > r)
549 break;
550 std::swap(mPoints[l], mPoints[r]);
551 std::swap(mIndices[l], mIndices[r]);
552 }
553
554 // returns the index of the first element on the second part
555 return (mPoints[l][dim] < splitValue ? l + 1 : l);
556 }
557};
558
559/* Deduction guides */
560
561template<MeshConcept MeshType>
562KDTree(const MeshType& m) -> KDTree<typename MeshType::VertexType::CoordType>;
563
564template<MeshConcept MeshType>
565KDTree(const MeshType& m, uint pointsPerCell)
566 -> KDTree<typename MeshType::VertexType::CoordType>;
567
568template<MeshConcept MeshType>
569KDTree(const MeshType& m, uint pointsPerCell, uint maxDepth)
570 -> KDTree<typename MeshType::VertexType::CoordType>;
571
572template<MeshConcept MeshType>
573KDTree(const MeshType& m, uint pointsPerCell, uint maxDepth, bool balanced)
574 -> KDTree<typename MeshType::VertexType::CoordType>;
575
576} // namespace vcl
577
578#endif // VCL_SPACE_COMPLEX_KD_TREE_H
Definition kd_tree.h:37
std::vector< uint > neighborsIndicesInDistance(const PointType &queryPoint, Scalar dist, std::vector< Scalar > &distances=dummyScalars) const
Performs the distance query.
Definition kd_tree.h:355
uint nearestNeighborIndex(const PointType &queryPoint, Scalar &dist=dummyScalar) const
Searchs the closest point.
Definition kd_tree.h:147
uint createTree(uint nodeId, uint start, uint end, uint level, bool balanced)
Rrecursively builds the kdtree.
Definition kd_tree.h:447
uint split(uint start, uint end, uint dim, Scalar splitValue)
Split the subarray between start and end in two part, one with the elements less than splitValue,...
Definition kd_tree.h:540
KDTree(const MeshType &m, uint pointsPerCell=16, uint maxDepth=64, bool balanced=false)
Builds the KDTree starting from the given mesh.
Definition kd_tree.h:115
std::vector< uint > kNearestNeighborsIndices(const PointType &queryPoint, uint k, std::vector< Scalar > &distances=dummyScalars) const
Performs the k nearest neighbour query.
Definition kd_tree.h:227
A class representing a line segment in n-dimensional space. The class is parameterized by a PointConc...
Definition segment.h:43
constexpr auto max(const T &p1, const T &p2)
Returns the maximum between the two parameters.
Definition min_max.h:83
Definition kd_tree.h:41
Definition kd_tree.h:63