예제 #1
0
파일: tasks.py 프로젝트: Kitware/SMQTK
def itq(task, folderId, **kwargs):
    """
    Celery task for training ITQ on a given folder.

    This trains ITQ on all descriptors within the index. Since this
    is typically called after computing descriptors, it will often
    only contain what's in the folder.

    :param task: Celery provided task object.
    :param folderId: The folder to train ITQ for, note this is only used to
        infer the descriptor index.
    """
    task.job_manager.updateProgress(message='Training ITQ', forceFlush=True)
    index = descriptorIndexFromFolderId(task.girder_client, folderId)

    if not index.count():
        # TODO SMQTK should account for this?
        raise Exception('Descriptor index is empty, cannot train ITQ.')

    smqtkFolder = getCreateFolder(task.girder_client, folderId, '.smqtk')
    meanVecFile = initializeItemWithFile(task.girder_client,
                                         createOverwriteItem(task.girder_client, smqtkFolder['_id'], 'mean_vec.npy'))
    rotationFile = initializeItemWithFile(task.girder_client,
                                          createOverwriteItem(task.girder_client, smqtkFolder['_id'], 'rotation.npy'))

    functor = ItqFunctor(mean_vec_cache=GirderDataElement(meanVecFile['_id'], api_root=task.request.apiUrl,
                                                          token=task.request.jobInfoSpec['headers']['Girder-Token']),
                         rotation_cache=GirderDataElement(rotationFile['_id'], api_root=task.request.apiUrl,
                                                          token=task.request.jobInfoSpec['headers']['Girder-Token']))

    functor.fit(index.iterdescriptors(), use_multiprocessing=False)
예제 #2
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
    def test_fit_with_cache(self):
        fit_descriptors = []
        for i in range(5):
            d = DescriptorMemoryElement(six.b('test'), i)
            d.set_vector([-2. + i, -2. + i])
            fit_descriptors.append(d)

        itq = ItqFunctor(DataMemoryElement(), DataMemoryElement(),
                         bit_length=1, random_seed=0)
        itq.fit(fit_descriptors)

        # TODO: Explanation as to why this is the expected result.
        numpy.testing.assert_array_almost_equal(itq.mean_vec, [0, 0])
        numpy.testing.assert_array_almost_equal(itq.rotation, [[1 / sqrt(2)],
                                                               [1 / sqrt(2)]])
        self.assertIsNotNone(itq.mean_vec_cache_elem)
        numpy.testing.assert_array_almost_equal(
            numpy.load(BytesIO(itq.mean_vec_cache_elem.get_bytes())),
            [0, 0]
        )

        self.assertIsNotNone(itq.rotation_cache_elem)
        numpy.testing.assert_array_almost_equal(
            numpy.load(BytesIO(itq.rotation_cache_elem.get_bytes())),
            [[1 / sqrt(2)],
             [1 / sqrt(2)]]
        )
예제 #3
0
    def test_fit_with_cache(self):
        fit_descriptors = []
        for i in range(5):
            d = DescriptorMemoryElement(six.b('test'), i)
            d.set_vector([-2. + i, -2. + i])
            fit_descriptors.append(d)

        itq = ItqFunctor(DataMemoryElement(),
                         DataMemoryElement(),
                         bit_length=1,
                         random_seed=0)
        itq.fit(fit_descriptors)

        # TODO: Explanation as to why this is the expected result.
        numpy.testing.assert_array_almost_equal(itq.mean_vec, [0, 0])
        numpy.testing.assert_array_almost_equal(itq.rotation,
                                                [[1 / sqrt(2)], [1 / sqrt(2)]])
        self.assertIsNotNone(itq.mean_vec_cache_elem)
        numpy.testing.assert_array_almost_equal(
            numpy.load(six.BytesIO(itq.mean_vec_cache_elem.get_bytes())),
            [0, 0])

        self.assertIsNotNone(itq.rotation_cache_elem)
        numpy.testing.assert_array_almost_equal(
            numpy.load(six.BytesIO(itq.rotation_cache_elem.get_bytes())),
            [[1 / sqrt(2)], [1 / sqrt(2)]])
예제 #4
0
 def test_fit_has_model(self):
     # When trying to run fit where there is already a mean vector and
     # rotation set.
     itq = ItqFunctor()
     itq.mean_vec = 'sim vec'
     itq.rotation = 'sim rot'
     self.assertRaisesRegex(RuntimeError,
                            "Model components have already been loaded.",
                            itq.fit, [])
예제 #5
0
    def test_norm_vector_n2(self):
        itq = ItqFunctor(normalize=2)

        v = numpy.array([1, 0])
        numpy.testing.assert_array_almost_equal(itq._norm_vector(v), [1, 0])

        v = numpy.array([1, 1])
        numpy.testing.assert_array_almost_equal(itq._norm_vector(v),
                                                [1. / sqrt(2), 1. / sqrt(2)])
