def deserialize_sequence_value(sequence_value_proto): """Deserializes a `tf.data.Dataset`. Args: sequence_value_proto: `Sequence` protocol buffer message. Returns: A tuple of `(tf.data.Dataset, tff.Type)`. """ py_typecheck.check_type(sequence_value_proto, executor_pb2.Value.Sequence) which_value = sequence_value_proto.WhichOneof('value') if which_value == 'zipped_saved_model': ds = tensorflow_serialization.deserialize_dataset( sequence_value_proto.zipped_saved_model) else: raise NotImplementedError( 'Deserializing Sequences enocded as {!s} has not been implemented'. format(which_value)) element_type = type_serialization.deserialize_type( sequence_value_proto.element_type) # If a serialized dataset had elements of nested structes of tensors (e.g. # `dict`, `OrderedDict`), the deserialized dataset will return `dict`, # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion). # # Since the dataset will only be used inside TFF, we wrap the dictionary # coming from TF in an `OrderedDict` when necessary (a type that both TF and # TFF understand), using the field order stored in the TFF type stored during # serialization. ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( ds, element_type) return ds, computation_types.SequenceType(element=element_type)
def test_coerce_dataset_elements_nested_structure(self): test_tuple_type = collections.namedtuple('TestTuple', ['u', 'v']) def _make_nested_tf_structure(x): return { 'b': tf.cast(x, tf.int32), 'a': tuple([ x, test_tuple_type(x * 2, x * 3), collections.OrderedDict([('x', x**2), ('y', x**3)]) ]), 'c': tf.cast(x, tf.float32), } x = tf.data.Dataset.range(5).map(_make_nested_tf_structure) element_type = computation_types.StructType([ ('a', computation_types.StructType([ (None, tf.int64), (None, test_tuple_type(tf.int64, tf.int64)), (None, computation_types.StructType([('x', tf.int64), ('y', tf.int64)])), ])), ('b', tf.int32), ('c', tf.float32), ]) y = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( x, element_type) computation_types.to_type(y.element_spec).check_equivalent_to(element_type)
def _deserialize_sequence_value( sequence_value_proto: executor_pb2.Value.Sequence ) -> _DeserializeReturnType: """Deserializes a `tf.data.Dataset`. Args: sequence_value_proto: `Sequence` protocol buffer message. Returns: A tuple of `(tf.data.Dataset, tff.Type)`. """ element_type = type_serialization.deserialize_type( sequence_value_proto.element_type) which_value = sequence_value_proto.WhichOneof('value') if which_value == 'zipped_saved_model': warnings.warn( 'Deserializng a sequence value that was encoded as a zipped SavedModel.' ' This is a deprecated path, please update the binary that is ' 'serializing the sequences.', DeprecationWarning) ds = _deserialize_dataset_from_zipped_saved_model( sequence_value_proto.zipped_saved_model) ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( ds, element_type) elif which_value == 'serialized_graph_def': ds = _deserialize_dataset_from_graph_def( sequence_value_proto.serialized_graph_def, element_type) else: raise NotImplementedError( 'Deserializing Sequences enocded as {!s} has not been implemented'. format(which_value)) return ds, computation_types.SequenceType(element=element_type)
def test_coerce_dataset_elements_noop(self): x = tf.data.Dataset.range(5) y = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( x, computation_types.TensorType(tf.int64)) self.assertEqual( tf.data.experimental.get_structure(x), tf.data.experimental.get_structure(y))
def _deserialize_sequence_value( sequence_value_proto: serialization_bindings.Sequence, type_hint: Optional[computation_types.Type] = None ) -> _DeserializeReturnType: """Deserializes a `tf.data.Dataset`. Args: sequence_value_proto: `Sequence` protocol buffer message. type_hint: A `computation_types.Type` that hints at what the value type should be for executors that only return values. If the `sequence_value_proto.element_type` field was not set, the `type_hint` is used instead. Returns: A tuple of `(tf.data.Dataset, tff.Type)`. """ if sequence_value_proto.HasField('element_type'): element_type = type_serialization.deserialize_type( sequence_value_proto.element_type) elif type_hint is not None: element_type = type_hint.element else: raise ValueError( 'Cannot deserialize a sequence Value proto that without one of ' '`element_type` proto field or `element_type_hint`') which_value = sequence_value_proto.WhichOneof('value') if which_value == 'zipped_saved_model': warnings.warn( 'Deserializng a sequence value that was encoded as a zipped SavedModel.' ' This is a deprecated path, please update the binary that is ' 'serializing the sequences.', DeprecationWarning) ds = _deserialize_dataset_from_zipped_saved_model( sequence_value_proto.zipped_saved_model) ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( ds, element_type) elif which_value == 'serialized_graph_def': ds = _deserialize_dataset_from_graph_def( sequence_value_proto.serialized_graph_def, element_type) else: raise NotImplementedError( 'Deserializing Sequences enocded as {!s} has not been implemented'. format(which_value)) return ds, computation_types.SequenceType(element=element_type)
def test_coerce_dataset_elements_noop(self): x = tf.data.Dataset.range(5) y = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( x, computation_types.TensorType(tf.int64)) self.assertEqual(x.element_spec, y.element_spec)
def _deserialize_dataset_from_graph_def(serialized_graph_def: bytes, element_type: computation_types.Type): """Deserializes a serialized `tf.compat.v1.GraphDef` to a `tf.data.Dataset`. Args: serialized_graph_def: `bytes` object produced by `tensorflow_serialization.serialize_dataset` element_type: a `tff.Type` object representing the type structure of the elements yielded from the dataset. Returns: A `tf.data.Dataset` instance. """ py_typecheck.check_type(element_type, computation_types.Type) type_analysis.check_tensorflow_compatible_type(element_type) def transform_to_tff_known_type( type_spec: computation_types.Type ) -> Tuple[computation_types.Type, bool]: """Transforms `StructType` to `StructWithPythonType`.""" if type_spec.is_struct() and not type_spec.is_struct_with_python(): field_is_named = tuple( name is not None for name, _ in structure.iter_elements(type_spec)) has_names = any(field_is_named) is_all_named = all(field_is_named) if is_all_named: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=collections.OrderedDict), True elif not has_names: return computation_types.StructWithPythonType( elements=structure.iter_elements(type_spec), container_type=tuple), True else: raise TypeError( 'Cannot represent TFF type in TF because it contains ' f'partially named structures. Type: {type_spec}') return type_spec, False if element_type.is_struct(): # TF doesn't suppor `structure.Strut` types, so we must transform the # `StructType` into a `StructWithPythonType` for use as the # `tf.data.Dataset.element_spec` later. tf_compatible_type, _ = type_transformations.transform_type_postorder( element_type, transform_to_tff_known_type) else: # We've checked this is only a struct or tensors, so we know this is a # `TensorType` here and will use as-is. tf_compatible_type = element_type def type_to_tensorspec(t: computation_types.TensorType) -> tf.TensorSpec: return tf.TensorSpec(shape=t.shape, dtype=t.dtype) element_spec = type_conversions.structure_from_tensor_type_tree( type_to_tensorspec, tf_compatible_type) ds = tf.data.experimental.from_variant( tf.raw_ops.DatasetFromGraph(graph_def=serialized_graph_def), structure=element_spec) # If a serialized dataset had elements of nested structes of tensors (e.g. # `dict`, `OrderedDict`), the deserialized dataset will return `dict`, # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion). # # Since the dataset will only be used inside TFF, we wrap the dictionary # coming from TF in an `OrderedDict` when necessary (a type that both TF and # TFF understand), using the field order stored in the TFF type stored during # serialization. return tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( ds, tf_compatible_type)