示例#1
0
    def testNestedNestedStructure(self):
        s = (tensor_spec.TensorSpec([], dtypes.int64),
             (tensor_spec.TensorSpec([], dtypes.float32),
              tensor_spec.TensorSpec([], dtypes.string)))

        int64_t = constant_op.constant(37, dtype=dtypes.int64)
        float32_t = constant_op.constant(42.0)
        string_t = constant_op.constant("Foo")

        nested_tensors = (int64_t, (float32_t, string_t))

        tensor_list = structure.to_tensor_list(s, nested_tensors)
        for expected, actual in zip([int64_t, float32_t, string_t],
                                    tensor_list):
            self.assertIs(expected, actual)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = structure.from_tensor_list(s, tensor_list)
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = (structure.from_compatible_tensor_list(
              s, tensor_list))
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)
示例#2
0
    def testRoundTripConversion(self, value_fn):
        value = value_fn()
        s = structure.type_spec_from_value(value)

        def maybe_stack_ta(v):
            if isinstance(v, tensor_array_ops.TensorArray):
                return v.stack()
            return v

        before = self.evaluate(maybe_stack_ta(value))
        after = self.evaluate(
            maybe_stack_ta(
                structure.from_tensor_list(s,
                                           structure.to_tensor_list(s,
                                                                    value))))

        flat_before = nest.flatten(before)
        flat_after = nest.flatten(after)
        for b, a in zip(flat_before, flat_after):
            if isinstance(b, sparse_tensor.SparseTensorValue):
                self.assertAllEqual(b.indices, a.indices)
                self.assertAllEqual(b.values, a.values)
                self.assertAllEqual(b.dense_shape, a.dense_shape)
            elif isinstance(b, (ragged_tensor.RaggedTensor,
                                ragged_tensor_value.RaggedTensorValue)):
                self.assertAllEqual(b, a)
            else:
                self.assertAllEqual(b, a)
示例#3
0
 def testPreserveTensorArrayShape(self):
   ta = tensor_array_ops.TensorArray(
       dtype=dtypes.int32, size=1, element_shape=(3,))
   ta_s = structure.type_spec_from_value(ta)
   ta_after = structure.from_tensor_list(ta_s,
                                         structure.to_tensor_list(ta_s, ta))
   self.assertEqual(ta_after.element_shape.as_list(), [3])
示例#4
0
    def testToBatchedTensorList(self, value_fn, element_0_fn):
        batched_value = value_fn()
        s = structure.type_spec_from_value(batched_value)
        batched_tensor_list = structure.to_batched_tensor_list(
            s, batched_value)

        # The batch dimension is 2 for all of the test cases.
        # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
        # tensors in which we store sparse tensors.
        for t in batched_tensor_list:
            if t.dtype != dtypes.variant:
                self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))

        # Test that the 0th element from the unbatched tensor is equal to the
        # expected value.
        expected_element_0 = self.evaluate(element_0_fn())
        unbatched_s = nest.map_structure(
            lambda component_spec: component_spec._unbatch(), s)
        actual_element_0 = structure.from_tensor_list(
            unbatched_s, [t[0] for t in batched_tensor_list])

        for expected, actual in zip(nest.flatten(expected_element_0),
                                    nest.flatten(actual_element_0)):
            if sparse_tensor.is_sparse(expected):
                self.assertSparseValuesEqual(expected, actual)
            elif ragged_tensor.is_ragged(expected):
                self.assertAllEqual(expected, actual)
            else:
                self.assertAllEqual(expected, actual)
示例#5
0
def at(dataset, index):
    """Returns the element at a specific index in a datasest.

  Args:
    dataset: A `tf.data.Dataset` to determine whether it supports random access.
    index: The index at which to fetch the element.

  Returns:
      A (nested) structure of values matching `tf.data.Dataset.element_spec`.

   Raises:
     UnimplementedError: If random access is not yet supported for a dataset.
     Currently, random access is supported for the following tf.data ops:
     `tf.data.Dataset.from_tensor_slices`, `tf.data.Dataset.shuffle`,
     `tf.data.Dataset.batch`, `tf.data.Dataset.shard`, `tf.data.Dataset.map`,
     and `tf.data.Dataset.range`, `tf.data.Dataset.skip`,
     `tf.data.Dataset.repeat`.
  """
    # pylint: disable=protected-access
    return structure.from_tensor_list(
        dataset.element_spec,
        gen_experimental_dataset_ops.get_element_at_index(
            dataset._variant_tensor,
            index,
            output_types=structure.get_flat_tensor_types(dataset.element_spec),
            output_shapes=structure.get_flat_tensor_shapes(
                dataset.element_spec)))
