示例#1
0
    def __init__(self, input_dataset):
        """See `unbatch()` for more details."""
        input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
        flat_shapes = nest.flatten(input_shapes)
        if any(s.ndims == 0 for s in flat_shapes):
            raise ValueError("Cannot unbatch an input with scalar components.")
        known_batch_dim = tensor_shape.Dimension(None)
        for s in flat_shapes:
            try:
                known_batch_dim = known_batch_dim.merge_with(s[0])
            except ValueError:
                raise ValueError(
                    "Cannot unbatch an input whose components have "
                    "different batch sizes.")
        self._input_dataset = input_dataset

        self._structure = structure.convert_legacy_structure(
            dataset_ops.get_legacy_output_types(input_dataset),
            nest.map_structure(lambda s: s[1:], input_shapes),
            dataset_ops.get_legacy_output_classes(input_dataset))

        variant_tensor = ged_ops.experimental_unbatch_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            **dataset_ops.flat_structure(self))
        super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
示例#2
0
  def __init__(self, input_dataset):
    """See `unbatch()` for more details."""
    input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
    flat_shapes = nest.flatten(input_shapes)
    if any(s.ndims == 0 for s in flat_shapes):
      raise ValueError("Cannot unbatch an input with scalar components.")
    known_batch_dim = tensor_shape.Dimension(None)
    for s in flat_shapes:
      try:
        known_batch_dim = known_batch_dim.merge_with(s[0])
      except ValueError:
        raise ValueError("Cannot unbatch an input whose components have "
                         "different batch sizes.")
    self._input_dataset = input_dataset

    self._structure = dataset_ops.get_structure(input_dataset)._unbatch()  # pylint: disable=protected-access

    variant_tensor = ged_ops.experimental_unbatch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        **dataset_ops.flat_structure(self))
    super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
示例#3
0
 def _as_variant_tensor(self):
     return ged_ops.experimental_unbatch_dataset(
         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
         **dataset_ops.flat_structure(self))
示例#4
0
 def _as_variant_tensor(self):
   return ged_ops.experimental_unbatch_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       **dataset_ops.flat_structure(self))