Пример #1
0
    def __init__(self, input_dataset, batch_size, row_shape):
        """See `Dataset.dense_to_sparse_batch()` for more details."""
        if not isinstance(dataset_ops.get_legacy_output_types(input_dataset),
                          dtypes.DType):
            raise TypeError(
                "DenseToSparseDataset requires an input whose elements "
                "have a single component, whereas the input has %r." %
                dataset_ops.get_legacy_output_types(input_dataset))
        self._input_dataset = input_dataset
        self._batch_size = batch_size
        self._row_shape = row_shape
        self._structure = structure.SparseTensorStructure(
            dataset_ops.get_legacy_output_types(input_dataset),
            tensor_shape.vector(None).concatenate(self._row_shape))

        if compat.forward_compatible(2019, 8, 3):
            variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._batch_size,
                row_shape=convert.partial_shape_to_tensor(self._row_shape),
                **self._flat_structure)
        else:
            variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._batch_size,
                row_shape=convert.partial_shape_to_tensor(self._row_shape),
                **self._flat_structure)
        super(_DenseToSparseBatchDataset,
              self).__init__(input_dataset, variant_tensor)
Пример #2
0
  def __init__(self, input_dataset, batch_size, row_shape):
    """See `Dataset.dense_to_sparse_batch()` for more details."""
    if not isinstance(input_dataset.output_types, dtypes.DType):
      raise TypeError("DenseToSparseDataset requires an input whose elements "
                      "have a single component, whereas the input has %r." %
                      input_dataset.output_types)
    self._input_dataset = input_dataset
    self._batch_size = batch_size
    self._row_shape = row_shape
    self._structure = structure.SparseTensorStructure(
        input_dataset.output_types,
        tensor_shape.vector(None).concatenate(self._row_shape))

    variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        self._batch_size,
        row_shape=convert.partial_shape_to_tensor(self._row_shape),
        **dataset_ops.flat_structure(self))
    super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
                                                     variant_tensor)
Пример #3
0
 def _as_variant_tensor(self):
     return ged_ops.experimental_dense_to_sparse_batch_dataset(
         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
         self._batch_size,
         row_shape=convert.partial_shape_to_tensor(self._row_shape),
         **dataset_ops.flat_structure(self))
Пример #4
0
 def _as_variant_tensor(self):
   return ged_ops.experimental_dense_to_sparse_batch_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._batch_size,
       row_shape=convert.partial_shape_to_tensor(self._row_shape),
       **dataset_ops.flat_structure(self))