def do_decode(self, value, decode_fn): return tensor_spec.TensorSpec( shape=decode_fn( struct_pb2.StructuredValue( tensor_shape_value=value.tensor_spec_value.shape)), dtype=decode_fn( struct_pb2.StructuredValue( tensor_dtype_value=value.tensor_spec_value.dtype)), name=value.tensor_spec_value.name)
def do_decode(self, value, decode_fn): btsv = value.bounded_tensor_spec_value name = btsv.name return tensor_spec.BoundedTensorSpec( shape=decode_fn( struct_pb2.StructuredValue(tensor_shape_value=btsv.shape)), dtype=decode_fn( struct_pb2.StructuredValue(tensor_dtype_value=btsv.dtype)), minimum=tensor_util.MakeNdarray(btsv.minimum), maximum=tensor_util.MakeNdarray(btsv.maximum), name=(name if name else None))
def testEncodeDecodeBoundedTensorSpecNoName(self): structure = [ tensor_spec.BoundedTensorSpec((28, 28, 3), dtypes.float64, -2, (1, 1, 20)) ] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected_list = expected.list_value expected_tensor_spec = expected_list.values.add( ).bounded_tensor_spec_value expected_tensor_spec.shape.dim.add().size = 28 expected_tensor_spec.shape.dim.add().size = 28 expected_tensor_spec.shape.dim.add().size = 3 expected_tensor_spec.name = "" expected_tensor_spec.dtype = dtypes.float64.as_datatype_enum expected_tensor_spec.minimum.CopyFrom( tensor_util.make_tensor_proto([-2], dtype=dtypes.float64, shape=[])) expected_tensor_spec.maximum.CopyFrom( tensor_util.make_tensor_proto([1, 1, 20], dtype=dtypes.float64, shape=[3])) self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def _get_element_from_tensor_info(tensor_info, graph): """Simplified copy of the deprecated `get_tensor_from_tensor_info`.""" encoding = tensor_info.WhichOneof("encoding") if encoding == "name": # We may get operations here in some cases. TensorInfo is a bit of a # misnomer if so. return graph.as_graph_element(tensor_info.name) elif encoding == "coo_sparse": return sparse_tensor.SparseTensor( graph.get_tensor_by_name( tensor_info.coo_sparse.indices_tensor_name), graph.get_tensor_by_name( tensor_info.coo_sparse.values_tensor_name), graph.get_tensor_by_name( tensor_info.coo_sparse.dense_shape_tensor_name)) elif encoding == "composite_tensor": struct_coder = nested_structure_coder.StructureCoder() spec_proto = struct_pb2.StructuredValue( type_spec_value=tensor_info.composite_tensor.type_spec) spec = struct_coder.decode_proto(spec_proto) components = [ graph.get_tensor_by_name(component.name) for component in tensor_info.composite_tensor.components ] return spec._from_components(components) # pylint: disable=protected-access else: raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
def lookup_tensor_or_sparse_or_composite_tensor(tensor_info): """Returns the remapped tensor corresponding to TensorInfo.""" encoding = tensor_info.WhichOneof('encoding') if encoding == 'coo_sparse': return tf.SparseTensor( lookup_remapped_tensor( tensor_info.coo_sparse.indices_tensor_name), lookup_remapped_tensor( tensor_info.coo_sparse.values_tensor_name), lookup_remapped_tensor( tensor_info.coo_sparse.dense_shape_tensor_name)) elif encoding == 'composite_tensor': components = [ lookup_remapped_tensor(info.name) for info in tensor_info.composite_tensor.components ] struct_coder = nested_structure_coder.StructureCoder() spec_proto = struct_pb2.StructuredValue( type_spec_value=tensor_info.composite_tensor.type_spec) spec = struct_coder.decode_proto(spec_proto) return spec._from_components(components) # pylint: disable=protected-access elif encoding == 'name': return lookup_remapped_tensor(tensor_info.name) else: raise ValueError('Unsupported TensorInfo encoding %s' % encoding)
def testDecodeUnknownTensorSpec(self): encoded = struct_pb2.StructuredValue() encoded.type_spec_value.type_spec_class = 0 encoded.type_spec_value.type_spec_class_name = "FutureTensorSpec" with self.assertRaisesRegex( ValueError, "The type 'FutureTensorSpec' is not supported"): self._coder.decode_proto(encoded)
def testEncodeDecodeSparseTensorSpec(self): structure = [sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected_pbtxt = r""" list_value { values { type_spec_value { type_spec_class: SPARSE_TENSOR_SPEC type_spec_class_name: 'SparseTensorSpec' num_flat_components: 3 type_state { tuple_value { # spec._shape values { tensor_shape_value { dim { size: 10 } dim { size: 20 } } } # spec._dtype values { tensor_dtype_value: DT_FLOAT } } } } } } """ expected = struct_pb2.StructuredValue() text_format.Parse(expected_pbtxt, expected) self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def do_encode(self, type_spec_value, encode_fn): """Returns an encoded proto for the given `tf.TypeSpec`.""" type_spec_class = self.TYPE_SPEC_CLASS_TO_PROTO.get(type(type_spec_value)) type_spec_class_name = type(type_spec_value).__name__ if type_spec_class is None: type_spec_class_name = type_spec.get_name(type(type_spec_value)) if isinstance(type_spec_value, extension_type.ExtensionTypeSpec): type_spec_class = struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC else: type_spec_class = struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC # Support for saving registered TypeSpecs is currently experimental. # Issue a warning to indicate the limitations. warnings.warn("Encoding a StructuredValue with type %s; loading this " "StructuredValue will require that this type be " "imported and registered." % type_spec_class_name) type_state = type_spec_value._serialize() # pylint: disable=protected-access num_flat_components = len( nest.flatten(type_spec_value._component_specs, expand_composites=True)) # pylint: disable=protected-access encoded_type_spec = struct_pb2.StructuredValue() encoded_type_spec.type_spec_value.CopyFrom( struct_pb2.TypeSpecProto( type_spec_class=type_spec_class, type_state=encode_fn(type_state), type_spec_class_name=type_spec_class_name, num_flat_components=num_flat_components)) return encoded_type_spec
def do_encode(self, tensor_spec_value, encode_fn): encoded_tensor_spec = struct_pb2.StructuredValue() encoded_tensor_spec.tensor_spec_value.CopyFrom( struct_pb2.TensorSpecProto( shape=encode_fn(tensor_spec_value.shape).tensor_shape_value, dtype=encode_fn(tensor_spec_value.dtype).tensor_dtype_value, name=tensor_spec_value.name)) return encoded_tensor_spec
def testBool(self): structure = [False] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.list_value.values.add().bool_value = False self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def do_encode(self, type_spec_value, encode_fn): """Returns an encoded proto for the given `tf.TypeSpec`.""" type_spec_class = self.TYPE_SPEC_CLASS_TO_PROTO[type(type_spec_value)] type_state = type_spec_value._serialize() # pylint: disable=protected-access encoded_type_spec = struct_pb2.StructuredValue() encoded_type_spec.type_spec_value.CopyFrom( struct_pb2.TypeSpecProto(type_spec_class=type_spec_class, type_state=encode_fn(type_state))) return encoded_type_spec
def do_encode(self, named_tuple_value, encode_fn): encoded_named_tuple = struct_pb2.StructuredValue() encoded_named_tuple.named_tuple_value.CopyFrom(struct_pb2.NamedTupleValue()) encoded_named_tuple.named_tuple_value.name = \ named_tuple_value.__class__.__name__ for key in named_tuple_value._fields: pair = encoded_named_tuple.named_tuple_value.values.add() pair.key = key pair.value.CopyFrom(encode_fn(named_tuple_value._asdict()[key])) return encoded_named_tuple
def do_encode(self, dict_value, encode_fn): encoded_dict = struct_pb2.StructuredValue() encoded_dict.ordered_dict_value_int_key.CopyFrom( struct_pb2.OrderedDictValueIntKey()) for key, value in dict_value.items(): pair = encoded_dict.ordered_dict_value_int_key.values.add() pair.key = key pair.value.CopyFrom(encode_fn(value)) return encoded_dict
def testDtype(self): structure = [dtypes.int64] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected = struct_pb2.StructuredValue() list_value = expected.list_value.values.add() list_value.tensor_dtype_value = dtypes.int64.as_datatype_enum self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testNone(self): structure = [1.0, None] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.list_value.values.add().float64_value = 1.0 expected.list_value.values.add().none_value.CopyFrom(struct_pb2.NoneValue()) self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def experimental_from_proto(cls, proto: struct_pb2.TypeSpecProto) -> "TypeSpec": """Returns a TypeSpec instance based on the serialized proto. Do NOT override for custom non-TF types. Args: proto: Proto generated using 'experimental_as_proto'. """ return nested_structure_coder.decode_proto( struct_pb2.StructuredValue(type_spec_value=proto))
def composite_tensor_info_to_type_spec(tensor_info): """Converts a `TensorInfo` for a composite tensor to a `TypeSpec` object.""" if nested_structure_coder is None or struct_pb2 is None: raise ValueError("This version of TensorFlow does not support " "composite tensors.") if tensor_info.WhichOneof("encoding") != "composite_tensor": raise ValueError("Expected a TensorInfo with encoding=composite_tensor") spec_proto = struct_pb2.StructuredValue( type_spec_value=tensor_info.composite_tensor.type_spec) struct_coder = nested_structure_coder.StructureCoder() return struct_coder.decode_proto(spec_proto)
def testEncodeDecodeList(self): structure = [1.5, 2.5, 3.0] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.list_value.values.add().float64_value = 1.5 expected.list_value.values.add().float64_value = 2.5 expected.list_value.values.add().float64_value = 3.0 self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEmptyStructures(self): structure = [list(), dict(), tuple()] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.list_value.values.add().list_value.CopyFrom(struct_pb2.ListValue()) expected.list_value.values.add().dict_value.CopyFrom(struct_pb2.DictValue()) expected.list_value.values.add().tuple_value.CopyFrom( struct_pb2.TupleValue()) self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeDict(self): structure = dict(a=3, b=[7, 2.5]) self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.dict_value.fields["a"].int64_value = 3 list_value = expected.dict_value.fields["b"].list_value list_value.values.add().int64_value = 7 list_value.values.add().float64_value = 2.5 self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertIsInstance(decoded["a"], int) self.assertEqual(structure, decoded)
def do_encode(self, bounded_tensor_spec_value, encode_fn): """Returns an encoded proto for the given `tf.BoundedTensorSpec`.""" encoded_bounded_tensor_spec = struct_pb2.StructuredValue() encoded_bounded_tensor_spec.bounded_tensor_spec_value.CopyFrom( struct_pb2.BoundedTensorSpecProto( shape=encode_fn(bounded_tensor_spec_value.shape).tensor_shape_value, dtype=encode_fn(bounded_tensor_spec_value.dtype).tensor_dtype_value, name=bounded_tensor_spec_value.name, minimum=tensor_util.make_tensor_proto( bounded_tensor_spec_value.minimum), maximum=tensor_util.make_tensor_proto( bounded_tensor_spec_value.maximum))) return encoded_bounded_tensor_spec
def testEncodeDecodeTuple(self): structure = ("hello", [3, (2, 1)]) self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected.tuple_value.values.add().string_value = "hello" list_value = expected.tuple_value.values.add().list_value list_value.values.add().int64_value = 3 tuple_value = list_value.values.add().tuple_value tuple_value.values.add().int64_value = 2 tuple_value.values.add().int64_value = 1 self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeTensorShape(self): structure = [tensor_shape.TensorShape([1, 2, 3]), "hello"] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected_list = expected.list_value expected_tensor_shape = expected_list.values.add().tensor_shape_value expected_tensor_shape.dim.add().size = 1 expected_tensor_shape.dim.add().size = 2 expected_tensor_shape.dim.add().size = 3 expected_tensor_shape = expected_list.values.add().string_value = "hello" self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeTensorSpec(self): structure = [tensor_spec.TensorSpec([1, 2, 3], dtypes.int64, "hello")] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected_list = expected.list_value expected_tensor_spec = expected_list.values.add().tensor_spec_value expected_tensor_spec.shape.dim.add().size = 1 expected_tensor_spec.shape.dim.add().size = 2 expected_tensor_spec.shape.dim.add().size = 3 expected_tensor_spec.name = "hello" expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): """Returns the Tensor or CompositeTensor described by a TensorInfo proto. Args: tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or CompositeTensor. graph: The tf.Graph in which tensors are looked up. If None, the current default graph is used. import_scope: If not None, names in `tensor_info` are prefixed with this string before lookup. Returns: The Tensor or SparseTensor or CompositeTensor in `graph` described by `tensor_info`. Raises: KeyError: If `tensor_info` does not correspond to a tensor in `graph`. ValueError: If `tensor_info` is malformed. """ graph = graph or ops.get_default_graph() def _get_tensor(name): return graph.get_tensor_by_name( ops.prepend_name_scope(name, import_scope=import_scope)) encoding = tensor_info.WhichOneof("encoding") if encoding == "name": return _get_tensor(tensor_info.name) elif encoding == "coo_sparse": return sparse_tensor.SparseTensor( _get_tensor(tensor_info.coo_sparse.indices_tensor_name), _get_tensor(tensor_info.coo_sparse.values_tensor_name), _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name)) elif encoding == "composite_tensor": struct_coder = nested_structure_coder.StructureCoder() spec_proto = struct_pb2.StructuredValue( type_spec_value=tensor_info.composite_tensor.type_spec) spec = struct_coder.decode_proto(spec_proto) components = [ _get_tensor(component.name) for component in tensor_info.composite_tensor.components ] return nest.pack_sequence_as(spec, components, expand_composites=True) else: raise ValueError( f"Invalid TensorInfo.encoding: {encoding}. Expected `" "coo_sparse`, `composite_tensor`, or `name` for a dense " "tensor.")
def testBuildTensorInfoRagged(self): x = ragged_factory_ops.constant([[1, 2], [3]]) x_tensor_info = utils.build_tensor_info(x) # Check components self.assertEqual(x.values.name, x_tensor_info.composite_tensor.components[0].name) self.assertEqual(types_pb2.DT_INT32, x_tensor_info.composite_tensor.components[0].dtype) self.assertEqual(x.row_splits.name, x_tensor_info.composite_tensor.components[1].name) self.assertEqual(types_pb2.DT_INT64, x_tensor_info.composite_tensor.components[1].dtype) # Check type_spec. spec_proto = struct_pb2.StructuredValue( type_spec_value=x_tensor_info.composite_tensor.type_spec) spec = nested_structure_coder.decode_proto(spec_proto) self.assertEqual(spec, x._type_spec)
def testEncodeDecodeExtensionTypeSpec(self): class Zoo(extension_type.ExtensionType): __name__ = "tf.nested_structure_coder_test.Zoo" zookeepers: typing.Tuple[str, ...] animals: typing.Mapping[str, ops.Tensor] structure = [ Zoo.Spec(zookeepers=["Zoey", "Zack"], animals={"tiger": tensor_spec.TensorSpec([16])}) ] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected_pbtxt = r""" list_value { values { type_spec_value { type_spec_class: EXTENSION_TYPE_SPEC type_spec_class_name: "tf.nested_structure_coder_test.Zoo.Spec" num_flat_components: 1 type_state { tuple_value { values { tuple_value { values { string_value: "zookeepers" } values { tuple_value { values { string_value: "Zoey" } values { string_value: "Zack" } } } } } values { tuple_value { values { string_value: "animals" } values { dict_value { fields { key: "tiger" value { tensor_spec_value { shape { dim { size: 16 } } dtype: DT_FLOAT } } } } } } } } } } } } """ expected = struct_pb2.StructuredValue() text_format.Parse(expected_pbtxt, expected) self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeNamedTuple(self): named_tuple_type = collections.namedtuple("NamedTuple", ["x", "y"]) named_tuple = named_tuple_type(x=[1, 2], y="hello") self.assertTrue(self._coder.can_encode(named_tuple)) encoded = self._coder.encode_structure(named_tuple) expected = struct_pb2.StructuredValue() expected_named_tuple = expected.named_tuple_value expected_named_tuple.name = "NamedTuple" key_value_pair = expected_named_tuple.values.add() key_value_pair.key = "x" list_value = key_value_pair.value.list_value list_value.values.add().int64_value = 1 list_value.values.add().int64_value = 2 key_value_pair = expected_named_tuple.values.add() key_value_pair.key = "y" key_value_pair.value.string_value = "hello" self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(named_tuple._asdict(), decoded._asdict()) self.assertEqual(named_tuple.__class__.__name__, decoded.__class__.__name__)
def testEncodeDecodeRaggedTensorSpec(self): structure = [ ragged_tensor.RaggedTensorSpec([1, 2, 3], dtypes.int64, 2, dtypes.int32) ] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected_pbtxt = r""" list_value { values { type_spec_value { type_spec_class: RAGGED_TENSOR_SPEC type_spec_class_name: 'RaggedTensorSpec' num_flat_components: 3 type_state { tuple_value { # spec._shape values { tensor_shape_value { dim { size: 1 } dim { size: 2 } dim { size: 3 } } } # spec._dtype values { tensor_dtype_value: DT_INT64 } # spec._ragged_rank values { int64_value: 2 } # spec._row_splits_dtype values { tensor_dtype_value: DT_INT32 } } } } } } """ expected = struct_pb2.StructuredValue() text_format.Parse(expected_pbtxt, expected) self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def testEncodeDecodeBoundedTensorSpec(self): structure = [ tensor_spec.BoundedTensorSpec([1, 2, 3], dtypes.int64, 0, 10, "hello-0-10") ] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected = struct_pb2.StructuredValue() expected_list = expected.list_value expected_tensor_spec = expected_list.values.add().bounded_tensor_spec_value expected_tensor_spec.shape.dim.add().size = 1 expected_tensor_spec.shape.dim.add().size = 2 expected_tensor_spec.shape.dim.add().size = 3 expected_tensor_spec.name = "hello-0-10" expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum expected_tensor_spec.minimum.CopyFrom( tensor_util.make_tensor_proto([0], dtype=dtypes.int64, shape=[])) expected_tensor_spec.maximum.CopyFrom( tensor_util.make_tensor_proto([10], dtype=dtypes.int64, shape=[])) self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)