def generator_py_func(iterator_id):
        """A `py_func` that will be called to invoke the iterator."""
        try:
          values = next(generator_state.get_iterator(iterator_id))
        except StopIteration:
          generator_state.iterator_completed(iterator_id)
          raise StopIteration("Iteration finished.")

        # Use the same _convert function from the py_func() implementation to
        # convert the returned values to arrays early, so that we can inspect
        # their values.
        # pylint: disable=protected-access
        ret_arrays = [
            script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype)
            for ret, dtype in zip(nest.flatten_up_to(output_types, values),
                                  flattened_types)
        ]
        # pylint: enable=protected-access

        # Additional type and shape checking to ensure that the components
        # of the generated element match the `output_types` and `output_shapes`
        # arguments.
        for (ret_array, expected_dtype, expected_shape) in zip(
            ret_arrays, flattened_types, flattened_shapes):
          if ret_array.dtype != expected_dtype.as_numpy_dtype:
            raise TypeError(
                "`generator` yielded an element of type %s where an element "
                "of type %s was expected." % (ret_array.dtype,
                                              expected_dtype.as_numpy_dtype))
          if not expected_shape.is_compatible_with(ret_array.shape):
            raise ValueError(
                "`generator` yielded an element of shape %s where an element "
                "of shape %s was expected." % (ret_array.shape, expected_shape))

        return ret_arrays
Exemple #2
0
 def _merge_output_shapes(original_shapes, expected_shapes):
   flat_original_shapes = nest.flatten(original_shapes)
   flat_new_shapes = nest.flatten_up_to(original_shapes, expected_shapes)
   flat_merged_output_shapes = [
       original_shape.merge_with(new_shape)
       for original_shape, new_shape in zip(flat_original_shapes,
                                            flat_new_shapes)]
   return nest.pack_sequence_as(original_shapes, flat_merged_output_shapes)
Exemple #3
0
 def _merge_output_shapes(original_shapes, expected_shapes):
   flat_original_shapes = nest.flatten(original_shapes)
   flat_new_shapes = nest.flatten_up_to(original_shapes, expected_shapes)
   flat_merged_output_shapes = [
       original_shape.merge_with(new_shape)
       for original_shape, new_shape in zip(flat_original_shapes,
                                            flat_new_shapes)]
   return nest.pack_sequence_as(original_shapes, flat_merged_output_shapes)
Exemple #4
0
 def _get_flat_item(i):
   values = self.get_item(i.decode())
   flattened_values = nest.flatten_up_to(output_types, values)
   # Cast types to the expected ones.
   flattened_values = [
       np.array(v, d.as_numpy_dtype)
       for v, d in zip(flattened_values, flattened_types)
   ]
   return flattened_values
Exemple #5
0
  def is_compatible_with(self, value):
    try:
      nest.assert_shallow_structure(self._nested_structure, value)
    except (ValueError, TypeError):
      return False

    return all(
        s.is_compatible_with(v) for s, v in zip(
            nest.flatten(self._nested_structure),
            nest.flatten_up_to(self._nested_structure, value)))
