コード例 #1
0
 def test_roundtrip_sequence_of_scalars(self):
   x = tf.data.Dataset.range(5).map(lambda x: x * 2)
   serialized_bytes = value_serialization._serialize_dataset(x)
   y = value_serialization._deserialize_dataset_from_graph_def(
       serialized_bytes,
       element_type=computation_types.to_type(x.element_spec))
   self.assertEqual(x.element_spec, y.element_spec)
   self.assertAllEqual(list(y), [x * 2 for x in range(5)])
コード例 #2
0
 def test_roundtrip_sequence_of_tuples(self):
   x = tf.data.Dataset.range(5).map(
       lambda x: (x * 2, tf.cast(x, tf.int32), tf.cast(x - 1, tf.float32)))
   serialized_bytes = value_serialization._serialize_dataset(x)
   y = value_serialization._deserialize_dataset_from_graph_def(
       serialized_bytes,
       element_type=computation_types.to_type(x.element_spec))
   self.assertEqual(x.element_spec, y.element_spec)
   self.assertAllEqual(list(y), [(x * 2, x, x - 1.) for x in range(5)])
コード例 #3
0
 def test_roundtrip_sequence_of_singleton_tuples(self):
   x = tf.data.Dataset.range(5).map(lambda x: (x,))
   serialized_bytes = value_serialization._serialize_dataset(x)
   y = value_serialization._deserialize_dataset_from_graph_def(
       serialized_bytes,
       element_type=computation_types.to_type(x.element_spec))
   self.assertEqual(x.element_spec, y.element_spec)
   expected_values = [(x,) for x in range(5)]
   actual_values = list(y)
   self.assertAllEqual(expected_values, actual_values)
コード例 #4
0
  def test_roundtrip_sequence_of_namedtuples(self):
    test_tuple_type = collections.namedtuple('TestTuple', ['a', 'b', 'c'])

    def make_test_tuple(x):
      return test_tuple_type(
          a=x * 2, b=tf.cast(x, tf.int32), c=tf.cast(x - 1, tf.float32))

    x = tf.data.Dataset.range(5).map(make_test_tuple)
    serialized_bytes = value_serialization._serialize_dataset(x)
    y = value_serialization._deserialize_dataset_from_graph_def(
        serialized_bytes,
        element_type=computation_types.to_type(x.element_spec))
    self.assertEqual(x.element_spec, y.element_spec)
    self.assertAllEqual(
        list(y), [test_tuple_type(a=x * 2, b=x, c=x - 1.) for x in range(5)])
コード例 #5
0
  def test_roundtrip_sequence_of_nested_structures(self):
    test_tuple_type = collections.namedtuple('TestTuple', ['u', 'v'])

    def _make_nested_tf_structure(x):
      return collections.OrderedDict(
          b=tf.cast(x, tf.int32),
          a=(
              x,
              test_tuple_type(x * 2, x * 3),
              collections.OrderedDict(x=x**2, y=x**3),
          ))

    x = tf.data.Dataset.range(5).map(_make_nested_tf_structure)
    serialzied_bytes = value_serialization._serialize_dataset(x)
    y = value_serialization._deserialize_dataset_from_graph_def(
        serialzied_bytes,
        element_type=computation_types.to_type(x.element_spec))
    # Note: TF loses the `OrderedDict` during serialization, so the expectation
    # here is for a `dict` in the result.
    expected_element_spec = collections.OrderedDict(
        b=tf.TensorSpec([], tf.int32),
        a=(tf.TensorSpec([], tf.int64),
           test_tuple_type(
               tf.TensorSpec([], tf.int64), tf.TensorSpec([], tf.int64)),
           collections.OrderedDict(
               x=tf.TensorSpec([], tf.int64), y=tf.TensorSpec([], tf.int64))))
    self.assertEqual(y.element_spec, expected_element_spec)

    def _build_expected_structure(x):
      return collections.OrderedDict(
          b=x,
          a=(
              x,
              test_tuple_type(x * 2, x * 3),
              collections.OrderedDict(x=x**2, y=x**3),
          ))

    expected_values = (_build_expected_structure(x) for x in range(5))
    for actual, expected in zip(y, expected_values):
      self.assertAllClose(actual, expected)
コード例 #6
0
 def test_serialize_sequence_bytes_too_large(self):
   with self.assertRaisesRegex(ValueError,
                               r'Serialized size .* exceeds maximum allowed'):
     _ = value_serialization._serialize_dataset(
         tf.data.Dataset.range(5), max_serialized_size_bytes=0)
コード例 #7
0
 def test_serialize_sequence_not_a_dataset(self):
   with self.assertRaisesRegex(TypeError, r'Expected .*Dataset.* found int'):
     _ = value_serialization._serialize_dataset(5)