def test_from_config(self): inst = MemoryCodeIndex.from_config({'file_cache': None}) ntools.assert_is_none(inst._file_cache) fp = '/doesnt/exist/yet' inst = MemoryCodeIndex.from_config({'file_cache': fp}) ntools.assert_equal(inst._file_cache, fp)
def test_add_descriptor(self): index = MemoryCodeIndex() d1 = random_descriptor() index.add_descriptor(0, d1) ntools.assert_equal(index._table[0][d1.uuid()], d1) d2 = random_descriptor() index.add_descriptor(5213, d2) ntools.assert_equal(index._table[5213][d2.uuid()], d2)
def test_count(self): index = MemoryCodeIndex() ntools.assert_equal(index.count(), 0) d1 = random_descriptor() index.add_descriptor(0, d1) ntools.assert_equal(index.count(), 1) d2 = random_descriptor() index.add_descriptor(1, d2) ntools.assert_equal(index.count(), 2)
def test_init_with_cache(self): fd, tmp_cache = tempfile.mkstemp() os.close(fd) try: test_cache = { 0: { 0: 'foo' }, 5: { 0: 'bar' }, 2354: { 0: 'baz', 1: "fak" }, } with open(tmp_cache, 'w') as f: cPickle.dump(test_cache, f) inst = MemoryCodeIndex(tmp_cache) ntools.assert_equal(inst._num_descr, 4) ntools.assert_equal(inst._file_cache, tmp_cache) ntools.assert_equal(inst._table, test_cache) finally: os.remove(tmp_cache)
def test_get_descriptors(self): code_descrs = [ (0, random_descriptor()), # [0] (1, random_descriptor()), # [1] (3, random_descriptor()), # [2] (0, random_descriptor()), # [3] (8, random_descriptor()), # [4] ] index = MemoryCodeIndex() index.add_many_descriptors(code_descrs) # single descriptor reference r = list(index.get_descriptors(1)) ntools.assert_equal(len(r), 1) ntools.assert_equal(r[0], code_descrs[1][1]) # multiple descriptor reference r = list(index.get_descriptors(0)) ntools.assert_equal(len(r), 2) ntools.assert_equal(set(r), {code_descrs[0][1], code_descrs[3][1]}) # multiple code query r = list(index.get_descriptors([0, 3, 8])) ntools.assert_equal(len(r), 4) ntools.assert_equal( set(r), { code_descrs[0][1], code_descrs[2][1], code_descrs[3][1], code_descrs[4][1] })
def test_init_nonexistant_file_cache(self): fd, tmp_cache = tempfile.mkstemp() os.close(fd) os.remove(tmp_cache) inst = MemoryCodeIndex(tmp_cache) ntools.assert_equal(inst._num_descr, 0) ntools.assert_equal(inst._file_cache, tmp_cache) ntools.assert_equal(inst._table, {})
def test_add_many(self): code_descrs = [ (0, random_descriptor()), (1, random_descriptor()), (3, random_descriptor()), (0, random_descriptor()), (8, random_descriptor()), ] index = MemoryCodeIndex() index.add_many_descriptors(code_descrs) # Compare code keys of input to code keys in internal table ntools.assert_equal(set(index._table.keys()), set([e[0] for e in code_descrs])) # Get the set of descriptors in the internal table and compare it with # the set of generated random descriptors. r_set = set() [r_set.update(d.values()) for d in index._table.values()] ntools.assert_equal(set([e[1] for e in code_descrs]), r_set)
def _make_inst(self, dist_method, bits=8): self._make_cache_files() # don't want the files to actually exist self._clean_cache_files() # Initialize with a fresh code index instance every time, otherwise the # same code index is maintained between constructions return ITQNearestNeighborsIndex(self.ITQ_MEAN_VEC, self.ITQ_ROTATION_MAT, code_index=MemoryCodeIndex(), bit_length=bits, distance_method=dist_method, random_seed=self.RANDOM_SEED)
def test_get_descriptors(self): code_descrs = [ (0, random_descriptor()), # [0] (1, random_descriptor()), # [1] (3, random_descriptor()), # [2] (0, random_descriptor()), # [3] (8, random_descriptor()), # [4] ] index = MemoryCodeIndex() index.add_many_descriptors(code_descrs) # single descriptor reference r = list(index.get_descriptors(1)) ntools.assert_equal(len(r), 1) ntools.assert_equal(r[0], code_descrs[1][1]) # multiple descriptor reference r = list(index.get_descriptors(0)) ntools.assert_equal(len(r), 2) ntools.assert_equal(set(r), {code_descrs[0][1], code_descrs[3][1]}) # multiple code query r = list(index.get_descriptors([0, 3, 8])) ntools.assert_equal(len(r), 4) ntools.assert_equal(set(r), {code_descrs[0][1], code_descrs[2][1], code_descrs[3][1], code_descrs[4][1]})
def test_add_many(self): code_descrs = [ (0, random_descriptor()), (1, random_descriptor()), (3, random_descriptor()), (0, random_descriptor()), (8, random_descriptor()), ] index = MemoryCodeIndex() index.add_many_descriptors(code_descrs) # Compare code keys of input to code keys in internal table ntools.assert_equal(set(index._table.keys()), set([e[0] for e in code_descrs])) # Get the set of descriptors in the internal table and compare it with # the set of generated random descriptors. r_set = set() [r_set.update(d.values()) for d in index._table.values()] ntools.assert_equal( set([e[1] for e in code_descrs]), r_set )
def test_default_config(self): ntools.assert_equal( MemoryCodeIndex.get_default_config(), {"file_cache": None} )
def test_is_usable(self): ntools.assert_equal(MemoryCodeIndex.is_usable(), True)
def __init__(self, mean_vec_filepath=None, rotation_filepath=None, code_index=MemoryCodeIndex(), # Index building parameters bit_length=8, itq_iterations=50, distance_method='cosine', random_seed=None): """ Initialize ITQ similarity index instance. This implementation allows optional persistant storage of a built model via providing file paths for the ``mean_vec_filepath`` and ``rotation_filepath`` parameters. The Code Index -------------- ``code_index`` should be an instance of a CodeIndex implementation class. The ``build_index`` call will clear the provided index anything in the index provided. For safety, make sure to check the index provided so as to not accidentally erase data. When providing existing mean_vector and rotation matrix caches, the ``code_index`` may be populated with codes. Pre-populated entries in the provided code index should have been generated from the same rotation and mean vector models provided, else nearest-neighbor query performance will not be as desired. A more advanced use case includes providing a code index that is update-able in the background. This is valid, assuming there is proper locking mechanisms in the code index. Build parameters ---------------- Parameters after file path parameters are only related to building the index. When providing existing mean, rotation and code elements, these can be safely ignored. :raise ValueError: Invalid argument values. :param mean_vec_filepath: Optional file location to load/store the mean vector when initialized and/or built. When None, this will only be stored in memory. This will use numpy to save/load, so this should have a ``.npy`` suffix, or one will be added at save time. :type mean_vec_filepath: str :param rotation_filepath: Optional file location to load/store the rotation matrix when initialize and/or built. When None, this will only be stored in memory. This will use numpy to save/load, so this should have a ``.npy`` suffix, or one will be added at save time. :type rotation_filepath: str :param code_index: CodeIndex instance to use. :type code_index: smqtk.representation.code_index.CodeIndex :param bit_length: Number of bits used to represent descriptors (hash code). This must be greater than 0. If given an existing :type bit_length: int :param itq_iterations: Number of iterations for the ITQ algorithm to perform. This must be greater than 0. :type itq_iterations: int :param distance_method: String label of distance method to use. This must one of the following: - "euclidean": Simple euclidean distance between two descriptors (L2 norm). - "cosine": Cosine angle distance/similarity between two descriptors. - "hik": Histogram intersection distance between two descriptors. :type distance_method: str :param random_seed: Integer to use as the random number generator seed. :type random_seed: int """ super(ITQNearestNeighborsIndex, self).__init__() self._mean_vec_cache_filepath = mean_vec_filepath self._rotation_cache_filepath = rotation_filepath # maps small-codes to a list of DescriptorElements mapped by that code self._code_index = code_index # Number of bits we convert descriptors into self._bit_len = int(bit_length) # Number of iterations ITQ performs self._itq_iter_num = int(itq_iterations) # Optional fixed random seed self._rand_seed = None if random_seed is None else int(random_seed) assert bit_length > 0, "Must be given a bit length greater than 1 " \ "(one)!" assert itq_iterations > 0, "Must be given a number of iterations " \ "greater than 1 (one)!" # Vector of mean feature values. Center of "train" set, and used to # "center" additional descriptors when computing small codes. #: :type: numpy.core.multiarray.ndarray[float] self._mean_vector = None if self._mean_vec_cache_filepath and \ osp.isfile(self._mean_vec_cache_filepath): self._log.debug("Loading existing descriptor vector mean") #: :type: numpy.core.multiarray.ndarray[float] self._mean_vector = numpy.load(self._mean_vec_cache_filepath) # rotation matrix of shape [d, b], found by ITQ process, to use to # transform new descriptors into binary hash decision vector. #: :type: numpy.core.multiarray.ndarray[float] self._r = None if self._rotation_cache_filepath and \ osp.isfile(self._rotation_cache_filepath): self._log.debug("Loading existing descriptor rotation matrix") #: :type: numpy.core.multiarray.ndarray[float] self._r = numpy.load(self._rotation_cache_filepath) self._dist_method = distance_method self._dist_func = self._get_dist_func(distance_method)
def test_default_config(self): ntools.assert_equal(MemoryCodeIndex.get_default_config(), {"file_cache": None})
def test_init_no_cache(self): inst = MemoryCodeIndex() ntools.assert_equal(inst._num_descr, 0) ntools.assert_is_none(inst._file_cache, None) ntools.assert_equal(inst._table, {})