示例#1
0
    def __init__(self, input_dataset, num_replicas, use_fallback=True):
        self._input_dataset = input_dataset

        def recalculate_output_shapes(output_shapes):
            """Recalculates the output_shapes after dividing it by num_replicas."""
            if len(output_shapes) < 1:
                raise ValueError(
                    "Input shape should have at least one dimension. "
                    "Perhaps your input dataset is not batched?")
            output_dims = [d.value for d in output_shapes.dims]

            if output_dims[
                    0] is not None and output_dims[0] % num_replicas == 0:
                output_dims[0] = output_dims[0] // num_replicas
            else:
                # Set the batch dimension to unknown. If the global batch size does not
                # divide num_replicas evenly, the minibatches may have different sizes.
                output_dims[0] = None
            return tensor_shape.TensorShape(output_dims)

        input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
        input_shapes = dataset_ops.get_legacy_output_shapes(
            self._input_dataset)
        input_classes = dataset_ops.get_legacy_output_classes(
            self._input_dataset)
        output_shapes = nest.map_structure(recalculate_output_shapes,
                                           input_shapes)

        self._element_spec = structure.convert_legacy_structure(
            input_types, output_shapes, input_classes)
        variant_tensor = ged_ops.rebatch_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_replicas=num_replicas,
            **self._flat_structure)
        super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#2
0
    def __init__(self, input_dataset, num_replicas, use_fallback=True):
        def recalculate_batch_size(output_shapes):
            """Recalculates the output_shapes after dividing it by num_replicas."""
            # If the output shape is unknown, we set the batch dimension to unknown.
            if output_shapes.rank is None:
                return None

            if len(output_shapes) < 1:
                raise ValueError(
                    "Input shape should have at least one dimension. "
                    "Perhaps your input dataset is not batched?")
            output_dims = [d.value for d in output_shapes.dims]

            if output_dims[
                    0] is not None and output_dims[0] % num_replicas == 0:
                return output_dims[0] // num_replicas

            # Set the batch dimension to unknown. If the global batch size does not
            # divide num_replicas evenly, the minibatches may have different sizes.
            return None

        def rebatch(type_spec):
            # pylint: disable=protected-access
            batch_size = recalculate_batch_size(
                type_spec._to_legacy_output_shapes())
            return type_spec._unbatch()._batch(batch_size)
            # pylint: enable=protected-access

        self._element_spec = nest.map_structure(
            rebatch, dataset_ops.get_structure(input_dataset))
        variant_tensor = ged_ops.rebatch_dataset(
            input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_replicas=num_replicas,
            **self._flat_structure)
        super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#3
0
  def __init__(self, input_dataset, num_workers):
    self._input_dataset = input_dataset

    def recalculate_output_shapes(output_shapes):
      """Recalculates the output_shapes after dividing it by num_workers."""
      if len(output_shapes) < 1:
        raise ValueError(
            "Input shape should have at least one dimension. "
            "Perhaps your input dataset is not batched?")
      output_dims = [d for d in output_shapes.dims]
      output_dims[0] = (output_dims[0] + num_workers - 1) // num_workers
      return tensor_shape.TensorShape(output_dims)

    input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
    input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
    input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
    output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)

    self._element_spec = structure.convert_legacy_structure(
        input_types, output_shapes, input_classes)
    if compat.forward_compatible(2019, 8, 3):
      variant_tensor = ged_ops.rebatch_dataset(
          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
          num_workers=num_workers,
          **self._flat_structure)
    else:
      variant_tensor = ged_ops.experimental_rebatch_dataset(
          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
          num_workers=num_workers,
          **self._flat_structure)
    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#4
