Esempio n. 1
0
 def testPreserveTensorArrayShape(self):
     ta = tensor_array_ops.TensorArray(dtype=dtypes.int32,
                                       size=1,
                                       element_shape=(3, ))
     ta_s = structure.type_spec_from_value(ta)
     ta_after = structure.from_tensor_list(
         ta_s, structure.to_tensor_list(ta_s, ta))
     self.assertEqual(ta_after.element_shape.as_list(), [3])
Esempio n. 2
0
 def testPreserveInferredTensorArrayShape(self):
     ta = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=1)
     # Shape is inferred from the write.
     ta = ta.write(0, [1, 2, 3])
     ta_s = structure.type_spec_from_value(ta)
     ta_after = structure.from_tensor_list(
         ta_s, structure.to_tensor_list(ta_s, ta))
     self.assertEqual(ta_after.element_shape.as_list(), [3])
Esempio n. 3
0
  def preserveStaticShape(self):
    rt = ragged_factory_ops.constant([[1, 2], [], [3]])
    rt_s = structure.type_spec_from_value(rt)
    rt_after = structure.from_tensor_list(rt_s,
                                          structure.to_tensor_list(rt_s, rt))
    self.assertEqual(rt_after.row_splits.shape.as_list(),
                     rt.row_splits.shape.as_list())
    self.assertEqual(rt_after.values.shape.as_list(), [None])

    st = sparse_tensor.SparseTensor(
        indices=[[3, 4]], values=[-1], dense_shape=[4, 5])
    st_s = structure.type_spec_from_value(st)
    st_after = structure.from_tensor_list(st_s,
                                          structure.to_tensor_list(st_s, st))
    self.assertEqual(st_after.indices.shape.as_list(), [None, 2])
    self.assertEqual(st_after.values.shape.as_list(), [None])
    self.assertEqual(st_after.dense_shape.shape.as_list(),
                     st.dense_shape.shape.as_list())
 def py_function_wrapper(*args):
     nested_args = structure.from_compatible_tensor_list(
         self._input_structure, args)
     if not _should_unpack(nested_args):
         nested_args = (nested_args, )
     ret = self._func(*nested_args)
     if _should_pack(ret):
         ret = tuple(ret)
     ret = structure.to_tensor_list(self._output_structure, ret)
     return [ops.convert_to_tensor(t) for t in ret]
Esempio n. 5
0
def compress(element):
    """Compress a dataset element.

  Args:
    element: A nested structure of types supported by Tensorflow.

  Returns:
    A variant tensor representing the compressed element. This variant can be
    passed to `uncompress` to get back the original element.
  """
    element_spec = structure.type_spec_from_value(element)
    tensor_list = structure.to_tensor_list(element_spec, element)
    return ged_ops.compress_element(tensor_list)
