예제 #1
0
def deserialize_sequence_value(sequence_value_proto):
    """Deserializes a `tf.data.Dataset`.

  Args:
    sequence_value_proto: `Sequence` protocol buffer message.

  Returns:
    A tuple of `(tf.data.Dataset, tff.Type)`.
  """
    py_typecheck.check_type(sequence_value_proto, executor_pb2.Value.Sequence)

    which_value = sequence_value_proto.WhichOneof('value')
    if which_value == 'zipped_saved_model':
        ds = tensorflow_serialization.deserialize_dataset(
            sequence_value_proto.zipped_saved_model)
    else:
        raise NotImplementedError(
            'Deserializing Sequences enocded as {!s} has not been implemented'.
            format(which_value))

    element_type = type_serialization.deserialize_type(
        sequence_value_proto.element_type)

    # If a serialized dataset had elements of nested structes of tensors (e.g.
    # `dict`, `OrderedDict`), the deserialized dataset will return `dict`,
    # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion).
    #
    # Since the dataset will only be used inside TFF, we wrap the dictionary
    # coming from TF in an `OrderedDict` when necessary (a type that both TF and
    # TFF understand), using the field order stored in the TFF type stored during
    # serialization.
    ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
        ds, element_type)

    return ds, computation_types.SequenceType(element=element_type)
예제 #2
0
  def test_coerce_dataset_elements_nested_structure(self):
    test_tuple_type = collections.namedtuple('TestTuple', ['u', 'v'])

    def _make_nested_tf_structure(x):
      return {
          'b':
              tf.cast(x, tf.int32),
          'a':
              tuple([
                  x,
                  test_tuple_type(x * 2, x * 3),
                  collections.OrderedDict([('x', x**2), ('y', x**3)])
              ]),
          'c':
              tf.cast(x, tf.float32),
      }

    x = tf.data.Dataset.range(5).map(_make_nested_tf_structure)

    element_type = computation_types.StructType([
        ('a',
         computation_types.StructType([
             (None, tf.int64),
             (None, test_tuple_type(tf.int64, tf.int64)),
             (None,
              computation_types.StructType([('x', tf.int64), ('y', tf.int64)])),
         ])),
        ('b', tf.int32),
        ('c', tf.float32),
    ])

    y = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
        x, element_type)

    computation_types.to_type(y.element_spec).check_equivalent_to(element_type)
예제 #3
0
def _deserialize_sequence_value(
    sequence_value_proto: executor_pb2.Value.Sequence
) -> _DeserializeReturnType:
    """Deserializes a `tf.data.Dataset`.

  Args:
    sequence_value_proto: `Sequence` protocol buffer message.

  Returns:
    A tuple of `(tf.data.Dataset, tff.Type)`.
  """
    element_type = type_serialization.deserialize_type(
        sequence_value_proto.element_type)
    which_value = sequence_value_proto.WhichOneof('value')
    if which_value == 'zipped_saved_model':
        warnings.warn(
            'Deserializng a sequence value that was encoded as a zipped SavedModel.'
            ' This is a deprecated path, please update the binary that is '
            'serializing the sequences.', DeprecationWarning)
        ds = _deserialize_dataset_from_zipped_saved_model(
            sequence_value_proto.zipped_saved_model)
        ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
            ds, element_type)
    elif which_value == 'serialized_graph_def':
        ds = _deserialize_dataset_from_graph_def(
            sequence_value_proto.serialized_graph_def, element_type)
    else:
        raise NotImplementedError(
            'Deserializing Sequences enocded as {!s} has not been implemented'.
            format(which_value))
    return ds, computation_types.SequenceType(element=element_type)
예제 #4
0
 def test_coerce_dataset_elements_noop(self):
   x = tf.data.Dataset.range(5)
   y = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
       x, computation_types.TensorType(tf.int64))
   self.assertEqual(
       tf.data.experimental.get_structure(x),
       tf.data.experimental.get_structure(y))
