示例#1
0
    def __init__(self, input_dataset, num_workers):
        self._input_dataset = input_dataset

        def recalculate_output_shapes(output_shapes):
            """Recalculates the output_shapes after dividing it by num_workers."""
            if len(output_shapes) < 1:
                raise ValueError(
                    "Input shape should have at least one dimension.")
            if (tensor_shape.dimension_value(output_shapes[0])
                    and tensor_shape.dimension_value(output_shapes[0]) %
                    num_workers != 0):
                raise errors.InvalidArgumentError(
                    None, None,
                    "First dim of input shape: %d is not divisible by num_workers: %d"
                    % (output_shapes[0], num_workers))
            output_dims = [d for d in output_shapes.dims]
            output_dims[0] = output_dims[0] // num_workers
            return tensor_shape.TensorShape(output_dims)

        input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
        input_shapes = dataset_ops.get_legacy_output_shapes(
            self._input_dataset)
        input_classes = dataset_ops.get_legacy_output_classes(
            self._input_dataset)
        output_shapes = nest.map_structure(recalculate_output_shapes,
                                           input_shapes)

        self._structure = structure.convert_legacy_structure(
            input_types, output_shapes, input_classes)
        variant_tensor = ged_ops.experimental_rebatch_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_workers=num_workers,
            **dataset_ops.flat_structure(self))
        super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#2
0
  def __init__(self, input_dataset, num_workers):
    self._input_dataset = input_dataset

    def recalculate_output_shapes(output_shapes):
      """Recalculates the output_shapes after dividing it by num_workers."""
      if len(output_shapes) < 1:
        raise ValueError(
            "Input shape should have at least one dimension. "
            "Perhaps your input dataset is not batched?")
      output_dims = [d for d in output_shapes.dims]
      output_dims[0] = (output_dims[0] + num_workers - 1) // num_workers
      return tensor_shape.TensorShape(output_dims)

    input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
    input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
    input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
    output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)

    self._structure = structure.convert_legacy_structure(
        input_types, output_shapes, input_classes)
    variant_tensor = ged_ops.experimental_rebatch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        num_workers=num_workers,
        **self._flat_structure)
    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#3
0
  def __init__(self, input_dataset, num_workers):
    self._input_dataset = input_dataset

    def recalculate_output_shapes(output_shapes):
      """Recalculates the output_shapes after dividing it by num_workers."""
      if len(output_shapes) < 1:
        raise ValueError("Input shape should have at least one dimension.")
      if (tensor_shape.dimension_value(output_shapes[0]) and
          tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
        raise errors.InvalidArgumentError(
            None, None,
            "First dim of input shape: %d is not divisible by num_workers: %d" %
            (output_shapes[0], num_workers))
      output_dims = [d for d in output_shapes.dims]
      output_dims[0] = output_dims[0] // num_workers
      return tensor_shape.TensorShape(output_dims)

    input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
    input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
    input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
    output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)

    self._structure = structure.convert_legacy_structure(
        input_types, output_shapes, input_classes)
    variant_tensor = ged_ops.experimental_rebatch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        num_workers=num_workers,
        **dataset_ops.flat_structure(self))
    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#4
0
    def __init__(self, input_dataset, num_workers, use_fallback=True):
        self._input_dataset = input_dataset

        def recalculate_output_shapes(output_shapes):
            """Recalculates the output_shapes after dividing it by num_workers."""
            if len(output_shapes) < 1:
                raise ValueError(
                    "Input shape should have at least one dimension. "
                    "Perhaps your input dataset is not batched?")
            output_dims = [d.value for d in output_shapes.dims]

            if output_dims[0] is not None and output_dims[0] % num_workers == 0:
                output_dims[0] = output_dims[0] // num_workers
            else:
                # Set the batch dimension to unknown. If the global batch size does not
                # divide num_workers evenly, the minibatches may have different sizes.
                output_dims[0] = None
            return tensor_shape.TensorShape(output_dims)

        input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
        input_shapes = dataset_ops.get_legacy_output_shapes(
            self._input_dataset)
        input_classes = dataset_ops.get_legacy_output_classes(
            self._input_dataset)
        output_shapes = nest.map_structure(recalculate_output_shapes,
                                           input_shapes)

        self._element_spec = structure.convert_legacy_structure(
            input_types, output_shapes, input_classes)
        if compat.forward_compatible(2019, 8, 13) or not use_fallback:
            variant_tensor = ged_ops.rebatch_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                num_workers=num_workers,
                use_fallback=use_fallback,
                **self._flat_structure)
        elif compat.forward_compatible(2019, 8, 3):
            variant_tensor = ged_ops.rebatch_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                num_workers=num_workers,
                **self._flat_structure)
        else:
            variant_tensor = ged_ops.experimental_rebatch_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                num_workers=num_workers,
                **self._flat_structure)
        super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#5
0
 def __init__(self, input_dataset, num_workers):
   self._input_dataset = input_dataset
   output_shapes = input_dataset.output_shapes
   if len(output_shapes) < 1:
     raise ValueError("Input shape should have at least one dimension.")
   if not output_shapes.dims[0].value:
     raise ValueError("Cannot rebatch unknown batch size datasets.")
   if output_shapes.dims[0].value % num_workers != 0:
     raise ValueError(
         "First dim of input shape: %d is not divisible by num_workers: %d" %
         (output_shapes[0], num_workers))
   output_dims = [d for d in output_shapes.dims]
   output_dims[0] = output_dims[0] // num_workers
   output_shapes = tensor_shape.TensorShapeV1(output_dims)
   self._structure = structure.convert_legacy_structure(
       self._input_dataset.output_types, output_shapes,
       self._input_dataset.output_classes)
   variant_tensor = ged_ops.experimental_rebatch_dataset(
       self._input_dataset._variant_tensor,  # pylint: disable=protected-access
       num_workers=num_workers,
       **dataset_ops.flat_structure(self))
   super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#6
0
 def __init__(self, input_dataset, num_workers):
   self._input_dataset = input_dataset
   output_shapes = input_dataset.output_shapes
   if len(output_shapes) < 1:
     raise ValueError("Input shape should have at least one dimension.")
   if not output_shapes.dims[0].value:
     raise ValueError("Cannot rebatch unknown batch size datasets.")
   if output_shapes.dims[0].value % num_workers != 0:
     raise ValueError(
         "First dim of input shape: %d is not divisible by num_workers: %d" %
         (output_shapes[0], num_workers))
   output_dims = [d for d in output_shapes.dims]
   output_dims[0] = output_dims[0] // num_workers
   output_shapes = tensor_shape.TensorShapeV1(output_dims)
   self._structure = structure.convert_legacy_structure(
       self._input_dataset.output_types, output_shapes,
       self._input_dataset.output_classes)
   variant_tensor = ged_ops.experimental_rebatch_dataset(
       self._input_dataset._variant_tensor,  # pylint: disable=protected-access
       num_workers=num_workers,
       **dataset_ops.flat_structure(self))
   super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)