def __init__(self, input_dataset, num_workers): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): """Recalculates the output_shapes after dividing it by num_workers.""" if len(output_shapes) < 1: raise ValueError( "Input shape should have at least one dimension.") if (tensor_shape.dimension_value(output_shapes[0]) and tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0): raise errors.InvalidArgumentError( None, None, "First dim of input shape: %d is not divisible by num_workers: %d" % (output_shapes[0], num_workers)) output_dims = [d for d in output_shapes.dims] output_dims[0] = output_dims[0] // num_workers return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_shapes = dataset_ops.get_legacy_output_shapes( self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes( self._input_dataset) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) self._structure = structure.convert_legacy_structure( input_types, output_shapes, input_classes) variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **dataset_ops.flat_structure(self)) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, num_workers): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): """Recalculates the output_shapes after dividing it by num_workers.""" if len(output_shapes) < 1: raise ValueError( "Input shape should have at least one dimension. " "Perhaps your input dataset is not batched?") output_dims = [d for d in output_shapes.dims] output_dims[0] = (output_dims[0] + num_workers - 1) // num_workers return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) self._structure = structure.convert_legacy_structure( input_types, output_shapes, input_classes) variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, num_workers): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): """Recalculates the output_shapes after dividing it by num_workers.""" if len(output_shapes) < 1: raise ValueError("Input shape should have at least one dimension.") if (tensor_shape.dimension_value(output_shapes[0]) and tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0): raise errors.InvalidArgumentError( None, None, "First dim of input shape: %d is not divisible by num_workers: %d" % (output_shapes[0], num_workers)) output_dims = [d for d in output_shapes.dims] output_dims[0] = output_dims[0] // num_workers return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) self._structure = structure.convert_legacy_structure( input_types, output_shapes, input_classes) variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **dataset_ops.flat_structure(self)) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, num_workers, use_fallback=True): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): """Recalculates the output_shapes after dividing it by num_workers.""" if len(output_shapes) < 1: raise ValueError( "Input shape should have at least one dimension. " "Perhaps your input dataset is not batched?") output_dims = [d.value for d in output_shapes.dims] if output_dims[0] is not None and output_dims[0] % num_workers == 0: output_dims[0] = output_dims[0] // num_workers else: # Set the batch dimension to unknown. If the global batch size does not # divide num_workers evenly, the minibatches may have different sizes. output_dims[0] = None return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_shapes = dataset_ops.get_legacy_output_shapes( self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes( self._input_dataset) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) self._element_spec = structure.convert_legacy_structure( input_types, output_shapes, input_classes) if compat.forward_compatible(2019, 8, 13) or not use_fallback: variant_tensor = ged_ops.rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, use_fallback=use_fallback, **self._flat_structure) elif compat.forward_compatible(2019, 8, 3): variant_tensor = ged_ops.rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **self._flat_structure) else: variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, num_workers): self._input_dataset = input_dataset output_shapes = input_dataset.output_shapes if len(output_shapes) < 1: raise ValueError("Input shape should have at least one dimension.") if not output_shapes.dims[0].value: raise ValueError("Cannot rebatch unknown batch size datasets.") if output_shapes.dims[0].value % num_workers != 0: raise ValueError( "First dim of input shape: %d is not divisible by num_workers: %d" % (output_shapes[0], num_workers)) output_dims = [d for d in output_shapes.dims] output_dims[0] = output_dims[0] // num_workers output_shapes = tensor_shape.TensorShapeV1(output_dims) self._structure = structure.convert_legacy_structure( self._input_dataset.output_types, output_shapes, self._input_dataset.output_classes) variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **dataset_ops.flat_structure(self)) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)