35 #ifndef OPENCV_FLANN_LSH_INDEX_H_
36 #define OPENCV_FLANN_LSH_INDEX_H_
47 #include "result_set.h"
49 #include "lsh_table.h"
50 #include "allocator.h"
56 #pragma warning(disable: 4702)
62 struct LshIndexParams :
public IndexParams
64 LshIndexParams(
int table_number = 12,
int key_size = 20,
int multi_probe_level = 2)
66 (*this)[
"algorithm"] = FLANN_INDEX_LSH;
68 (*this)[
"table_number"] = table_number;
70 (*this)[
"key_size"] = key_size;
72 (*this)[
"multi_probe_level"] = multi_probe_level;
82 template<
typename Distance>
83 class LshIndex :
public NNIndex<Distance>
86 typedef typename Distance::ElementType ElementType;
87 typedef typename Distance::ResultType DistanceType;
94 LshIndex(
const Matrix<ElementType>& input_data,
const IndexParams& params = LshIndexParams(),
95 Distance d = Distance()) :
96 dataset_(input_data), index_params_(params), distance_(d)
100 table_number_ = get_param(index_params_,
"table_number",12);
101 key_size_ = get_param(index_params_,
"key_size",20);
102 multi_probe_level_ = get_param(index_params_,
"multi_probe_level",2);
104 feature_size_ = (unsigned)dataset_.cols;
105 fill_xor_mask(0, key_size_, multi_probe_level_, xor_masks_);
109 LshIndex(
const LshIndex&);
110 LshIndex& operator=(
const LshIndex&);
117 tables_.resize(table_number_);
118 for (
int i = 0; i < table_number_; ++i) {
119 lsh::LshTable<ElementType>& table = tables_[i];
120 table = lsh::LshTable<ElementType>(feature_size_, key_size_);
129 return FLANN_INDEX_LSH;
135 save_value(stream,table_number_);
136 save_value(stream,key_size_);
137 save_value(stream,multi_probe_level_);
138 save_value(stream, dataset_);
143 load_value(stream, table_number_);
144 load_value(stream, key_size_);
145 load_value(stream, multi_probe_level_);
146 load_value(stream, dataset_);
150 index_params_[
"algorithm"] = getType();
151 index_params_[
"table_number"] = table_number_;
152 index_params_[
"key_size"] = key_size_;
153 index_params_[
"multi_probe_level"] = multi_probe_level_;
161 return dataset_.rows;
169 return feature_size_;
178 return (
int)(dataset_.rows *
sizeof(int));
184 return index_params_;
195 virtual void knnSearch(
const Matrix<ElementType>& queries, Matrix<int>& indices, Matrix<DistanceType>& dists,
int knn,
const SearchParams& params)
CV_OVERRIDE
204 KNNUniqueResultSet<DistanceType> resultSet(knn);
205 for (
size_t i = 0; i < queries.rows; i++) {
209 findNeighbors(resultSet, queries[i], params);
210 if (get_param(params,
"sorted",
true)) resultSet.sortAndCopy(indices[i], dists[i], knn);
211 else resultSet.copy(indices[i], dists[i], knn);
225 void findNeighbors(ResultSet<DistanceType>&
result,
const ElementType* vec,
const SearchParams& )
CV_OVERRIDE
227 getNeighbors(vec,
result);
234 struct SortScoreIndexPairOnSecond
236 bool operator()(
const ScoreIndexPair& left,
const ScoreIndexPair& right)
const
248 void fill_xor_mask(lsh::BucketKey key,
int lowest_index,
unsigned int level,
252 if (level == 0)
return;
255 lsh::BucketKey new_key = key | (1 <<
index);
256 fill_xor_mask(new_key,
index, level - 1, xor_masks);
268 void getNeighbors(
const ElementType* vec,
bool ,
float radius,
bool do_k,
unsigned int k_nn,
277 for (; table != table_end; ++table) {
278 size_t key = table->getKey(vec);
281 for (; xor_mask != xor_mask_end; ++xor_mask) {
282 size_t sub_key = key ^ (*xor_mask);
283 const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
284 if (bucket == 0)
continue;
289 DistanceType hamming_distance;
292 for (; training_index < last_training_index; ++training_index) {
293 hamming_distance = distance_(vec, dataset_[*training_index], dataset_.cols);
295 if (hamming_distance < worst_score) {
297 score_index_heap.
push_back(ScoreIndexPair(hamming_distance, training_index));
300 if (score_index_heap.
size() > (
unsigned int)k_nn) {
305 worst_score = score_index_heap.
front().first;
315 for (; table != table_end; ++table) {
316 size_t key = table->getKey(vec);
319 for (; xor_mask != xor_mask_end; ++xor_mask) {
320 size_t sub_key = key ^ (*xor_mask);
321 const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
322 if (bucket == 0)
continue;
327 DistanceType hamming_distance;
330 for (; training_index < last_training_index; ++training_index) {
332 hamming_distance = distance_(vec, dataset_[*training_index], dataset_.cols);
333 if (hamming_distance <
radius) score_index_heap.
push_back(ScoreIndexPair(hamming_distance, training_index));
344 void getNeighbors(
const ElementType* vec, ResultSet<DistanceType>&
result)
348 for (; table != table_end; ++table) {
349 size_t key = table->getKey(vec);
352 for (; xor_mask != xor_mask_end; ++xor_mask) {
353 size_t sub_key = key ^ (*xor_mask);
354 const lsh::Bucket* bucket = table->getBucketFromKey((lsh::BucketKey)sub_key);
355 if (bucket == 0)
continue;
360 DistanceType hamming_distance;
363 for (; training_index < last_training_index; ++training_index) {
365 hamming_distance = distance_(vec, dataset_[*training_index], (
int)dataset_.cols);
366 result.addPoint(hamming_distance, *training_index);
376 Matrix<ElementType> dataset_;
379 unsigned int feature_size_;
381 IndexParams index_params_;
388 int multi_probe_level_;
int index
Definition: core_c.h:634
CvSize size
Definition: core_c.h:112
const CvArr const CvArr CvArr * result
Definition: core_c.h:1423
#define CV_OVERRIDE
Definition: cvdef.h:792
#define CV_Assert(expr)
Checks a condition at runtime and throws exception if it fails.
Definition: base.hpp:342
CvPoint2D32f float * radius
Definition: imgproc_c.h:534
QTextStream & left(QTextStream &stream)
QTextStream & right(QTextStream &stream)