예제 #6
0
 def test_get_config_no_cache(self):
     itq = ItqFunctor(bit_length=1, itq_iterations=2, normalize=3,
                      random_seed=4)
     c = itq.get_config()
     NT.assert_equal(c['bit_length'], 1)
     NT.assert_equal(c['itq_iterations'], 2)
     NT.assert_equal(c['normalize'], 3)
     NT.assert_equal(c['random_seed'], 4)
     NT.assert_is_none(c['mean_vec_cache']['type'])
     NT.assert_is_none(c['rotation_cache']['type'])
예제 #7
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
 def test_get_config_no_cache(self):
     itq = ItqFunctor(bit_length=1, itq_iterations=2, normalize=3,
                      random_seed=4)
     c = itq.get_config()
     self.assertEqual(c['bit_length'], 1)
     self.assertEqual(c['itq_iterations'], 2)
     self.assertEqual(c['normalize'], 3)
     self.assertEqual(c['random_seed'], 4)
     self.assertIsNone(c['mean_vec_cache']['type'])
     self.assertIsNone(c['rotation_cache']['type'])
예제 #8
0
    def test_norm_vector_no_normalization(self):
        itq = ItqFunctor(normalize=None)

        v = numpy.array([0, 1])
        numpy.testing.assert_array_equal(itq._norm_vector(v), v)

        v = numpy.array([[0, 1, 1, .4, .1]])
        numpy.testing.assert_array_equal(itq._norm_vector(v), v)

        v = numpy.array([0] * 128)
        numpy.testing.assert_array_equal(itq._norm_vector(v), v)
예제 #9
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
 def test_fit_has_model(self):
     # When trying to run fit where there is already a mean vector and
     # rotation set.
     itq = ItqFunctor()
     itq.mean_vec = 'sim vec'
     itq.rotation = 'sim rot'
     self.assertRaisesRegexp(
         RuntimeError,
         "Model components have already been loaded.",
         itq.fit, []
     )
예제 #10
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
    def test_norm_vector_no_normalization(self):
        itq = ItqFunctor(normalize=None)

        v = numpy.array([0, 1])
        numpy.testing.assert_array_equal(itq._norm_vector(v), v)

        v = numpy.array([[0, 1, 1, .4, .1]])
        numpy.testing.assert_array_equal(itq._norm_vector(v), v)

        v = numpy.array([0]*128)
        numpy.testing.assert_array_equal(itq._norm_vector(v), v)
예제 #11
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
    def test_save_model_no_caches(self):
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        # Cache variables should remain None after save.
        itq = ItqFunctor()
        itq.mean_vec = expected_mean_vec
        itq.rotation = expected_rotation
        itq.save_model()
        self.assertIsNone(itq.mean_vec_cache_elem)
        self.assertIsNone(itq.mean_vec_cache_elem)
예제 #12
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
    def test_norm_vector_n2(self):
        itq = ItqFunctor(normalize=2)

        v = numpy.array([1, 0])
        numpy.testing.assert_array_almost_equal(
            itq._norm_vector(v), [1, 0]
        )

        v = numpy.array([1, 1])
        numpy.testing.assert_array_almost_equal(
            itq._norm_vector(v), [1./sqrt(2), 1./sqrt(2)]
        )
예제 #13
0
    def test_configuration_with_caches(self):
        # This should run without error in both python
        # 2 and 3, as str/unicode are JSON compliant in both.
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        expected_mean_vec_bytes = six.BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_mean_vec_bytes, expected_mean_vec)
        expected_mean_vec_str = \
            expected_mean_vec_bytes.getvalue().decode(BYTES_CONFIG_ENCODING)

        expected_rotation_bytes = six.BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_rotation_bytes, expected_rotation)
        expected_rotation_str = \
            expected_rotation_bytes.getvalue().decode(BYTES_CONFIG_ENCODING)

        new_parts = {
            'mean_vec_cache': {
                'smqtk.representation.data_element.memory_element.DataMemoryElement':
                {
                    'bytes': expected_mean_vec_str
                },
                'type':
                'smqtk.representation.data_element.memory_element.DataMemoryElement'
            },
            'rotation_cache': {
                'smqtk.representation.data_element.memory_element.DataMemoryElement':
                {
                    'bytes': expected_rotation_str
                },
                'type':
                'smqtk.representation.data_element.memory_element.DataMemoryElement'
            },
            'bit_length': 153,
            'itq_iterations': 7,
            'normalize': 2,
            'random_seed': 58,
        }
        c = merge_dict(ItqFunctor.get_default_config(), new_parts)

        itq = ItqFunctor.from_config(c)

        # Checking that loaded parameters were correctly set and cache elements
        # correctly return intended vector/matrix.
        numpy.testing.assert_equal(itq.mean_vec, [1, 2, 3])
        numpy.testing.assert_equal(itq.rotation,
                                   [[1, 0, 0], [0, 1, 0], [0, 0, 1]])
        self.assertEqual(itq.bit_length, 153)
        self.assertEqual(itq.itq_iterations, 7)
        self.assertEqual(itq.normalize, 2)
        self.assertEqual(itq.random_seed, 58)
