예제 #1
0
    def __init__(self, selector_input, data_inputs):
        self._selector_input = selector_input
        self._data_inputs = list(data_inputs)

        first_output_types = dataset_ops.get_legacy_output_types(
            data_inputs[0])
        first_output_classes = dataset_ops.get_legacy_output_classes(
            data_inputs[0])

        for data_input in data_inputs[1:]:
            if (dataset_ops.get_legacy_output_types(data_input) !=
                    first_output_types
                    or dataset_ops.get_legacy_output_classes(data_input) !=
                    first_output_classes):
                raise TypeError(
                    "All datasets must have the same type and class.")

        output_shapes = dataset_ops.get_legacy_output_shapes(
            self._data_inputs[0])
        for data_input in self._data_inputs[1:]:
            output_shapes = nest.pack_sequence_as(output_shapes, [
                ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
                    nest.flatten(output_shapes),
                    nest.flatten(
                        dataset_ops.get_legacy_output_shapes(data_input)))
            ])

        self._element_spec = structure.convert_legacy_structure(
            first_output_types, output_shapes, first_output_classes)
        super(_DirectedInterleaveDataset, self).__init__()
예제 #2
0
    def __init__(self, selector_input, data_inputs):
        self._selector_input = selector_input
        self._data_inputs = list(data_inputs)

        first_output_types = dataset_ops.get_legacy_output_types(
            data_inputs[0])
        first_output_classes = dataset_ops.get_legacy_output_classes(
            data_inputs[0])

        for data_input in data_inputs[1:]:
            if (dataset_ops.get_legacy_output_types(data_input) !=
                    first_output_types
                    or dataset_ops.get_legacy_output_classes(data_input) !=
                    first_output_classes):
                raise TypeError(
                    "All datasets must have the same type and class.")

        output_shapes = dataset_ops.get_legacy_output_shapes(
            self._data_inputs[0])
        for data_input in self._data_inputs[1:]:
            output_shapes = nest.pack_sequence_as(output_shapes, [
                ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
                    nest.flatten(output_shapes),
                    nest.flatten(
                        dataset_ops.get_legacy_output_shapes(data_input)))
            ])

        self._element_spec = structure.convert_legacy_structure(
            first_output_types, output_shapes, first_output_classes)
        # pylint: disable=protected-access
        variant_tensor = gen_experimental_dataset_ops.directed_interleave_dataset(
            self._selector_input._variant_tensor,
            [data_input._variant_tensor for data_input in self._data_inputs],
            **self._flat_structure)
        super(_DirectedInterleaveDataset, self).__init__(variant_tensor)
예제 #3
0
    def __init__(self,
                 selector_input,
                 data_inputs,
                 stop_on_empty_dataset=False):
        self._selector_input = selector_input
        self._data_inputs = list(data_inputs)
        self._stop_on_empty_dataset = stop_on_empty_dataset

        first_output_types = dataset_ops.get_legacy_output_types(
            data_inputs[0])
        first_output_classes = dataset_ops.get_legacy_output_classes(
            data_inputs[0])

        for i, data_input in enumerate(data_inputs[1:]):
            if (dataset_ops.get_legacy_output_types(data_input) !=
                    first_output_types
                    or dataset_ops.get_legacy_output_classes(data_input) !=
                    first_output_classes):
                raise TypeError(
                    "All datasets must have the same type and class.\n"
                    "dataset 0 vs dataset %s types: %s ; %s\n"
                    "classes: %s ; %s" %
                    (i + 1, first_output_types,
                     dataset_ops.get_legacy_output_types(data_input),
                     first_output_classes,
                     dataset_ops.get_legacy_output_classes(data_input)))

        output_shapes = dataset_ops.get_legacy_output_shapes(
            self._data_inputs[0])
        for data_input in self._data_inputs[1:]:
            output_shapes = nest.pack_sequence_as(output_shapes, [
                ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
                    nest.flatten(output_shapes),
                    nest.flatten(
                        dataset_ops.get_legacy_output_shapes(data_input)))
            ])
        self._element_spec = structure.convert_legacy_structure(
            first_output_types, output_shapes, first_output_classes)

        compat_kwargs = {}
        if compat.forward_compatible(2021, 5,
                                     14) or self._stop_on_empty_dataset:
            compat_kwargs[
                "stop_on_empty_dataset"] = self._stop_on_empty_dataset

        # pylint: disable=protected-access
        variant_tensor = (
            gen_experimental_dataset_ops.directed_interleave_dataset(
                self._selector_input._variant_tensor, [
                    data_input._variant_tensor
                    for data_input in self._data_inputs
                ], **compat_kwargs, **self._flat_structure))

        super(_DirectedInterleaveDataset, self).__init__(variant_tensor)