Exemple #6
0
  def __init__(self, dataset, output_types, output_shapes=None):
    """Creates a new dataset with the given output types and shapes.

    The given `dataset` must have a structure that is convertible:
    * `dataset.output_types` must be the same as `output_types` module nesting.
    * Each shape in `dataset.output_shapes` must be compatible with each shape
      in `output_shapes` (if given).

    Note: This helper permits "unsafe casts" for shapes, equivalent to using
    `tf.Tensor.set_shape()` where domain-specific knowledge is available.

    Args:
      dataset: A `Dataset` object.
      output_types: A nested structure of `tf.DType` objects.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
        If omitted, the shapes will be inherited from `dataset`.

    Raises:
      ValueError: If either `output_types` or `output_shapes` is not compatible
        with the structure of `dataset`.
    """
    super(_RestructuredDataset, self).__init__()
    self._dataset = dataset

    # Validate that the types are compatible.
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    flat_original_types = nest.flatten(dataset.output_types)
    flat_new_types = nest.flatten(output_types)
    if flat_original_types != flat_new_types:
      raise ValueError(
          "Dataset with output types %r cannot be restructured to have output "
          "types %r" % (dataset.output_types, output_types))

    self._output_types = output_types

    if output_shapes is None:
      # Inherit shapes from the original `dataset`.
      self._output_shapes = nest.pack_sequence_as(output_types,
                                                  nest.flatten(
                                                      dataset.output_shapes))
    else:
      # Validate that the shapes are compatible.
      nest.assert_same_structure(output_types, output_shapes)
      flat_original_shapes = nest.flatten(dataset.output_shapes)
      flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)

      for original_shape, new_shape in zip(flat_original_shapes,
                                           flat_new_shapes):
        if not original_shape.is_compatible_with(new_shape):
          raise ValueError(
              "Dataset with output shapes %r cannot be restructured to have "
              "incompatible output shapes %r" % (dataset.output_shapes,
                                                 output_shapes))
      self._output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
Exemple #7
0
  def __init__(self, dataset, output_types, output_shapes=None):
    """Creates a new dataset with the given output types and shapes.

    The given `dataset` must have a structure that is convertible:
    * `dataset.output_types` must be the same as `output_types` module nesting.
    * Each shape in `dataset.output_shapes` must be compatible with each shape
      in `output_shapes` (if given).

    Note: This helper permits "unsafe casts" for shapes, equivalent to using
    `tf.Tensor.set_shape()` where domain-specific knowledge is available.

    Args:
      dataset: A `Dataset` object.
      output_types: A nested structure of `tf.DType` objects.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
        If omitted, the shapes will be inherited from `dataset`.

    Raises:
      ValueError: If either `output_types` or `output_shapes` is not compatible
        with the structure of `dataset`.
    """
    super(_RestructuredDataset, self).__init__()
    self._dataset = dataset

    # Validate that the types are compatible.
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    flat_original_types = nest.flatten(dataset.output_types)
    flat_new_types = nest.flatten(output_types)
    if flat_original_types != flat_new_types:
      raise ValueError(
          "Dataset with output types %r cannot be restructured to have output "
          "types %r" % (dataset.output_types, output_types))

    self._output_types = output_types

    if output_shapes is None:
      # Inherit shapes from the original `dataset`.
      self._output_shapes = nest.pack_sequence_as(output_types,
                                                  nest.flatten(
                                                      dataset.output_shapes))
    else:
      # Validate that the shapes are compatible.
      nest.assert_same_structure(output_types, output_shapes)
      flat_original_shapes = nest.flatten(dataset.output_shapes)
      flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)

      for original_shape, new_shape in zip(flat_original_shapes,
                                           flat_new_shapes):
        if not original_shape.is_compatible_with(new_shape):
          raise ValueError(
              "Dataset with output shapes %r cannot be restructured to have "
              "incompatible output shapes %r" % (dataset.output_shapes,
                                                 output_shapes))
      self._output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
Exemple #8
0
  def _to_tensor_list(self, value):
    ret = []

    try:
      flat_value = nest.flatten_up_to(self._nested_structure, value)
    except (ValueError, TypeError):
      raise ValueError("The value %r is not compatible with the nested "
                       "structure %r." % (value, self._nested_structure))

    for sub_value, structure in zip(flat_value, self._flat_nested_structure):
      if not structure.is_compatible_with(Structure.from_value(sub_value)):
        raise ValueError("Component value %r is not compatible with the nested "
                         "structure %r." % (sub_value, structure))
      ret.extend(structure._to_tensor_list(sub_value))
    return ret
