예제 #1
0
파일: train_itq.py 프로젝트: Kitware/SMQTK
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):
        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()
        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
예제 #2
0
    def test_configuration_with_caches(self):
        # This should run without error in both python
        # 2 and 3, as str/unicode are JSON compliant in both.
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        expected_mean_vec_bytes = six.BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_mean_vec_bytes, expected_mean_vec)
        expected_mean_vec_str = \
            expected_mean_vec_bytes.getvalue().decode(BYTES_CONFIG_ENCODING)

        expected_rotation_bytes = six.BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_rotation_bytes, expected_rotation)
        expected_rotation_str = \
            expected_rotation_bytes.getvalue().decode(BYTES_CONFIG_ENCODING)

        new_parts = {
            'mean_vec_cache': {
                'smqtk.representation.data_element.memory_element.DataMemoryElement':
                {
                    'bytes': expected_mean_vec_str
                },
                'type':
                'smqtk.representation.data_element.memory_element.DataMemoryElement'
            },
            'rotation_cache': {
                'smqtk.representation.data_element.memory_element.DataMemoryElement':
                {
                    'bytes': expected_rotation_str
                },
                'type':
                'smqtk.representation.data_element.memory_element.DataMemoryElement'
            },
            'bit_length': 153,
            'itq_iterations': 7,
            'normalize': 2,
            'random_seed': 58,
        }
        c = merge_dict(ItqFunctor.get_default_config(), new_parts)

        itq = ItqFunctor.from_config(c)

        # Checking that loaded parameters were correctly set and cache elements
        # correctly return intended vector/matrix.
        numpy.testing.assert_equal(itq.mean_vec, [1, 2, 3])
        numpy.testing.assert_equal(itq.rotation,
                                   [[1, 0, 0], [0, 1, 0], [0, 0, 1]])
        self.assertEqual(itq.bit_length, 153)
        self.assertEqual(itq.itq_iterations, 7)
        self.assertEqual(itq.normalize, 2)
        self.assertEqual(itq.random_seed, 58)
예제 #3
0
def main():
    description = """
    Tool for training the ITQ functor algorithm's model on descriptors in an
    index.

    By default, we use all descriptors in the configured index
    (``uuids_list_filepath`` is not given a value).

    The ``uuids_list_filepath`` configuration property is optional and should
    be used to specify a sub-set of descriptors in the configured index to
    train on. This only works if the stored descriptors' UUID is a type of
    string.
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):

        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()

        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
예제 #4
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
    def test_configuration_with_caches(self):
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        expected_mean_vec_bytes = BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_mean_vec_bytes, expected_mean_vec)
        expected_mean_vec_bytes = expected_mean_vec_bytes.getvalue()

        expected_rotation_bytes = BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_rotation_bytes, expected_rotation)
        expected_rotation_bytes = expected_rotation_bytes.getvalue()

        new_parts = {
            'mean_vec_cache': {
                'DataMemoryElement': {
                    'bytes': expected_mean_vec_bytes
                },
                'type': 'DataMemoryElement'
            },
            'rotation_cache': {
                'DataMemoryElement': {
                    'bytes': expected_rotation_bytes
                },
                'type': 'DataMemoryElement'
            },
            'bit_length': 153,
            'itq_iterations': 7,
            'normalize': 2,
            'random_seed': 58,
        }
        c = merge_dict(ItqFunctor.get_default_config(), new_parts)

        itq = ItqFunctor.from_config(c)

        # Checking that loaded parameters were correctly set and cache elements
        # correctly return intended vector/matrix.
        numpy.testing.assert_equal(itq.mean_vec, [1, 2, 3])
        numpy.testing.assert_equal(itq.rotation, [[1, 0, 0],
                                                  [0, 1, 0],
                                                  [0, 0, 1]])
        self.assertEqual(itq.bit_length, 153)
        self.assertEqual(itq.itq_iterations, 7)
        self.assertEqual(itq.normalize, 2)
        self.assertEqual(itq.random_seed, 58)
예제 #5
0
    def test_configuration_with_caches(self):
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        expected_mean_vec_bytes = BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_mean_vec_bytes, expected_mean_vec)
        expected_mean_vec_bytes = expected_mean_vec_bytes.getvalue()

        expected_rotation_bytes = BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_rotation_bytes, expected_rotation)
        expected_rotation_bytes = expected_rotation_bytes.getvalue()

        new_parts = {
            'mean_vec_cache': {
                'DataMemoryElement': {
                    'bytes': expected_mean_vec_bytes
                },
                'type': 'DataMemoryElement'
            },
            'rotation_cache': {
                'DataMemoryElement': {
                    'bytes': expected_rotation_bytes
                },
                'type': 'DataMemoryElement'
            },
            'bit_length': 153,
            'itq_iterations': 7,
            'normalize': 2,
            'random_seed': 58,
        }
        c = merge_dict(ItqFunctor.get_default_config(), new_parts)

        itq = ItqFunctor.from_config(c)

        # Checking that loaded parameters were correctly set and cache elements
        # correctly return intended vector/matrix.
        numpy.testing.assert_equal(itq.mean_vec, [1, 2, 3])
        numpy.testing.assert_equal(itq.rotation,
                                   [[1, 0, 0], [0, 1, 0], [0, 0, 1]])
        self.assertEqual(itq.bit_length, 153)
        self.assertEqual(itq.itq_iterations, 7)
        self.assertEqual(itq.normalize, 2)
        self.assertEqual(itq.random_seed, 58)
예제 #6
0
파일: train_itq.py 프로젝트: dhandeo/SMQTK
def main():
    description = """
    Tool for training the ITQ functor algorithm's model on descriptors in an
    index.

    By default, we use all descriptors in the configured index
    (``uuids_list_filepath`` is not given a value).

    The ``uuids_list_filepath`` configuration property is optional and should
    be used to specify a sub-set of descriptors in the configured index to
    train on. This only works if the stored descriptors' UUID is a type of
    string.
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):
        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()
        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
예제 #7
0
파일: train_itq.py 프로젝트: fangi22/SMQTK
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorSet [type=%s]",
             config['descriptor_set']['type'])
    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(
        config['descriptor_set'],
        DescriptorSet.get_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):

        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()

        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_set.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorSet (count=%d)",
                 len(descriptor_set))
        d_iter = descriptor_set

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
예제 #8
0
 def test_default_configuration(self):
     c = ItqFunctor.get_default_config()
     self.assertEqual(ItqFunctor.from_config(c).get_config(), c)
예제 #9
0
파일: test_itq.py 프로젝트: Kitware/SMQTK
 def test_default_configuration(self):
     c = ItqFunctor.get_default_config()
     self.assertEqual(ItqFunctor.from_config(c).get_config(), c)