def testScaleSplitToInfeedGPU(self, use_per_host_infeed, split_size): with cluster_factory.ForTestingWorker( gpus=128, split_size=split_size) as cluster: num_splits = 128 // split_size self.assertEqual(cluster.num_splits_per_client, num_splits) self.assertEqual( batch_utils.scale_split_to_infeed(1024, use_per_host_infeed), 1024 * num_splits)
def testScaleSplitToInfeedTPU(self, use_per_host_infeed, split_size, num_tpu_hosts): with cluster_factory.ForTestingWorker( tpus=128, split_size=split_size, num_tpu_hosts=num_tpu_hosts) as cluster: num_splits = 128 // split_size num_infeeds = num_tpu_hosts if use_per_host_infeed else 1 self.assertEqual(cluster.num_splits_per_client, num_splits) self.assertEqual( batch_utils.scale_split_to_infeed(1024, use_per_host_infeed), 1024 * num_splits // num_infeeds)
def infeed_bucket_batch_limit(self): """Returns the bucket batch limit for one infeed host.""" p = self.params cluster = self.cluster infeed_bucket_batch_limit = [ batch_utils.scale_split_to_infeed(b, p.use_per_host_infeed) for b in p.bucket_batch_limit ] tf.logging.info( 'infeed_bucket_batch_limit={} num_splits_per_client={} bucket_batch_limit={}' .format(infeed_bucket_batch_limit, cluster.num_splits_per_client, p.bucket_batch_limit)) return infeed_bucket_batch_limit
def Transform(self, dataset): """Batches a dataset containing NestedMaps of tensors.""" p = self.params require_sequential_order = p.require_sequential_order or self.do_eval seqlen_fn = getattr(self._input_generator, p.seqlen_fn) def SetBucketKeys(example): example.bucket_keys = seqlen_fn(example) return example dataset = dataset.map(SetBucketKeys, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=require_sequential_order) dataset = dataset.filter( lambda x: x.bucket_keys <= p.bucket_upper_bound[-1]) dataset_structure = py_utils.NestedMap.FromNestedDict( tf.data.experimental.get_structure(dataset)) input_shape_fn = getattr(self._input_generator, p.input_shape_fn) padded_shapes = dataset_structure.TransformWithKey( lambda k, _: tf.TensorShape(input_shape_fn(k))) input_padding_fn = getattr(self._input_generator, p.input_padding_fn) padding_values = dataset_structure.TransformWithKey(input_padding_fn) dataset_structure.VLog(0, 'dataset_structure:') padded_shapes.VLog(0, 'padded_shapes:') bucket_batch_limit = [ batch_utils.scale_split_to_infeed( b, self._input_generator.params.use_per_host_infeed) for b in p.bucket_batch_limit ] dataset = dataset.apply( tf.data.experimental.bucket_by_sequence_length( lambda x: x.bucket_keys, # Upper-bound for bucket_by_sequence_length is exclusive, so add 1 # TODO(jeffreyzhao): There is a off-by-one bug with the upper bound # boundary check, so add 2 instead. Remove when fixed. [x + 2 for x in p.bucket_upper_bound], bucket_batch_limit + [1], padded_shapes=padded_shapes, padding_values=padding_values, pad_to_bucket_boundary=True, drop_remainder=py_utils.use_tpu())) if py_utils.use_tpu(): # Set static shapes for TPU. if min(bucket_batch_limit) != max(bucket_batch_limit): raise ValueError('TPU requires constant batch sizes.') else: b = bucket_batch_limit[0] def SetShape(element): for t in element.Flatten(): t.set_shape((b, ) + t.shape[1:]) return element dataset = dataset.map( SetShape, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=require_sequential_order) return dataset
def InfeedBatchSize(self): """Returns the batch size of the input batch: int or dynamic int tensor.""" batch_per_input = batch_utils.scale_split_to_infeed( self.params.batch_size, self.params.use_per_host_infeed) tf.logging.info('batch_per_input: %d', batch_per_input) return batch_per_input
def Transform(self, dataset): """Batches a dataset containing NestedMaps of tensors.""" p = self.params seqlen_fn = getattr(self._input_generator, p.seqlen_fn) def SetBucketKeys(example): example.bucket_keys = seqlen_fn(example) return example dataset = dataset.map(SetBucketKeys, **self._map_args) dataset = dataset.filter( lambda x: x.bucket_keys <= p.bucket_upper_bound[-1]) dataset_structure = py_utils.NestedMap.FromNestedDict( tf.data.experimental.get_structure(dataset)) input_shape_fn = getattr(self._input_generator, p.input_shape_fn) padded_shapes = dataset_structure.TransformWithKey( lambda k, _: tf.TensorShape(input_shape_fn(k))) input_padding_fn = getattr(self._input_generator, p.input_padding_fn) padding_values = dataset_structure.TransformWithKey(input_padding_fn) dataset_structure.VLog(0, 'dataset_structure:') padded_shapes.VLog(0, 'padded_shapes:') bucket_batch_limit = [ batch_utils.scale_split_to_infeed( b, self._input_generator.params.use_per_host_infeed) for b in p.bucket_batch_limit ] dataset = dataset.apply( tf.data.experimental.bucket_by_sequence_length( lambda x: x.bucket_keys, # Upper-bound for bucket_by_sequence_length is exclusive, so add 1 [x + 1 for x in p.bucket_upper_bound], bucket_batch_limit + [1], padded_shapes=padded_shapes, padding_values=padding_values, pad_to_bucket_boundary=True, drop_remainder=py_utils.use_tpu())) # Set static shapes if possible. if self.cluster.require_sequential_input_order: # When require_sequential_input_order is True the input is not repeated so # only one epoch is available, thus the last batch may be a smaller size. pass elif min(bucket_batch_limit) == max(bucket_batch_limit): b = bucket_batch_limit[0] def SetShape(element): for t in element.Flatten(): t.set_shape((b, ) + t.shape[1:]) return element dataset = dataset.map(SetShape, **self._map_args) elif py_utils.use_tpu(): raise ValueError('TPU requires constant batch sizes.') return dataset