示例#6
0
 def testPreserveInferredTensorArrayShape(self):
     ta = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=1)
     # Shape is inferred from the write.
     ta = ta.write(0, [1, 2, 3])
     ta_s = structure.type_spec_from_value(ta)
     ta_after = structure.from_tensor_list(
         ta_s, structure.to_tensor_list(ta_s, ta))
     self.assertEqual(ta_after.element_shape.as_list(), [3])
示例#7
0
    def get_next(self, name=None):
        """Returns the next element.

    In graph mode, you should typically call this method *once* and use its
    result as the input to another computation. A typical loop will then call
    `tf.Session.run` on the result of that computation. The loop will terminate
    when the `Iterator.get_next()` operation raises
    `tf.errors.OutOfRangeError`. The following skeleton shows how to use
    this method when building a training loop:

    ```python
    dataset = ...  # A `tf.data.Dataset` object.
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    # Build a TensorFlow graph that does something with each element.
    loss = model_function(next_element)
    optimizer = ...  # A `tf.compat.v1.train.Optimizer` object.
    train_op = optimizer.minimize(loss)

    with tf.compat.v1.Session() as sess:
      try:
        while True:
          sess.run(train_op)
      except tf.errors.OutOfRangeError:
        pass
    ```

    NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g.
    when you are distributing different elements to multiple devices in a single
    step. However, a common pitfall arises when users call `Iterator.get_next()`
    in each iteration of their training loop. `Iterator.get_next()` adds ops to
    the graph, and executing each op allocates resources (including threads); as
    a consequence, invoking it in every iteration of a training loop causes
    slowdown and eventual resource exhaustion. To guard against this outcome, we
    log a warning when the number of uses crosses a fixed threshold of
    suspiciousness.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A (nested) structure of values matching `tf.data.Iterator.element_spec`.
    """
        self._get_next_call_count += 1
        if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
            warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)

        # TODO(b/169442955): Investigate the need for this colocation constraint.
        with ops.colocate_with(self._iterator_resource):
            # pylint: disable=protected-access
            flat_ret = gen_dataset_ops.iterator_get_next(
                self._iterator_resource,
                output_types=self._flat_tensor_types,
                output_shapes=self._flat_tensor_shapes,
                name=name)
            return structure.from_tensor_list(self._element_spec, flat_ret)
示例#8
0
  def preserveStaticShape(self):
    rt = ragged_factory_ops.constant([[1, 2], [], [3]])
    rt_s = structure.type_spec_from_value(rt)
    rt_after = structure.from_tensor_list(rt_s,
                                          structure.to_tensor_list(rt_s, rt))
    self.assertEqual(rt_after.row_splits.shape.as_list(),
                     rt.row_splits.shape.as_list())
    self.assertEqual(rt_after.values.shape.as_list(), [None])

    st = sparse_tensor.SparseTensor(
        indices=[[3, 4]], values=[-1], dense_shape=[4, 5])
    st_s = structure.type_spec_from_value(st)
    st_after = structure.from_tensor_list(st_s,
                                          structure.to_tensor_list(st_s, st))
    self.assertEqual(st_after.indices.shape.as_list(), [None, 2])
    self.assertEqual(st_after.values.shape.as_list(), [None])
    self.assertEqual(st_after.dense_shape.shape.as_list(),
                     st.dense_shape.shape.as_list())
