Beispiel #1
0
    def test_encode_decode_type_spec(self):
        spec = tf.TensorSpec((1, 5), tf.float32)
        string = json_utils.Encoder().encode(spec)
        loaded = json_utils.decode(string)
        self.assertEqual(spec, loaded)

        invalid_type_spec = {
            "class_name": "TypeSpec",
            "type_spec": "Invalid Type",
            "serialized": None,
        }
        string = json_utils.Encoder().encode(invalid_type_spec)
        with self.assertRaisesRegexp(ValueError,
                                     "No TypeSpec has been registered"):
            loaded = json_utils.decode(string)
Beispiel #2
0
    def test_encode_decode_type_spec(self):
        spec = tf.TensorSpec((1, 5), tf.float32)
        string = json_utils.Encoder().encode(spec)
        loaded = json_utils.decode(string)
        self.assertEqual(spec, loaded)

        invalid_type_spec = {
            'class_name': 'TypeSpec',
            'type_spec': 'Invalid Type',
            'serialized': None
        }
        string = json_utils.Encoder().encode(invalid_type_spec)
        with self.assertRaisesRegexp(ValueError,
                                     'No TypeSpec has been registered'):
            loaded = json_utils.decode(string)
Beispiel #3
0
    def test_encode_decode_tuple(self):
        metadata = {"key1": (3, 5), "key2": [(1, (3, 4)), (1, )]}
        string = json_utils.Encoder().encode(metadata)
        loaded = json_utils.decode(string)

        self.assertEqual(set(loaded.keys()), {"key1", "key2"})
        self.assertAllEqual(loaded["key1"], (3, 5))
        self.assertAllEqual(loaded["key2"], [(1, (3, 4)), (1, )])
Beispiel #4
0
    def test_encode_decode_tuple(self):
        metadata = {'key1': (3, 5), 'key2': [(1, (3, 4)), (1, )]}
        string = json_utils.Encoder().encode(metadata)
        loaded = json_utils.decode(string)

        self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
        self.assertAllEqual(loaded['key1'], (3, 5))
        self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1, )])
Beispiel #5
0
    def test_encode_decode_enum(self):
        class Enum(enum.Enum):
            CLASS_A = "a"
            CLASS_B = "b"

        config = {"key": Enum.CLASS_A, "key2": Enum.CLASS_B}
        string = json_utils.Encoder().encode(config)
        loaded = json_utils.decode(string)
        self.assertAllEqual({"key": "a", "key2": "b"}, loaded)
Beispiel #6
0
    def test_encode_decode_enum(self):
        class Enum(enum.Enum):
            CLASS_A = 'a'
            CLASS_B = 'b'

        config = {'key': Enum.CLASS_A, 'key2': Enum.CLASS_B}
        string = json_utils.Encoder().encode(config)
        loaded = json_utils.decode(string)
        self.assertAllEqual({'key': 'a', 'key2': 'b'}, loaded)
    def tracking_metadata(self):
        """String stored in metadata field in the SavedModel proto.

    Returns:
      A serialized JSON storing information necessary for recreating this layer.
    """
        # TODO(kathywu): check that serialized JSON can be loaded (e.g., if an
        # object is in the python property)
        return json_utils.Encoder().encode(self.python_properties)
Beispiel #8
0
    def test_encode_decode_extension_type_tensor(self):
        class MaskedTensor(tf.experimental.ExtensionType):
            __name__ = 'MaskedTensor'
            values: tf.Tensor
            mask: tf.Tensor

        x = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                         mask=[[True, True, False], [True, False, True]])
        string = json_utils.Encoder().encode(x)
        loaded = json_utils.decode(string)
        self.assertAllEqual(loaded, x)
Beispiel #9
0
    def test_encode_decode_tensor_shape(self):
        metadata = {
            "key1": tf.TensorShape(None),
            "key2": [tf.TensorShape([None]), tf.TensorShape([3, None, 5])],
        }
        string = json_utils.Encoder().encode(metadata)
        loaded = json_utils.decode(string)

        self.assertEqual(set(loaded.keys()), {"key1", "key2"})
        self.assertAllEqual(loaded["key1"].rank, None)
        self.assertAllEqual(loaded["key2"][0].as_list(), [None])
        self.assertAllEqual(loaded["key2"][1].as_list(), [3, None, 5])
Beispiel #10
0
    def test_encode_decode_tensor_shape(self):
        metadata = {
            'key1': tf.TensorShape(None),
            'key2': [tf.TensorShape([None]),
                     tf.TensorShape([3, None, 5])]
        }
        string = json_utils.Encoder().encode(metadata)
        loaded = json_utils.decode(string)

        self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
        self.assertAllEqual(loaded['key1'].rank, None)
        self.assertAllEqual(loaded['key2'][0].as_list(), [None])
        self.assertAllEqual(loaded['key2'][1].as_list(), [3, None, 5])
Beispiel #11
0
def _update_to_current_version(metadata):
  """Applies version updates to the metadata proto for backwards compat."""
  for node in metadata.nodes:
    if node.version.producer == 1 and node.identifier in [
        constants.MODEL_IDENTIFIER, constants.SEQUENTIAL_IDENTIFIER,
        constants.NETWORK_IDENTIFIER]:
      node_metadata = json_utils.decode(node.metadata)
      save_spec = node_metadata.get('save_spec')

      if save_spec is not None:
        node_metadata['full_save_spec'] = ([save_spec], {})
        node.metadata = json_utils.Encoder().encode(node_metadata)
  return metadata
Beispiel #12
0
  def testAddFullSaveSpec(self):
    save_spec = tf.TensorSpec([3, 5], dtype=tf.int32)
    node_metadata = json_utils.Encoder().encode({'save_spec': save_spec})

    metadata = saved_metadata_pb2.SavedMetadata()
    metadata.nodes.add(
        version=versions_pb2.VersionDef(
            producer=1, min_consumer=1, bad_consumers=[]),
        identifier='_tf_keras_model',
        metadata=node_metadata)  # pylint: disable=protected-access

    new_metadata = keras_load._update_to_current_version(metadata)
    node_metadata = json_utils.decode(new_metadata.nodes[0].metadata)
    expected_full_spec = ([tf.TensorSpec(shape=(3, 5), dtype=tf.int32)], {})
    self.assertAllEqual(expected_full_spec, node_metadata.get('full_save_spec'))
Beispiel #13
0
        def _serialize_keras_tensor(t):
            """Serializes a single Tensor passed to `call`."""
            if hasattr(t, '_keras_history'):
                kh = t._keras_history
                node_index = kh.node_index
                node_key = make_node_key(kh.layer.name, node_index)
                new_node_index = node_conversion_map.get(node_key, 0)
                return [kh.layer.name, new_node_index, kh.tensor_index]

            if isinstance(t, np.ndarray):
                return t.tolist()

            if isinstance(t, tf.Tensor):
                return backend.get_value(t).tolist()

            # Not using json_utils to serialize both constant Tensor and constant
            # CompositeTensor for saving format backward compatibility.
            if isinstance(t, tf.__internal__.CompositeTensor):
                return (_COMPOSITE_TYPE, json_utils.Encoder().encode(t))

            return t
Beispiel #14
0
 def test_encode_decode_ragged_tensor(self):
     x = tf.ragged.constant([[1.0, 2.0], [3.0]])
     string = json_utils.Encoder().encode(x)
     loaded = json_utils.decode(string)
     self.assertAllEqual(loaded, x)