def test_encode_decode_type_spec(self): spec = tf.TensorSpec((1, 5), tf.float32) string = json_utils.Encoder().encode(spec) loaded = json_utils.decode(string) self.assertEqual(spec, loaded) invalid_type_spec = { "class_name": "TypeSpec", "type_spec": "Invalid Type", "serialized": None, } string = json_utils.Encoder().encode(invalid_type_spec) with self.assertRaisesRegexp(ValueError, "No TypeSpec has been registered"): loaded = json_utils.decode(string)
def test_encode_decode_type_spec(self): spec = tf.TensorSpec((1, 5), tf.float32) string = json_utils.Encoder().encode(spec) loaded = json_utils.decode(string) self.assertEqual(spec, loaded) invalid_type_spec = { 'class_name': 'TypeSpec', 'type_spec': 'Invalid Type', 'serialized': None } string = json_utils.Encoder().encode(invalid_type_spec) with self.assertRaisesRegexp(ValueError, 'No TypeSpec has been registered'): loaded = json_utils.decode(string)
def test_encode_decode_tuple(self): metadata = {"key1": (3, 5), "key2": [(1, (3, 4)), (1, )]} string = json_utils.Encoder().encode(metadata) loaded = json_utils.decode(string) self.assertEqual(set(loaded.keys()), {"key1", "key2"}) self.assertAllEqual(loaded["key1"], (3, 5)) self.assertAllEqual(loaded["key2"], [(1, (3, 4)), (1, )])
def test_encode_decode_tuple(self): metadata = {'key1': (3, 5), 'key2': [(1, (3, 4)), (1, )]} string = json_utils.Encoder().encode(metadata) loaded = json_utils.decode(string) self.assertEqual(set(loaded.keys()), {'key1', 'key2'}) self.assertAllEqual(loaded['key1'], (3, 5)) self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1, )])
def test_encode_decode_enum(self): class Enum(enum.Enum): CLASS_A = "a" CLASS_B = "b" config = {"key": Enum.CLASS_A, "key2": Enum.CLASS_B} string = json_utils.Encoder().encode(config) loaded = json_utils.decode(string) self.assertAllEqual({"key": "a", "key2": "b"}, loaded)
def test_encode_decode_enum(self): class Enum(enum.Enum): CLASS_A = 'a' CLASS_B = 'b' config = {'key': Enum.CLASS_A, 'key2': Enum.CLASS_B} string = json_utils.Encoder().encode(config) loaded = json_utils.decode(string) self.assertAllEqual({'key': 'a', 'key2': 'b'}, loaded)
def tracking_metadata(self): """String stored in metadata field in the SavedModel proto. Returns: A serialized JSON storing information necessary for recreating this layer. """ # TODO(kathywu): check that serialized JSON can be loaded (e.g., if an # object is in the python property) return json_utils.Encoder().encode(self.python_properties)
def test_encode_decode_extension_type_tensor(self): class MaskedTensor(tf.experimental.ExtensionType): __name__ = 'MaskedTensor' values: tf.Tensor mask: tf.Tensor x = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]], mask=[[True, True, False], [True, False, True]]) string = json_utils.Encoder().encode(x) loaded = json_utils.decode(string) self.assertAllEqual(loaded, x)
def test_encode_decode_tensor_shape(self): metadata = { "key1": tf.TensorShape(None), "key2": [tf.TensorShape([None]), tf.TensorShape([3, None, 5])], } string = json_utils.Encoder().encode(metadata) loaded = json_utils.decode(string) self.assertEqual(set(loaded.keys()), {"key1", "key2"}) self.assertAllEqual(loaded["key1"].rank, None) self.assertAllEqual(loaded["key2"][0].as_list(), [None]) self.assertAllEqual(loaded["key2"][1].as_list(), [3, None, 5])
def test_encode_decode_tensor_shape(self): metadata = { 'key1': tf.TensorShape(None), 'key2': [tf.TensorShape([None]), tf.TensorShape([3, None, 5])] } string = json_utils.Encoder().encode(metadata) loaded = json_utils.decode(string) self.assertEqual(set(loaded.keys()), {'key1', 'key2'}) self.assertAllEqual(loaded['key1'].rank, None) self.assertAllEqual(loaded['key2'][0].as_list(), [None]) self.assertAllEqual(loaded['key2'][1].as_list(), [3, None, 5])
def _update_to_current_version(metadata): """Applies version updates to the metadata proto for backwards compat.""" for node in metadata.nodes: if node.version.producer == 1 and node.identifier in [ constants.MODEL_IDENTIFIER, constants.SEQUENTIAL_IDENTIFIER, constants.NETWORK_IDENTIFIER]: node_metadata = json_utils.decode(node.metadata) save_spec = node_metadata.get('save_spec') if save_spec is not None: node_metadata['full_save_spec'] = ([save_spec], {}) node.metadata = json_utils.Encoder().encode(node_metadata) return metadata
def testAddFullSaveSpec(self): save_spec = tf.TensorSpec([3, 5], dtype=tf.int32) node_metadata = json_utils.Encoder().encode({'save_spec': save_spec}) metadata = saved_metadata_pb2.SavedMetadata() metadata.nodes.add( version=versions_pb2.VersionDef( producer=1, min_consumer=1, bad_consumers=[]), identifier='_tf_keras_model', metadata=node_metadata) # pylint: disable=protected-access new_metadata = keras_load._update_to_current_version(metadata) node_metadata = json_utils.decode(new_metadata.nodes[0].metadata) expected_full_spec = ([tf.TensorSpec(shape=(3, 5), dtype=tf.int32)], {}) self.assertAllEqual(expected_full_spec, node_metadata.get('full_save_spec'))
def _serialize_keras_tensor(t): """Serializes a single Tensor passed to `call`.""" if hasattr(t, '_keras_history'): kh = t._keras_history node_index = kh.node_index node_key = make_node_key(kh.layer.name, node_index) new_node_index = node_conversion_map.get(node_key, 0) return [kh.layer.name, new_node_index, kh.tensor_index] if isinstance(t, np.ndarray): return t.tolist() if isinstance(t, tf.Tensor): return backend.get_value(t).tolist() # Not using json_utils to serialize both constant Tensor and constant # CompositeTensor for saving format backward compatibility. if isinstance(t, tf.__internal__.CompositeTensor): return (_COMPOSITE_TYPE, json_utils.Encoder().encode(t)) return t
def test_encode_decode_ragged_tensor(self): x = tf.ragged.constant([[1.0, 2.0], [3.0]]) string = json_utils.Encoder().encode(x) loaded = json_utils.decode(string) self.assertAllEqual(loaded, x)