示例#9
0
def at(dataset, index):
    """Returns the element at a specific index in a datasest.

  Currently, random access is supported for the following tf.data operations:

     - `tf.data.Dataset.from_tensor_slices`,
     - `tf.data.Dataset.from_tensors`,
     - `tf.data.Dataset.shuffle`,
     - `tf.data.Dataset.batch`,
     - `tf.data.Dataset.shard`,
     - `tf.data.Dataset.map`,
     - `tf.data.Dataset.range`,
     - `tf.data.Dataset.zip`,
     - `tf.data.Dataset.skip`,
     - `tf.data.Dataset.repeat`,
     - `tf.data.Dataset.list_files`,
     - `tf.data.Dataset.SSTableDataset`,
     - `tf.data.Dataset.concatenate`,
     - `tf.data.Dataset.enumerate`,
     - `tf.data.Dataset.parallel_map`,
     - `tf.data.Dataset.prefetch`,
     - `tf.data.Dataset.take`,
     - `tf.data.Dataset.cache` (in-memory only)

     Users can use the cache operation to enable random access for any dataset,
     even one comprised of transformations which are not on this list.
     E.g., to get the third element of a TFDS dataset:

       ```python
       ds = tfds.load("mnist", split="train").cache()
       elem = tf.data.Dataset.experimental.at(ds, 3)
       ```

  Args:
    dataset: A `tf.data.Dataset` to determine whether it supports random access.
    index: The index at which to fetch the element.

  Returns:
      A (nested) structure of values matching `tf.data.Dataset.element_spec`.

   Raises:
     UnimplementedError: If random access is not yet supported for a dataset.
  """
    # pylint: disable=protected-access
    return structure.from_tensor_list(
        dataset.element_spec,
        gen_experimental_dataset_ops.get_element_at_index(
            dataset._variant_tensor,
            index,
            output_types=structure.get_flat_tensor_types(dataset.element_spec),
            output_shapes=structure.get_flat_tensor_shapes(
                dataset.element_spec)))
示例#10
0
 def get_value(self, name=None):
     # TODO(b/110122868): Consolidate the restructuring logic with similar logic
     # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
     with ops.name_scope(name, "OptionalGetValue",
                         [self._variant_tensor]) as scope:
         return structure.from_tensor_list(
             self._value_structure,
             gen_dataset_ops.optional_get_value(
                 self._variant_tensor,
                 name=scope,
                 output_types=structure.get_flat_tensor_types(
                     self._value_structure),
                 output_shapes=structure.get_flat_tensor_shapes(
                     self._value_structure)))
