Exemple #1
0
    def test_roundtrip_sequence_of_scalars(self):
        x = tf.data.Dataset.range(5).map(lambda x: x * 2)
        serialized_bytes = tensorflow_serialization.serialize_dataset(x)
        y = tensorflow_serialization.deserialize_dataset(serialized_bytes)

        self.assertEqual(x.element_spec, y.element_spec)
        self.assertAllEqual([y_val for y_val in y], [x * 2 for x in range(5)])
    def test_roundtrip_sequence_of_scalars(self):
        x = tf.data.Dataset.range(5).map(lambda x: x * 2)
        serialized_bytes = tensorflow_serialization.serialize_dataset(x)
        y = tensorflow_serialization.deserialize_dataset(serialized_bytes)

        self.assertEqual(tf.data.experimental.get_structure(x),
                         tf.data.experimental.get_structure(y))
        self.assertAllEqual([y_val for y_val in y], [x * 2 for x in range(5)])
Exemple #3
0
    def test_roundtrip_sequence_of_singleton_tuples(self):
        x = tf.data.Dataset.range(5).map(lambda x: (x, ))
        serialized_bytes = tensorflow_serialization.serialize_dataset(x)
        y = tensorflow_serialization.deserialize_dataset(serialized_bytes)

        self.assertEqual(x.element_spec, y.element_spec)
        expected_values = [(x, ) for x in range(5)]
        actual_values = self.evaluate([y_val for y_val in y])
        self.assertAllEqual(expected_values, actual_values)
Exemple #4
0
    def test_roundtrip_sequence_of_tuples(self):
        x = tf.data.Dataset.range(5).map(lambda x: (x * 2, tf.cast(
            x, tf.int32), tf.cast(x - 1, tf.float32)))
        serialized_bytes = tensorflow_serialization.serialize_dataset(x)
        y = tensorflow_serialization.deserialize_dataset(serialized_bytes)

        self.assertEqual(x.element_spec, y.element_spec)
        self.assertAllEqual(self.evaluate([y_val for y_val in y]),
                            [(x * 2, x, x - 1.) for x in range(5)])
    def test_roundtrip_sequence_of_nested_structures(self):
        test_tuple_type = collections.namedtuple('TestTuple', ['u', 'v'])

        def _make_nested_tf_structure(x):
            return collections.OrderedDict([
                ('b', tf.cast(x, tf.int32)),
                ('a',
                 tuple([
                     x,
                     test_tuple_type(x * 2, x * 3),
                     collections.OrderedDict([('x', x**2), ('y', x**3)])
                 ])),
            ])

        x = tf.data.Dataset.range(5).map(_make_nested_tf_structure)
        serialzied_bytes = tensorflow_serialization.serialize_dataset(x)
        y = tensorflow_serialization.deserialize_dataset(serialzied_bytes)

        # NOTE: TF loses the `OrderedDict` during serialization, so the expectation
        # here is for a `dict` in the result.
        self.assertEqual(
            tf.data.experimental.get_structure(y), {
                'b':
                tf.TensorSpec([], tf.int32),
                'a':
                tuple([
                    tf.TensorSpec([], tf.int64),
                    test_tuple_type(
                        tf.TensorSpec([], tf.int64),
                        tf.TensorSpec([], tf.int64),
                    ),
                    {
                        'x': tf.TensorSpec([], tf.int64),
                        'y': tf.TensorSpec([], tf.int64),
                    },
                ]),
            })

        def _build_expected_structure(x):
            return {
                'b':
                x,
                'a':
                tuple(
                    [x,
                     test_tuple_type(x * 2, x * 3), {
                         'x': x**2,
                         'y': x**3
                     }])
            }

        actual_values = self.evaluate([y_val for y_val in y])
        expected_values = [_build_expected_structure(x) for x in range(5)]
        for actual, expected in zip(actual_values, expected_values):
            self.assertAllClose(actual, expected)
