31 #ifndef OPENCV_FLANN_KDTREE_INDEX_H_
32 #define OPENCV_FLANN_KDTREE_INDEX_H_
41 #include "dynamic_bitset.h"
43 #include "result_set.h"
45 #include "allocator.h"
53 struct KDTreeIndexParams :
public IndexParams
55 KDTreeIndexParams(
int trees = 4)
57 (*this)[
"algorithm"] = FLANN_INDEX_KDTREE;
58 (*this)[
"trees"] = trees;
69 template <
typename Distance>
70 class KDTreeIndex :
public NNIndex<Distance>
73 typedef typename Distance::ElementType ElementType;
74 typedef typename Distance::ResultType DistanceType;
84 KDTreeIndex(
const Matrix<ElementType>& inputData,
const IndexParams& params = KDTreeIndexParams(),
85 Distance d = Distance() ) :
86 dataset_(inputData), index_params_(params), distance_(d)
88 size_ = dataset_.rows;
89 veclen_ = dataset_.cols;
91 trees_ = get_param(index_params_,
"trees",4);
92 tree_roots_ =
new NodePtr[trees_];
96 for (
size_t i = 0; i < size_; ++i) {
100 mean_ =
new DistanceType[veclen_];
101 var_ =
new DistanceType[veclen_];
105 KDTreeIndex(
const KDTreeIndex&);
106 KDTreeIndex& operator=(
const KDTreeIndex&);
113 if (tree_roots_!=NULL) {
114 delete[] tree_roots_;
126 for (
int i = 0; i < trees_; i++) {
128 #ifndef OPENCV_FLANN_USE_STD_RAND
134 tree_roots_[i] = divideTree(&vind_[0],
int(size_) );
141 return FLANN_INDEX_KDTREE;
147 save_value(stream, trees_);
148 for (
int i=0; i<trees_; ++i) {
149 save_tree(stream, tree_roots_[i]);
157 load_value(stream, trees_);
158 if (tree_roots_!=NULL) {
159 delete[] tree_roots_;
161 tree_roots_ =
new NodePtr[trees_];
162 for (
int i=0; i<trees_; ++i) {
163 load_tree(stream,tree_roots_[i]);
166 index_params_[
"algorithm"] = getType();
167 index_params_[
"trees"] = tree_roots_;
192 return int(pool_.usedMemory+pool_.wastedMemory+dataset_.rows*
sizeof(
int));
204 void findNeighbors(ResultSet<DistanceType>&
result,
const ElementType* vec,
const SearchParams& searchParams)
CV_OVERRIDE
206 const int maxChecks = get_param(searchParams,
"checks", 32);
207 const float epsError = 1+get_param(searchParams,
"eps",0.0f);
208 const bool explore_all_trees = get_param(searchParams,
"explore_all_trees",
false);
210 if (maxChecks==FLANN_CHECKS_UNLIMITED) {
211 getExactNeighbors(
result, vec, epsError);
214 getNeighbors(
result, vec, maxChecks, epsError, explore_all_trees);
220 return index_params_;
240 Node* child1, * child2;
242 typedef Node* NodePtr;
243 typedef BranchStruct<NodePtr, DistanceType> BranchSt;
244 typedef BranchSt* Branch;
248 void save_tree(FILE* stream, NodePtr tree)
250 save_value(stream, *tree);
251 if (tree->child1!=NULL) {
252 save_tree(stream, tree->child1);
254 if (tree->child2!=NULL) {
255 save_tree(stream, tree->child2);
260 void load_tree(FILE* stream, NodePtr& tree)
262 tree = pool_.allocate<Node>();
263 load_value(stream, *tree);
264 if (tree->child1!=NULL) {
265 load_tree(stream, tree->child1);
267 if (tree->child2!=NULL) {
268 load_tree(stream, tree->child2);
282 NodePtr divideTree(
int* ind,
int count)
284 NodePtr node = pool_.allocate<Node>();
288 node->child1 = node->child2 = NULL;
289 node->divfeat = *ind;
295 meanSplit(ind,
count,
idx, cutfeat, cutval);
297 node->divfeat = cutfeat;
298 node->divval = cutval;
299 node->child1 = divideTree(ind,
idx);
312 void meanSplit(
int* ind,
int count,
int&
index,
int& cutfeat, DistanceType& cutval)
314 memset(mean_,0,veclen_*
sizeof(DistanceType));
315 memset(var_,0,veclen_*
sizeof(DistanceType));
321 for (
int j = 0; j < cnt; ++j) {
322 ElementType* v = dataset_[ind[j]];
323 for (
size_t k=0;
k<veclen_; ++
k) {
327 for (
size_t k=0;
k<veclen_; ++
k) {
332 for (
int j = 0; j < cnt; ++j) {
333 ElementType* v = dataset_[ind[j]];
334 for (
size_t k=0;
k<veclen_; ++
k) {
335 DistanceType dist = v[
k] - mean_[
k];
336 var_[
k] += dist * dist;
340 cutfeat = selectDivision(var_);
341 cutval = mean_[cutfeat];
344 planeSplit(ind,
count, cutfeat, cutval, lim1, lim2);
361 int selectDivision(DistanceType* v)
364 size_t topind[RAND_DIM];
367 for (
size_t i = 0; i < veclen_; ++i) {
368 if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) {
370 if (num < RAND_DIM) {
378 while (j > 0 && v[topind[j]] > v[topind[j-1]]) {
385 int rnd = rand_int(num);
386 return (
int)topind[rnd];
399 void planeSplit(
int* ind,
int count,
int cutfeat, DistanceType cutval,
int& lim1,
int& lim2)
405 while (left<=right && dataset_[ind[left]][cutfeat]<cutval) ++
left;
406 while (left<=right && dataset_[ind[right]][cutfeat]>=cutval) --
right;
407 if (left>right)
break;
413 while (left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++
left;
414 while (left<=right && dataset_[ind[right]][cutfeat]>cutval) --
right;
415 if (left>right)
break;
425 void getExactNeighbors(ResultSet<DistanceType>&
result,
const ElementType* vec,
float epsError)
430 fprintf(stderr,
"It doesn't make any sense to use more than one tree for exact search");
433 searchLevelExact(
result, vec, tree_roots_[0], 0.0, epsError);
443 void getNeighbors(ResultSet<DistanceType>&
result,
const ElementType* vec,
444 int maxCheck,
float epsError,
bool explore_all_trees =
false)
449 DynamicBitset checked(size_);
455 for (i = 0; i < trees_; ++i) {
456 searchLevel(
result, vec, tree_roots_[i], 0, checkCount, maxCheck,
457 epsError, heap, checked, explore_all_trees);
458 if (!explore_all_trees && (checkCount >= maxCheck) &&
result.full())
463 while ( heap->popMin(branch) && (checkCount < maxCheck || !
result.full() )) {
464 searchLevel(
result, vec, branch.node, branch.mindist, checkCount, maxCheck,
465 epsError, heap, checked,
false);
477 void searchLevel(ResultSet<DistanceType>& result_set,
const ElementType* vec, NodePtr node, DistanceType mindist,
int& checkCount,
int maxCheck,
478 float epsError,
const cv::Ptr<Heap<BranchSt>>& heap, DynamicBitset& checked,
bool explore_all_trees =
false)
480 if (result_set.worstDist()<mindist) {
486 if ((node->child1 == NULL)&&(node->child2 == NULL)) {
491 int index = node->divfeat;
492 if ( checked.test(
index) ||
493 (!explore_all_trees && (checkCount>=maxCheck) && result_set.full()) ) {
499 DistanceType dist = distance_(dataset_[
index], vec, veclen_);
500 result_set.addPoint(dist,
index);
506 ElementType val = vec[node->divfeat];
507 DistanceType diff = val - node->divval;
508 NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
509 NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
519 DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
521 if ((new_distsq*epsError < result_set.worstDist())|| !result_set.full()) {
522 heap->insert( BranchSt(otherChild, new_distsq) );
526 searchLevel(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
532 void searchLevelExact(ResultSet<DistanceType>& result_set,
const ElementType* vec,
const NodePtr node, DistanceType mindist,
const float epsError)
535 if ((node->child1 == NULL)&&(node->child2 == NULL)) {
536 int index = node->divfeat;
537 DistanceType dist = distance_(dataset_[
index], vec, veclen_);
538 result_set.addPoint(dist,
index);
543 ElementType val = vec[node->divfeat];
544 DistanceType diff = val - node->divval;
545 NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
546 NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
556 DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
559 searchLevelExact(result_set, vec, bestChild, mindist, epsError);
561 if (new_distsq*epsError<=result_set.worstDist()) {
562 searchLevelExact(result_set, vec, otherChild, new_distsq, epsError);
601 const Matrix<ElementType> dataset_;
603 IndexParams index_params_;
616 NodePtr* tree_roots_;
625 PooledAllocator pool_;
CV_EXPORTS_W void randShuffle(InputOutputArray dst, double iterFactor=1., RNG *rng=0)
Shuffles the array elements randomly.
const int * idx
Definition: core_c.h:668
int index
Definition: core_c.h:634
CvSize size
Definition: core_c.h:112
int count
Definition: core_c.h:1413
const CvArr const CvArr CvArr * result
Definition: core_c.h:1423
#define CV_OVERRIDE
Definition: cvdef.h:792
#define CV_Assert(expr)
Checks a condition at runtime and throws exception if it fails.
Definition: base.hpp:342
CV_EXPORTS OutputArray int double double InputArray OutputArray int int bool double k
Definition: imgproc.hpp:2133
CV_EXPORTS int getThreadID()
QTextStream & left(QTextStream &stream)
QTextStream & right(QTextStream &stream)
T random_shuffle(T... args)
Definition: cvstd_wrapper.hpp:74