예제 #1
0
    def test_configuration_with_caches(self) -> None:
        # This should run without error in both python
        # 2 and 3, as str/unicode are JSON compliant in both.
        expected_mean_vec = numpy.array([1, 2, 3])
        expected_rotation = numpy.eye(3)

        expected_mean_vec_bytes = BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_mean_vec_bytes, expected_mean_vec)
        expected_mean_vec_str = \
            expected_mean_vec_bytes.getvalue().decode(BYTES_CONFIG_ENCODING)

        expected_rotation_bytes = BytesIO()
        # noinspection PyTypeChecker
        numpy.save(expected_rotation_bytes, expected_rotation)
        expected_rotation_str = \
            expected_rotation_bytes.getvalue().decode(BYTES_CONFIG_ENCODING)

        new_parts = {
            'mean_vec_cache': {
                'smqtk_dataprovider.impls.data_element.memory.DataMemoryElement':
                {
                    'bytes': expected_mean_vec_str
                },
                'type':
                'smqtk_dataprovider.impls.data_element.memory.DataMemoryElement'
            },
            'rotation_cache': {
                'smqtk_dataprovider.impls.data_element.memory.DataMemoryElement':
                {
                    'bytes': expected_rotation_str
                },
                'type':
                'smqtk_dataprovider.impls.data_element.memory.DataMemoryElement'
            },
            'bit_length': 153,
            'itq_iterations': 7,
            'normalize': 2,
            'random_seed': 58,
        }
        c = merge_dict(ItqFunctor.get_default_config(), new_parts)

        itq = ItqFunctor.from_config(c)

        # Checking that loaded parameters were correctly set and cache elements
        # correctly return intended vector/matrix.
        numpy.testing.assert_equal(itq.mean_vec, [1, 2, 3])
        numpy.testing.assert_equal(itq.rotation,
                                   [[1, 0, 0], [0, 1, 0], [0, 0, 1]])
        self.assertEqual(itq.bit_length, 153)
        self.assertEqual(itq.itq_iterations, 7)
        self.assertEqual(itq.normalize, 2)
        self.assertEqual(itq.random_seed, 58)
예제 #2
0
 def test_default_configuration(self) -> None:
     c = ItqFunctor.get_default_config()
     self.assertEqual(ItqFunctor.from_config(c).get_config(), c)