def _maybe_apply_data_service( self, dataset: tf.data.Dataset, input_context: Optional[tf.distribute.InputContext] = None ) -> tf.data.Dataset: """Potentially distributes a dataset.""" if self._enable_tf_data_service and input_context: if self._enable_round_robin_tf_data_service: replicas_per_input_pipeline = input_context.num_replicas_in_sync // ( input_context.num_input_pipelines) base_consumer_index = input_context.input_pipeline_id * ( replicas_per_input_pipeline) num_consumers = input_context.num_input_pipelines * ( replicas_per_input_pipeline) range_dataset = tf.data.Dataset.range( replicas_per_input_pipeline) tfds_kwargs = { 'processing_mode': 'parallel_epochs', 'service': self._tf_data_service_address, 'job_name': self._tf_data_service_job_name, 'num_consumers': num_consumers } if self._enable_shared_tf_data_service_between_parallel_trainers: raise ValueError( 'Shared tf.data service does not support round-robin' ' tf.data service.') dataset = range_dataset.map(lambda i: dataset.apply( # pylint: disable=g-long-lambda tf.data.experimental.service. distribute(consumer_index=base_consumer_index + i, **tfds_kwargs))) # Use parallel interleave to read multiple batches from a tf.data # service worker in parallel. dataset = dataset.interleave( lambda x: x, cycle_length=replicas_per_input_pipeline, num_parallel_calls=replicas_per_input_pipeline, deterministic=True) else: tfds_kwargs = { 'processing_mode': 'parallel_epochs', 'service': self._tf_data_service_address, 'job_name': self._tf_data_service_job_name, } if self._enable_shared_tf_data_service_between_parallel_trainers: tfds_kwargs.update({ 'processing_mode': tf.data.experimental.service.ShardingPolicy.OFF, 'cross_trainer_cache': tf.data.experimental.service.CrossTrainerCache( trainer_id=self._trainer_id) }) dataset = dataset.apply( tf.data.experimental.service.distribute(**tfds_kwargs)) return dataset
def apply( dataset: tf.data.Dataset, transform: Union[Transform, Iterable[Optional[Transform]]], ) -> tf.data.Dataset: if callable(transform): return dataset.apply(transform) for tr in transform: if tr is not None: dataset = dataset.apply(tr) return dataset
def get_bucket_iter(self, dataset: tf.data.Dataset, batch_size=32, train=True) -> tf.data.Dataset: padded_shapes = self._padded_shapes() padding_values = self._padding_values() if train: bucket_boundaries = self._bucket_boundaries(batch_size) bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1) dataset = dataset.apply( tf.data.experimental.bucket_by_sequence_length( self.element_length_func, bucket_boundaries, bucket_batch_sizes, padded_shapes=padded_shapes, padding_values=padding_values)) dataset = dataset.shuffle(100) else: dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes) dataset = dataset.map(self._collate_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset
def get_augmented_data( dataset: tf.data.Dataset, batch_size: int, map_func: Callable, shuffle_buffer: Optional[int] = None, shuffle_seed: Optional[int] = None, augment_seed: Optional[int] = None, use_stateless_map: bool = False, ) -> RepeatedData: if shuffle_buffer is not None: dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed) dataset = dataset.batch(batch_size) steps_per_epoch = tf.keras.backend.get_value(dataset.cardinality()) # repeat before map so stateless map is different across epochs dataset = dataset.repeat() AUTOTUNE = tf.data.experimental.AUTOTUNE if use_stateless_map: dataset = dataset.apply( tfrng.data.stateless_map( map_func, seed=augment_seed, num_parallel_calls=AUTOTUNE, )) else: # if map_func has random elements this won't be deterministic dataset = dataset.map(map_func, num_parallel_calls=AUTOTUNE) dataset = dataset.prefetch(AUTOTUNE) return RepeatedData(dataset, steps_per_epoch)
def prepare_dataset(self, dataset: tf.data.Dataset, buckets: List[int], batch_sizes: List[int], shuffle: bool = False) -> tf.data.Dataset: dataset = dataset.map(self._deserialization_func, num_parallel_calls=128) buckets_array = np.array(buckets) batch_sizes_array = np.array(batch_sizes) if np.any(batch_sizes_array == 0) and shuffle: iszero = np.where(batch_sizes_array == 0)[0][0] filterlen = buckets_array[iszero - 1] print("Filtering sequences of length {}".format(filterlen)) dataset = dataset.filter( lambda example: example['protein_length'] < filterlen) else: batch_sizes_array[batch_sizes_array <= 0] = 1 dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024) batch_fun = tf.data.experimental.bucket_by_sequence_length( operator.itemgetter('protein_length'), buckets_array, batch_sizes_array) dataset = dataset.apply(batch_fun) return dataset
def batch_dataset(dataset: tf.data.Dataset, model: LineRecognizer, batch_size=32, bucket_boundaries=None, padded=True): # add image widths and text length dataset = dataset.map(lambda i, t: (i, tf.shape(i)[ 1], t, tf.strings.length(t, unit='UTF8_CHAR'))) dataset = dataset.map(lambda image, width, text, length: (image, width, model.encoder.encode(text), length)) output_shapes = (model.image_shape, [], [None], []) if bucket_boundaries: if isinstance(batch_size, int): batch_size = [batch_size] * (len(bucket_boundaries) + 1) dataset = dataset.apply( tf.data.experimental.bucket_by_sequence_length( lambda i, w, label, length: w, bucket_boundaries=bucket_boundaries, bucket_batch_sizes=batch_size, padded_shapes=output_shapes)) elif padded: dataset = dataset.padded_batch(batch_size=batch_size, padded_shapes=output_shapes) else: dataset = dataset.batch(batch_size) return dataset
def bucket_batch_dataset(dataset: tf.data.Dataset, batches: List[int], boundaries: List[int]) -> tf.data.Dataset: op = tf.contrib.data.bucket_by_sequence_length( element_length_func=lambda x, y: tf.shape(x)[0], bucket_batch_sizes=batches, bucket_boundaries=boundaries) return dataset.apply(op)
def block_diagonal_batch_with_batch_size(dataset: tf.data.Dataset, batch_size: int): """ Batch the input dataset block diagonally up to the given batch size. Args: dataset: tf.data.Dataset with spec ((nodes, (link*)), labels). nodes: [V?, ...] node features. link: [E?, 2] int edge/link indices. labels: [V?, ...] or [...] label data. batch_size: number of examples in the resulting batch. Returns: dataset with spec: nodes: [B, V?, ...] ragged node features. links: [E, 2] indices into flattened nodes. labels: [BV, ...] or [B, ...] B = batch_size BV = sum_b V_b """ dataset = dataset.apply( tf.data.experimental.dense_to_ragged_batch(batch_size, row_splits_dtype=tf.int32)) return dataset.map( lambda *args: _block_diagonalize_batched(*_unpack_dataset(*args)))
def prefetch_to_available_gpu_device(dataset: tf.data.Dataset, buffer_size: int = None, use_workaround: bool = False): gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: assert len( gpus) == 1, 'Expected to find exactly 1 GPU, but found: ' + gpus if use_workaround: dataset = dataset.apply( tf.data.experimental.copy_to_device("/GPU:0")) with tf.device("/GPU:0"): return dataset.prefetch(buffer_size) else: return dataset.apply( tf.data.experimental.prefetch_to_device('/GPU:0', buffer_size)) else: return dataset
def accumulated_batch( dataset: tf.data.Dataset, accumulator: Union[Accumulator, Mapping, Iterable], **map_kwargs, ): accumulator = accumulator_structure(accumulator) def initial_map_fn(*args): if len(args) == 1: (args,) = args return args, False @tf.function def scan_fn(state, el_and_final): el, final = el_and_final new_state = accumulator.append(state, el) valid = tf.reduce_all(accumulator.valid_conditions(new_state)) invalid = tf.logical_not(valid) if invalid: new_state = accumulator.append(accumulator.initial_state(), el) return new_state, (state, tf.logical_or(invalid, final)) def filter_fn(state, invalid): del state return invalid def map_fn(state, invalid): del invalid return accumulator.finalize(state) cardinality = tf.data.experimental.cardinality(dataset) dataset = dataset.map(initial_map_fn) if cardinality != tf.data.experimental.INFINITE_CARDINALITY: # append (empty, True) element to ensure final elements are generated state_spec = dataset.element_spec[0] empty_el = tf.nest.map_structure( lambda spec: tf.zeros( [1, *(0 if s is None else s for s in spec.shape)], dtype=spec.dtype ), state_spec, ) true_el = tf.ones((1,), dtype=tf.bool) dataset = dataset.concatenate( tf.data.Dataset.from_tensor_slices((empty_el, true_el)) ) dataset = dataset.apply( tf.data.experimental.scan(accumulator.initial_state(), scan_fn) ) dataset = dataset.filter(filter_fn) dataset = dataset.map(map_fn, **map_kwargs) return dataset
def batch(self, dataset: tf.data.Dataset) -> tf.data.Dataset: bounds = list(range(self.hist_min, self.hist_max, self.hist_step)) logging.info("Quantile bucketing from %d-%d with %d buckets" % (bounds[0], bounds[-1], len(bounds))) return dataset.apply( ops.bucket_by_quantiles( len_fn=lambda x: tf.shape(x[PREMISE_KEY])[0], batch_size=self.batch_size, n_buckets=self.n_buckets, hist_bounds=bounds))
def prepare_dataset(self, dataset: tf.data.Dataset, buckets: List[int], batch_sizes: List[int], shuffle: bool = False) -> tf.data.Dataset: dataset = dataset.map(self._deserialization_func, 128) dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024) batch_fun = tf.data.experimental.bucket_by_sequence_length( lambda example: tf.maximum(example['first']['protein_length'], example['second']['protein_length']), buckets, batch_sizes) dataset = dataset.apply(batch_fun) return dataset
def _maybe_apply_data_service( self, dataset: tf.data.Dataset, input_context: Optional[tf.distribute.InputContext] = None ) -> tf.data.Dataset: """Potentially distributes a dataset.""" if self._enable_tf_data_service and input_context: if self._enable_round_robin_tf_data_service: replicas_per_input_pipeline = input_context.num_replicas_in_sync // ( input_context.num_input_pipelines) base_consumer_index = input_context.input_pipeline_id * ( replicas_per_input_pipeline) num_consumers = input_context.num_input_pipelines * ( replicas_per_input_pipeline) range_dataset = tf.data.Dataset.range( replicas_per_input_pipeline) dataset = range_dataset.map(lambda i: dataset.apply( # pylint: disable=g-long-lambda tf.data.experimental.service.distribute( processing_mode='parallel_epochs', service=self._tf_data_service_address, job_name=self._tf_data_service_job_name, consumer_index=base_consumer_index + i, num_consumers=num_consumers))) # Use parallel interleave to read multiple batches from a tf.data # service worker in parallel. dataset = dataset.interleave( lambda x: x, cycle_length=replicas_per_input_pipeline, num_parallel_calls=replicas_per_input_pipeline, deterministic=True) else: dataset = dataset.apply( tf.data.experimental.service.distribute( processing_mode='parallel_epochs', service=self._tf_data_service_address, job_name=self._tf_data_service_job_name)) return dataset
def reduce_func(unused_key, window: tf.data.Dataset): # ToDo: use padded_batch instead of padded_batch_and_drop_remainder # Currently this model only works with static batch size apply_fn = tf.contrib.data.padded_batch_and_drop_remainder( batch_size, padded_shapes=(PreparedSourceData( id=tf.TensorShape([]), text=tf.TensorShape([]), source=tf.TensorShape([None]), source_length=tf.TensorShape([]), text_positions=tf.TensorShape([None]), text2=tf.TensorShape([]), source2=tf.TensorShape([None]), source_length2=tf.TensorShape([]), text_positions2=tf.TensorShape([None]), ), _PreparedTargetData( id=tf.TensorShape([]), spec=tf.TensorShape( [None, self.hparams.fft_size // 2 + 1]), spec_width=tf.TensorShape([]), mel=tf.TensorShape( [None, self.hparams.num_mels]), mel_width=tf.TensorShape([]), target_length=tf.TensorShape([]), done=tf.TensorShape([None]), )), padding_values=(PreparedSourceData( id=tf.to_int64(0), text="", source=tf.to_int64(0), source_length=tf.to_int64(0), text_positions=tf.to_int64(0), text2="", source2=tf.to_int64(0), source_length2=tf.to_int64(0), text_positions2=tf.to_int64(0), ), _PreparedTargetData( id=tf.to_int64(0), spec=tf.to_float(0), spec_width=tf.to_int64(0), mel=tf.to_float(0), mel_width=tf.to_int64(0), target_length=tf.to_int64(0), done=tf.to_float(1), ))) return window.apply(apply_fn)
def prepare_dataset( dataset: tf.data.Dataset, model_image_size: Tuple[int, int], augmentation_fn: Optional[ImageDataMapFn] = None, num_epochs: Optional[int] = None, batch_size: Optional[int] = None, shuffle_buffer_size: Optional[int] = None, num_parallel_calls: Optional[int] = None, prefetch_buffer_size: Optional[int] = None, prefetch_to_device: Optional[str] = None, ) -> tf.data.Dataset: # apply data augmentation: if augmentation_fn is not None: dataset = dataset.map( map_image_data(augmentation_fn), num_parallel_calls=num_parallel_calls, ) if shuffle_buffer_size is not None: dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) dataset = dataset.repeat(num_epochs) dataset = dataset.map( map_image_data(prepare_for_batching(model_image_size)), num_parallel_calls=num_parallel_calls, ) # batching and padding if batch_size is not None: dataset = dataset.padded_batch( batch_size=batch_size, padded_shapes=get_padding_shapes( dataset, spatial_image_shape=model_image_size), drop_remainder=True, ) # try to prefetch dataset on certain device if prefetch_to_device is not None: prefetch_fn = tf.data.experimental.prefetch_to_device( device=prefetch_to_device, buffer_size=prefetch_buffer_size) dataset = dataset.apply(prefetch_fn) else: if prefetch_buffer_size is not None: dataset = dataset.prefetch(buffer_size=prefetch_buffer_size) return dataset
def send_to_data_service(dataset: tf.data.Dataset, compute_config: TfDataServiceConfig, rank: int, size: Optional[int] = None, processing_mode: str = 'distributed_epoch', reuse_dataset: bool = False, round_robin: bool = False) -> tf.data.Dataset: if compute_config.dispatcher_side == 'training': raise RuntimeError( 'training side dispatcher not supported, use tf_data_service context manager instead' ) with tf_data_service(compute_config, rank) as dispatcher_address: return dataset.apply( tf.data.experimental.service.distribute( processing_mode=processing_mode, service=dispatcher_address, job_name='job' if reuse_dataset else None, consumer_index=rank if reuse_dataset and round_robin else None, num_consumers=size if reuse_dataset and round_robin else None))
def get_top_tokens(corpus: tf.data.Dataset, n_top: int = 1000) -> Tuple[dict, int, int]: """ Builds the token mapping which is used to initialize the word embeddings in the model. Get the most frequent terms which appear in the training corpus. Parameters ---------- corpus : tf.data.Dataset Entire dataset object n_top : int, optional Number of most frequent vocab terms to keep for training, by default 1000 Returns ------- (dict, int, int) (token->integer lookup, maximum sequence length, size of data set) """ lookup = Counter() max_sequence_length, data_set_size = 0, 0 corpus = corpus.map(lambda x: tf.strings.split(x, sep=''), num_parallel_calls=tf.data.experimental.AUTOTUNE) for tokens_list in corpus.apply( tf.data.experimental.dense_to_ragged_batch(32)).prefetch(5): lookup.update(tokens_list.flat_values.numpy()) max_batch_seq_len = int(tokens_list.row_lengths().numpy().max()) if max_batch_seq_len > max_sequence_length: max_sequence_length = max_batch_seq_len data_set_size += int(tokens_list.nrows()) # tensorflow converts strings to bytes, let's maintain that (no decoding) hash_map = { key: idx + 2 for idx, (key, value) in enumerate(lookup.most_common(n_top)) } hash_map["<s>".encode('utf8')] = 0 hash_map["</s>".encode('utf8')] = 1 return hash_map, max_sequence_length, data_set_size
def prepare_dataset( self, # type: ignore filenames: tf.data.Dataset, buckets: List[int], batch_sizes: List[int], shuffle: bool, is_holdout: bool, holdout_clans: set, holdout_families: set) -> tf.data.Dataset: def _check_membership(tensor, array): iscontained = tf.py_func(lambda t: t in array, [tensor], tf.bool) iscontained.set_shape(()) return iscontained def _filter_fn(example): is_holdout_example = \ _check_membership(example['clan'], holdout_clans) | \ _check_membership(example['family'], holdout_families) return ~(is_holdout ^ is_holdout_example) def _load_records_and_preprocess(fname: tf.Tensor): dataset = tf.data.TFRecordDataset(fname) dataset = dataset.map(self._deserialization_func) # Hold out a prespecified set of families and clans dataset = dataset.filter(_filter_fn) return dataset dataset = filenames.apply( tf.data.experimental.parallel_interleave( _load_records_and_preprocess, sloppy=True, cycle_length=128, buffer_output_elements=32)) dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024) batch_fun = tf.data.experimental.bucket_by_sequence_length( lambda example: example['protein_length'], buckets, batch_sizes) dataset = dataset.apply(batch_fun) return dataset
def read_dataset(dataset: tf.data.Dataset) -> Tuple[float, int]: dataset = dataset.apply( tf.data.experimental.map_and_batch(dataset_parser, batch_size=1, num_parallel_batches=2, drop_remainder=True)) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) next_element_from_dataset = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: data_samples = 0 dataset_read_start_time = time.time() while True: try: sess.run(next_element_from_dataset) data_samples += 1 except tf.errors.OutOfRangeError: break dataset_read_time = time.time() - dataset_read_start_time return dataset_read_time, data_samples
def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset: """Build a pipeline fetching, shuffling, and preprocessing the dataset. Args: dataset: A `tf.data.Dataset` that loads raw files. Returns: A TensorFlow dataset outputting batched images and labels. """ if (self.config.builder != 'tfds' and self.input_context and self.input_context.num_input_pipelines > 1): dataset = dataset.shard(self.input_context.num_input_pipelines, self.input_context.input_pipeline_id) logging.info( 'Sharding the dataset: input_pipeline_id=%d ' 'num_input_pipelines=%d', self.input_context.num_input_pipelines, self.input_context.input_pipeline_id) if self.is_training and self.config.builder == 'records': # Shuffle the input files. dataset.shuffle(buffer_size=self.config.file_shuffle_buffer_size) if self.is_training and not self.config.cache: dataset = dataset.repeat() if self.config.builder == 'records': # Read the data from disk in parallel dataset = dataset.interleave( tf.data.TFRecordDataset, cycle_length=10, block_length=1, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self.config.cache: dataset = dataset.cache() if self.is_training: dataset = dataset.shuffle(self.config.shuffle_buffer_size) dataset = dataset.repeat() # Parse, pre-process, and batch the data in parallel if self.config.builder == 'records': preprocess = self.parse_record else: preprocess = self.preprocess dataset = dataset.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self.input_context and self.config.num_devices > 1: if not self.config.use_per_replica_batch_size: raise ValueError( 'The builder does not support a global batch size with more than ' 'one replica. Got {} replicas. Please set a ' '`per_replica_batch_size` and enable ' '`use_per_replica_batch_size=True`.'.format( self.config.num_devices)) # The batch size of the dataset will be multiplied by the number of # replicas automatically when strategy.distribute_datasets_from_function # is called, so we use local batch size here. dataset = dataset.batch(self.local_batch_size, drop_remainder=self.is_training) else: dataset = dataset.batch(self.global_batch_size, drop_remainder=self.is_training) # Prefetch overlaps in-feed with training dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) if self.config.tf_data_service: if not hasattr(tf.data.experimental, 'service'): raise ValueError( 'The tf_data_service flag requires Tensorflow version ' '>= 2.3.0, but the version is {}'.format(tf.__version__)) dataset = dataset.apply( tf.data.experimental.service.distribute( processing_mode='parallel_epochs', service=self.config.tf_data_service, job_name='resnet_train')) dataset = dataset.prefetch( buffer_size=tf.data.experimental.AUTOTUNE) return dataset
def __call__(self, dataset: tf.data.Dataset) -> tf.data.Dataset: return dataset.apply( tf.data.experimental.dense_to_ragged_batch(self._batch_size, self._drop_remainder))
def get_bucketed_batches( dataset: tf.data.Dataset, batch_size: int, bucket_size: int, max_length: int, padded_shapes: Any, example_size_fn: Any, shuffle: bool = False, shuffle_seed: Optional[int] = None, drop_remainder: bool = False, ) -> tf.data.Dataset: """Returns padded batches of shuffled examples bucketed by length. This shuffles the examples randomly each epoch. The random order is deterministic and controlled by the seed. Batches are padded because sentences have different lengths. Sentences that are shorter in a batch will get 0s added at the end, until all sentences in the batch have the same length. For performance, examples of similar lengths are bucketed together. However, the contents of the buckets and their order is random each epoch, and controlled by the seed. Args: dataset: A TF Dataset with SST examples to be shuffled and batched. batch_size: The size of each batch. The remainder is dropped. bucket_size: How many different lengths go in each bucket. max_length: The maximum length to provide a bucket for. padded_shapes: A nested structure representing the shape to which the respective component of each input element should be padded prior to batching. See `tf.data.Dataset.padded_batch` for examples. example_size_fn: A TF function that returns the size of an example to determine in which bucket it goes. E.g., the sentence length. shuffle: Shuffle the dataset each epoch using seed. shuffle_seed: The seed that determines the shuffling order, with a different order each epoch. drop_remainder: Drop the last batch if it is not of size batch_size. Returns: A TF Dataset containing padded bucketed batches. """ if shuffle: assert shuffle_seed is not None, 'When shuffling you must provide a seed.' # For bucket_size 8 and max length 24, we get bucket boundaries [9, 17, 25]. bucket_boundaries = get_bucket_boundaries(bucket_size, max_length) logging.info('Batching bucket boundaries: %r', bucket_boundaries) # One batch size for each bucket plus one additional one (per requirement). bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1) bucket_fn = tf.data.experimental.bucket_by_sequence_length( example_size_fn, bucket_boundaries, bucket_batch_sizes, padded_shapes=padded_shapes, pad_to_bucket_boundary=True, drop_remainder=drop_remainder) if shuffle: # For shuffling we need to know how many training examples we have. num_examples = get_num_examples(dataset) num_batches = num_examples // batch_size return dataset.shuffle( num_examples, seed=shuffle_seed, reshuffle_each_iteration=True).apply(bucket_fn).shuffle( num_batches, seed=shuffle_seed, reshuffle_each_iteration=True).prefetch( tf.data.experimental.AUTOTUNE) return dataset.apply(bucket_fn).prefetch(tf.data.experimental.AUTOTUNE)
def batch(dataset: tf.data.Dataset, batch_size: int): return dataset.apply( tf.data.experimental.dense_to_ragged_batch(batch_size))
def build_meta_model_trainable( meta_model_func: Callable, train_dataset: tf.data.Dataset, validation_dataset: tf.data.Dataset, batcher: Transform, shuffle_buffer: int, compiler: Optional[Callable[[tf.keras.Model], Any]] = None, cache_factory: Optional[Callable[[str], Transform]] = None, cache_dir: Optional[str] = None, cache_repeats: Optional[int] = None, train_augment_func: Optional[Callable] = None, validation_augment_func: Optional[Callable] = None, callbacks: Iterable[tf.keras.callbacks.Callback] = (), seed: Optional[int] = None, ): # this version DOES bleed examples from 1 train epoch to the next # via a final full batch and shuffle buffer # Hopefully it gets rid of the memory leak we see? if cache_factory is not None: assert cache_dir is not None assert cache_repeats is not None cache_dir = expand(cache_dir) else: assert cache_dir is None assert cache_repeats is None if validation_augment_func: validation_dataset = validation_dataset.map( validation_augment_func, num_parallel_calls=AUTOTUNE) pipeline, model = pl.build_pipelined_model(meta_model_func, validation_dataset.element_spec, batcher) if compiler is not None: compiler(model) # finalize validation_dataset pre_cache, pre_batch, post_batch = _get_map_funcs(pipeline, training=False) validation_dataset = (validation_dataset.map( chain_map_funcs(pre_cache, pre_batch), AUTOTUNE).apply(batcher).map(post_batch, AUTOTUNE)) if cache_factory is not None: validation_dataset = validation_dataset.apply( cache_factory(os.path.join(cache_dir, "validation"))) # train_data steps_per_epoch = tf.keras.backend.get_value( train_dataset.apply(batcher).cardinality()) pre_cache, pre_batch, post_batch = _get_map_funcs(pipeline, training=True) if cache_factory is None: train_dataset = train_dataset.repeat().shuffle(shuffle_buffer, seed=seed) if train_augment_func is None: train_dataset = (train_dataset.map( chain_map_funcs(pre_cache, pre_batch), num_parallel_calls=AUTOTUNE, ).apply(batcher).map(post_batch, AUTOTUNE)) else: # cache_factory is None, train_augment_func is not train_dataset = (train_dataset.repeat().apply( tfrng.data.stateless_map( chain_map_funcs(train_augment_func, pre_cache, pre_batch), seed=seed, num_parallel_calls=AUTOTUNE, )).apply(batcher).map(post_batch, AUTOTUNE)) else: # cache_factory is not None # We create separately cached datasets and flat_map over them. # This allows us to reuse the same caches if we want to change cache_repeats assert cache_repeats < 1e4 # for unique path naming paths = [ os.path.join(cache_dir, "train", f"repeat-{i:04d}") for i in range(cache_repeats) ] if train_augment_func is None: # No augmentation assert cache_repeats == 1 (path, ) = paths train_dataset = (train_dataset.map( pre_cache, num_parallel_calls=AUTOTUNE).apply( cache_factory(path)).repeat()) else: def get_cached(epoch_seed, path): return train_dataset.apply( tfrng.data.stateless_map( chain_map_funcs(train_augment_func, pre_cache), seed=epoch_seed, num_parallel_calls=AUTOTUNE, )).apply(cache_factory(path)) # train_dataset = ( # paths.apply(tfrng.data.with_seed(seed)) # .repeat() # .flat_map(get_cached) # ) # create cached datasets in eager mode. # For some cache_factory implementations this will mean all `cache_repeat` # cache files are run ahead of time. # This allows us to save iterators for training. datasets = [ get_cached(s, p) for s, p in zip( tf.data.experimental.RandomDataset(seed), paths) ] train_dataset = (tf.data.Dataset.from_tensor_slices( datasets).repeat().flat_map(lambda ds: ds)) train_dataset = ( train_dataset.shuffle(shuffle_buffer, seed=seed).map(pre_batch, num_parallel_calls=AUTOTUNE). apply(batcher).map(post_batch, num_parallel_calls=AUTOTUNE).repeat( ) # HACK: because the assert_cardinality below raises on iteration # .apply(tf.data.experimental.assert_cardinality( # tf.data.INFINITE_CARDINALITY)) ) # https://github.com/tensorflow/tensorflow/issues/45894 # assert specs the same train_spec = train_dataset.element_spec val_spec = validation_dataset.element_spec tf.nest.assert_same_structure(train_spec, val_spec) flat_train = tf.nest.flatten(train_spec) flat_val = tf.nest.flatten(val_spec) for t, v in zip(flat_train, flat_val): assert t == v return Trainable( model=model, train_data=train_dataset, steps_per_epoch=steps_per_epoch, validation_data=validation_dataset, callbacks=tuple(callbacks), )
def transform(dataset: tf.data.Dataset) -> tf.data.Dataset: return dataset.apply(with_seed(seed, 2)).map( actual_map_func, num_parallel_calls=num_parallel_calls, deterministic=deterministic, )