예제 #14
0
    def test_save_model_with_writable_caches(self):
        # If one or both cache elements are read-only, no saving.
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        expected_mean_vec_bytes = six.BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_mean_vec_bytes, expected_mean_vec)
        expected_mean_vec_bytes = expected_mean_vec_bytes.getvalue()

        expected_rotation_bytes = six.BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_rotation_bytes, expected_rotation)
        expected_rotation_bytes = expected_rotation_bytes.getvalue()

        itq = ItqFunctor()
        itq.mean_vec = expected_mean_vec
        itq.rotation = expected_rotation
        itq.mean_vec_cache_elem = DataMemoryElement(readonly=False)
        itq.rotation_cache_elem = DataMemoryElement(readonly=False)

        itq.save_model()
        self.assertEqual(itq.mean_vec_cache_elem.get_bytes(),
                         expected_mean_vec_bytes)
        self.assertEqual(itq.rotation_cache_elem.get_bytes(),
                         expected_rotation_bytes)
예제 #15
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
    def test_get_config_with_cache_elements(self):
        itq = ItqFunctor(bit_length=5, itq_iterations=6, normalize=7,
                         random_seed=8)
        itq.mean_vec_cache_elem = DataMemoryElement('cached vec bytes')
        itq.rotation_cache_elem = DataMemoryElement('cached rot bytes')

        c = itq.get_config()
        self.assertEqual(c['bit_length'], 5)
        self.assertEqual(c['itq_iterations'], 6)
        self.assertEqual(c['normalize'], 7)
        self.assertEqual(c['random_seed'], 8)
        self.assertEqual(c['mean_vec_cache']['type'], "DataMemoryElement")
        self.assertEqual(c['mean_vec_cache']['DataMemoryElement']['bytes'],
                         'cached vec bytes')
        self.assertEqual(c['rotation_cache']['DataMemoryElement']['bytes'],
                         'cached rot bytes')
예제 #16
0
파일: train_itq.py 프로젝트: Kitware/SMQTK
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):
        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()
        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
예제 #17
0
    def nearestNeighborIndex(item, user, descriptorSet):
        """
        Get the nearest neighbor index from a given item and descriptor set.

        :param item: Item to find the nn index from, usually the item that the
            user is performing the nearest neighbors search on.
        :param user: The owner of the .smqtk folder.
        :param descriptorSet: The relevant descriptor set.
        """
        folder = ModelImporter.model('folder')

        _GirderDataElement = functools.partial(GirderDataElement,
                                               api_root=getApiUrl(),
                                               token=getCurrentToken()['_id'])

        smqtkFolder = folder.createFolder(folder.load(item['folderId'], user=user), '.smqtk',
                                          reuseExisting=True)

        try:
            meanVecFileId = localSmqtkFileIdFromName(smqtkFolder, 'mean_vec.npy')
            rotationFileId = localSmqtkFileIdFromName(smqtkFolder, 'rotation.npy')
            hash2uuidsFileId = localSmqtkFileIdFromName(smqtkFolder, 'hash2uuids.pickle')
        except Exception:
            logger.warn('SMQTK files didn\'t exist for performing NN on %s' % item['_id'])
            return None

        # TODO Should these be Girder data elements? Unnecessary HTTP requests.
        functor = ItqFunctor(mean_vec_cache=_GirderDataElement(meanVecFileId),
                             rotation_cache=_GirderDataElement(rotationFileId))

        hash2uuidsKV = MemoryKeyValueStore(_GirderDataElement(hash2uuidsFileId))

        return LSHNearestNeighborIndex(functor, descriptorSet,
                                       hash2uuidsKV, read_only=True)
예제 #18
0
    def test_fit(self):
        fit_descriptors = []
        for i in range(5):
            d = DescriptorMemoryElement('test', i)
            d.set_vector([-2. + i, -2. + i])
            fit_descriptors.append(d)

        itq = ItqFunctor(bit_length=1, random_seed=0)
        itq.fit(fit_descriptors)

        # TODO: Explanation as to why this is the expected result.
        numpy.testing.assert_array_almost_equal(itq.mean_vec, [0, 0])
        numpy.testing.assert_array_almost_equal(itq.rotation,
                                                [[1 / sqrt(2)], [1 / sqrt(2)]])
        NT.assert_is_none(itq.mean_vec_cache_elem)
        NT.assert_is_none(itq.rotation_cache_elem)
