Example #1
0
    def __init__(self, input_dataset, batch_size, row_shape):
        """See `Dataset.dense_to_sparse_batch()` for more details."""
        if not isinstance(dataset_ops.get_legacy_output_types(input_dataset),
                          dtypes.DType):
            raise TypeError(
                "DenseToSparseDataset requires an input whose elements "
                "have a single component, whereas the input has %r." %
                dataset_ops.get_legacy_output_types(input_dataset))
        self._input_dataset = input_dataset
        self._batch_size = batch_size
        self._row_shape = row_shape
        self._structure = structure.SparseTensorStructure(
            dataset_ops.get_legacy_output_types(input_dataset),
            tensor_shape.vector(None).concatenate(self._row_shape))

        if compat.forward_compatible(2019, 8, 3):
            variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._batch_size,
                row_shape=convert.partial_shape_to_tensor(self._row_shape),
                **self._flat_structure)
        else:
            variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._batch_size,
                row_shape=convert.partial_shape_to_tensor(self._row_shape),
                **self._flat_structure)
        super(_DenseToSparseBatchDataset,
              self).__init__(input_dataset, variant_tensor)
Example #2
0
  def __init__(self, input_dataset, batch_size, row_shape):
    """See `Dataset.dense_to_sparse_batch()` for more details."""
    if not isinstance(
        dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType):
      raise TypeError("`dense_to_sparse_batch` requires an input dataset whose "
                      "elements have a single component, but the given dataset "
                      "has the following component types: "
                      f"{dataset_ops.get_legacy_output_types(input_dataset)}.")
    self._input_dataset = input_dataset
    self._batch_size = batch_size
    self._row_shape = row_shape
    self._element_spec = sparse_tensor.SparseTensorSpec(
        tensor_shape.TensorShape([None]).concatenate(self._row_shape),
        dataset_ops.get_legacy_output_types(input_dataset))

    variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        self._batch_size,
        row_shape=convert.partial_shape_to_tensor(self._row_shape),
        **self._flat_structure)
    super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
                                                     variant_tensor)