예제 #4
0
  def _apply_fn(dataset):
    """Function from `Dataset` to `Dataset` that applies the transformation."""
    # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
    # are normalized to the rank-1 dense representation, so that the
    # sparse-oblivious unbatching logic will slice them
    # appropriately. This leads to a somewhat inefficient re-encoding step
    # for all SparseTensor components.
    # TODO(mrry): Consider optimizing this in future if it turns out to be
    # a bottleneck.
    def normalize(arg, *rest):
      # pylint: disable=protected-access
      if rest:
        return dataset._element_structure._to_batched_tensor_list((arg,) + rest)
      else:
        return dataset._element_structure._to_batched_tensor_list(arg)

    normalized_dataset = dataset.map(normalize)

    # NOTE(mrry): Our `map()` has lost information about the sparseness
    # of any SparseTensor components, so re-apply the structure of the
    # original dataset.
    restructured_dataset = _RestructuredDataset(
        normalized_dataset,
        dataset_ops.get_legacy_output_types(dataset),
        dataset_ops.get_legacy_output_shapes(dataset),
        dataset_ops.get_legacy_output_classes(dataset),
        allow_unsafe_cast=True)
    return _UnbatchDataset(restructured_dataset)
예제 #5
0
def padded_batch_window(dataset, padded_shape, padding_value=None):
  """Batches a window of tensors with padding.

  Args:
    dataset: the input dataset.
    padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like
      object representing the shape to which the input elements should be padded
      prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a
      `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the
      maximum size of that dimension in each batch.
    padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the
      padding value to use. Defaults are `0` for numeric types and the empty
      string for string types. If `dataset` contains `tf.SparseTensor`, this
      value is ignored.

  Returns:
    A `Tensor` representing the batch of the entire input dataset.

  Raises:
    ValueError: if invalid arguments are provided.
  """
  dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset)
  if not issubclass(dataset_output_classes,
                    (ops.Tensor, sparse_tensor.SparseTensor)):
    raise TypeError("Input dataset expected to have a single tensor component")
  if issubclass(dataset_output_classes, (ops.Tensor)):
    return _padded_batch_dense_window(dataset, padded_shape, padding_value)
  elif issubclass(dataset_output_classes, (sparse_tensor.SparseTensor)):
    if padding_value is not None:
      raise ValueError("Padding value not allowed for sparse tensors")
    return _padded_batch_sparse_window(dataset, padded_shape)
  else:
    raise TypeError("Unsupported dataset type: %s" % dataset_output_classes)
예제 #6
0
  def __init__(self, input_dataset, num_workers):
    self._input_dataset = input_dataset

    def recalculate_output_shapes(output_shapes):
      """Recalculates the output_shapes after dividing it by num_workers."""
      if len(output_shapes) < 1:
        raise ValueError(
            "Input shape should have at least one dimension. "
            "Perhaps your input dataset is not batched?")
      output_dims = [d for d in output_shapes.dims]
      output_dims[0] = (output_dims[0] + num_workers - 1) // num_workers
      return tensor_shape.TensorShape(output_dims)

    input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
    input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
    input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
    output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)

    self._structure = structure.convert_legacy_structure(
        input_types, output_shapes, input_classes)
    variant_tensor = ged_ops.experimental_rebatch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        num_workers=num_workers,
        **self._flat_structure)
    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
예제 #7
0
    def __init__(self, input_dataset, num_replicas, use_fallback=True):
        self._input_dataset = input_dataset

        def recalculate_output_shapes(output_shapes):
            """Recalculates the output_shapes after dividing it by num_replicas."""
            if len(output_shapes) < 1:
                raise ValueError(
                    "Input shape should have at least one dimension. "
                    "Perhaps your input dataset is not batched?")
            output_dims = [d.value for d in output_shapes.dims]

            if output_dims[
                    0] is not None and output_dims[0] % num_replicas == 0:
                output_dims[0] = output_dims[0] // num_replicas
            else:
                # Set the batch dimension to unknown. If the global batch size does not
                # divide num_replicas evenly, the minibatches may have different sizes.
                output_dims[0] = None
            return tensor_shape.TensorShape(output_dims)

        input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
        input_shapes = dataset_ops.get_legacy_output_shapes(
            self._input_dataset)
        input_classes = dataset_ops.get_legacy_output_classes(
            self._input_dataset)
        output_shapes = nest.map_structure(recalculate_output_shapes,
                                           input_shapes)

        self._element_spec = structure.convert_legacy_structure(
            input_types, output_shapes, input_classes)
        variant_tensor = ged_ops.rebatch_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_replicas=num_replicas,
            **self._flat_structure)
        super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