Exemple #9
0
  def _to_batched_tensor_list(self, value):
    ret = []

    try:
      flat_value = nest.flatten_up_to(self._nested_structure, value)
    except (ValueError, TypeError):
      raise ValueError("The value %r is not compatible with the nested "
                       "structure %r." % (value, self._nested_structure))

    for sub_value, structure in zip(flat_value, self._flat_nested_structure):
      if not structure.is_compatible_with(Structure.from_value(sub_value)):
        raise ValueError("Component value %r is not compatible with the nested "
                         "structure %r." % (sub_value, structure))
      ret.extend(structure._to_batched_tensor_list(sub_value))
    return ret
Exemple #10
0
            def generator_py_func(iterator_id):
                """A `py_func` that will be called to invoke the iterator."""
                try:
                    values = next(generator_state.get_iterator(iterator_id))
                except StopIteration:
                    generator_state.iterator_completed(iterator_id)
                    raise StopIteration("Iteration finished.")

                # Use the same _convert function from the py_func() implementation to
                # convert the returned values to arrays early, so that we can inspect
                # their values.
                # pylint: disable=protected-access
                ret_arrays = [
                    script_ops.FuncRegistry._convert(
                        ret, dtype=dtype.as_numpy_dtype) for ret, dtype in zip(
                            nest.flatten_up_to(output_types, values),
                            flattened_types)
                ]
                # pylint: enable=protected-access

                # Additional type and shape checking to ensure that the components
                # of the generated element match the `output_types` and `output_shapes`
                # arguments.
                for (ret_array, expected_dtype,
                     expected_shape) in zip(ret_arrays, flattened_types,
                                            flattened_shapes):
                    if ret_array.dtype != expected_dtype.as_numpy_dtype:
                        raise TypeError(
                            "`generator` yielded an element of type %s where an element "
                            "of type %s was expected." %
                            (ret_array.dtype, expected_dtype.as_numpy_dtype))
                    if not expected_shape.is_compatible_with(ret_array.shape):
                        raise ValueError(
                            "`generator` yielded an element of shape %s where an element "
                            "of shape %s was expected." %
                            (ret_array.shape, expected_shape))

                return ret_arrays
Exemple #11
0
    def __init__(self,
                 dataset,
                 output_types,
                 output_shapes=None,
                 output_classes=None,
                 allow_unsafe_cast=False):
        """Creates a new dataset with the given output types and shapes.

    The given `dataset` must have a structure that is convertible:
    * `dataset.output_types` must be the same as `output_types` module nesting.
    * Each shape in `dataset.output_shapes` must be compatible with each shape
      in `output_shapes` (if given).

    Note: This helper permits "unsafe casts" for shapes, equivalent to using
    `tf.Tensor.set_shape()` where domain-specific knowledge is available.

    Args:
      dataset: A `Dataset` object.
      output_types: A nested structure of `tf.DType` objects.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
        If omitted, the shapes will be inherited from `dataset`.
      output_classes: (Optional.) A nested structure of class types. If omitted,
        the class types will be inherited from `dataset`.
      allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
        reported output types and shapes of the restructured dataset, e.g. to
        switch a sparse tensor represented as `tf.variant` to its user-visible
        type and shape.

    Raises:
      ValueError: If either `output_types` or `output_shapes` is not compatible
        with the structure of `dataset`.
    """
        self._input_dataset = dataset

        input_types = dataset_ops.get_legacy_output_types(dataset)
        if not allow_unsafe_cast:
            # Validate that the types are compatible.
            output_types = nest.map_structure(dtypes.as_dtype, output_types)
            flat_original_types = nest.flatten(input_types)
            flat_new_types = nest.flatten(output_types)
            if flat_original_types != flat_new_types:
                raise ValueError(
                    "Dataset with output types %r cannot be restructured to have "
                    "output types %r" %
                    (dataset_ops.get_legacy_output_types(dataset),
                     output_types))

        input_shapes = dataset_ops.get_legacy_output_shapes(dataset)
        if output_shapes is None:
            # Inherit shapes from the original `dataset`.
            output_shapes = nest.pack_sequence_as(output_types,
                                                  nest.flatten(input_shapes))
        else:
            if not allow_unsafe_cast:
                # Validate that the shapes are compatible.
                nest.assert_same_structure(output_types, output_shapes)
                flat_original_shapes = nest.flatten(input_shapes)
                flat_new_shapes = nest.flatten_up_to(output_types,
                                                     output_shapes)

                for original_shape, new_shape in zip(flat_original_shapes,
                                                     flat_new_shapes):
                    if not original_shape.is_compatible_with(new_shape):
                        raise ValueError(
                            "Dataset with output shapes %r cannot be restructured to have "
                            "incompatible output shapes %r" %
                            (input_shapes, output_shapes))
            output_shapes = nest.map_structure_up_to(output_types,
                                                     tensor_shape.as_shape,
                                                     output_shapes)

        input_classes = dataset_ops.get_legacy_output_classes(dataset)
        if output_classes is None:
            # Inherit class types from the original `dataset`.
            output_classes = nest.pack_sequence_as(output_types,
                                                   nest.flatten(input_classes))

        self._structure = structure.convert_legacy_structure(
            output_types, output_shapes, output_classes)
        variant_tensor = self._input_dataset._variant_tensor  # pylint: disable=protected-access
        super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
