Esempio n. 1
0
    def tf_key_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                            input_dataset.output_classes)
      for arg, shape in zip(args, nest.flatten(dense_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types, input_dataset.output_shapes,
          input_dataset.output_classes)
      # pylint: disable=protected-access
      if dataset_ops._should_unpack_args(nested_args):
        ret = key_func(*nested_args)
      # pylint: enable=protected-access
      else:
        ret = key_func(nested_args)
      ret = ops.convert_to_tensor(ret)
      if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar():
        raise ValueError(
            "`key_func` must return a single tf.int64 tensor. "
            "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access
      return ret
Esempio n. 2
0
    def tf_map_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                            input_dataset.output_classes)
      for arg, shape in zip(args, nest.flatten(dense_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types, input_dataset.output_shapes,
          input_dataset.output_classes)
      if dataset_ops._should_unpack_args(nested_args):  # pylint: disable=protected-access
        dataset = map_func(*nested_args)
      else:
        dataset = map_func(nested_args)

      if not isinstance(dataset, dataset_ops.Dataset):
        raise TypeError("`map_func` must return a `Dataset` object.")

      self._output_classes = dataset.output_classes
      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes

      return dataset._as_variant_tensor()  # pylint: disable=protected-access
Esempio n. 3
0
        def tf_key_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            # Pass in shape information from the input_dataset.
            dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                                  input_dataset.output_classes)
            for arg, shape in zip(args, nest.flatten(dense_shapes)):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(input_dataset.output_types,
                                                args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, input_dataset.output_types,
                input_dataset.output_shapes, input_dataset.output_classes)
            # pylint: disable=protected-access
            if dataset_ops._should_unpack_args(nested_args):
                ret = key_func(*nested_args)
            # pylint: enable=protected-access
            else:
                ret = key_func(nested_args)
            ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
            if ret.dtype != dtypes.int64:
                raise ValueError(
                    "`key_func` must return a single tf.int64 tensor.")
            dataset_ops._warn_if_collections(
                "tf.contrib.data.group_by_window()")  # pylint: disable=protected-access
            return ret
Esempio n. 4
0
        def tf_key_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            # Pass in shape information from the input_dataset.
            dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                                  input_dataset.output_classes)
            for arg, shape in zip(args, nest.flatten(dense_shapes)):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(input_dataset.output_types,
                                                args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, input_dataset.output_types,
                input_dataset.output_shapes, input_dataset.output_classes)
            # pylint: disable=protected-access
            if dataset_ops._should_unpack_args(nested_args):
                ret = key_func(*nested_args)
            # pylint: enable=protected-access
            else:
                ret = key_func(nested_args)
            ret = ops.convert_to_tensor(ret)
            if ret.dtype != dtypes.int64 or ret.get_shape(
            ) != tensor_shape.scalar():
                raise ValueError(
                    "`key_func` must return a single tf.int64 tensor. "
                    "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
            return ret
Esempio n. 5
0
  def __init__(self, input_dataset, row_splits_dtype):
    """Constructs a new _DenseToRaggedDataset.

    Args:
      input_dataset: The dataset whose tf.Tensor elements should be made ragged.
      row_splits_dtype: The dtype that should be used for the `row_splits` of
        any new ragged tensors.  Existing `tf.RaggedTensor` elements do *not*
        have their row_splits dtype changed.
    """
    # Replace each TensorSpec in the input dataset's structure with a
    # corresponding RaggedTensorSpec.
    def to_ragged_spec(spec):
      """Returns the new spec based on RaggedTensors."""
      if (not isinstance(spec, tensor_spec.TensorSpec) or
          spec.shape.rank is None or
          spec.shape.is_fully_defined()):
        return spec
      else:
        ragged_rank = max([
            axis for (axis, size) in enumerate(spec.shape.as_list())
            if size is None
        ])
        return ragged_tensor.RaggedTensorSpec(
            shape=spec.shape,
            dtype=spec.dtype,
            ragged_rank=ragged_rank,
            row_splits_dtype=row_splits_dtype)

    self._structure = nest.map_structure(to_ragged_spec,
                                         input_dataset.element_spec)

    # Replace each tf.Tensor value in the input dataset with a variant-encoded
    # RaggedTensor. Since we're updating the corresponding structure to be
    # a RaggedTensorSpec, this variant-encoded tensor will be decoded with
    # RaggedTensorSpec._from_tensor_list.
    def to_ragged_variant(value):
      """Re-encode Tensors as RaggedTensors."""
      if (not isinstance(value, ops.Tensor) or
          value.shape.rank is None or
          value.shape.is_fully_defined()):
        return value
      else:
        spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
        if spec._ragged_rank > 0:  # pylint: disable=protected-access
          value = ragged_tensor.RaggedTensor.from_tensor(
              value, ragged_rank=spec._ragged_rank)  # pylint: disable=protected-access
        return spec._to_tensor_list(value)[0]  # pylint: disable=protected-access

    # Tuples are automatically unpacked by `dataset.map` so we repack them.
    if dataset_ops._should_unpack_args(input_dataset.element_spec):  # pylint: disable=protected-access
      map_fn = lambda *value: nest.map_structure(to_ragged_variant, value)
    else:
      map_fn = lambda value: nest.map_structure(to_ragged_variant, value)

    self._mapped_dataset = input_dataset.map(map_fn)

    variant = self._mapped_dataset._variant_tensor  # pylint: disable=protected-access
    super(_DenseToRaggedDataset, self).__init__(input_dataset, variant)
Esempio n. 6
0
    def __init__(self, input_dataset, row_splits_dtype):
        """Constructs a new _DenseToRaggedDataset.

    Args:
      input_dataset: The dataset whose tf.Tensor elements should be made ragged.
      row_splits_dtype: The dtype that should be used for the `row_splits` of
        any new ragged tensors.  Existing `tf.RaggedTensor` elements do *not*
        have their row_splits dtype changed.
    """

        # Replace each TensorSpec in the input dataset's structure with a
        # corresponding RaggedTensorSpec.
        def to_ragged_spec(spec):
            if isinstance(spec,
                          tensor_spec.TensorSpec) and spec.shape.ndims != 0:
                return ragged_tensor.RaggedTensorSpec(
                    shape=spec.shape,
                    dtype=spec.dtype,
                    ragged_rank=0,
                    row_splits_dtype=row_splits_dtype)
            else:
                return spec

        self._structure = nest.map_structure(to_ragged_spec,
                                             input_dataset.element_spec)

        # Replace each tf.Tensor value in the input dataset with a variant-encoded
        # RaggedTensor.  Since we're updating the corresponding structure to be
        # a RaggedTensorSpec, this variant-encoded tensor will be decoded with
        # RaggedTensorSpec._from_tensor_list.
        def to_ragged_variant(value):
            if isinstance(value, ops.Tensor) and value.shape.ndims != 0:
                spec = to_ragged_spec(
                    tensor_spec.TensorSpec.from_tensor(value))
                return spec._to_tensor_list(value)[0]  # pylint: disable=protected-access
            else:
                return value

        # Tuples are automatically unpacked by `dataset.map` so we repack them.
        if dataset_ops._should_unpack_args(input_dataset.element_spec):  # pylint: disable=protected-access
            map_fn = lambda *value: nest.map_structure(to_ragged_variant, value
                                                       )
        else:
            map_fn = lambda value: nest.map_structure(to_ragged_variant, value)

        self._mapped_dataset = input_dataset.map(map_fn)

        variant = self._mapped_dataset._variant_tensor  # pylint: disable=protected-access
        super(_DenseToRaggedDataset, self).__init__(input_dataset, variant)
Esempio n. 7
0
 def tf_key_func(*args):
   """A wrapper for Defun that facilitates shape inference."""
   # Pass in shape information from the input_dataset.
   for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
     arg.set_shape(shape)
   nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
   # pylint: disable=protected-access
   if dataset_ops._should_unpack_args(nested_args):
     ret = key_func(*nested_args)
   # pylint: enable=protected-access
   else:
     ret = key_func(nested_args)
   ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
   if ret.dtype != dtypes.int64:
     raise ValueError("`key_func` must return a single tf.int64 tensor.")
   return ret
Esempio n. 8
0
 def tf_key_func(*args):
   """A wrapper for Defun that facilitates shape inference."""
   # Pass in shape information from the input_dataset.
   for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
     arg.set_shape(shape)
   nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
   # pylint: disable=protected-access
   if dataset_ops._should_unpack_args(nested_args):
     ret = key_func(*nested_args)
   # pylint: enable=protected-access
   else:
     ret = key_func(nested_args)
   ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
   if ret.dtype != dtypes.int64:
     raise ValueError("`key_func` must return a single tf.int64 tensor.")
   return ret