示例#1
0
    def test_encode_decode_type_spec(self):
        spec = tensor_spec.TensorSpec((1, 5), dtypes.float32)
        string = json_utils.Encoder().encode(spec)
        loaded = json_utils.decode(string)
        self.assertEqual(spec, loaded)

        invalid_type_spec = {
            'class_name': 'TypeSpec',
            'type_spec': 'Invalid Type',
            'serialized': None
        }
        string = json_utils.Encoder().encode(invalid_type_spec)
        with self.assertRaisesRegexp(ValueError,
                                     'No TypeSpec has been registered'):
            loaded = json_utils.decode(string)
示例#2
0
    def test_encode_decode_tuple(self):
        metadata = {'key1': (3, 5), 'key2': [(1, (3, 4)), (1, )]}
        string = json_utils.Encoder().encode(metadata)
        loaded = json_utils.decode(string)

        self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
        self.assertAllEqual(loaded['key1'], (3, 5))
        self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1, )])
示例#3
0
    def tracking_metadata(self):
        """String stored in metadata field in the SavedModel proto.

    Returns:
      A serialized JSON storing information necessary for recreating this layer.
    """
        # TODO(kathywu): check that serialized JSON can be loaded (e.g., if an
        # object is in the python property)
        return json_utils.Encoder().encode(self.python_properties)
    def test_encode_decode_enum(self):
        class Enum(enum.Enum):
            CLASS_A = 'a'
            CLASS_B = 'b'

        config = {'key': Enum.CLASS_A, 'key2': Enum.CLASS_B}
        string = json_utils.Encoder().encode(config)
        loaded = json_utils.decode(string)
        self.assertAllEqual({'key': 'a', 'key2': 'b'}, loaded)
示例#5
0
    def test_encode_decode_tensor_shape(self):
        metadata = {
            'key1':
            tensor_shape.TensorShape(None),
            'key2': [
                tensor_shape.TensorShape([None]),
                tensor_shape.TensorShape([3, None, 5])
            ]
        }
        string = json_utils.Encoder().encode(metadata)
        loaded = json_utils.decode(string)

        self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
        self.assertAllEqual(loaded['key1'].rank, None)
        self.assertAllEqual(loaded['key2'][0].as_list(), [None])
        self.assertAllEqual(loaded['key2'][1].as_list(), [3, None, 5])