예제 #5
0
def _deserialize_sequence_value(
    sequence_value_proto: serialization_bindings.Sequence,
    type_hint: Optional[computation_types.Type] = None
) -> _DeserializeReturnType:
    """Deserializes a `tf.data.Dataset`.

  Args:
    sequence_value_proto: `Sequence` protocol buffer message.
    type_hint: A `computation_types.Type` that hints at what the value type
      should be for executors that only return values. If the
      `sequence_value_proto.element_type` field was not set, the `type_hint` is
      used instead.

  Returns:
    A tuple of `(tf.data.Dataset, tff.Type)`.
  """
    if sequence_value_proto.HasField('element_type'):
        element_type = type_serialization.deserialize_type(
            sequence_value_proto.element_type)
    elif type_hint is not None:
        element_type = type_hint.element
    else:
        raise ValueError(
            'Cannot deserialize a sequence Value proto that without one of '
            '`element_type` proto field or `element_type_hint`')
    which_value = sequence_value_proto.WhichOneof('value')
    if which_value == 'zipped_saved_model':
        warnings.warn(
            'Deserializng a sequence value that was encoded as a zipped SavedModel.'
            ' This is a deprecated path, please update the binary that is '
            'serializing the sequences.', DeprecationWarning)
        ds = _deserialize_dataset_from_zipped_saved_model(
            sequence_value_proto.zipped_saved_model)
        ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
            ds, element_type)
    elif which_value == 'serialized_graph_def':
        ds = _deserialize_dataset_from_graph_def(
            sequence_value_proto.serialized_graph_def, element_type)
    else:
        raise NotImplementedError(
            'Deserializing Sequences enocded as {!s} has not been implemented'.
            format(which_value))
    return ds, computation_types.SequenceType(element=element_type)
예제 #6
0
 def test_coerce_dataset_elements_noop(self):
   x = tf.data.Dataset.range(5)
   y = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
       x, computation_types.TensorType(tf.int64))
   self.assertEqual(x.element_spec, y.element_spec)
예제 #7
0
def _deserialize_dataset_from_graph_def(serialized_graph_def: bytes,
                                        element_type: computation_types.Type):
    """Deserializes a serialized `tf.compat.v1.GraphDef` to a `tf.data.Dataset`.

  Args:
    serialized_graph_def: `bytes` object produced by
      `tensorflow_serialization.serialize_dataset`
    element_type: a `tff.Type` object representing the type structure of the
      elements yielded from the dataset.

  Returns:
    A `tf.data.Dataset` instance.
  """
    py_typecheck.check_type(element_type, computation_types.Type)
    type_analysis.check_tensorflow_compatible_type(element_type)

    def transform_to_tff_known_type(
        type_spec: computation_types.Type
    ) -> Tuple[computation_types.Type, bool]:
        """Transforms `StructType` to `StructWithPythonType`."""
        if type_spec.is_struct() and not type_spec.is_struct_with_python():
            field_is_named = tuple(
                name is not None
                for name, _ in structure.iter_elements(type_spec))
            has_names = any(field_is_named)
            is_all_named = all(field_is_named)
            if is_all_named:
                return computation_types.StructWithPythonType(
                    elements=structure.iter_elements(type_spec),
                    container_type=collections.OrderedDict), True
            elif not has_names:
                return computation_types.StructWithPythonType(
                    elements=structure.iter_elements(type_spec),
                    container_type=tuple), True
            else:
                raise TypeError(
                    'Cannot represent TFF type in TF because it contains '
                    f'partially named structures. Type: {type_spec}')
        return type_spec, False

    if element_type.is_struct():
        # TF doesn't suppor `structure.Strut` types, so we must transform the
        # `StructType` into a `StructWithPythonType` for use as the
        # `tf.data.Dataset.element_spec` later.
        tf_compatible_type, _ = type_transformations.transform_type_postorder(
            element_type, transform_to_tff_known_type)
    else:
        # We've checked this is only a struct or tensors, so we know this is a
        # `TensorType` here and will use as-is.
        tf_compatible_type = element_type

    def type_to_tensorspec(t: computation_types.TensorType) -> tf.TensorSpec:
        return tf.TensorSpec(shape=t.shape, dtype=t.dtype)

    element_spec = type_conversions.structure_from_tensor_type_tree(
        type_to_tensorspec, tf_compatible_type)
    ds = tf.data.experimental.from_variant(
        tf.raw_ops.DatasetFromGraph(graph_def=serialized_graph_def),
        structure=element_spec)
    # If a serialized dataset had elements of nested structes of tensors (e.g.
    # `dict`, `OrderedDict`), the deserialized dataset will return `dict`,
    # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion).
    #
    # Since the dataset will only be used inside TFF, we wrap the dictionary
    # coming from TF in an `OrderedDict` when necessary (a type that both TF and
    # TFF understand), using the field order stored in the TFF type stored during
    # serialization.
    return tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
        ds, tf_compatible_type)