예제 #8
0
def padded_batch_window(dataset, padded_shape, padding_value=None):
  """Batches a window of tensors with padding.

  Args:
    dataset: the input dataset.
    padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like
      object representing the shape to which the input elements should be padded
      prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a
      `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the
      maximum size of that dimension in each batch.
    padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the
      padding value to use. Defaults are `0` for numeric types and the empty
      string for string types. If `dataset` contains `tf.SparseTensor`, this
      value is ignored.

  Returns:
    A `Tensor` representing the batch of the entire input dataset.

  Raises:
    ValueError: if invalid arguments are provided.
  """
  dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset)
  if not issubclass(dataset_output_classes,
                    (ops.Tensor, sparse_tensor.SparseTensor)):
    raise TypeError("Input dataset expected to have a single tensor component")
  if issubclass(dataset_output_classes, (ops.Tensor)):
    return _padded_batch_dense_window(dataset, padded_shape, padding_value)
  elif issubclass(dataset_output_classes, (sparse_tensor.SparseTensor)):
    if padding_value is not None:
      raise ValueError("Padding value not allowed for sparse tensors")
    return _padded_batch_sparse_window(dataset, padded_shape)
  else:
    raise TypeError("Unsupported dataset type: %s" % dataset_output_classes)
예제 #9
0
  def __init__(self, input_dataset, num_workers):
    self._input_dataset = input_dataset

    def recalculate_output_shapes(output_shapes):
      """Recalculates the output_shapes after dividing it by num_workers."""
      if len(output_shapes) < 1:
        raise ValueError("Input shape should have at least one dimension.")
      if (tensor_shape.dimension_value(output_shapes[0]) and
          tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
        raise errors.InvalidArgumentError(
            None, None,
            "First dim of input shape: %d is not divisible by num_workers: %d" %
            (output_shapes[0], num_workers))
      output_dims = [d for d in output_shapes.dims]
      output_dims[0] = output_dims[0] // num_workers
      return tensor_shape.TensorShape(output_dims)

    input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
    input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
    input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
    output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)

    self._structure = structure.convert_legacy_structure(
        input_types, output_shapes, input_classes)
    variant_tensor = ged_ops.experimental_rebatch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        num_workers=num_workers,
        **dataset_ops.flat_structure(self))
    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
예제 #10
0
    def _apply_fn(dataset):
        """Function from `Dataset` to `Dataset` that applies the transformation."""

        # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
        # are normalized to the rank-1 dense representation, so that the
        # sparse-oblivious unbatching logic will slice them
        # appropriately. This leads to a somewhat inefficient re-encoding step
        # for all SparseTensor components.
        # TODO(mrry): Consider optimizing this in future if it turns out to be
        # a bottleneck.
        def normalize(arg, *rest):
            # pylint: disable=protected-access
            if rest:
                return dataset._element_structure._to_batched_tensor_list(
                    (arg, ) + rest)
            else:
                return dataset._element_structure._to_batched_tensor_list(arg)

        normalized_dataset = dataset.map(normalize)

        # NOTE(mrry): Our `map()` has lost information about the sparseness
        # of any SparseTensor components, so re-apply the structure of the
        # original dataset.
        restructured_dataset = _RestructuredDataset(
            normalized_dataset,
            dataset_ops.get_legacy_output_types(dataset),
            dataset_ops.get_legacy_output_shapes(dataset),
            dataset_ops.get_legacy_output_classes(dataset),
            allow_unsafe_cast=True)
        return _UnbatchDataset(restructured_dataset)
