Beispiel #1
0
def _deserialize_sequence_value(
    sequence_value_proto: executor_pb2.Value.Sequence
) -> _DeserializeReturnType:
  """Deserializes a `tf.data.Dataset`.

  Args:
    sequence_value_proto: `Sequence` protocol buffer message.

  Returns:
    A tuple of `(tf.data.Dataset, tff.Type)`.
  """
  which_value = sequence_value_proto.WhichOneof('value')
  if which_value == 'zipped_saved_model':
    ds = _deserialize_dataset(sequence_value_proto.zipped_saved_model)
  else:
    raise NotImplementedError(
        'Deserializing Sequences enocded as {!s} has not been implemented'
        .format(which_value))

  element_type = type_serialization.deserialize_type(
      sequence_value_proto.element_type)

  # If a serialized dataset had elements of nested structes of tensors (e.g.
  # `dict`, `OrderedDict`), the deserialized dataset will return `dict`,
  # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion).
  #
  # Since the dataset will only be used inside TFF, we wrap the dictionary
  # coming from TF in an `OrderedDict` when necessary (a type that both TF and
  # TFF understand), using the field order stored in the TFF type stored during
  # serialization.
  ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
      ds, element_type)

  return ds, computation_types.SequenceType(element=element_type)
Beispiel #2
0
def _deserialize_sequence_value(
    sequence_value_proto: executor_pb2.Value.Sequence
) -> _DeserializeReturnType:
    """Deserializes a `tf.data.Dataset`.

  Args:
    sequence_value_proto: `Sequence` protocol buffer message.

  Returns:
    A tuple of `(tf.data.Dataset, tff.Type)`.
  """
    element_type = type_serialization.deserialize_type(
        sequence_value_proto.element_type)
    which_value = sequence_value_proto.WhichOneof('value')
    if which_value == 'zipped_saved_model':
        warnings.warn(
            'Deserializng a sequence value that was encoded as a zipped SavedModel.'
            ' This is a deprecated path, please update the binary that is '
            'serializing the sequences.', DeprecationWarning)
        ds = _deserialize_dataset_from_zipped_saved_model(
            sequence_value_proto.zipped_saved_model)
        ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
            ds, element_type)
    elif which_value == 'serialized_graph_def':
        ds = _deserialize_dataset_from_graph_def(
            sequence_value_proto.serialized_graph_def, element_type)
    else:
        raise NotImplementedError(
            'Deserializing Sequences enocded as {!s} has not been implemented'.
            format(which_value))
    return ds, computation_types.SequenceType(element=element_type)
def _deserialize_sequence_value(
    sequence_value_proto: executor_pb2.Value.Sequence,
    type_hint: Optional[computation_types.Type] = None
) -> _DeserializeReturnType:
    """Deserializes a `tf.data.Dataset`.

  Args:
    sequence_value_proto: `Sequence` protocol buffer message.
    type_hint: A `computation_types.Type` that hints at what the value type
      should be for executors that only return values. If the
      `sequence_value_proto.element_type` field was not set, the `type_hint` is
      used instead.

  Returns:
    A tuple of `(tf.data.Dataset, tff.Type)`.
  """
    if sequence_value_proto.HasField('element_type'):
        element_type = type_serialization.deserialize_type(
            sequence_value_proto.element_type)
    elif type_hint is not None:
        element_type = type_hint.element
    else:
        raise ValueError(
            'Cannot deserialize a sequence Value proto that without one of '
            '`element_type` proto field or `element_type_hint`')
    which_value = sequence_value_proto.WhichOneof('value')
    if which_value == 'zipped_saved_model':
        warnings.warn(
            'Deserializng a sequence value that was encoded as a zipped SavedModel.'
            ' This is a deprecated path, please update the binary that is '
            'serializing the sequences.', DeprecationWarning)
        ds = _deserialize_dataset_from_zipped_saved_model(
            sequence_value_proto.zipped_saved_model)
        ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
            ds, element_type)
    elif which_value == 'serialized_graph_def':
        ds = _deserialize_dataset_from_graph_def(
            sequence_value_proto.serialized_graph_def, element_type)
    else:
        raise NotImplementedError(
            'Deserializing Sequences enocded as {!s} has not been implemented'.
            format(which_value))
    return ds, computation_types.SequenceType(element=element_type)