Exemple #1
0
    def __init__(self, input_dataset, batch_sizes, drop_remainder=False):
        """Creates a _RebatchDataset.

    Args:
      input_dataset: `Dataset` to rebatch.
      batch_sizes: A `tf.int64` scalar or vector, representing the size of
        batches to produce. If this argument is a vector, these values are
        cycled through in order.
      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
        whether the last batch should be dropped in the case it has fewer than
        `batch_sizes[cycle_index] elements; the default behavior is not to drop
        the smaller batch.
    """
        self._input_dataset = input_dataset
        self._batch_sizes = ops.convert_to_tensor(batch_sizes,
                                                  dtype=dtypes.int64,
                                                  name="batch_sizes")
        self._drop_remainder = ops.convert_to_tensor(drop_remainder,
                                                     dtype=dtypes.bool,
                                                     name="drop_remainder")
        new_batch_dim = self._compute_static_batch_dim()

        # pylint: disable=protected-access
        self._element_spec = nest.map_structure(
            lambda ts: ts._unbatch()._batch(new_batch_dim),
            dataset_ops.get_structure(input_dataset))
        # pylint: enable=protected-access

        input_dataset = dataset_ops.normalize_to_dense(input_dataset)
        variant_tensor = ged_ops.rebatch_dataset_v2(
            input_dataset._variant_tensor,  # pylint: disable=protected-access
            batch_sizes=batch_sizes,
            drop_remainder=drop_remainder,
            **self._flat_structure)
        super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
Exemple #2
0
    def __init__(self, input_dataset, num_replicas):
        """Creates a _LegacyRebatchDataset.

    Args:
      input_dataset: `Dataset` to rebatch.
      num_replicas: A `tf.int64` scalar, representing the number of sub-batches
        to split each batch from `input_dataset` into.
    """
        def recalculate_batch_size(type_spec):
            """Recalculates the output_shape after dividing it by num_replicas."""
            output_shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
            if not isinstance(output_shape, tensor_shape.TensorShape):
                return None

            # If the output shape is unknown, we set the batch dimension to unknown.
            if output_shape.rank is None:
                return None

            if len(output_shape) < 1:
                raise ValueError(
                    "Invalid `input_dataset`. Expected a dataset whose elements "
                    "have rank >= 1 but found a dataset whose elements are scalars. "
                    "Fix the issue by adding the `batch` transformation to the "
                    "dataset.")
            output_dims = [d.value for d in output_shape.dims]

            if output_dims[
                    0] is not None and output_dims[0] % num_replicas == 0:
                return output_dims[0] // num_replicas

            # Set the batch dimension to unknown. If the global batch size does not
            # divide num_replicas evenly, the minibatches may have different sizes.
            return None

        def rebatch(type_spec):
            # pylint: disable=protected-access
            batch_size = recalculate_batch_size(type_spec)
            return type_spec._unbatch()._batch(batch_size)
            # pylint: enable=protected-access

        self._element_spec = nest.map_structure(
            rebatch, dataset_ops.get_structure(input_dataset))

        # auto_shard rewrite assumes that there's normalize_to_dense before
        # rebatch_dataset.
        # LINT.IfChange
        input_dataset = dataset_ops.normalize_to_dense(input_dataset)
        variant_tensor = ged_ops.rebatch_dataset(
            input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_replicas=num_replicas,
            **self._flat_structure)
        # LINT.ThenChange(//tensorflow/core/grappler/optimizers/data/auto_shard.cc)
        super(_LegacyRebatchDataset, self).__init__(input_dataset,
                                                    variant_tensor)
Exemple #3
0
    def __init__(self, input_dataset, num_replicas, use_fallback=True):
        def recalculate_batch_size(type_spec):
            """Recalculates the output_shape after dividing it by num_replicas."""
            output_shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
            if not isinstance(output_shape, tensor_shape.TensorShape):
                return None

            # If the output shape is unknown, we set the batch dimension to unknown.
            if output_shape.rank is None:
                return None

            if len(output_shape) < 1:
                raise ValueError(
                    "Expected a dataset whose elements have rank >= 1 "
                    "but found a dataset whose elements are scalars. "
                    "You can fix the issue by adding the `batch` "
                    "transformation to the dataset.")
            output_dims = [d.value for d in output_shape.dims]

            if output_dims[
                    0] is not None and output_dims[0] % num_replicas == 0:
                return output_dims[0] // num_replicas

            # Set the batch dimension to unknown. If the global batch size does not
            # divide num_replicas evenly, the minibatches may have different sizes.
            return None

        def rebatch(type_spec):
            # pylint: disable=protected-access
            batch_size = recalculate_batch_size(type_spec)
            return type_spec._unbatch()._batch(batch_size)
            # pylint: enable=protected-access

        self._element_spec = nest.map_structure(
            rebatch, dataset_ops.get_structure(input_dataset))
        input_dataset = dataset_ops.normalize_to_dense(input_dataset)
        variant_tensor = ged_ops.rebatch_dataset(
            input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_replicas=num_replicas,
            **self._flat_structure)
        super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
    def __init__(self, input_dataset, num_replicas, use_fallback=True):
        def recalculate_batch_size(output_shape):
            """Recalculates the output_shape after dividing it by num_replicas."""
            # If the output shape is unknown, we set the batch dimension to unknown.
            if output_shape.rank is None:
                return None

            if len(output_shape) < 1:
                raise ValueError(
                    "Input shape should have at least one dimension. "
                    "Perhaps your input dataset is not batched?")
            output_dims = [d.value for d in output_shape.dims]

            if output_dims[
                    0] is not None and output_dims[0] % num_replicas == 0:
                return output_dims[0] // num_replicas

            # Set the batch dimension to unknown. If the global batch size does not
            # divide num_replicas evenly, the minibatches may have different sizes.
            return None

        def rebatch(type_spec):
            # pylint: disable=protected-access
            batch_size = recalculate_batch_size(
                type_spec._to_legacy_output_shapes())
            return type_spec._unbatch()._batch(batch_size)
            # pylint: enable=protected-access

        self._element_spec = nest.map_structure(
            rebatch, dataset_ops.get_structure(input_dataset))
        input_dataset = dataset_ops.normalize_to_dense(input_dataset)
        variant_tensor = ged_ops.rebatch_dataset(
            input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_replicas=num_replicas,
            **self._flat_structure)
        super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)