예제 #11
0
    def __init__(self, input_dataset):
        """See `unbatch()` for more details."""
        input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
        flat_shapes = nest.flatten(input_shapes)
        if any(s.ndims == 0 for s in flat_shapes):
            raise ValueError("Cannot unbatch an input with scalar components.")
        known_batch_dim = tensor_shape.Dimension(None)
        for s in flat_shapes:
            try:
                known_batch_dim = known_batch_dim.merge_with(s[0])
            except ValueError:
                raise ValueError(
                    "Cannot unbatch an input whose components have "
                    "different batch sizes.")
        self._input_dataset = input_dataset

        self._structure = structure.convert_legacy_structure(
            dataset_ops.get_legacy_output_types(input_dataset),
            nest.map_structure(lambda s: s[1:], input_shapes),
            dataset_ops.get_legacy_output_classes(input_dataset))

        variant_tensor = ged_ops.experimental_unbatch_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            **dataset_ops.flat_structure(self))
        super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
예제 #12
0
  def __init__(self, input_dataset, num_workers):
    self._input_dataset = input_dataset

    def recalculate_output_shapes(output_shapes):
      """Recalculates the output_shapes after dividing it by num_workers."""
      if len(output_shapes) < 1:
        raise ValueError("Input shape should have at least one dimension.")
      if (tensor_shape.dimension_value(output_shapes[0]) and
          tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
        raise errors.InvalidArgumentError(
            None, None,
            "First dim of input shape: %d is not divisible by num_workers: %d" %
            (output_shapes[0], num_workers))
      output_dims = [d for d in output_shapes.dims]
      output_dims[0] = output_dims[0] // num_workers
      return tensor_shape.TensorShape(output_dims)

    input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
    input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
    input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
    output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)

    self._structure = structure.convert_legacy_structure(
        input_types, output_shapes, input_classes)
    variant_tensor = ged_ops.experimental_rebatch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        num_workers=num_workers,
        **dataset_ops.flat_structure(self))
    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
예제 #13
0
 def _apply_fn(dataset):
   output_shapes = _merge_output_shapes(
       dataset_ops.get_legacy_output_shapes(dataset), expected_shapes)
   # pylint: disable=protected-access
   return batching._RestructuredDataset(
       dataset.map(_check_shape),
       dataset_ops.get_legacy_output_types(dataset),
       output_shapes=output_shapes,
       output_classes=dataset_ops.get_legacy_output_classes(dataset))
예제 #14
0
 def _apply_fn(dataset):
     output_shapes = _merge_output_shapes(
         dataset_ops.get_legacy_output_shapes(dataset), expected_shapes)
     # pylint: disable=protected-access
     return batching._RestructuredDataset(
         dataset.map(_check_shape),
         dataset_ops.get_legacy_output_types(dataset),
         output_shapes=output_shapes,
         output_classes=dataset_ops.get_legacy_output_classes(dataset))
예제 #15
0
 def testIndefiniteRepeatShapeInference(self):
     dataset = self.make_batch_feature(filenames=self._filenames[0],
                                       label_key="label",
                                       num_epochs=None,
                                       batch_size=32)
     for shape, clazz in zip(
             nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)),
             nest.flatten(dataset_ops.get_legacy_output_classes(dataset))):
         if issubclass(clazz, ops.Tensor):
             self.assertEqual(32, shape[0])
 def testIndefiniteRepeatShapeInference(self):
   dataset = self.make_batch_feature(
       filenames=self.test_filenames[0],
       label_key="label",
       num_epochs=None,
       batch_size=32)
   for shape, clazz in zip(
       nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)),
       nest.flatten(dataset_ops.get_legacy_output_classes(dataset))):
     if issubclass(clazz, ops.Tensor):
       self.assertEqual(32, shape[0])
