Exemple #1
0
    def test_roundtrip_sequence_of_scalars(self):
        x = tf.data.Dataset.range(5).map(lambda x: x * 2)
        serialized_bytes = tensorflow_serialization.serialize_dataset(x)
        y = tensorflow_serialization.deserialize_dataset(serialized_bytes)

        self.assertEqual(x.element_spec, y.element_spec)
        self.assertAllEqual([y_val for y_val in y], [x * 2 for x in range(5)])
Exemple #2
0
def deserialize_sequence_value(sequence_value_proto):
    """Deserializes a `tf.data.Dataset`.

  Args:
    sequence_value_proto: `Sequence` protocol buffer message.

  Returns:
    A tuple of `(tf.data.Dataset, tff.Type)`.
  """
    py_typecheck.check_type(sequence_value_proto, executor_pb2.Value.Sequence)

    which_value = sequence_value_proto.WhichOneof('value')
    if which_value == 'zipped_saved_model':
        ds = tensorflow_serialization.deserialize_dataset(
            sequence_value_proto.zipped_saved_model)
    else:
        raise NotImplementedError(
            'Deserializing Sequences enocded as {!s} has not been implemented'.
            format(which_value))

    element_type = type_serialization.deserialize_type(
        sequence_value_proto.element_type)

    # If a serialized dataset had elements of nested structes of tensors (e.g.
    # `dict`, `OrderedDict`), the deserialized dataset will return `dict`,
    # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion).
    #
    # Since the dataset will only be used inside TFF, we wrap the dictionary
    # coming from TF in an `OrderedDict` when necessary (a type that both TF and
    # TFF understand), using the field order stored in the TFF type stored during
    # serialization.
    ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
        ds, element_type)

    return ds, computation_types.SequenceType(element=element_type)
    def test_roundtrip_sequence_of_scalars(self):
        x = tf.data.Dataset.range(5).map(lambda x: x * 2)
        serialized_bytes = tensorflow_serialization.serialize_dataset(x)
        y = tensorflow_serialization.deserialize_dataset(serialized_bytes)

        self.assertEqual(tf.data.experimental.get_structure(x),
                         tf.data.experimental.get_structure(y))
        self.assertAllEqual([y_val for y_val in y], [x * 2 for x in range(5)])
Exemple #4
0
    def test_roundtrip_sequence_of_singleton_tuples(self):
        x = tf.data.Dataset.range(5).map(lambda x: (x, ))
        serialized_bytes = tensorflow_serialization.serialize_dataset(x)
        y = tensorflow_serialization.deserialize_dataset(serialized_bytes)

        self.assertEqual(x.element_spec, y.element_spec)
        expected_values = [(x, ) for x in range(5)]
        actual_values = self.evaluate([y_val for y_val in y])
        self.assertAllEqual(expected_values, actual_values)
Exemple #5
0
    def test_roundtrip_sequence_of_tuples(self):
        x = tf.data.Dataset.range(5).map(lambda x: (x * 2, tf.cast(
            x, tf.int32), tf.cast(x - 1, tf.float32)))
        serialized_bytes = tensorflow_serialization.serialize_dataset(x)
        y = tensorflow_serialization.deserialize_dataset(serialized_bytes)

        self.assertEqual(x.element_spec, y.element_spec)
        self.assertAllEqual(self.evaluate([y_val for y_val in y]),
                            [(x * 2, x, x - 1.) for x in range(5)])
    def test_roundtrip_sequence_of_nested_structures(self):
        test_tuple_type = collections.namedtuple('TestTuple', ['u', 'v'])

        def _make_nested_tf_structure(x):
            return collections.OrderedDict([
                ('b', tf.cast(x, tf.int32)),
                ('a',
                 tuple([
                     x,
                     test_tuple_type(x * 2, x * 3),
                     collections.OrderedDict([('x', x**2), ('y', x**3)])
                 ])),
            ])

        x = tf.data.Dataset.range(5).map(_make_nested_tf_structure)
        serialzied_bytes = tensorflow_serialization.serialize_dataset(x)
        y = tensorflow_serialization.deserialize_dataset(serialzied_bytes)

        # NOTE: TF loses the `OrderedDict` during serialization, so the expectation
        # here is for a `dict` in the result.
        self.assertEqual(
            tf.data.experimental.get_structure(y), {
                'b':
                tf.TensorSpec([], tf.int32),
                'a':
                tuple([
                    tf.TensorSpec([], tf.int64),
                    test_tuple_type(
                        tf.TensorSpec([], tf.int64),
                        tf.TensorSpec([], tf.int64),
                    ),
                    {
                        'x': tf.TensorSpec([], tf.int64),
                        'y': tf.TensorSpec([], tf.int64),
                    },
                ]),
            })

        def _build_expected_structure(x):
            return {
                'b':
                x,
                'a':
                tuple(
                    [x,
                     test_tuple_type(x * 2, x * 3), {
                         'x': x**2,
                         'y': x**3
                     }])
            }

        actual_values = self.evaluate([y_val for y_val in y])
        expected_values = [_build_expected_structure(x) for x in range(5)]
        for actual, expected in zip(actual_values, expected_values):
            self.assertAllClose(actual, expected)
Exemple #7
0
  def test_roundtrip_sequence_of_namedtuples(self):
    test_tuple_type = collections.namedtuple('TestTuple', ['a', 'b', 'c'])

    def make_test_tuple(x):
      return test_tuple_type(
          a=x * 2, b=tf.cast(x, tf.int32), c=tf.cast(x - 1, tf.float32))

    x = tf.data.Dataset.range(5).map(make_test_tuple)
    serialized_bytes = tensorflow_serialization.serialize_dataset(x)
    y = tensorflow_serialization.deserialize_dataset(serialized_bytes)

    self.assertEqual(x.element_spec, y.element_spec)
    self.assertAllEqual(
        self.evaluate([y_val for y_val in y]),
        [test_tuple_type(a=x * 2, b=x, c=x - 1.) for x in range(5)])