forked from Kitware/SMQTK
/
__init__.py
109 lines (82 loc) · 3.82 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Interface for generic element-wise nearest-neighbor computation.
"""
import abc
import os
from smqtk.algorithms import SmqtkAlgorithm
from smqtk.utils.plugin import get_plugins
__author__ = "paul.tunison@kitware.com"
class NearestNeighborsIndex (SmqtkAlgorithm):
"""
Common interface for descriptor-based nearest-neighbor computation over a
built index of descriptors.
Implementations, if they allow persistent storage of their index, should
take the necessary parameters at construction time. Persistent storage
content should be (over)written ``build_index`` is called.
"""
def __len__(self):
return self.count()
@abc.abstractmethod
def count(self):
"""
:return: Number of elements in this index.
:rtype: int
"""
@abc.abstractmethod
def build_index(self, descriptors):
"""
Build the index over the descriptor data elements.
Subsequent calls to this method should rebuild the index, not add to
it, or raise an exception to as to protect the current index.
:raises ValueError: No data available in the given iterable.
:param descriptors: Iterable of descriptor elements to build index
over.
:type descriptors:
collections.Iterable[smqtk.representation.DescriptorElement]
"""
@abc.abstractmethod
def nn(self, d, n=1):
"""
Return the nearest `N` neighbors to the given descriptor element.
:param d: Descriptor element to compute the neighbors of.
:type d: smqtk.representation.DescriptorElement
:param n: Number of nearest neighbors to find.
:type n: int
:return: Tuple of nearest N DescriptorElement instances, and a tuple of
the distance values to those neighbors.
:rtype: (tuple[smqtk.representation.DescriptorElement], tuple[float])
"""
if not d.has_vector():
raise ValueError("Query descriptor did not have a vector set!")
elif not self.count():
raise ValueError("No index currently set to query from!")
def get_nn_index_impls(reload_modules=False):
"""
Discover and return discovered ``NearestNeighborsIndex`` classes. Keys in
the returned map are the names of the discovered classes, and the paired
values are the actual class type objects.
We search for implementation classes in:
- modules next to this file this function is defined in (ones that begin
with an alphanumeric character),
- python modules listed in the environment variable ``NN_INDEX_PATH``
- This variable should contain a sequence of python module
specifications, separated by the platform specific PATH separator
character (``;`` for Windows, ``:`` for unix)
Within a module we first look for a helper variable by the name
``NN_INDEX_CLASS``, which can either be a single class object or
an iterable of class objects, to be specifically exported. If the variable
is set to None, we skip that module and do not import anything. If the
variable is not present, we look at attributes defined in that module for
classes that descend from the given base class type. If none of the above
are found, or if an exception occurs, the module is skipped.
:param reload_modules: Explicitly reload discovered modules from source.
:type reload_modules: bool
:return: Map of discovered class object of type ``NearestNeighborsIndex``
whose keys are the string names of the classes.
:rtype: dict[str, type]
"""
this_dir = os.path.abspath(os.path.dirname(__file__))
env_var = "NN_INDEX_PATH"
helper_var = "NN_INDEX_CLASS"
return get_plugins(__name__, this_dir, env_var, helper_var,
NearestNeighborsIndex, reload_modules=reload_modules)