Exemple #12
0
  def testFlattenUpTo(self):
    input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5)))
    shallow_tree = ((True, True), (False, True))
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)])
    self.assertEqual(flattened_shallow_tree, [True, True, False, True])

    input_tree = ((("a", 1), (("b", 2), (("c", 3), (("d", 4))))))
    shallow_tree = (("level_1", ("level_2", ("level_3", ("level_4")))))
    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
                                                              input_tree)
    input_tree_flattened = nest.flatten(input_tree)
    self.assertEqual(input_tree_flattened_as_shallow_tree,
                     [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
    self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])

    ## Shallow non-list edge-case.
    # Using iterable elements.
    input_tree = ["input_tree"]
    shallow_tree = "shallow_tree"
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    input_tree = ("input_tree_0", "input_tree_1")
    shallow_tree = "shallow_tree"
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    # Using non-iterable elements.
    input_tree = (0,)
    shallow_tree = 9
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    input_tree = (0, 1)
    shallow_tree = 9
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    ## Both non-list edge-case.
    # Using iterable elements.
    input_tree = "input_tree"
    shallow_tree = "shallow_tree"
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    # Using non-iterable elements.
    input_tree = 0
    shallow_tree = 0
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [input_tree])
    self.assertEqual(flattened_shallow_tree, [shallow_tree])

    ## Input non-list edge-case.
    # Using iterable elements.
    input_tree = "input_tree"
    shallow_tree = ("shallow_tree",)
    expected_message = ("If shallow structure is a sequence, input must also "
                        "be a sequence. Input has type: <(type|class) 'str'>.")
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, list(shallow_tree))

    input_tree = "input_tree"
    shallow_tree = ("shallow_tree_9", "shallow_tree_8")
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, list(shallow_tree))

    # Using non-iterable elements.
    input_tree = 0
    shallow_tree = (9,)
    expected_message = ("If shallow structure is a sequence, input must also "
                        "be a sequence. Input has type: <(type|class) 'int'>.")
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, list(shallow_tree))

    input_tree = 0
    shallow_tree = (9, 8)
    with self.assertRaisesRegexp(TypeError, expected_message):
      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_shallow_tree, list(shallow_tree))

    # Using dict.
    input_tree = {"a": ((2, 2), (3, 3)), "b": ((4, 9), (5, 5))}
    shallow_tree = {"a": (True, True), "b": (False, True)}
    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
    self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)])
    self.assertEqual(flattened_shallow_tree, [True, True, False, True])
