31 #ifndef OPENCV_FLANN_KDTREE_SINGLE_INDEX_H_
32 #define OPENCV_FLANN_KDTREE_SINGLE_INDEX_H_
42 #include "result_set.h"
44 #include "allocator.h"
51 struct KDTreeSingleIndexParams :
public IndexParams
53 KDTreeSingleIndexParams(
int leaf_max_size = 10,
bool reorder =
true,
int dim = -1)
55 (*this)[
"algorithm"] = FLANN_INDEX_KDTREE_SINGLE;
56 (*this)[
"leaf_max_size"] = leaf_max_size;
57 (*this)[
"reorder"] = reorder;
69 template <
typename Distance>
70 class KDTreeSingleIndex :
public NNIndex<Distance>
73 typedef typename Distance::ElementType ElementType;
74 typedef typename Distance::ResultType DistanceType;
84 KDTreeSingleIndex(
const Matrix<ElementType>& inputData,
const IndexParams& params = KDTreeSingleIndexParams(),
85 Distance d = Distance() ) :
86 dataset_(inputData), index_params_(params), distance_(d)
88 size_ = dataset_.rows;
91 int dim_param = get_param(params,
"dim",-1);
92 if (dim_param>0) dim_ = dim_param;
93 leaf_max_size_ = get_param(params,
"leaf_max_size",10);
94 reorder_ = get_param(params,
"reorder",
true);
98 for (
size_t i = 0; i < size_; i++) {
103 KDTreeSingleIndex(
const KDTreeSingleIndex&);
104 KDTreeSingleIndex& operator=(
const KDTreeSingleIndex&);
111 if (reorder_)
delete[] data_.data;
119 computeBoundingBox(root_bbox_);
120 root_node_ = divideTree(0, (
int)size_, root_bbox_ );
124 data_ = cvflann::Matrix<ElementType>(
new ElementType[size_*dim_], size_, dim_);
125 for (
size_t i=0; i<size_; ++i) {
126 for (
size_t j=0; j<dim_; ++j) {
127 data_[i][j] = dataset_[vind_[i]][j];
138 return FLANN_INDEX_KDTREE_SINGLE;
144 save_value(stream, size_);
145 save_value(stream, dim_);
146 save_value(stream, root_bbox_);
147 save_value(stream, reorder_);
148 save_value(stream, leaf_max_size_);
149 save_value(stream, vind_);
151 save_value(stream, data_);
153 save_tree(stream, root_node_);
159 load_value(stream, size_);
160 load_value(stream, dim_);
161 load_value(stream, root_bbox_);
162 load_value(stream, reorder_);
163 load_value(stream, leaf_max_size_);
164 load_value(stream, vind_);
166 load_value(stream, data_);
171 load_tree(stream, root_node_);
174 index_params_[
"algorithm"] = getType();
175 index_params_[
"leaf_max_size"] = leaf_max_size_;
176 index_params_[
"reorder"] = reorder_;
201 return (
int)(pool_.usedMemory+pool_.wastedMemory+dataset_.rows*
sizeof(int));
213 void knnSearch(
const Matrix<ElementType>& queries, Matrix<int>& indices, Matrix<DistanceType>& dists,
int knn,
const SearchParams& params)
CV_OVERRIDE
221 KNNSimpleResultSet<DistanceType> resultSet(knn);
222 for (
size_t i = 0; i < queries.rows; i++) {
223 resultSet.init(indices[i], dists[i]);
224 findNeighbors(resultSet, queries[i], params);
230 return index_params_;
242 void findNeighbors(ResultSet<DistanceType>&
result,
const ElementType* vec,
const SearchParams& searchParams)
CV_OVERRIDE
244 float epsError = 1+get_param(searchParams,
"eps",0.0f);
247 DistanceType distsq = computeInitialDistances(vec, dists);
248 searchLevel(
result, vec, root_node_, distsq, dists, epsError);
268 DistanceType divlow, divhigh;
272 Node* child1, * child2;
274 typedef Node* NodePtr;
279 DistanceType low, high;
284 typedef BranchStruct<NodePtr, DistanceType> BranchSt;
285 typedef BranchSt* Branch;
290 void save_tree(FILE* stream, NodePtr tree)
292 save_value(stream, *tree);
293 if (tree->child1!=NULL) {
294 save_tree(stream, tree->child1);
296 if (tree->child2!=NULL) {
297 save_tree(stream, tree->child2);
302 void load_tree(FILE* stream, NodePtr& tree)
304 tree = pool_.allocate<Node>();
305 load_value(stream, *tree);
306 if (tree->child1!=NULL) {
307 load_tree(stream, tree->child1);
309 if (tree->child2!=NULL) {
310 load_tree(stream, tree->child2);
315 void computeBoundingBox(BoundingBox& bbox)
318 for (
size_t i=0; i<dim_; ++i) {
319 bbox[i].low = (DistanceType)dataset_[0][i];
320 bbox[i].high = (DistanceType)dataset_[0][i];
322 for (
size_t k=1;
k<dataset_.
rows; ++
k) {
323 for (
size_t i=0; i<dim_; ++i) {
324 if (dataset_[
k][i]<bbox[i].low) bbox[i].low = (DistanceType)dataset_[
k][i];
325 if (dataset_[
k][i]>bbox[i].high) bbox[i].high = (DistanceType)dataset_[
k][i];
340 NodePtr divideTree(
int left,
int right, BoundingBox& bbox)
342 NodePtr node = pool_.allocate<Node>();
345 if ( (right-left) <= leaf_max_size_) {
346 node->child1 = node->child2 = NULL;
351 for (
size_t i=0; i<dim_; ++i) {
352 bbox[i].low = (DistanceType)dataset_[vind_[left]][i];
353 bbox[i].high = (DistanceType)dataset_[vind_[left]][i];
356 for (
size_t i=0; i<dim_; ++i) {
357 if (bbox[i].low>dataset_[vind_[
k]][i]) bbox[i].low=(DistanceType)dataset_[vind_[
k]][i];
358 if (bbox[i].high<dataset_[vind_[
k]][i]) bbox[i].high=(DistanceType)dataset_[vind_[
k]][i];
366 middleSplit_(&vind_[0]+left, right-left,
idx, cutfeat, cutval, bbox);
368 node->divfeat = cutfeat;
370 BoundingBox left_bbox(bbox);
371 left_bbox[cutfeat].high = cutval;
372 node->child1 = divideTree(left, left+
idx, left_bbox);
374 BoundingBox right_bbox(bbox);
375 right_bbox[cutfeat].low = cutval;
376 node->child2 = divideTree(left+
idx, right, right_bbox);
378 node->divlow = left_bbox[cutfeat].high;
379 node->divhigh = right_bbox[cutfeat].low;
381 for (
size_t i=0; i<dim_; ++i) {
382 bbox[i].low =
std::min(left_bbox[i].low, right_bbox[i].low);
383 bbox[i].high =
std::max(left_bbox[i].high, right_bbox[i].high);
390 void computeMinMax(
int* ind,
int count,
int dim, ElementType& min_elem, ElementType& max_elem)
392 min_elem = dataset_[ind[0]][dim];
393 max_elem = dataset_[ind[0]][dim];
394 for (
int i=1; i<
count; ++i) {
395 ElementType val = dataset_[ind[i]][dim];
396 if (val<min_elem) min_elem = val;
397 if (val>max_elem) max_elem = val;
401 void middleSplit(
int* ind,
int count,
int&
index,
int& cutfeat, DistanceType& cutval,
const BoundingBox& bbox)
404 ElementType max_span = bbox[0].high-bbox[0].low;
406 cutval = (bbox[0].high+bbox[0].low)/2;
407 for (
size_t i=1; i<dim_; ++i) {
408 ElementType span = bbox[i].high-bbox[i].low;
412 cutval = (bbox[i].high+bbox[i].low)/2;
417 ElementType min_elem, max_elem;
418 computeMinMax(ind,
count, cutfeat, min_elem, max_elem);
419 cutval = (min_elem+max_elem)/2;
420 max_span = max_elem - min_elem;
424 for (
size_t i=0; i<dim_; ++i) {
426 ElementType span = bbox[i].high-bbox[i].low;
428 computeMinMax(ind,
count, i, min_elem, max_elem);
429 span = max_elem - min_elem;
433 cutval = (min_elem+max_elem)/2;
438 planeSplit(ind,
count, cutfeat, cutval, lim1, lim2);
446 void middleSplit_(
int* ind,
int count,
int&
index,
int& cutfeat, DistanceType& cutval,
const BoundingBox& bbox)
448 const float EPS=0.00001f;
449 DistanceType max_span = bbox[0].high-bbox[0].low;
450 for (
size_t i=1; i<dim_; ++i) {
451 DistanceType span = bbox[i].high-bbox[i].low;
456 DistanceType max_spread = -1;
458 for (
size_t i=0; i<dim_; ++i) {
459 DistanceType span = bbox[i].high-bbox[i].low;
460 if (span>(DistanceType)((1-EPS)*max_span)) {
461 ElementType min_elem, max_elem;
462 computeMinMax(ind,
count, (
int)i, min_elem, max_elem);
463 DistanceType spread = (DistanceType)(max_elem-min_elem);
464 if (spread>max_spread) {
471 DistanceType split_val = (bbox[cutfeat].low+bbox[cutfeat].high)/2;
472 ElementType min_elem, max_elem;
473 computeMinMax(ind,
count, cutfeat, min_elem, max_elem);
475 if (split_val<min_elem) cutval = (DistanceType)min_elem;
476 else if (split_val>max_elem) cutval = (DistanceType)max_elem;
477 else cutval = split_val;
480 planeSplit(ind,
count, cutfeat, cutval, lim1, lim2);
497 void planeSplit(
int* ind,
int count,
int cutfeat, DistanceType cutval,
int& lim1,
int& lim2)
503 while (left<=right && dataset_[ind[left]][cutfeat]<cutval) ++
left;
504 while (left<=right && dataset_[ind[right]][cutfeat]>=cutval) --
right;
505 if (left>right)
break;
514 while (left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++
left;
515 while (left<=right && dataset_[ind[right]][cutfeat]>cutval) --
right;
516 if (left>right)
break;
524 DistanceType distsq = 0.0;
526 for (
size_t i = 0; i < dim_; ++i) {
527 if (vec[i] < root_bbox_[i].low) {
528 dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].low, (
int)i);
531 if (vec[i] > root_bbox_[i].high) {
532 dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].high, (
int)i);
543 void searchLevel(ResultSet<DistanceType>& result_set,
const ElementType* vec,
const NodePtr node, DistanceType mindistsq,
547 if ((node->child1 == NULL)&&(node->child2 == NULL)) {
548 DistanceType worst_dist = result_set.worstDist();
550 for (
int i=node->left; i<node->right; ++i) {
551 DistanceType dist = distance_(vec, data_[i], dim_, worst_dist);
552 if (dist<worst_dist) {
553 result_set.addPoint(dist,vind_[i]);
557 for (
int i=node->left; i<node->right; ++i) {
558 DistanceType dist = distance_(vec, data_[vind_[i]], dim_, worst_dist);
559 if (dist<worst_dist) {
560 result_set.addPoint(dist,vind_[i]);
568 int idx = node->divfeat;
569 ElementType val = vec[
idx];
570 DistanceType diff1 = val - node->divlow;
571 DistanceType diff2 = val - node->divhigh;
575 DistanceType cut_dist;
576 if ((diff1+diff2)<0) {
577 bestChild = node->child1;
578 otherChild = node->child2;
579 cut_dist = distance_.accum_dist(val, node->divhigh,
idx);
582 bestChild = node->child2;
583 otherChild = node->child1;
584 cut_dist = distance_.accum_dist( val, node->divlow,
idx);
588 searchLevel(result_set, vec, bestChild, mindistsq, dists, epsError);
590 DistanceType
dst = dists[
idx];
591 mindistsq = mindistsq + cut_dist -
dst;
592 dists[
idx] = cut_dist;
593 if (mindistsq*epsError<=result_set.worstDist()) {
594 searchLevel(result_set, vec, otherChild, mindistsq, dists, epsError);
604 const Matrix<ElementType> dataset_;
606 IndexParams index_params_;
617 Matrix<ElementType> data_;
627 BoundingBox root_bbox_;
636 PooledAllocator pool_;
CvArr * dst
Definition: core_c.h:875
const int * idx
Definition: core_c.h:668
int index
Definition: core_c.h:634
CvSize size
Definition: core_c.h:112
int count
Definition: core_c.h:1413
const CvArr const CvArr CvArr * result
Definition: core_c.h:1423
#define CV_OVERRIDE
Definition: cvdef.h:792
#define CV_Assert(expr)
Checks a condition at runtime and throws exception if it fails.
Definition: base.hpp:342
CV_EXPORTS OutputArray int double double InputArray OutputArray int int bool double k
Definition: imgproc.hpp:2133
QTextStream & left(QTextStream &stream)
QTextStream & right(QTextStream &stream)