0
    def __init__(self, input_dataset, num_replicas):
        """Creates a _LegacyRebatchDataset.

    Args:
      input_dataset: `Dataset` to rebatch.
      num_replicas: A `tf.int64` scalar, representing the number of sub-batches
        to split each batch from `input_dataset` into.
    """
        def recalculate_batch_size(type_spec):
            """Recalculates the output_shape after dividing it by num_replicas."""
            output_shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
            if not isinstance(output_shape, tensor_shape.TensorShape):
                return None

            # If the output shape is unknown, we set the batch dimension to unknown.
            if output_shape.rank is None:
                return None

            if len(output_shape) < 1:
                raise ValueError(
                    "Invalid `input_dataset`. Expected a dataset whose elements "
                    "have rank >= 1 but found a dataset whose elements are scalars. "
                    "Fix the issue by adding the `batch` transformation to the "
                    "dataset.")
            output_dims = [d.value for d in output_shape.dims]

            if output_dims[
                    0] is not None and output_dims[0] % num_replicas == 0:
                return output_dims[0] // num_replicas

            # Set the batch dimension to unknown. If the global batch size does not
            # divide num_replicas evenly, the minibatches may have different sizes.
            return None

        def rebatch(type_spec):
            # pylint: disable=protected-access
            batch_size = recalculate_batch_size(type_spec)
            return type_spec._unbatch()._batch(batch_size)
            # pylint: enable=protected-access

        self._element_spec = nest.map_structure(
            rebatch, dataset_ops.get_structure(input_dataset))

        # auto_shard rewrite assumes that there's normalize_to_dense before
        # rebatch_dataset.
        # LINT.IfChange
        input_dataset = dataset_ops.normalize_to_dense(input_dataset)
        variant_tensor = ged_ops.rebatch_dataset(
            input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_replicas=num_replicas,
            **self._flat_structure)
        # LINT.ThenChange(//tensorflow/core/grappler/optimizers/data/auto_shard.cc)
        super(_LegacyRebatchDataset, self).__init__(input_dataset,
                                                    variant_tensor)
示例#5
0
    def __init__(self, input_dataset, num_replicas, use_fallback=True):
        def recalculate_batch_size(type_spec):
            """Recalculates the output_shape after dividing it by num_replicas."""
            output_shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
            if not isinstance(output_shape, tensor_shape.TensorShape):
                return None

            # If the output shape is unknown, we set the batch dimension to unknown.
            if output_shape.rank is None:
                return None

            if len(output_shape) < 1:
                raise ValueError(
                    "Expected a dataset whose elements have rank >= 1 "
                    "but found a dataset whose elements are scalars. "
                    "You can fix the issue by adding the `batch` "
                    "transformation to the dataset.")
            output_dims = [d.value for d in output_shape.dims]

            if output_dims[
                    0] is not None and output_dims[0] % num_replicas == 0:
                return output_dims[0] // num_replicas

            # Set the batch dimension to unknown. If the global batch size does not
            # divide num_replicas evenly, the minibatches may have different sizes.
            return None

        def rebatch(type_spec):
            # pylint: disable=protected-access
            batch_size = recalculate_batch_size(type_spec)
            return type_spec._unbatch()._batch(batch_size)
            # pylint: enable=protected-access

        self._element_spec = nest.map_structure(
            rebatch, dataset_ops.get_structure(input_dataset))
        input_dataset = dataset_ops.normalize_to_dense(input_dataset)
        variant_tensor = ged_ops.rebatch_dataset(
            input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_replicas=num_replicas,
            **self._flat_structure)
        super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#6
0
    def __init__(self, input_dataset, num_workers):
        self._input_dataset = input_dataset

        def recalculate_output_shapes(output_shapes):
            """Recalculates the output_shapes after dividing it by num_workers."""
            if len(output_shapes) < 1:
                raise ValueError(
                    "Input shape should have at least one dimension.")
            if (tensor_shape.dimension_value(output_shapes[0])
                    and tensor_shape.dimension_value(output_shapes[0]) %
                    num_workers != 0):
                raise errors.InvalidArgumentError(
                    None, None,
                    "First dim of input shape: %d is not divisible by num_workers: %d"
                    % (output_shapes[0], num_workers))
            output_dims = [d for d in output_shapes.dims]
            output_dims[0] = output_dims[0] // num_workers
            return tensor_shape.TensorShape(output_dims)

        input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
        input_shapes = dataset_ops.get_legacy_output_shapes(
            self._input_dataset)
        input_classes = dataset_ops.get_legacy_output_classes(
            self._input_dataset)
        output_shapes = nest.map_structure(recalculate_output_shapes,
                                           input_shapes)

        self._structure = structure.convert_legacy_structure(
            input_types, output_shapes, input_classes)
        if compat.forward_compatible(2019, 8, 3):
            variant_tensor = ged_ops.rebatch_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                num_workers=num_workers,
                **self._flat_structure)
        else:
            variant_tensor = ged_ops.experimental_rebatch_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                num_workers=num_workers,
                **self._flat_structure)
        super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)