Exemple #13
0
    def testFlattenUpTo(self):
        input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5)))
        shallow_tree = ((True, True), (False, True))
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9),
                                                (5, 5)])
        self.assertEqual(flattened_shallow_tree, [True, True, False, True])

        input_tree = ((("a", 1), (("b", 2), (("c", 3), (("d", 4))))))
        shallow_tree = (("level_1", ("level_2", ("level_3", ("level_4")))))
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        input_tree_flattened = nest.flatten(input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1),
                                                                ("b", 2),
                                                                ("c", 3),
                                                                ("d", 4)])
        self.assertEqual(input_tree_flattened,
                         ["a", 1, "b", 2, "c", 3, "d", 4])

        ## Shallow non-list edge-case.
        # Using iterable elements.
        input_tree = ["input_tree"]
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = ("input_tree_0", "input_tree_1")
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = (0, )
        shallow_tree = 9
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = (0, 1)
        shallow_tree = 9
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Both non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = 0
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Input non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = ("shallow_tree", )
        expected_message = (
            "If shallow structure is a sequence, input must also "
            "be a sequence. Input has type: <(type|class) 'str'>.")
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, list(shallow_tree))

        input_tree = "input_tree"
        shallow_tree = ("shallow_tree_9", "shallow_tree_8")
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, list(shallow_tree))

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = (9, )
        expected_message = (
            "If shallow structure is a sequence, input must also "
            "be a sequence. Input has type: <(type|class) 'int'>.")
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, list(shallow_tree))

        input_tree = 0
        shallow_tree = (9, 8)
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, list(shallow_tree))

        # Using dict.
        input_tree = {"a": ((2, 2), (3, 3)), "b": ((4, 9), (5, 5))}
        shallow_tree = {"a": (True, True), "b": (False, True)}
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9),
                                                (5, 5)])
        self.assertEqual(flattened_shallow_tree, [True, True, False, True])
Exemple #14
0
def normalize_element(element, element_signature=None):
    """Normalizes a nested structure of element components.

  * Components matching `SparseTensorSpec` are converted to `SparseTensor`.
  * Components matching `RaggedTensorSpec` are converted to `RaggedTensor`.
  * Components matching `DatasetSpec` or `TensorArraySpec` are passed through.
  * `CompositeTensor` components are passed through.
  * All other components are converted to `Tensor`.

  Args:
    element: A nested structure of individual components.
    element_signature: (Optional.) A nested structure of `tf.DType` objects
      corresponding to each component of `element`. If specified, it will be
      used to set the exact type of output tensor when converting input
      components which are not tensors themselves (e.g. numpy arrays, native
      python types, etc.)

  Returns:
    A nested structure of `Tensor`, `Dataset`, `SparseTensor`, `RaggedTensor`,
    or `TensorArray` objects.
  """
    normalized_components = []
    if element_signature is None:
        components = nest.flatten(element)
        flattened_signature = [None] * len(components)
        pack_as = element
    else:
        flattened_signature = nest.flatten(element_signature)
        components = nest.flatten_up_to(element_signature, element)
        pack_as = element_signature
    with ops.name_scope("normalize_element"):
        # Imported here to avoid circular dependency.
        from tensorflow.python.data.ops import dataset_ops  # pylint: disable=g-import-not-at-top
        for i, (t, spec) in enumerate(zip(components, flattened_signature)):
            try:
                if spec is None:
                    spec = type_spec_from_value(t, use_fallback=False)
            except TypeError:
                # TypeError indicates it was not possible to compute a `TypeSpec` for
                # the value. As a fallback try converting the value to a tensor.
                normalized_components.append(
                    ops.convert_to_tensor(t, name="component_%d" % i))
            else:
                if isinstance(spec, sparse_tensor.SparseTensorSpec):
                    normalized_components.append(
                        sparse_tensor.SparseTensor.from_value(t))
                elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
                    normalized_components.append(
                        ragged_tensor.convert_to_tensor_or_ragged_tensor(
                            t, name="component_%d" % i))
                elif isinstance(spec, (tensor_array_ops.TensorArraySpec,
                                       dataset_ops.DatasetSpec)):
                    normalized_components.append(t)
                elif isinstance(spec, NoneTensorSpec):
                    normalized_components.append(NoneTensor())
                elif isinstance(t, composite_tensor.CompositeTensor):
                    normalized_components.append(t)
                else:
                    dtype = getattr(spec, "dtype", None)
                    normalized_components.append(
                        ops.convert_to_tensor(t,
                                              name="component_%d" % i,
                                              dtype=dtype))
    return nest.pack_sequence_as(pack_as, normalized_components)