예제 #19
0
    def _make_ftor_itq(self, bits=32):
        itq_ftor = ItqFunctor(bit_length=bits, random_seed=self.RANDOM_SEED)

        def itq_fit(D):
            itq_ftor.fit(D)

        return itq_ftor, itq_fit
예제 #20
0
    def _nearestNeighborIndex(sid, descriptor_set):
        """
        Retrieve the Nearest neighbor index for a given session.

        :param sid: ID of the session
        :param descriptor_set: The descriptor set corresponding to the session id,
        see _descriptorSetFromSessionId.
        :returns: Nearest neighbor index or None if no session exists
        :rtype: LSHNearestNeighborIndex|None
        """
        session = ModelImporter.model('item').findOne({'_id': ObjectId(sid)})

        if not session:
            return None
        else:
            smqtkFolder = {'_id': ObjectId(session['meta']['smqtk_folder_id'])}

            functor = ItqFunctor(
                smqtkDataElementFromGirderFileId(
                    localSmqtkFileIdFromName(smqtkFolder, 'mean_vec.npy')),
                smqtkDataElementFromGirderFileId(
                    localSmqtkFileIdFromName(smqtkFolder, 'rotation.npy')))
            hash2uuidsKV = MemoryKeyValueStore(
                smqtkDataElementFromGirderFileId(
                    localSmqtkFileIdFromName(smqtkFolder,
                                             'hash2uuids.pickle')))

            return LSHNearestNeighborIndex(functor,
                                           descriptor_set,
                                           hash2uuidsKV,
                                           read_only=True)
예제 #21
0
    def test_fit_short_descriptors_for_bit_length(self):
        # Should error when input descriptors have fewer dimensions than set bit
        # length for output hash codes (limitation of PCA method currently
        # used).
        fit_descriptors = []
        for i in range(3):
            d = DescriptorMemoryElement(six.b('test'), i)
            d.set_vector([-1 + i, -1 + i])
            fit_descriptors.append(d)

        itq = ItqFunctor(bit_length=8)
        self.assertRaisesRegex(
            ValueError,
            "Input descriptors have fewer features than requested bit encoding",
            itq.fit, fit_descriptors)
        self.assertIsNone(itq.mean_vec)
        self.assertIsNone(itq.rotation)

        # Should behave the same when input is an iterable
        self.assertRaisesRegex(
            ValueError,
            "Input descriptors have fewer features than requested bit encoding",
            itq.fit, iter(fit_descriptors))
        self.assertIsNone(itq.mean_vec)
        self.assertIsNone(itq.rotation)
예제 #22
0
    def test_get_config_with_cache_elements(self):
        itq = ItqFunctor(bit_length=5, itq_iterations=6, normalize=7,
                         random_seed=8)
        itq.mean_vec_cache_elem = DataMemoryElement('cached vec bytes')
        itq.rotation_cache_elem = DataMemoryElement('cached rot bytes')

        c = itq.get_config()
        NT.assert_equal(c['bit_length'], 5)
        NT.assert_equal(c['itq_iterations'], 6)
        NT.assert_equal(c['normalize'], 7)
        NT.assert_equal(c['random_seed'], 8)
        NT.assert_equal(c['mean_vec_cache']['type'], "DataMemoryElement")
        NT.assert_equal(c['mean_vec_cache']['DataMemoryElement']['bytes'],
                        'cached vec bytes')
        NT.assert_equal(c['rotation_cache']['DataMemoryElement']['bytes'],
                        'cached rot bytes')
예제 #23
0
    def test_get_config_with_cache_elements(self):
        itq = ItqFunctor(bit_length=5, itq_iterations=6, normalize=7,
                         random_seed=8)
        itq.mean_vec_cache_elem = DataMemoryElement(b'cached vec bytes')
        itq.rotation_cache_elem = DataMemoryElement(b'cached rot bytes')

        c = itq.get_config()
        self.assertEqual(c['bit_length'], 5)
        self.assertEqual(c['itq_iterations'], 6)
        self.assertEqual(c['normalize'], 7)
        self.assertEqual(c['random_seed'], 8)
        self.assertEqual(c['mean_vec_cache']['type'], "DataMemoryElement")
        # Check using string encodings of set bytes (JSON compliant).
        self.assertEqual(c['mean_vec_cache']['DataMemoryElement']['bytes'],
                         'cached vec bytes')
        self.assertEqual(c['rotation_cache']['DataMemoryElement']['bytes'],
                         'cached rot bytes')