예제 #17
0
def print_info_data(dataset, print_example=True, n_example=3):
    # function to print data structure/shape about glue tensorflow dataset
    print('# Structure of the data:\n\n   {}'.format(dataset))
    print('\n# Output shape of one entry:\n   {}'.format(dataset_ops.get_legacy_output_shapes(dataset)))
    print('\n# Output types of one entry:\n   {}'.format(dataset_ops.get_legacy_output_types(dataset)))
    print('\n# Output typesof one entry:\n   {}'.format(dataset_ops.get_legacy_output_classes(dataset)))
    print(' \n')
    np_array = np.array(list(dataset.as_numpy_iterator()))
    print('# Shape of the data:\n\n   {}'.format(np.shape(np_array)))
    if len(np_array) > 0:
        if type(np_array[0]) is dict:
            structure = list(np_array[0].keys())
            print('   ---> {} entries'.format(np.shape(np_array)[0]))
            print('   ---> {} dim'.format(np_array.ndim))
            print('        dict structure')
            print('           dim: {}'.format(len(structure)))
            print('           [{:9} / {:9} / {:9}]'.format(structure[0], structure[1], structure[2]))

            print('           [{:9} / {:9} / {:9}]'.format(str(np.shape(np_array[0].get(structure[0]))),
                                                           str(np.shape(np_array[0].get(structure[1]))),
                                                           str(np.shape(np_array[0].get(structure[2])))))
            print('           [{:9} / {:9} / {:9}]'.format(type(np_array[0].get(structure[0])).__name__,
                                                           type(np_array[0].get(structure[1])).__name__,
                                                           type(np_array[0].get(structure[2])).__name__))

        if type(np_array[0]) is np.ndarray:
            if type(np_array[0][0]) is dict:
                structure = list(np_array[0][0].keys())
                print('   ---> {} batches'.format(np.shape(np_array)[0]))
                print('   ---> {} dim'.format(np_array.ndim))
                print('        label')
                print('           shape: {}'.format(np_array[0][1].shape))
                print('        dict structure')
                print('           dim: {}'.format(len(structure)))
                print('           [{:15} / {:15} / {:15}]'.format(structure[0], structure[1], structure[2]))
                print('           [{:15} / {:15} / {:15}]'.format(str(np_array[0][0].get(structure[0]).shape),
                                                                  str(np_array[0][0].get(structure[1]).shape),
                                                                  str(np_array[0][0].get(structure[2]).shape)))
                print('           [{:15} / {:15} / {:15}]'.format(type(np_array[0][0].get(structure[0])).__name__,
                                                                  type(np_array[0][0].get(structure[1])).__name__,
                                                                  type(np_array[0][0].get(structure[2])).__name__))
            else:
                print('   ---> {} entries'.format(np.shape(np_array)[0]))
                print('   ---> {} dim'.format(np_array.ndim))
                print('           [{:15} / {:15} ]'.format('text', 'label'))
                print('           [{:15} / {:15} ]'.format(str(np_array[0][0].shape), str(np_array[0][1].shape)))
                print('           [{:15} / {:15} ]'.format(str(np_array[0][0].dtype), str(np_array[0][1].dtype)))

    if print_example:
        print('\n\n# Examples of data:')
        for i, ex in enumerate(np_array):
            print('{}'.format(pprint.pformat(ex)))
            if i + 1 > n_example:
                break
예제 #18
0
  def __init__(self, selector_input, data_inputs, stop_on_empty_dataset=False):
    self._selector_input = selector_input
    self._data_inputs = list(data_inputs)
    self._stop_on_empty_dataset = stop_on_empty_dataset

    first_output_types = dataset_ops.get_legacy_output_types(data_inputs[0])
    first_output_classes = dataset_ops.get_legacy_output_classes(data_inputs[0])

    for i, data_input in enumerate(data_inputs[1:]):
      if (dataset_ops.get_legacy_output_types(data_input) != first_output_types
          or dataset_ops.get_legacy_output_classes(data_input) !=
          first_output_classes):
        raise TypeError("All datasets must have the same type and class.\n"
                        "dataset 0 vs dataset %s types: %s ; %s\n"
                        "classes: %s ; %s" %
                        (i + 1, first_output_types,
                         dataset_ops.get_legacy_output_types(data_input),
                         first_output_classes,
                         dataset_ops.get_legacy_output_classes(data_input)))

    spec = self._data_inputs[0].element_spec
    for data_input in self._data_inputs[1:]:
      spec = nest.pack_sequence_as(spec, [
          x.most_specific_compatible_type(y) for (x, y) in zip(
              nest.flatten(spec),
              nest.flatten(data_input.element_spec))
      ])
    self._element_spec = spec

    # pylint: disable=protected-access
    variant_tensor = (
        gen_experimental_dataset_ops.directed_interleave_dataset(
            self._selector_input._variant_tensor,
            [data_input._variant_tensor for data_input in self._data_inputs],
            stop_on_empty_dataset=self._stop_on_empty_dataset,
            **self._flat_structure))

    super(_DirectedInterleaveDataset, self).__init__(variant_tensor)