Exemple #15
0
 def flatten_to_numpy(self, data):
     flat_data = nest.flatten_up_to(self._input_types, data)
     return [
         np.reshape(np.array(x, dtype=t.as_numpy_dtype), s)
         for x, s, t in zip(flat_data, self._flat_shapes, self._flat_types)
     ]
Exemple #16
0
  def __init__(self,
               dataset,
               output_types,
               output_shapes=None,
               output_classes=None,
               allow_unsafe_cast=False):
    """Creates a new dataset with the given output types and shapes.

    The given `dataset` must have a structure that is convertible:
    * `dataset.output_types` must be the same as `output_types` module nesting.
    * Each shape in `dataset.output_shapes` must be compatible with each shape
      in `output_shapes` (if given).

    Note: This helper permits "unsafe casts" for shapes, equivalent to using
    `tf.Tensor.set_shape()` where domain-specific knowledge is available.

    Args:
      dataset: A `Dataset` object.
      output_types: A nested structure of `tf.DType` objects.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
        If omitted, the shapes will be inherited from `dataset`.
      output_classes: (Optional.) A nested structure of class types.
        If omitted, the class types will be inherited from `dataset`.
      allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
        reported output types and shapes of the restructured dataset, e.g. to
        switch a sparse tensor represented as `tf.variant` to its user-visible
        type and shape.

    Raises:
      ValueError: If either `output_types` or `output_shapes` is not compatible
        with the structure of `dataset`.
    """
    self._input_dataset = dataset

    input_types = dataset_ops.get_legacy_output_types(dataset)
    if not allow_unsafe_cast:
      # Validate that the types are compatible.
      output_types = nest.map_structure(dtypes.as_dtype, output_types)
      flat_original_types = nest.flatten(input_types)
      flat_new_types = nest.flatten(output_types)
      if flat_original_types != flat_new_types:
        raise ValueError(
            "Dataset with output types %r cannot be restructured to have "
            "output types %r" %
            (dataset_ops.get_legacy_output_types(dataset), output_types))

    input_shapes = dataset_ops.get_legacy_output_shapes(dataset)
    if output_shapes is None:
      # Inherit shapes from the original `dataset`.
      output_shapes = nest.pack_sequence_as(
          output_types, nest.flatten(input_shapes))
    else:
      if not allow_unsafe_cast:
        # Validate that the shapes are compatible.
        nest.assert_same_structure(output_types, output_shapes)
        flat_original_shapes = nest.flatten(input_shapes)
        flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)

        for original_shape, new_shape in zip(flat_original_shapes,
                                             flat_new_shapes):
          if not original_shape.is_compatible_with(new_shape):
            raise ValueError(
                "Dataset with output shapes %r cannot be restructured to have "
                "incompatible output shapes %r" % (input_shapes,
                                                   output_shapes))
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)

    input_classes = dataset_ops.get_legacy_output_classes(dataset)
    if output_classes is None:
      # Inherit class types from the original `dataset`.
      output_classes = nest.pack_sequence_as(
          output_types, nest.flatten(input_classes))

    self._structure = structure.convert_legacy_structure(
        output_types, output_shapes, output_classes)
    variant_tensor = self._input_dataset._variant_tensor  # pylint: disable=protected-access
    super(_RestructuredDataset, self).__init__(dataset, variant_tensor)