def __init__(self, input_dataset, num_replicas, use_fallback=True): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): """Recalculates the output_shapes after dividing it by num_replicas.""" if len(output_shapes) < 1: raise ValueError( "Input shape should have at least one dimension. " "Perhaps your input dataset is not batched?") output_dims = [d.value for d in output_shapes.dims] if output_dims[ 0] is not None and output_dims[0] % num_replicas == 0: output_dims[0] = output_dims[0] // num_replicas else: # Set the batch dimension to unknown. If the global batch size does not # divide num_replicas evenly, the minibatches may have different sizes. output_dims[0] = None return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_shapes = dataset_ops.get_legacy_output_shapes( self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes( self._input_dataset) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) self._element_spec = structure.convert_legacy_structure( input_types, output_shapes, input_classes) variant_tensor = ged_ops.rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_replicas=num_replicas, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, num_replicas, use_fallback=True): def recalculate_batch_size(output_shapes): """Recalculates the output_shapes after dividing it by num_replicas.""" # If the output shape is unknown, we set the batch dimension to unknown. if output_shapes.rank is None: return None if len(output_shapes) < 1: raise ValueError( "Input shape should have at least one dimension. " "Perhaps your input dataset is not batched?") output_dims = [d.value for d in output_shapes.dims] if output_dims[ 0] is not None and output_dims[0] % num_replicas == 0: return output_dims[0] // num_replicas # Set the batch dimension to unknown. If the global batch size does not # divide num_replicas evenly, the minibatches may have different sizes. return None def rebatch(type_spec): # pylint: disable=protected-access batch_size = recalculate_batch_size( type_spec._to_legacy_output_shapes()) return type_spec._unbatch()._batch(batch_size) # pylint: enable=protected-access self._element_spec = nest.map_structure( rebatch, dataset_ops.get_structure(input_dataset)) variant_tensor = ged_ops.rebatch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access num_replicas=num_replicas, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, num_workers): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): """Recalculates the output_shapes after dividing it by num_workers.""" if len(output_shapes) < 1: raise ValueError( "Input shape should have at least one dimension. " "Perhaps your input dataset is not batched?") output_dims = [d for d in output_shapes.dims] output_dims[0] = (output_dims[0] + num_workers - 1) // num_workers return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) self._element_spec = structure.convert_legacy_structure( input_types, output_shapes, input_classes) if compat.forward_compatible(2019, 8, 3): variant_tensor = ged_ops.rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **self._flat_structure) else: variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, num_replicas): """Creates a _LegacyRebatchDataset. Args: input_dataset: `Dataset` to rebatch. num_replicas: A `tf.int64` scalar, representing the number of sub-batches to split each batch from `input_dataset` into. """ def recalculate_batch_size(type_spec): """Recalculates the output_shape after dividing it by num_replicas.""" output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access if not isinstance(output_shape, tensor_shape.TensorShape): return None # If the output shape is unknown, we set the batch dimension to unknown. if output_shape.rank is None: return None if len(output_shape) < 1: raise ValueError( "Invalid `input_dataset`. Expected a dataset whose elements " "have rank >= 1 but found a dataset whose elements are scalars. " "Fix the issue by adding the `batch` transformation to the " "dataset.") output_dims = [d.value for d in output_shape.dims] if output_dims[ 0] is not None and output_dims[0] % num_replicas == 0: return output_dims[0] // num_replicas # Set the batch dimension to unknown. If the global batch size does not # divide num_replicas evenly, the minibatches may have different sizes. return None def rebatch(type_spec): # pylint: disable=protected-access batch_size = recalculate_batch_size(type_spec) return type_spec._unbatch()._batch(batch_size) # pylint: enable=protected-access self._element_spec = nest.map_structure( rebatch, dataset_ops.get_structure(input_dataset)) # auto_shard rewrite assumes that there's normalize_to_dense before # rebatch_dataset. # LINT.IfChange input_dataset = dataset_ops.normalize_to_dense(input_dataset) variant_tensor = ged_ops.rebatch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access num_replicas=num_replicas, **self._flat_structure) # LINT.ThenChange(//tensorflow/core/grappler/optimizers/data/auto_shard.cc) super(_LegacyRebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, num_replicas, use_fallback=True): def recalculate_batch_size(type_spec): """Recalculates the output_shape after dividing it by num_replicas.""" output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access if not isinstance(output_shape, tensor_shape.TensorShape): return None # If the output shape is unknown, we set the batch dimension to unknown. if output_shape.rank is None: return None if len(output_shape) < 1: raise ValueError( "Expected a dataset whose elements have rank >= 1 " "but found a dataset whose elements are scalars. " "You can fix the issue by adding the `batch` " "transformation to the dataset.") output_dims = [d.value for d in output_shape.dims] if output_dims[ 0] is not None and output_dims[0] % num_replicas == 0: return output_dims[0] // num_replicas # Set the batch dimension to unknown. If the global batch size does not # divide num_replicas evenly, the minibatches may have different sizes. return None def rebatch(type_spec): # pylint: disable=protected-access batch_size = recalculate_batch_size(type_spec) return type_spec._unbatch()._batch(batch_size) # pylint: enable=protected-access self._element_spec = nest.map_structure( rebatch, dataset_ops.get_structure(input_dataset)) input_dataset = dataset_ops.normalize_to_dense(input_dataset) variant_tensor = ged_ops.rebatch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access num_replicas=num_replicas, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, num_workers): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): """Recalculates the output_shapes after dividing it by num_workers.""" if len(output_shapes) < 1: raise ValueError( "Input shape should have at least one dimension.") if (tensor_shape.dimension_value(output_shapes[0]) and tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0): raise errors.InvalidArgumentError( None, None, "First dim of input shape: %d is not divisible by num_workers: %d" % (output_shapes[0], num_workers)) output_dims = [d for d in output_shapes.dims] output_dims[0] = output_dims[0] // num_workers return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_shapes = dataset_ops.get_legacy_output_shapes( self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes( self._input_dataset) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) self._structure = structure.convert_legacy_structure( input_types, output_shapes, input_classes) if compat.forward_compatible(2019, 8, 3): variant_tensor = ged_ops.rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **self._flat_structure) else: variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)