コード例 #1
0
def test_to_config_dict():
    """
    Test that the second-level helper function is called appropriately and
    directly returns.
    """
    expected_ret_val = 'expected return value'

    with mock.patch('smqtk.utils.configuration.cls_conf_to_config_dict') \
            as m_cctcd:
        m_cctcd.return_value = expected_ret_val

        i1 = T1()
        i1_expected_conf = {
            'foo': 1,
            'bar': 'baz',
        }
        r1 = to_config_dict(i1)
        m_cctcd.assert_called_once_with(T1, i1_expected_conf)
        assert r1 == expected_ret_val

    with mock.patch('smqtk.utils.configuration.cls_conf_to_config_dict') \
            as m_cctcd:
        m_cctcd.return_value = expected_ret_val

        i2 = T1(foo=8)
        i2_expected_conf = {
            'foo': 8,
            'bar': 'baz',
        }
        r2 = to_config_dict(i2)
        m_cctcd.assert_called_once_with(T1, i2_expected_conf)
        assert r2 == expected_ret_val
コード例 #2
0
    def test_config_cycle_imagemean_nonetyped(self, m_cdg_setup_network):
        """
        Test being able to get an instances config and use that config to
        construct an equivalently parameterized instance where the second
        instance is configured with a None-typed  'image_mean' parameter.
        """
        # Mocking ``_setup_network`` so no caffe functionality is hit during
        # this test

        # Only required parameters, image_mean is empty SMQTK configuration
        # dict
        g1 = CaffeDescriptorGenerator(self.dummy_net_topo_elem,
                                      self.dummy_caffe_model_elem)
        g1_config = g1.get_config()
        # Modify config for g2 to pass None for image_mean
        for_g2 = dict(g1_config)
        for_g2['image_mean'] = {'type': None}
        g2 = CaffeDescriptorGenerator.from_config(for_g2)
        expected_config = {
            'network_prototxt': to_config_dict(self.dummy_net_topo_elem),
            'network_model': to_config_dict(self.dummy_caffe_model_elem),
            'image_mean': None,
            'return_layer': 'fc7',
            'batch_size': 1,
            'use_gpu': False,
            'gpu_device_id': 0,
            'network_is_bgr': True,
            'data_layer': 'data',
            'load_truncated_images': False,
            'pixel_rescale': None,
            'input_scale': None,
            'threads': None,
        }
        assert g1_config == g2.get_config() == expected_config
コード例 #3
0
    def get_config(self):
        """
        Return a JSON-compliant dictionary that could be passed to this class's
        ``from_config`` method to produce an instance with identical
        configuration.

        In the common case, this involves naming the keys of the dictionary
        based on the initialization argument names as if it were to be passed
        to the constructor via dictionary expansion.

        :return: JSON type compliant configuration dictionary.
        :rtype: dict

        """
        if self.image_mean is not None:
            image_mean_config = to_config_dict(self.image_mean)
        else:
            image_mean_config = None
        return {
            "network_prototxt": to_config_dict(self.network_prototxt),
            "network_model": to_config_dict(self.network_model),
            "image_mean": image_mean_config,
            "return_layer": self.return_layer,
            "batch_size": self.batch_size,
            "use_gpu": self.use_gpu,
            "gpu_device_id": self.gpu_device_id,
            "network_is_bgr": self.network_is_bgr,
            "data_layer": self.data_layer,
            "load_truncated_images": self.load_truncated_images,
            "pixel_rescale": self.pixel_rescale,
            "input_scale": self.input_scale,
            "threads": self.threads,
        }
コード例 #4
0
ファイル: __init__.py プロジェクト: sanyarud/SMQTK
 def get_config(self):
     hi_conf = None
     if self.hash_index is not None:
         hi_conf = to_config_dict(self.hash_index)
     return {
         "lsh_functor": to_config_dict(self.lsh_functor),
         "descriptor_index": to_config_dict(self.descriptor_index),
         "hash_index": hi_conf,
         "hash2uuids_kvstore": to_config_dict(self.hash2uuids_kvstore),
         "distance_method": self.distance_method,
         "read_only": self.read_only,
     }
コード例 #5
0
def test_to_config_dict_given_non_configurable():
    """
    Test that ``to_config_dict`` errors when passed an instance that does not
    descend from configurable.
    """
    class SomeOtherClassType(object):
        pass

    inst = SomeOtherClassType()
    with pytest.raises(ValueError,
                       match="c_inst must be an instance and its type must "
                       "subclass from Configurable\."):
        # noinspection PyTypeChecker
        to_config_dict(inst)
