def __init__(self, input_dataset, batch_size, row_shape): """See `Dataset.dense_to_sparse_batch()` for more details.""" if not isinstance(dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType): raise TypeError( "DenseToSparseDataset requires an input whose elements " "have a single component, whereas the input has %r." % dataset_ops.get_legacy_output_types(input_dataset)) self._input_dataset = input_dataset self._batch_size = batch_size self._row_shape = row_shape self._structure = structure.SparseTensorStructure( dataset_ops.get_legacy_output_types(input_dataset), tensor_shape.vector(None).concatenate(self._row_shape)) if compat.forward_compatible(2019, 8, 3): variant_tensor = ged_ops.dense_to_sparse_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._batch_size, row_shape=convert.partial_shape_to_tensor(self._row_shape), **self._flat_structure) else: variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._batch_size, row_shape=convert.partial_shape_to_tensor(self._row_shape), **self._flat_structure) super(_DenseToSparseBatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, batch_size, row_shape): """See `Dataset.dense_to_sparse_batch()` for more details.""" if not isinstance(input_dataset.output_types, dtypes.DType): raise TypeError("DenseToSparseDataset requires an input whose elements " "have a single component, whereas the input has %r." % input_dataset.output_types) self._input_dataset = input_dataset self._batch_size = batch_size self._row_shape = row_shape self._structure = structure.SparseTensorStructure( input_dataset.output_types, tensor_shape.vector(None).concatenate(self._row_shape)) variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._batch_size, row_shape=convert.partial_shape_to_tensor(self._row_shape), **dataset_ops.flat_structure(self)) super(_DenseToSparseBatchDataset, self).__init__(input_dataset, variant_tensor)
def _as_variant_tensor(self): return ged_ops.experimental_dense_to_sparse_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._batch_size, row_shape=convert.partial_shape_to_tensor(self._row_shape), **dataset_ops.flat_structure(self))
def _as_variant_tensor(self): return ged_ops.experimental_dense_to_sparse_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._batch_size, row_shape=convert.partial_shape_to_tensor(self._row_shape), **dataset_ops.flat_structure(self))