Esempio n. 6
0
    def _next_func(string_handle):
      """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
      with ops.device(self._source_device_string):
        iterator = iterator_ops.Iterator.from_string_handle(
            string_handle,
            dataset_ops.get_legacy_output_types(self),
            dataset_ops.get_legacy_output_shapes(self),
            dataset_ops.get_legacy_output_classes(self))
      return structure.to_tensor_list(self.element_spec, iterator.get_next())
Esempio n. 7
0
    def from_value(value):
        """Returns an `Optional` that wraps the given value.

    Args:
      value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.

    Returns:
      An `Optional` that wraps `value`.
    """
        with ops.name_scope("optional") as scope:
            with ops.name_scope("value"):
                value_structure = structure.type_spec_from_value(value)
                encoded_value = structure.to_tensor_list(
                    value_structure, value)

        return _OptionalImpl(
            gen_dataset_ops.optional_from_value(encoded_value, name=scope),
            value_structure)
Esempio n. 8
0
  def from_value(value):
    """Returns an `Optional` that wraps the given value.

    Args:
      value: A value to wrap. The value must be convertible to `Tensor` or
        `CompositeTensor`.

    Returns:
      An `Optional` that wraps `value`.
    """
    with ops.name_scope("optional") as scope:
      with ops.name_scope("value"):
        value_structure = structure.type_spec_from_value(value)
        encoded_value = structure.to_tensor_list(value_structure, value)

    return _OptionalImpl(
        gen_dataset_ops.optional_from_value(encoded_value, name=scope),
        value_structure)
Esempio n. 9
0
    def from_value(value):
        """Returns a `tf.experimental.Optional` that wraps the given value.

    >>> optional = tf.experimental.Optional.from_value(42)
    >>> print(optional.has_value())
    tf.Tensor(True, shape=(), dtype=bool)
    >>> print(optional.get_value())
    tf.Tensor(42, shape=(), dtype=int32)

    Args:
      value: A value to wrap. The value must be convertible to `tf.Tensor` or
        `tf.CompositeTensor`.

    Returns:
      A `tf.experimental.Optional` that wraps `value`.
    """
        with ops.name_scope("optional") as scope:
            with ops.name_scope("value"):
                element_spec = structure.type_spec_from_value(value)
                encoded_value = structure.to_tensor_list(element_spec, value)

        return _OptionalImpl(
            gen_dataset_ops.optional_from_value(encoded_value, name=scope),
            element_spec)
Esempio n. 10
0
  def __init__(self,
               input_dataset,
               initial_state,
               scan_func,
               use_default_device=None):
    """See `scan()` for details."""
    self._input_dataset = input_dataset
    self._initial_state = structure.normalize_element(initial_state)

    # Compute initial values for the state classes, shapes and types based on
    # the initial state. The shapes may be refined by running `tf_scan_func` one
    # or more times below.
    self._state_structure = structure.type_spec_from_value(self._initial_state)

    # Iteratively rerun the scan function until reaching a fixed point on
    # `self._state_shapes`.
    need_to_rerun = True
    while need_to_rerun:

      wrapped_func = dataset_ops.StructuredFunctionWrapper(
          scan_func,
          self._transformation_name(),
          input_structure=(self._state_structure,
                           input_dataset.element_spec),
          add_to_graph=False)
      if not (isinstance(wrapped_func.output_types, collections_abc.Sequence)
              and len(wrapped_func.output_types) == 2):
        raise TypeError("The scan function must return a pair comprising the "
                        "new state and the output value.")

      new_state_classes, self._output_classes = wrapped_func.output_classes

      # Extract and validate class information from the returned values.
      new_state_classes, output_classes = wrapped_func.output_classes
      old_state_classes = nest.map_structure(
          lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
          self._state_structure)
      for new_state_class, old_state_class in zip(
          nest.flatten(new_state_classes),
          nest.flatten(old_state_classes)):
        if not issubclass(new_state_class, old_state_class):
          raise TypeError(
              "The element classes for the new state must match the initial "
              "state. Expected %s; got %s." %
              (old_state_classes, new_state_classes))

      # Extract and validate type information from the returned values.
      new_state_types, output_types = wrapped_func.output_types
      old_state_types = nest.map_structure(
          lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
          self._state_structure)
      for new_state_type, old_state_type in zip(
          nest.flatten(new_state_types), nest.flatten(old_state_types)):
        if new_state_type != old_state_type:
          raise TypeError(
              "The element types for the new state must match the initial "
              "state. Expected %s; got %s." %
              (old_state_types, new_state_types))

      # Extract shape information from the returned values.
      new_state_shapes, output_shapes = wrapped_func.output_shapes
      old_state_shapes = nest.map_structure(
          lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
          self._state_structure)
      self._element_spec = structure.convert_legacy_structure(
          output_types, output_shapes, output_classes)

      flat_state_shapes = nest.flatten(old_state_shapes)
      flat_new_state_shapes = nest.flatten(new_state_shapes)
      weakened_state_shapes = [
          original.most_specific_compatible_shape(new)
          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
      ]

      need_to_rerun = False
      for original_shape, weakened_shape in zip(flat_state_shapes,
                                                weakened_state_shapes):
        if original_shape.ndims is not None and (
            weakened_shape.ndims is None or
            original_shape.as_list() != weakened_shape.as_list()):
          need_to_rerun = True
          break

      if need_to_rerun:
        # TODO(b/110122868): Support a "most specific compatible structure"
        # method for combining structures, to avoid using legacy structures
        # in this method.
        self._state_structure = structure.convert_legacy_structure(
            old_state_types,
            nest.pack_sequence_as(old_state_shapes, weakened_state_shapes),
            old_state_classes)

    self._scan_func = wrapped_func
    self._scan_func.function.add_to_graph(ops.get_default_graph())
    # pylint: disable=protected-access
    if compat.forward_compatible(2019, 10,
                                 15) or use_default_device is not None:
      variant_tensor = gen_experimental_dataset_ops.scan_dataset(
          self._input_dataset._variant_tensor,
          structure.to_tensor_list(self._state_structure, self._initial_state),
          self._scan_func.function.captured_inputs,
          f=self._scan_func.function,
          preserve_cardinality=True,
          use_default_device=use_default_device,
          **self._flat_structure)
    else:
      variant_tensor = gen_experimental_dataset_ops.scan_dataset(
          self._input_dataset._variant_tensor,
          structure.to_tensor_list(self._state_structure, self._initial_state),
          self._scan_func.function.captured_inputs,
          f=self._scan_func.function,
          preserve_cardinality=True,
          **self._flat_structure)
    super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
Esempio n. 11
0
    def testIncompatibleNestedStructure(self):
        # Define three mutually incompatible nested values/structures, and assert
        # that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructure a flattened value with an
        #    incompatible structure fails.

        value_0 = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_0 = structure.type_spec_from_value(value_0)
        flat_s_0 = structure.to_tensor_list(s_0, value_0)

        # `value_1` has compatible nested structure with `value_0`, but different
        # classes.
        value_1 = {
            "a":
            constant_op.constant(37.0),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 0]],
                                       values=[1],
                                       dense_shape=[1, 1])
        }
        s_1 = structure.type_spec_from_value(value_1)
        flat_s_1 = structure.to_tensor_list(s_1, value_1)

        # `value_2` has incompatible nested structure with `value_0` and `value_1`.
        value_2 = {
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                             values=[1],
                                             dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(indices=[[3, 4]],
                                             values=[-1],
                                             dense_shape=[4, 5]))
        }
        s_2 = structure.type_spec_from_value(value_2)
        flat_s_2 = structure.to_tensor_list(s_2, value_2)

        with self.assertRaisesRegex(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*int32.* and shape \(3,\)"):
            structure.to_tensor_list(s_0, value_1)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_0, value_2)

        with self.assertRaisesRegex(
                TypeError, "neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_1, value_0)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_1, value_2)

        # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
        # needs to account for "a" coming before or after "b". It might be worth
        # adding a deterministic repr for these error messages (among other
        # improvements).
        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_0)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_1)

        with self.assertRaisesRegex(
                ValueError, r"Cannot create a Tensor from the tensor list"):
            structure.from_tensor_list(s_0, flat_s_1)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 3"):
            structure.from_tensor_list(s_0, flat_s_2)

        with self.assertRaisesRegex(
                ValueError,
                "Cannot create a SparseTensor from the tensor list"):
            structure.from_tensor_list(s_1, flat_s_0)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 3"):
            structure.from_tensor_list(s_1, flat_s_2)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 3 tensors but got 2"):
            structure.from_tensor_list(s_2, flat_s_0)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 3 tensors but got 2"):
            structure.from_tensor_list(s_2, flat_s_1)
Esempio n. 12
0
    def testIncompatibleStructure(self):
        # Define three mutually incompatible values/structures, and assert that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructure a flattened value with an
        #    incompatible structure fails.
        value_tensor = constant_op.constant(42.0)
        s_tensor = structure.type_spec_from_value(value_tensor)
        flat_tensor = structure.to_tensor_list(s_tensor, value_tensor)

        value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                                         values=[1],
                                                         dense_shape=[1, 1])
        s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor)
        flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor,
                                                      value_sparse_tensor)

        value_nest = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_nest = structure.type_spec_from_value(value_nest)
        flat_nest = structure.to_tensor_list(s_nest, value_nest)

        with self.assertRaisesRegex(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            structure.to_tensor_list(s_tensor, value_sparse_tensor)
        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_tensor, value_nest)

        with self.assertRaisesRegex(
                TypeError, "neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_sparse_tensor, value_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_sparse_tensor, value_nest)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_sparse_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "Cannot create a Tensor from the tensor list because item 0 "
                ".*tf.Tensor.* is incompatible with the expected TypeSpec "
                ".*TensorSpec.*"):
            structure.from_tensor_list(s_tensor, flat_sparse_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_tensor, flat_nest)

        with self.assertRaisesRegex(
                ValueError,
                "Cannot create a SparseTensor from the tensor list because "
                "item 0 .*tf.Tensor.* is incompatible with the expected TypeSpec "
                ".*TensorSpec.*"):
            structure.from_tensor_list(s_sparse_tensor, flat_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_sparse_tensor, flat_nest)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_sparse_tensor)
 def wrapped_fn(*args):  # pylint: disable=missing-docstring
     ret = wrapper_helper(*args)
     ret = structure.to_tensor_list(self._output_structure, ret)
     return [ops.convert_to_tensor(t) for t in ret]
 def wrapped_fn(*args):
     ret = wrapper_helper(*args)
     return structure.to_tensor_list(self._output_structure, ret)