Exemple #6
0
  def test_roundtrip_sequence_of_namedtuples(self):
    test_tuple_type = collections.namedtuple('TestTuple', ['a', 'b', 'c'])

    def make_test_tuple(x):
      return test_tuple_type(
          a=x * 2, b=tf.cast(x, tf.int32), c=tf.cast(x - 1, tf.float32))

    x = tf.data.Dataset.range(5).map(make_test_tuple)
    serialized_bytes = tensorflow_serialization.serialize_dataset(x)
    y = tensorflow_serialization.deserialize_dataset(serialized_bytes)

    self.assertEqual(x.element_spec, y.element_spec)
    self.assertAllEqual(
        self.evaluate([y_val for y_val in y]),
        [test_tuple_type(a=x * 2, b=x, c=x - 1.) for x in range(5)])
def serialize_sequence_value(value):
    """Serializes a `tf.data.Dataset` value into `executor_pb2.Value`.

  Args:
    value: A `tf.data.Dataset`, or equivalent.

  Returns:
    A tuple `(value_proto, type_spec)` in which `value_proto` is an instance
    of `executor_pb2.Value` with the serialized content of `value`, and
    `type_spec` is the type of the serialized value.
  """
    py_typecheck.check_type(value, type_utils.TF_DATASET_REPRESENTATION_TYPES)
    # TFF must store the type spec here because TF will lose the ordering of the
    # names for `tf.data.Dataset` that return elements of `collections.Mapping`
    # type. This allows TFF to preserve and restore the key ordering upon
    # deserialization.
    element_type = computation_types.to_type(value.element_spec)
    return executor_pb2.Value(sequence=executor_pb2.Value.Sequence(
        zipped_saved_model=tensorflow_serialization.serialize_dataset(value),
        element_type=type_serialization.serialize_type(element_type)))
Exemple #8
0
def _serialize_sequence_value(
    value: Union[type_conversions.TF_DATASET_REPRESENTATION_TYPES],
    type_spec: computation_types.SequenceType) -> _SerializeReturnType:
  """Serializes a `tf.data.Dataset` value into `executor_pb2.Value`.

  Args:
    value: A `tf.data.Dataset`, or equivalent.
    type_spec: A `computation_types.Type` specifying the TFF sequence type of
      `value.`

  Returns:
    A tuple `(value_proto, type_spec)` in which `value_proto` is an instance
    of `executor_pb2.Value` with the serialized content of `value`, and
    `type_spec` is the type of the serialized value.
  """
  if not isinstance(value, type_conversions.TF_DATASET_REPRESENTATION_TYPES):
    raise TypeError(
        'Cannot serialize Python type {!s} as TFF type {!s}.'.format(
            py_typecheck.type_string(type(value)),
            type_spec if type_spec is not None else 'unknown'))

  value_type = computation_types.SequenceType(
      computation_types.to_type(value.element_spec))
  if not type_spec.is_assignable_from(value_type):
    raise TypeError(
        'Cannot serialize dataset with elements of type {!s} as TFF type {!s}.'
        .format(value_type, type_spec if type_spec is not None else 'unknown'))

  # TFF must store the type spec here because TF will lose the ordering of the
  # names for `tf.data.Dataset` that return elements of `collections.Mapping`
  # type. This allows TFF to preserve and restore the key ordering upon
  # deserialization.
  element_type = computation_types.to_type(value.element_spec)
  return executor_pb2.Value(
      sequence=executor_pb2.Value.Sequence(
          zipped_saved_model=tensorflow_serialization.serialize_dataset(value),
          element_type=type_serialization.serialize_type(
              element_type))), type_spec
 def test_serialize_sequence_bytes_too_large(self):
     with self.assertRaisesRegex(
             ValueError, r'Serialized size .* exceeds maximum allowed'):
         _ = tensorflow_serialization.serialize_dataset(
             tf.data.Dataset.range(5), max_serialized_size_bytes=0)
 def test_serialize_sequence_not_a_dataset(self):
     with self.assertRaisesRegex(TypeError,
                                 r'Expected .*Dataset.* found int'):
         _ = tensorflow_serialization.serialize_dataset(5)