コード例 #6
0
 def get_config(self):
     # If no cache elements (set to None), return default plugin configs.
     c = merge_dict(
         self.get_default_config(), {
             "bit_length": self.bit_length,
             "itq_iterations": self.itq_iterations,
             "normalize": self.normalize,
             "random_seed": self.random_seed,
         })
     if self.mean_vec_cache_elem:
         c['mean_vec_cache'] = \
             to_config_dict(self.mean_vec_cache_elem)
     if self.rotation_cache_elem:
         c['rotation_cache'] = \
             to_config_dict(self.rotation_cache_elem)
     return c
コード例 #7
0
    def test_get_config(self, _m_cdg_setupNetwork):
        # Mocking set_network so we don't have to worry about actually
        # initializing any caffe things for this test.
        expected_params = {
            'network_prototxt': DataMemoryElement(),
            'network_model': DataMemoryElement(),
            'image_mean': DataMemoryElement(),
            'return_layer': 'layer name',
            'batch_size': 777,
            'use_gpu': False,
            'gpu_device_id': 8,
            'network_is_bgr': False,
            'data_layer': 'data-other',
            'load_truncated_images': True,
            'pixel_rescale': (.2, .8),
            'input_scale': 1.5,
            'threads': 14,
        }
        # make sure that we're considering all constructor parameter
        # options
        default_params = CaffeDescriptorGenerator.get_default_config()
        assert set(default_params) == set(expected_params)
        g = CaffeDescriptorGenerator(**expected_params)

        # Shift to expecting sub-configs for DataElement params
        for key in ('network_prototxt', 'network_model', 'image_mean'):
            expected_params[key] = to_config_dict(expected_params[key])
        assert g.get_config() == expected_params
コード例 #8
0
 def get_config(self):
     c = merge_dict(self.get_default_config(), {
         "pickle_protocol": self.pickle_protocol,
     })
     if self.cache_element:
         merge_dict(c['cache_element'], to_config_dict(self.cache_element))
     return c
コード例 #9
0
ファイル: test_caffe.py プロジェクト: sanyarud/SMQTK
 def test_get_config(self, _m_cdg_setupNetwork):
     # Mocking set_network so we don't have to worry about actually
     # initializing any caffe things for this test.
     expected_params = {
         'network_prototxt': DataMemoryElement(),
         'network_model': DataMemoryElement(),
         'image_mean': DataMemoryElement(),
         'return_layer': 'layer name',
         'batch_size': 777,
         'use_gpu': False,
         'gpu_device_id': 8,
         'network_is_bgr': False,
         'data_layer': 'data-other',
         'load_truncated_images': True,
         'pixel_rescale': (.2, .8),
         'input_scale': 1.5,
     }
     # make sure that we're considering all constructor parameter options
     expected_param_keys = \
         set(inspect.getargspec(CaffeDescriptorGenerator.__init__)
                    .args[1:])
     self.assertSetEqual(set(expected_params.keys()),
                         expected_param_keys)
     g = CaffeDescriptorGenerator(**expected_params)
     for key in ('network_prototxt', 'network_model', 'image_mean'):
         expected_params[key] = to_config_dict(expected_params[key])
     self.assertEqual(g.get_config(), expected_params)
コード例 #10
0
 def get_config(self):
     # Recursively get config from data element if we have one.
     if hasattr(self._cache_element, 'get_config'):
         elem_config = to_config_dict(self._cache_element)
     else:
         # No cache element, output default config with no type.
         elem_config = make_default_config(DataElement.get_impls())
     return {'cache_element': elem_config}
コード例 #11
0
 def get_config(self):
     c = merge_dict(self.get_default_config(), {
         'leaf_size': self.leaf_size,
         'random_seed': self.random_seed,
     })
     if self.cache_element:
         c['cache_element'] = merge_dict(c['cache_element'],
                                         to_config_dict(self.cache_element))
     return c
コード例 #12
0
    def get_config(self):
        config = {
            "descriptor_set": to_config_dict(self._descriptor_set),
            "uid2idx_kvs": to_config_dict(self._uid2idx_kvs),
            "idx2uid_kvs": to_config_dict(self._idx2uid_kvs),
            "factory_string": self.factory_string,
            "ivf_nprobe": self._ivf_nprobe,
            "read_only": self.read_only,
            "random_seed": self.random_seed,
            "use_gpu": self._use_gpu,
            "gpu_id": self._gpu_id,
        }
        if self._index_element:
            config['index_element'] = to_config_dict(self._index_element)
        if self._index_param_element:
            config['index_param_element'] = to_config_dict(
                self._index_param_element)

        return config
