Exemple #1
0
    def test_fit_with_cache(self):
        fit_descriptors = []
        for i in range(5):
            d = DescriptorMemoryElement(six.b('test'), i)
            d.set_vector([-2. + i, -2. + i])
            fit_descriptors.append(d)

        itq = ItqFunctor(DataMemoryElement(),
                         DataMemoryElement(),
                         bit_length=1,
                         random_seed=0)
        itq.fit(fit_descriptors)

        # TODO: Explanation as to why this is the expected result.
        numpy.testing.assert_array_almost_equal(itq.mean_vec, [0, 0])
        numpy.testing.assert_array_almost_equal(itq.rotation,
                                                [[1 / sqrt(2)], [1 / sqrt(2)]])
        self.assertIsNotNone(itq.mean_vec_cache_elem)
        numpy.testing.assert_array_almost_equal(
            numpy.load(six.BytesIO(itq.mean_vec_cache_elem.get_bytes())),
            [0, 0])

        self.assertIsNotNone(itq.rotation_cache_elem)
        numpy.testing.assert_array_almost_equal(
            numpy.load(six.BytesIO(itq.rotation_cache_elem.get_bytes())),
            [[1 / sqrt(2)], [1 / sqrt(2)]])
Exemple #2
0
def itq(task, folderId, **kwargs):
    """
    Celery task for training ITQ on a given folder.

    This trains ITQ on all descriptors within the index. Since this
    is typically called after computing descriptors, it will often
    only contain what's in the folder.

    :param task: Celery provided task object.
    :param folderId: The folder to train ITQ for, note this is only used to
        infer the descriptor index.
    """
    task.job_manager.updateProgress(message='Training ITQ', forceFlush=True)
    index = descriptorIndexFromFolderId(task.girder_client, folderId)

    if not index.count():
        # TODO SMQTK should account for this?
        raise Exception('Descriptor index is empty, cannot train ITQ.')

    smqtkFolder = getCreateFolder(task.girder_client, folderId, '.smqtk')
    meanVecFile = initializeItemWithFile(task.girder_client,
                                         createOverwriteItem(task.girder_client, smqtkFolder['_id'], 'mean_vec.npy'))
    rotationFile = initializeItemWithFile(task.girder_client,
                                          createOverwriteItem(task.girder_client, smqtkFolder['_id'], 'rotation.npy'))

    functor = ItqFunctor(mean_vec_cache=GirderDataElement(meanVecFile['_id'], api_root=task.request.apiUrl,
                                                          token=task.request.jobInfoSpec['headers']['Girder-Token']),
                         rotation_cache=GirderDataElement(rotationFile['_id'], api_root=task.request.apiUrl,
                                                          token=task.request.jobInfoSpec['headers']['Girder-Token']))

    functor.fit(index.iterdescriptors(), use_multiprocessing=False)
Exemple #3
0
    def test_fit_with_cache(self):
        fit_descriptors = []
        for i in range(5):
            d = DescriptorMemoryElement(six.b('test'), i)
            d.set_vector([-2. + i, -2. + i])
            fit_descriptors.append(d)

        itq = ItqFunctor(DataMemoryElement(), DataMemoryElement(),
                         bit_length=1, random_seed=0)
        itq.fit(fit_descriptors)

        # TODO: Explanation as to why this is the expected result.
        numpy.testing.assert_array_almost_equal(itq.mean_vec, [0, 0])
        numpy.testing.assert_array_almost_equal(itq.rotation, [[1 / sqrt(2)],
                                                               [1 / sqrt(2)]])
        self.assertIsNotNone(itq.mean_vec_cache_elem)
        numpy.testing.assert_array_almost_equal(
            numpy.load(BytesIO(itq.mean_vec_cache_elem.get_bytes())),
            [0, 0]
        )

        self.assertIsNotNone(itq.rotation_cache_elem)
        numpy.testing.assert_array_almost_equal(
            numpy.load(BytesIO(itq.rotation_cache_elem.get_bytes())),
            [[1 / sqrt(2)],
             [1 / sqrt(2)]]
        )
Exemple #4
0
    def test_fit(self):
        fit_descriptors = []
        for i in range(5):
            d = DescriptorMemoryElement('test', i)
            d.set_vector([-2. + i, -2. + i])
            fit_descriptors.append(d)

        itq = ItqFunctor(bit_length=1, random_seed=0)
        itq.fit(fit_descriptors)

        # TODO: Explanation as to why this is the expected result.
        numpy.testing.assert_array_almost_equal(itq.mean_vec, [0, 0])
        numpy.testing.assert_array_almost_equal(itq.rotation,
                                                [[1 / sqrt(2)], [1 / sqrt(2)]])
        NT.assert_is_none(itq.mean_vec_cache_elem)
        NT.assert_is_none(itq.rotation_cache_elem)
Exemple #5
0
def itq(task, folderId, **kwargs):
    """
    Celery task for training ITQ on a given folder.

    This trains ITQ on all descriptors within the index. Since this
    is typically called after computing descriptors, it will often
    only contain what's in the folder.

    :param task: Celery provided task object.
    :param folderId: The folder to train ITQ for, note this is only used to
        infer the descriptor index.
    """
    task.job_manager.updateProgress(message='Training ITQ', forceFlush=True)
    index = descriptorIndexFromFolderId(task.girder_client, folderId)

    if not index.count():
        # TODO SMQTK should account for this?
        raise Exception('Descriptor index is empty, cannot train ITQ.')

    smqtkFolder = getCreateFolder(task.girder_client, folderId, '.smqtk')
    meanVecFile = initializeItemWithFile(
        task.girder_client,
        createOverwriteItem(task.girder_client, smqtkFolder['_id'],
                            'mean_vec.npy'))
    rotationFile = initializeItemWithFile(
        task.girder_client,
        createOverwriteItem(task.girder_client, smqtkFolder['_id'],
                            'rotation.npy'))

    functor = ItqFunctor(
        mean_vec_cache=GirderDataElement(
            meanVecFile['_id'],
            api_root=task.request.apiUrl,
            token=task.request.jobInfoSpec['headers']['Girder-Token']),
        rotation_cache=GirderDataElement(
            rotationFile['_id'],
            api_root=task.request.apiUrl,
            token=task.request.jobInfoSpec['headers']['Girder-Token']))

    functor.fit(index.iterdescriptors(), use_multiprocessing=False)