예제 #24
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
    def test_configuration_with_caches(self):
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        expected_mean_vec_bytes = BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_mean_vec_bytes, expected_mean_vec)
        expected_mean_vec_bytes = expected_mean_vec_bytes.getvalue()

        expected_rotation_bytes = BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_rotation_bytes, expected_rotation)
        expected_rotation_bytes = expected_rotation_bytes.getvalue()

        new_parts = {
            'mean_vec_cache': {
                'DataMemoryElement': {
                    'bytes': expected_mean_vec_bytes
                },
                'type': 'DataMemoryElement'
            },
            'rotation_cache': {
                'DataMemoryElement': {
                    'bytes': expected_rotation_bytes
                },
                'type': 'DataMemoryElement'
            },
            'bit_length': 153,
            'itq_iterations': 7,
            'normalize': 2,
            'random_seed': 58,
        }
        c = merge_dict(ItqFunctor.get_default_config(), new_parts)

        itq = ItqFunctor.from_config(c)

        # Checking that loaded parameters were correctly set and cache elements
        # correctly return intended vector/matrix.
        numpy.testing.assert_equal(itq.mean_vec, [1, 2, 3])
        numpy.testing.assert_equal(itq.rotation, [[1, 0, 0],
                                                  [0, 1, 0],
                                                  [0, 0, 1]])
        self.assertEqual(itq.bit_length, 153)
        self.assertEqual(itq.itq_iterations, 7)
        self.assertEqual(itq.normalize, 2)
        self.assertEqual(itq.random_seed, 58)
예제 #25
0
    def test_configuration_with_caches(self):
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        expected_mean_vec_bytes = BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_mean_vec_bytes, expected_mean_vec)
        expected_mean_vec_bytes = expected_mean_vec_bytes.getvalue()

        expected_rotation_bytes = BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_rotation_bytes, expected_rotation)
        expected_rotation_bytes = expected_rotation_bytes.getvalue()

        new_parts = {
            'mean_vec_cache': {
                'DataMemoryElement': {
                    'bytes': expected_mean_vec_bytes
                },
                'type': 'DataMemoryElement'
            },
            'rotation_cache': {
                'DataMemoryElement': {
                    'bytes': expected_rotation_bytes
                },
                'type': 'DataMemoryElement'
            },
            'bit_length': 153,
            'itq_iterations': 7,
            'normalize': 2,
            'random_seed': 58,
        }
        c = merge_dict(ItqFunctor.get_default_config(), new_parts)

        itq = ItqFunctor.from_config(c)

        # Checking that loaded parameters were correctly set and cache elements
        # correctly return intended vector/matrix.
        numpy.testing.assert_equal(itq.mean_vec, [1, 2, 3])
        numpy.testing.assert_equal(itq.rotation,
                                   [[1, 0, 0], [0, 1, 0], [0, 0, 1]])
        self.assertEqual(itq.bit_length, 153)
        self.assertEqual(itq.itq_iterations, 7)
        self.assertEqual(itq.normalize, 2)
        self.assertEqual(itq.random_seed, 58)
예제 #26
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
    def test_save_model_with_writable_caches(self):
        # If one or both cache elements are read-only, no saving.
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        expected_mean_vec_bytes = BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_mean_vec_bytes, expected_mean_vec)
        expected_mean_vec_bytes = expected_mean_vec_bytes.getvalue()

        expected_rotation_bytes = BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_rotation_bytes, expected_rotation)
        expected_rotation_bytes = expected_rotation_bytes.getvalue()

        itq = ItqFunctor()
        itq.mean_vec = expected_mean_vec
        itq.rotation = expected_rotation
        itq.mean_vec_cache_elem = DataMemoryElement(readonly=False)
        itq.rotation_cache_elem = DataMemoryElement(readonly=False)

        itq.save_model()
        self.assertEqual(itq.mean_vec_cache_elem.get_bytes(),
                         expected_mean_vec_bytes)
        self.assertEqual(itq.rotation_cache_elem.get_bytes(),
                         expected_rotation_bytes)
