def __init__(self, input_dataset): """See `unbatch()` for more details.""" input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset) flat_shapes = nest.flatten(input_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") known_batch_dim = tensor_shape.Dimension(None) for s in flat_shapes: try: known_batch_dim = known_batch_dim.merge_with(s[0]) except ValueError: raise ValueError( "Cannot unbatch an input whose components have " "different batch sizes.") self._input_dataset = input_dataset self._structure = structure.convert_legacy_structure( dataset_ops.get_legacy_output_types(input_dataset), nest.map_structure(lambda s: s[1:], input_shapes), dataset_ops.get_legacy_output_classes(input_dataset)) variant_tensor = ged_ops.experimental_unbatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access **dataset_ops.flat_structure(self)) super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset): """See `unbatch()` for more details.""" input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset) flat_shapes = nest.flatten(input_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") known_batch_dim = tensor_shape.Dimension(None) for s in flat_shapes: try: known_batch_dim = known_batch_dim.merge_with(s[0]) except ValueError: raise ValueError("Cannot unbatch an input whose components have " "different batch sizes.") self._input_dataset = input_dataset self._structure = dataset_ops.get_structure(input_dataset)._unbatch() # pylint: disable=protected-access variant_tensor = ged_ops.experimental_unbatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access **dataset_ops.flat_structure(self)) super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
def _as_variant_tensor(self): return ged_ops.experimental_unbatch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access **dataset_ops.flat_structure(self))