コード例 #13
0
ファイル: mrpt.py プロジェクト: vishalbelsare/SMQTK
 def get_config(self):
     return {
         "descriptor_set": to_config_dict(self._descriptor_set),
         "index_filepath": self._index_filepath,
         "parameters_filepath": self._index_param_filepath,
         "read_only": self._read_only,
         "random_seed": self._rand_seed,
         "pickle_protocol": self._pickle_protocol,
         "use_multiprocessing": self._use_multiprocessing,
         "depth": self._depth,
         "num_trees": self._num_trees,
     }
コード例 #14
0
def test_to_config_dict_given_type():
    """
    Test that ``to_config_dict`` errors when passed a type.
    """
    # Just with `object`.
    with pytest.raises(ValueError,
                       match="c_inst must be an instance and its type must "
                       "subclass from Configurable\."):
        # noinspection PyTypeChecker
        to_config_dict(object)

    # Literally the Configurable interface (abstract class)
    with pytest.raises(ValueError,
                       match="c_inst must be an instance and its type must "
                       "subclass from Configurable\."):
        # noinspection PyTypeChecker
        to_config_dict(Configurable)

    # New sub-class implementing Configurable
    class SomeConfigurableType(Configurable):
        def get_config(self):
            return {}

    with pytest.raises(ValueError):
        # noinspection PyTypeChecker
        to_config_dict(SomeConfigurableType)
コード例 #15
0
    def get_default_config(cls):
        """
        Generate and return a default configuration dictionary for this class.

        It is not be guaranteed that the configuration dictionary returned
        from this method is valid for construction of an instance of this class.

        :return: Default configuration dictionary for the class.
        :rtype: dict
        """
        c = super(KVSDataSet, cls).get_default_config()
        c['kvstore'] = merge_dict(
            make_default_config(KeyValueStore.get_impls()),
            to_config_dict(c['kvstore']))
        return c
コード例 #16
0
ファイル: memory_set.py プロジェクト: sanyarud/SMQTK
    def get_config(self):
        """
        This implementation has no configuration properties.

        :return: JSON type compliant configuration dictionary.
        :rtype: dict

        """
        c = merge_dict(self.get_default_config(), {
            "pickle_protocol": self.pickle_protocol,
        })
        if self.cache_element:
            c['cache_element'] = merge_dict(c['cache_element'],
                                            to_config_dict(self.cache_element))
        return c
コード例 #17
0
    def get_config(self):
        """
        Return a JSON-compliant dictionary that could be passed to this class's
        ``from_config`` method to produce an instance with identical
        configuration.

        In the most cases, this involves naming the keys of the dictionary
        based on the initialization argument names as if it were to be passed
        to the constructor via dictionary expansion.  In some cases, where it
        doesn't make sense to store some object constructor parameters are
        expected to be supplied at as configuration values (i.e. must be
        supplied at runtime), this method's returned dictionary may leave those
        parameters out. In such cases, the object's ``from_config``
        class-method would also take additional positional arguments to fill in
        for the parameters that this returned configuration lacks.

        :return: JSON type compliant configuration dictionary.
        :rtype: dict

        """
        return {
            'image_reader': to_config_dict(self._image_reader),
        }
コード例 #18
0
 def get_config(self):
     return {'kvstore': to_config_dict(self._kvstore)}
コード例 #19
0
 def get_config(self):
     with self._label_to_classifier_lock:
         c = dict((label, to_config_dict(classifier)) for label, classifier
                  in six.iteritems(self._label_to_classifier))
     return c
コード例 #20
0
ファイル: iqr_search.py プロジェクト: vishalbelsare/SMQTK
 def get_config(self):
     return {
         'iqr_service_url': self._iqr_service.url,
         'working_directory': self._working_dir,
         'data_set': to_config_dict(self._data_set),
     }
コード例 #21
0
ファイル: linear.py プロジェクト: sanyarud/SMQTK
 def get_config(self):
     c = self.get_default_config()
     if self.cache_element:
         c['cache_element'] = merge_dict(c['cache_element'],
                                         to_config_dict(self.cache_element))
     return c