예제 #27
0
파일: tasks.py 프로젝트: spongezhang/SMQTK
def itq(task, folderId, **kwargs):
    """
    Celery task for training ITQ on a given folder.

    This trains ITQ on all descriptors within the index. Since this
    is typically called after computing descriptors, it will often
    only contain what's in the folder.

    :param task: Celery provided task object.
    :param folderId: The folder to train ITQ for, note this is only used to
        infer the descriptor index.
    """
    task.job_manager.updateProgress(message='Training ITQ', forceFlush=True)
    index = descriptorIndexFromFolderId(task.girder_client, folderId)

    if not index.count():
        # TODO SMQTK should account for this?
        raise Exception('Descriptor index is empty, cannot train ITQ.')

    smqtkFolder = getCreateFolder(task.girder_client, folderId, '.smqtk')
    meanVecFile = initializeItemWithFile(
        task.girder_client,
        createOverwriteItem(task.girder_client, smqtkFolder['_id'],
                            'mean_vec.npy'))
    rotationFile = initializeItemWithFile(
        task.girder_client,
        createOverwriteItem(task.girder_client, smqtkFolder['_id'],
                            'rotation.npy'))

    functor = ItqFunctor(
        mean_vec_cache=GirderDataElement(
            meanVecFile['_id'],
            api_root=task.request.apiUrl,
            token=task.request.jobInfoSpec['headers']['Girder-Token']),
        rotation_cache=GirderDataElement(
            rotationFile['_id'],
            api_root=task.request.apiUrl,
            token=task.request.jobInfoSpec['headers']['Girder-Token']))

    functor.fit(index.iterdescriptors(), use_multiprocessing=False)
예제 #28
0
 def test_configuration(self):
     i = LSHNearestNeighborIndex(lsh_functor=ItqFunctor(),
                                 descriptor_set=MemoryDescriptorSet(),
                                 hash2uuids_kvstore=MemoryKeyValueStore(),
                                 hash_index=LinearHashIndex(),
                                 distance_method='euclidean',
                                 read_only=True)
     for inst in configuration_test_helper(
             i):  # type: LSHNearestNeighborIndex
         assert isinstance(inst.lsh_functor, LshFunctor)
         assert isinstance(inst.descriptor_set, MemoryDescriptorSet)
         assert isinstance(inst.hash_index, LinearHashIndex)
         assert isinstance(inst.hash2uuids_kvstore, MemoryKeyValueStore)
         assert inst.distance_method == 'euclidean'
         assert inst.read_only is True
예제 #29
0
def main():
    description = """
    Tool for training the ITQ functor algorithm's model on descriptors in an
    index.

    By default, we use all descriptors in the configured index
    (``uuids_list_filepath`` is not given a value).

    The ``uuids_list_filepath`` configuration property is optional and should
    be used to specify a sub-set of descriptors in the configured index to
    train on. This only works if the stored descriptors' UUID is a type of
    string.
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):

        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()

        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
예제 #30
0
    def test_save_model_no_caches(self):
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        # Cache variables should remain None after save.
        itq = ItqFunctor()
        itq.mean_vec = expected_mean_vec
        itq.rotation = expected_rotation
        itq.save_model()
        self.assertIsNone(itq.mean_vec_cache_elem)
        self.assertIsNone(itq.mean_vec_cache_elem)
예제 #31
0
파일: train_itq.py 프로젝트: dhandeo/SMQTK
def main():
    description = """
    Tool for training the ITQ functor algorithm's model on descriptors in an
    index.

    By default, we use all descriptors in the configured index
    (``uuids_list_filepath`` is not given a value).

    The ``uuids_list_filepath`` configuration property is optional and should
    be used to specify a sub-set of descriptors in the configured index to
    train on. This only works if the stored descriptors' UUID is a type of
    string.
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):
        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()
        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
예제 #32
0
파일: tasks.py 프로젝트: spongezhang/SMQTK
def compute_hash_codes(task, folderId, **kwargs):
    """
    Celery task for computing hash codes on a given folder (descriptor index).

    :param task: Celery provided task object.
    :param folderId: The folder to train ITQ for, note this is only used to
        infer the descriptor index.
    """
    task.job_manager.updateProgress(message='Computing Hash Codes',
                                    forceFlush=True)

    index = descriptorIndexFromFolderId(task.girder_client, folderId)

    smqtkFolder = getCreateFolder(task.girder_client, folderId, '.smqtk')

    meanVecFileId = smqtkFileIdFromName(task.girder_client, smqtkFolder,
                                        'mean_vec.npy')
    rotationFileId = smqtkFileIdFromName(task.girder_client, smqtkFolder,
                                         'rotation.npy')
    hash2uuidsFile = initializeItemWithFile(
        task.girder_client,
        createOverwriteItem(task.girder_client, smqtkFolder['_id'],
                            'hash2uuids.pickle'))

    functor = ItqFunctor(
        mean_vec_cache=GirderDataElement(
            meanVecFileId,
            api_root=task.request.apiUrl,
            token=task.request.jobInfoSpec['headers']['Girder-Token']),
        rotation_cache=GirderDataElement(
            rotationFileId,
            api_root=task.request.apiUrl,
            token=task.request.jobInfoSpec['headers']['Girder-Token']))

    hash2uuids = compute_functions.compute_hash_codes(index.iterkeys(),
                                                      index,
                                                      functor,
                                                      use_mp=False)

    data = pickle.dumps(dict((y, x) for (x, y) in hash2uuids))
    task.girder_client.uploadFileContents(hash2uuidsFile['_id'],
                                          six.BytesIO(data), len(data))