예제 #19
0
        def _next_func(string_handle):
            """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
            with ops.device(self._source_device_string):
                iterator = iterator_ops.Iterator.from_string_handle(
                    string_handle, dataset_ops.get_legacy_output_types(self),
                    dataset_ops.get_legacy_output_shapes(self),
                    dataset_ops.get_legacy_output_classes(self))
            return self._element_structure._to_tensor_list(iterator.get_next())  # pylint: disable=protected-access
예제 #20
0
  def __init__(self, selector_input, data_inputs):
    self._selector_input = selector_input
    self._data_inputs = list(data_inputs)

    first_output_types = dataset_ops.get_legacy_output_types(data_inputs[0])
    first_output_classes = dataset_ops.get_legacy_output_classes(data_inputs[0])

    for data_input in data_inputs[1:]:
      if (dataset_ops.get_legacy_output_types(data_input) != first_output_types
          or dataset_ops.get_legacy_output_classes(data_input)
          != first_output_classes):
        raise TypeError("All datasets must have the same type and class.")

    output_shapes = dataset_ops.get_legacy_output_shapes(self._data_inputs[0])
    for data_input in self._data_inputs[1:]:
      output_shapes = nest.pack_sequence_as(output_shapes, [
          ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
              nest.flatten(output_shapes),
              nest.flatten(dataset_ops.get_legacy_output_shapes(data_input)))
      ])

    self._structure = structure.convert_legacy_structure(
        first_output_types, output_shapes, first_output_classes)
    super(_DirectedInterleaveDataset, self).__init__()
예제 #21
0
    def _next_func(string_handle):
      """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
      with ops.device(self._source_device_string):
        iterator = iterator_ops.Iterator.from_string_handle(
            string_handle,
            dataset_ops.get_legacy_output_types(self),
            dataset_ops.get_legacy_output_shapes(self),
            dataset_ops.get_legacy_output_classes(self))
      return self._element_structure._to_tensor_list(iterator.get_next())  # pylint: disable=protected-access
    def testIteratorStructure(self, tf_value_fn, expected_element_structure,
                              expected_output_classes, expected_output_types,
                              expected_output_shapes):
        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))
예제 #23
0
def batch_window(dataset):
  """Batches a window of tensors.

  Args:
    dataset: the input dataset.

  Returns:
    A `Tensor` representing the batch of the entire input dataset.
  """
  dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset)
  if isinstance(dataset_output_classes, tuple):
    raise TypeError("Input dataset expected to have a single component")
  if dataset_output_classes is ops.Tensor:
    return _batch_dense_window(dataset)
  elif dataset_output_classes is sparse_tensor.SparseTensor:
    return _batch_sparse_window(dataset)
  else:
    raise TypeError("Unsupported dataset type: %s" % dataset_output_classes)
예제 #24
0
def batch_window(dataset):
  """Batches a window of tensors.

  Args:
    dataset: the input dataset.

  Returns:
    A `Tensor` representing the batch of the entire input dataset.
  """
  dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset)
  if isinstance(dataset_output_classes, tuple):
    raise TypeError("Input dataset expected to have a single component")
  if dataset_output_classes is ops.Tensor:
    return _batch_dense_window(dataset)
  elif dataset_output_classes is sparse_tensor.SparseTensor:
    return _batch_sparse_window(dataset)
  else:
    raise TypeError("Unsupported dataset type: %s" % dataset_output_classes)
예제 #25
0
  def testIteratorStructure(self, tf_value_fn, expected_element_structure,
                            expected_output_classes, expected_output_types,
                            expected_output_shapes):
    tf_value = tf_value_fn()
    iterator = dataset_ops.make_one_shot_iterator(
        dataset_ops.Dataset.from_tensors(tf_value))

    self.assertTrue(expected_element_structure.is_compatible_with(
        iterator._element_structure))
    self.assertTrue(iterator._element_structure.is_compatible_with(
        expected_element_structure))

    self.assertEqual(expected_output_classes,
                     dataset_ops.get_legacy_output_classes(iterator))
    self.assertEqual(expected_output_types,
                     dataset_ops.get_legacy_output_types(iterator))
    self.assertEqual(expected_output_shapes,
                     dataset_ops.get_legacy_output_shapes(iterator))
예제 #26
0
    def testSparseTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return sparse_tensor.SparseTensor(indices=[[0]],
                                              values=constant_op.constant(
                                                  [0], dtype=dtypes.int32),
                                              dense_shape=[1])

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))
예제 #27
0
  def __init__(self, input_dataset):
    """See `unbatch()` for more details."""
    input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
    flat_shapes = nest.flatten(input_shapes)
    if any(s.ndims == 0 for s in flat_shapes):
      raise ValueError("Cannot unbatch an input with scalar components.")
    known_batch_dim = tensor_shape.Dimension(None)
    for s in flat_shapes:
      try:
        known_batch_dim = known_batch_dim.merge_with(s[0])
      except ValueError:
        raise ValueError("Cannot unbatch an input whose components have "
                         "different batch sizes.")
    self._input_dataset = input_dataset

    self._structure = structure.convert_legacy_structure(
        dataset_ops.get_legacy_output_types(input_dataset),
        nest.map_structure(lambda s: s[1:], input_shapes),
        dataset_ops.get_legacy_output_classes(input_dataset))

    variant_tensor = ged_ops.experimental_unbatch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        **dataset_ops.flat_structure(self))
    super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