示例#11
0
 def get_value(self, name=None):
   # TODO(b/110122868): Consolidate the restructuring logic with similar logic
   # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
   with ops.name_scope(name, "OptionalGetValue",
                       [self._variant_tensor]) as scope:
     with ops.colocate_with(self._variant_tensor):
       result = gen_dataset_ops.optional_get_value(
           self._variant_tensor,
           name=scope,
           output_types=structure.get_flat_tensor_types(self._element_spec),
           output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
     # NOTE: We do not colocate the deserialization of composite tensors
     # because not all ops are guaranteed to have non-GPU kernels.
     return structure.from_tensor_list(self._element_spec, result)
示例#12
0
def uncompress(element, output_spec):
    """Uncompress a compressed dataset element.

  Args:
    element: A scalar variant tensor to uncompress. The element should have been
      created by calling `compress`.
    output_spec: A nested structure of `tf.TypeSpec` representing the type(s) of
      the uncompressed element.

  Returns:
    The uncompressed element.
  """
    flat_types = structure.get_flat_tensor_types(output_spec)
    flat_shapes = structure.get_flat_tensor_shapes(output_spec)
    tensor_list = ged_ops.uncompress_element(element,
                                             output_types=flat_types,
                                             output_shapes=flat_shapes)
    return structure.from_tensor_list(output_spec, tensor_list)
示例#13
0
    def _dequeue(self):
        """Returns a nested structure of `tf.Tensor`s representing the next element
    in the infeed queue.

    This function should not be called directly, instead the infeed should be
    passed to a loop from `tensorflow.python.ipu.loops`.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
        flat_ret = gen_pop_datastream_ops.pop_datastream_infeed_dequeue(
            feed_id=self._id,
            replication_factor=self._replication_factor,
            io_batch_size=self._io_batch_size,
            prefetch_depth=self._prefetch_depth,
            **self._flat_structure)
        self._dequeued = True
        return structure.from_tensor_list(self._structure, flat_ret)
示例#14
0
def at(dataset, index):
    """Returns the element at a specific index in a datasest.

  Args:
    dataset: A `tf.data.Dataset` to determine whether it supports random access.
    index: The index at which to fetch the element.

  Returns:
      A (nested) structure of values matching `tf.data.Dataset.element_spec`.

   Raises:
     UnimplementedError: If random access is not yet supported for a dataset.
  """
    # pylint: disable=protected-access
    return structure.from_tensor_list(
        dataset.element_spec,
        gen_experimental_dataset_ops.get_element_at_index(
            dataset,
            index,
            output_types=dataset._flat_types,
            output_shapes=dataset._flat_shapes))
示例#15
0
    def testIncompatibleNestedStructure(self):
        # Define three mutually incompatible nested values/structures, and assert
        # that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructure a flattened value with an
        #    incompatible structure fails.

        value_0 = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_0 = structure.type_spec_from_value(value_0)
        flat_s_0 = structure.to_tensor_list(s_0, value_0)

        # `value_1` has compatible nested structure with `value_0`, but different
        # classes.
        value_1 = {
            "a":
            constant_op.constant(37.0),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 0]],
                                       values=[1],
                                       dense_shape=[1, 1])
        }
        s_1 = structure.type_spec_from_value(value_1)
        flat_s_1 = structure.to_tensor_list(s_1, value_1)

        # `value_2` has incompatible nested structure with `value_0` and `value_1`.
        value_2 = {
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                             values=[1],
                                             dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(indices=[[3, 4]],
                                             values=[-1],
                                             dense_shape=[4, 5]))
        }
        s_2 = structure.type_spec_from_value(value_2)
        flat_s_2 = structure.to_tensor_list(s_2, value_2)

        with self.assertRaisesRegex(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*int32.* and shape \(3,\)"):
            structure.to_tensor_list(s_0, value_1)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_0, value_2)

        with self.assertRaisesRegex(
                TypeError, "neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_1, value_0)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_1, value_2)

        # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
        # needs to account for "a" coming before or after "b". It might be worth
        # adding a deterministic repr for these error messages (among other
        # improvements).
        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_0)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_1)

        with self.assertRaisesRegex(
                ValueError, r"Cannot create a Tensor from the tensor list"):
            structure.from_tensor_list(s_0, flat_s_1)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 3"):
            structure.from_tensor_list(s_0, flat_s_2)

        with self.assertRaisesRegex(
                ValueError,
                "Cannot create a SparseTensor from the tensor list"):
            structure.from_tensor_list(s_1, flat_s_0)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 3"):
            structure.from_tensor_list(s_1, flat_s_2)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 3 tensors but got 2"):
            structure.from_tensor_list(s_2, flat_s_0)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 3 tensors but got 2"):
            structure.from_tensor_list(s_2, flat_s_1)
示例#16
0
    def testIncompatibleStructure(self):
        # Define three mutually incompatible values/structures, and assert that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructure a flattened value with an
        #    incompatible structure fails.
        value_tensor = constant_op.constant(42.0)
        s_tensor = structure.type_spec_from_value(value_tensor)
        flat_tensor = structure.to_tensor_list(s_tensor, value_tensor)

        value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                                         values=[1],
                                                         dense_shape=[1, 1])
        s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor)
        flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor,
                                                      value_sparse_tensor)

        value_nest = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_nest = structure.type_spec_from_value(value_nest)
        flat_nest = structure.to_tensor_list(s_nest, value_nest)

        with self.assertRaisesRegex(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            structure.to_tensor_list(s_tensor, value_sparse_tensor)
        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_tensor, value_nest)

        with self.assertRaisesRegex(
                TypeError, "neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_sparse_tensor, value_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_sparse_tensor, value_nest)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_sparse_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "Cannot create a Tensor from the tensor list because item 0 "
                ".*tf.Tensor.* is incompatible with the expected TypeSpec "
                ".*TensorSpec.*"):
            structure.from_tensor_list(s_tensor, flat_sparse_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_tensor, flat_nest)

        with self.assertRaisesRegex(
                ValueError,
                "Cannot create a SparseTensor from the tensor list because "
                "item 0 .*tf.Tensor.* is incompatible with the expected TypeSpec "
                ".*TensorSpec.*"):
            structure.from_tensor_list(s_sparse_tensor, flat_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_sparse_tensor, flat_nest)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_sparse_tensor)