예제 #33
0
파일: train_itq.py 프로젝트: fangi22/SMQTK
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorSet [type=%s]",
             config['descriptor_set']['type'])
    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(
        config['descriptor_set'],
        DescriptorSet.get_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):

        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()

        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_set.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorSet (count=%d)",
                 len(descriptor_set))
        d_iter = descriptor_set

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
예제 #34
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
 def test_is_usable(self):
     # Should always be usable due to no non-standard dependencies.
     self.assertTrue(ItqFunctor.is_usable())
예제 #35
0
    def test_get_hash(self):
        fit_descriptors = []
        for i in range(5):
            d = DescriptorMemoryElement(six.b('test'), i)
            d.set_vector([-2. + i, -2. + i])
            fit_descriptors.append(d)

        # The following "rotation" matrix should cause any 2-feature descriptor
        # to the right of the line ``y = -x`` to be True, and to the left as
        # False. If on the line, should be True.
        itq = ItqFunctor(bit_length=1, random_seed=0)
        itq.mean_vec = numpy.array([0., 0.])
        itq.rotation = numpy.array([[1. / sqrt(2)], [1. / sqrt(2)]])

        numpy.testing.assert_array_equal(itq.get_hash(numpy.array([1, 1])),
                                         [True])
        numpy.testing.assert_array_equal(itq.get_hash(numpy.array([-1, -1])),
                                         [False])

        numpy.testing.assert_array_equal(itq.get_hash(numpy.array([-1, 1])),
                                         [True])
        numpy.testing.assert_array_equal(
            itq.get_hash(numpy.array([-1.001, 1])), [False])
        numpy.testing.assert_array_equal(
            itq.get_hash(numpy.array([-1, 1.001])), [True])

        numpy.testing.assert_array_equal(itq.get_hash(numpy.array([1, -1])),
                                         [True])
        numpy.testing.assert_array_equal(
            itq.get_hash(numpy.array([1, -1.001])), [False])
        numpy.testing.assert_array_equal(
            itq.get_hash(numpy.array([1.001, -1])), [True])
예제 #36
0
def default_config():
    return {
        "itq_config": ItqFunctor.get_default_config(),
        "uuids_list_filepath": None,
        "descriptor_index": plugin.make_config(get_descriptor_index_impls()),
    }
예제 #37
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
    def test_get_hash(self):
        fit_descriptors = []
        for i in range(5):
            d = DescriptorMemoryElement(six.b('test'), i)
            d.set_vector([-2. + i, -2. + i])
            fit_descriptors.append(d)

        # The following "rotation" matrix should cause any 2-feature descriptor
        # to the right of the line ``y = -x`` to be True, and to the left as
        # False. If on the line, should be True.
        itq = ItqFunctor(bit_length=1, random_seed=0)
        itq.mean_vec = numpy.array([0., 0.])
        itq.rotation = numpy.array([[1. / sqrt(2)],
                                    [1. / sqrt(2)]])

        numpy.testing.assert_array_equal(
            itq.get_hash(numpy.array([1, 1])), [True])
        numpy.testing.assert_array_equal(
            itq.get_hash(numpy.array([-1, -1])), [False])

        numpy.testing.assert_array_equal(
            itq.get_hash(numpy.array([-1, 1])), [True])
        numpy.testing.assert_array_equal(
            itq.get_hash(numpy.array([-1.001, 1])), [False])
        numpy.testing.assert_array_equal(
            itq.get_hash(numpy.array([-1, 1.001])), [True])

        numpy.testing.assert_array_equal(
            itq.get_hash(numpy.array([1, -1])), [True])
        numpy.testing.assert_array_equal(
            itq.get_hash(numpy.array([1, -1.001])), [False])
        numpy.testing.assert_array_equal(
            itq.get_hash(numpy.array([1.001, -1])), [True])