예제 #28
0
    def testNestedTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return {
                "a": constant_op.constant(37.0),
                "b":
                (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
            }

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))
예제 #29
0
  def __init__(self,
               input_dataset,
               batch_size,
               padded_shapes,
               padding_values,
               drop_remainder,
               ):
    """See `Dataset.batch()` for details."""
    self._input_dataset = input_dataset
    self._batch_size = batch_size
    self._padded_shapes = padded_shapes
    self._padding_values = padding_values
    self._drop_remainder = drop_remainder
    
    def _padded_shape_to_batch_shape(s):
      return tensor_shape.TensorShape([
          tensor_util.constant_value(self._batch_size)
          if smart_cond.smart_constant_value(self._drop_remainder) else None
      ]).concatenate(tensor_util.constant_value_as_shape(s))

    output_shapes = nest.map_structure(
        _padded_shape_to_batch_shape, self._padded_shapes)
    self._structure = structure.convert_legacy_structure(
        ds.get_legacy_output_types(self._input_dataset), output_shapes,
        ds.get_legacy_output_classes(self._input_dataset))

    variant_tensor = gen_dataset_ops.padded_batch_dataset_v2(
          input_dataset._variant_tensor,  # pylint: disable=protected-access
          batch_size=self._batch_size,
          padded_shapes=[ ops.convert_to_tensor(s, dtype=dtypes.int64)
                          for s in nest.flatten(self._padded_shapes)
                        ],
          padding_values=nest.flatten(self._padding_values),
          drop_remainder=self._drop_remainder,
          output_shapes=structure.get_flat_tensor_shapes(self._structure))
    super(TntPaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
 def _get_output_classes(self, ds_fn):
   with ops.Graph().as_default():
     return dataset_ops.get_legacy_output_classes(ds_fn())
예제 #31
0
 def output_classes(self):
     return dataset_ops.get_legacy_output_classes(self._iterator)
예제 #32
0
    def __init__(self,
                 dataset,
                 output_types,
                 output_shapes=None,
                 output_classes=None,
                 allow_unsafe_cast=False):
        """Creates a new dataset with the given output types and shapes.

    The given `dataset` must have a structure that is convertible:
    * `dataset.output_types` must be the same as `output_types` module nesting.
    * Each shape in `dataset.output_shapes` must be compatible with each shape
      in `output_shapes` (if given).

    Note: This helper permits "unsafe casts" for shapes, equivalent to using
    `tf.Tensor.set_shape()` where domain-specific knowledge is available.

    Args:
      dataset: A `Dataset` object.
      output_types: A nested structure of `tf.DType` objects.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
        If omitted, the shapes will be inherited from `dataset`.
      output_classes: (Optional.) A nested structure of class types. If omitted,
        the class types will be inherited from `dataset`.
      allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
        reported output types and shapes of the restructured dataset, e.g. to
        switch a sparse tensor represented as `tf.variant` to its user-visible
        type and shape.

    Raises:
      ValueError: If either `output_types` or `output_shapes` is not compatible
        with the structure of `dataset`.
    """
        self._input_dataset = dataset

        input_types = dataset_ops.get_legacy_output_types(dataset)
        if not allow_unsafe_cast:
            # Validate that the types are compatible.
            output_types = nest.map_structure(dtypes.as_dtype, output_types)
            flat_original_types = nest.flatten(input_types)
            flat_new_types = nest.flatten(output_types)
            if flat_original_types != flat_new_types:
                raise ValueError(
                    "Dataset with output types %r cannot be restructured to have "
                    "output types %r" %
                    (dataset_ops.get_legacy_output_types(dataset),
                     output_types))

        input_shapes = dataset_ops.get_legacy_output_shapes(dataset)
        if output_shapes is None:
            # Inherit shapes from the original `dataset`.
            output_shapes = nest.pack_sequence_as(output_types,
                                                  nest.flatten(input_shapes))
        else:
            if not allow_unsafe_cast:
                # Validate that the shapes are compatible.
                nest.assert_same_structure(output_types, output_shapes)
                flat_original_shapes = nest.flatten(input_shapes)
                flat_new_shapes = nest.flatten_up_to(output_types,
                                                     output_shapes)

                for original_shape, new_shape in zip(flat_original_shapes,
                                                     flat_new_shapes):
                    if not original_shape.is_compatible_with(new_shape):
                        raise ValueError(
                            "Dataset with output shapes %r cannot be restructured to have "
                            "incompatible output shapes %r" %
                            (input_shapes, output_shapes))
            output_shapes = nest.map_structure_up_to(output_types,
                                                     tensor_shape.as_shape,
                                                     output_shapes)

        input_classes = dataset_ops.get_legacy_output_classes(dataset)
        if output_classes is None:
            # Inherit class types from the original `dataset`.
            output_classes = nest.pack_sequence_as(output_types,
                                                   nest.flatten(input_classes))

        self._structure = structure.convert_legacy_structure(
            output_types, output_shapes, output_classes)
        variant_tensor = self._input_dataset._variant_tensor  # pylint: disable=protected-access
        super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
예제 #33
0
  def __init__(self,
               dataset,
               output_types,
               output_shapes=None,
               output_classes=None,
               allow_unsafe_cast=False):
    """Creates a new dataset with the given output types and shapes.

    The given `dataset` must have a structure that is convertible:
    * `dataset.output_types` must be the same as `output_types` module nesting.
    * Each shape in `dataset.output_shapes` must be compatible with each shape
      in `output_shapes` (if given).

    Note: This helper permits "unsafe casts" for shapes, equivalent to using
    `tf.Tensor.set_shape()` where domain-specific knowledge is available.

    Args:
      dataset: A `Dataset` object.
      output_types: A nested structure of `tf.DType` objects.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
        If omitted, the shapes will be inherited from `dataset`.
      output_classes: (Optional.) A nested structure of class types.
        If omitted, the class types will be inherited from `dataset`.
      allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
        reported output types and shapes of the restructured dataset, e.g. to
        switch a sparse tensor represented as `tf.variant` to its user-visible
        type and shape.

    Raises:
      ValueError: If either `output_types` or `output_shapes` is not compatible
        with the structure of `dataset`.
    """
    self._input_dataset = dataset

    input_types = dataset_ops.get_legacy_output_types(dataset)
    if not allow_unsafe_cast:
      # Validate that the types are compatible.
      output_types = nest.map_structure(dtypes.as_dtype, output_types)
      flat_original_types = nest.flatten(input_types)
      flat_new_types = nest.flatten(output_types)
      if flat_original_types != flat_new_types:
        raise ValueError(
            "Dataset with output types %r cannot be restructured to have "
            "output types %r" %
            (dataset_ops.get_legacy_output_types(dataset), output_types))

    input_shapes = dataset_ops.get_legacy_output_shapes(dataset)
    if output_shapes is None:
      # Inherit shapes from the original `dataset`.
      output_shapes = nest.pack_sequence_as(
          output_types, nest.flatten(input_shapes))
    else:
      if not allow_unsafe_cast:
        # Validate that the shapes are compatible.
        nest.assert_same_structure(output_types, output_shapes)
        flat_original_shapes = nest.flatten(input_shapes)
        flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)

        for original_shape, new_shape in zip(flat_original_shapes,
                                             flat_new_shapes):
          if not original_shape.is_compatible_with(new_shape):
            raise ValueError(
                "Dataset with output shapes %r cannot be restructured to have "
                "incompatible output shapes %r" % (input_shapes,
                                                   output_shapes))
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)

    input_classes = dataset_ops.get_legacy_output_classes(dataset)
    if output_classes is None:
      # Inherit class types from the original `dataset`.
      output_classes = nest.pack_sequence_as(
          output_types, nest.flatten(input_classes))

    self._structure = structure.convert_legacy_structure(
        output_types, output_shapes, output_classes)
    variant_tensor = self._input_dataset._variant_tensor  # pylint: disable=protected-access
    super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
예제 #34
0
 def _get_output_classes(self, ds_fn):
     assert not context.executing_eagerly()
     with ops.Graph().as_default():
         return dataset_ops.get_legacy_output_classes(ds_fn())
 def _get_output_classes(self, ds_fn):
     with ops.Graph().as_default():
         return dataset_ops.get_legacy_output_classes(ds_fn())
예제 #36
0
 def output_classes(self):
   return dataset_ops.get_legacy_output_classes(self._iterator)