예제 #38
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
    def test_save_model_with_read_only_cache(self):
        # If one or both cache elements are read-only, no saving.
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        itq = ItqFunctor()
        itq.mean_vec = expected_mean_vec
        itq.rotation = expected_rotation

        # read-only mean-vec cache
        itq.mean_vec_cache_elem = DataMemoryElement(readonly=True)
        itq.rotation_cache_elem = DataMemoryElement(readonly=False)
        itq.save_model()
        self.assertEqual(itq.mean_vec_cache_elem.get_bytes(), six.b(''))
        self.assertEqual(itq.rotation_cache_elem.get_bytes(), six.b(''))

        # read-only rotation cache
        itq.mean_vec_cache_elem = DataMemoryElement(readonly=False)
        itq.rotation_cache_elem = DataMemoryElement(readonly=True)
        itq.save_model()
        self.assertEqual(itq.mean_vec_cache_elem.get_bytes(), six.b(''))
        self.assertEqual(itq.rotation_cache_elem.get_bytes(), six.b(''))

        # Both read-only
        itq.mean_vec_cache_elem = DataMemoryElement(readonly=True)
        itq.rotation_cache_elem = DataMemoryElement(readonly=True)
        itq.save_model()
        self.assertEqual(itq.mean_vec_cache_elem.get_bytes(), six.b(''))
        self.assertEqual(itq.rotation_cache_elem.get_bytes(), six.b(''))
예제 #39
0
 def test_default_configuration(self):
     c = ItqFunctor.get_default_config()
     self.assertEqual(ItqFunctor.from_config(c).get_config(), c)
예제 #40
0
    def test_save_model_with_read_only_cache(self):
        # If one or both cache elements are read-only, no saving.
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        itq = ItqFunctor()
        itq.mean_vec = expected_mean_vec
        itq.rotation = expected_rotation

        # read-only mean-vec cache
        itq.mean_vec_cache_elem = DataMemoryElement(readonly=True)
        itq.rotation_cache_elem = DataMemoryElement(readonly=False)
        itq.save_model()
        self.assertEqual(itq.mean_vec_cache_elem.get_bytes(), six.b(''))
        self.assertEqual(itq.rotation_cache_elem.get_bytes(), six.b(''))

        # read-only rotation cache
        itq.mean_vec_cache_elem = DataMemoryElement(readonly=False)
        itq.rotation_cache_elem = DataMemoryElement(readonly=True)
        itq.save_model()
        self.assertEqual(itq.mean_vec_cache_elem.get_bytes(), six.b(''))
        self.assertEqual(itq.rotation_cache_elem.get_bytes(), six.b(''))

        # Both read-only
        itq.mean_vec_cache_elem = DataMemoryElement(readonly=True)
        itq.rotation_cache_elem = DataMemoryElement(readonly=True)
        itq.save_model()
        self.assertEqual(itq.mean_vec_cache_elem.get_bytes(), six.b(''))
        self.assertEqual(itq.rotation_cache_elem.get_bytes(), six.b(''))
예제 #41
0
 def test_is_usable(self):
     # Should always be usable due to no non-standard dependencies.
     self.assertTrue(ItqFunctor.is_usable())
예제 #42
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
 def test_default_configuration(self):
     c = ItqFunctor.get_default_config()
     self.assertEqual(ItqFunctor.from_config(c).get_config(), c)
예제 #43
0
 def test_build_index_read_only(self):
     index = LSHNearestNeighborIndex(ItqFunctor(),
                                     MemoryDescriptorIndex(),
                                     MemoryKeyValueStore(),
                                     read_only=True)
     ntools.assert_raises(ReadOnlyError, index.build_index, [])
예제 #44
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
 def test_has_model(self):
     itq = ItqFunctor()
     # with no vector/rotation set, should return false.
     self.assertFalse(itq.has_model())
     # If only one of the two is None, then false should be returned.
     itq.mean_vec = 'mean vec'
     itq.rotation = None
     self.assertFalse(itq.has_model())
     itq.mean_vec = None
     itq.rotation = 'rotation'
     self.assertFalse(itq.has_model())
     # If both are not None, return true.
     itq.mean_vec = 'mean vec'
     itq.rotation = 'rotation'
     self.assertTrue(itq.has_model())
예제 #45
0
파일: train_itq.py 프로젝트: fangi22/SMQTK
def default_config():
    return {
        "itq_config": ItqFunctor.get_default_config(),
        "uuids_list_filepath": None,
        "descriptor_set": make_default_config(DescriptorSet.get_impls()),
    }
예제 #46
0
 def test_has_model(self):
     itq = ItqFunctor()
     # with no vector/rotation set, should return false.
     self.assertFalse(itq.has_model())
     # If only one of the two is None, then false should be returned.
     itq.mean_vec = 'mean vec'
     itq.rotation = None
     self.assertFalse(itq.has_model())
     itq.mean_vec = None
     itq.rotation = 'rotation'
     self.assertFalse(itq.has_model())
     # If both are not None, return true.
     itq.mean_vec = 'mean vec'
     itq.rotation = 'rotation'
     self.assertTrue(itq.has_model())
예제 #47
0
파일: train_itq.py 프로젝트: Kitware/SMQTK
def default_config():
    return {
        "itq_config": ItqFunctor.get_default_config(),
        "uuids_list_filepath": None,
        "descriptor_index": plugin.make_config(get_descriptor_index_impls()),
    }