def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset: """Build a pipeline fetching, shuffling, and preprocessing the dataset. Args: dataset: A `tf.data.Dataset` that loads raw files. Returns: A TensorFlow dataset outputting batched images and labels. """ if self._num_gpus > 1: dataset = dataset.shard(self._num_gpus, hvd.rank()) if self.is_training: # Shuffle the input files. dataset.shuffle(buffer_size=self._file_shuffle_buffer_size) if self.is_training and not self._cache: dataset = dataset.repeat() # Read the data from disk in parallel dataset = dataset.interleave( tf.data.TFRecordDataset, cycle_length=10, block_length=1, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self._cache: dataset = dataset.cache() if self.is_training: dataset = dataset.shuffle(self._shuffle_buffer_size) dataset = dataset.repeat() # Parse, pre-process, and batch the data in parallel preprocess = self.parse_record dataset = dataset.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self._num_gpus > 1: # The batch size of the dataset will be multiplied by the number of # replicas automatically when strategy.distribute_datasets_from_function # is called, so we use local batch size here. dataset = dataset.batch(self.local_batch_size, drop_remainder=self.is_training) else: dataset = dataset.batch(self.global_batch_size, drop_remainder=self.is_training) # Apply Mixup mixup_alpha = self.mixup_alpha if self.is_training else 0.0 dataset = dataset.map( functools.partial(self.mixup, self.local_batch_size, mixup_alpha), num_parallel_calls=64) # Prefetch overlaps in-feed with training dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset
def shard(self, data: tf.data.Dataset) -> tf.data.Dataset: """ Shard the data Parameters ---------- data data Returns ------- data_sharded shard of the data """ data = data.shard(self.number_of_shards, self.shard_index) return data
def __init__( self, dataset: tf.data.Dataset, shard_num: int, shard_mod: int, name: str = "DATA", save_folder: Path = "./", ): super(Shard, self).__init__() self.data = dataset self.shard = dataset.shard(shard_num, shard_mod) filename = f"{name}-{shard_num}-{shard_mod}" self.save_file = Path(save_folder) / filename self.count = 0
def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset: """Build a pipeline fetching, shuffling, and preprocessing the dataset. Args: dataset: A `tf.data.Dataset` that loads raw files. Returns: A TensorFlow dataset outputting batched images and labels. """ if (self.config.builder != 'tfds' and self.input_context and self.input_context.num_input_pipelines > 1): dataset = dataset.shard(self.input_context.num_input_pipelines, self.input_context.input_pipeline_id) logging.info( 'Sharding the dataset: input_pipeline_id=%d ' 'num_input_pipelines=%d', self.input_context.num_input_pipelines, self.input_context.input_pipeline_id) if self.is_training and self.config.builder == 'records': # Shuffle the input files. dataset.shuffle(buffer_size=self.config.file_shuffle_buffer_size) if self.is_training and not self.config.cache: dataset = dataset.repeat() if self.config.builder == 'records': # Read the data from disk in parallel dataset = dataset.interleave( tf.data.TFRecordDataset, cycle_length=10, block_length=1, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self.config.cache: dataset = dataset.cache() if self.is_training: dataset = dataset.shuffle(self.config.shuffle_buffer_size) dataset = dataset.repeat() # Parse, pre-process, and batch the data in parallel if self.config.builder == 'records': preprocess = self.parse_record else: preprocess = self.preprocess dataset = dataset.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self.input_context and self.config.num_devices > 1: if not self.config.use_per_replica_batch_size: raise ValueError( 'The builder does not support a global batch size with more than ' 'one replica. Got {} replicas. Please set a ' '`per_replica_batch_size` and enable ' '`use_per_replica_batch_size=True`.'.format( self.config.num_devices)) # The batch size of the dataset will be multiplied by the number of # replicas automatically when strategy.distribute_datasets_from_function # is called, so we use local batch size here. dataset = dataset.batch(self.local_batch_size, drop_remainder=self.is_training) else: dataset = dataset.batch(self.global_batch_size, drop_remainder=self.is_training) # Prefetch overlaps in-feed with training dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) if self.config.tf_data_service: if not hasattr(tf.data.experimental, 'service'): raise ValueError( 'The tf_data_service flag requires Tensorflow version ' '>= 2.3.0, but the version is {}'.format(tf.__version__)) dataset = dataset.apply( tf.data.experimental.service.distribute( processing_mode='parallel_epochs', service=self.config.tf_data_service, job_name='resnet_train')) dataset = dataset.prefetch( buffer_size=tf.data.experimental.AUTOTUNE) return dataset
def pipeline( self, dataset: tf.data.Dataset, input_context: tf.distribute.InputContext = None ) -> tf.data.Dataset: """Build a pipeline fetching, shuffling, and preprocessing the dataset. Args: dataset: A `tf.data.Dataset` that loads raw files. input_context: An optional context provided by `tf.distribute` for cross-replica training. This isn't necessary if using Keras compile/fit. Returns: A TensorFlow dataset outputting batched images and labels. """ if input_context and input_context.num_input_pipelines > 1: dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) if self.is_training and not self.config.cache: dataset = dataset.repeat() if self.config.builder == 'records': # Read the data from disk in parallel buffer_size = 8 * 1024 * 1024 # Use 8 MiB per file dataset = dataset.interleave( lambda name: tf.data.TFRecordDataset(name, buffer_size=buffer_size), cycle_length=16, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(self.global_batch_size) if self.config.cache: dataset = dataset.cache() if self.is_training: dataset = dataset.shuffle(self.config.shuffle_buffer_size) dataset = dataset.repeat() # Parse, pre-process, and batch the data in parallel if self.config.builder == 'records': preprocess = self.parse_record else: preprocess = self.preprocess dataset = dataset.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.batch(self.batch_size, drop_remainder=self.is_training) # Note: we could do image normalization here, but we defer it to the model # which can perform it much faster on a GPU/TPU # TODO(dankondratyuk): if we fix prefetching, we can do it here if self.is_training and self.config.deterministic_train is not None: options = tf.data.Options() options.experimental_deterministic = self.config.deterministic_train options.experimental_slack = self.config.use_slack options.experimental_optimization.parallel_batch = True options.experimental_optimization.map_fusion = True options.experimental_optimization.map_vectorization.enabled = True options.experimental_optimization.map_parallelization = True dataset = dataset.with_options(options) # Prefetch overlaps in-feed with training # Note: autotune here is not recommended, as this can lead to memory leaks. # Instead, use a constant prefetch size like the the number of devices. dataset = dataset.prefetch(self.config.num_devices) return dataset
def pipeline( self, dataset: tf.data.Dataset, input_context: tf.distribute.InputContext = None ) -> tf.data.Dataset: """Build a pipeline fetching, shuffling, and preprocessing the dataset. Args: dataset: A `tf.data.Dataset` that loads raw files. input_context: An optional context provided by `tf.distribute` for cross-replica training. If set with more than one replica, this function assumes `use_per_replica_batch_size=True`. Returns: A TensorFlow dataset outputting batched images and labels. """ if input_context and input_context.num_input_pipelines > 1: dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) if self.is_training and not self.config.cache: dataset = dataset.repeat() if self.config.builder == 'records': # Read the data from disk in parallel buffer_size = 8 * 1024 * 1024 # Use 8 MiB per file dataset = dataset.interleave( lambda name: tf.data.TFRecordDataset(name, buffer_size=buffer_size), cycle_length=16, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self.config.cache: dataset = dataset.cache() if self.is_training: dataset = dataset.shuffle(self.config.shuffle_buffer_size) dataset = dataset.repeat() # Parse, pre-process, and batch the data in parallel if self.config.builder == 'records': preprocess = self.parse_record else: preprocess = self.preprocess dataset = dataset.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) if input_context and self.config.num_devices > 1: if not self.config.use_per_replica_batch_size: raise ValueError( 'The builder does not support a global batch size with more than ' 'one replica. Got {} replicas. Please set a ' '`per_replica_batch_size` and enable ' '`use_per_replica_batch_size=True`.'.format( self.config.num_devices)) # The batch size of the dataset will be multiplied by the number of # replicas automatically when strategy.distribute_datasets_from_function # is called, so we use local batch size here. dataset = dataset.batch(self.local_batch_size, drop_remainder=self.is_training) else: dataset = dataset.batch(self.global_batch_size, drop_remainder=self.is_training) if self.is_training: options = tf.data.Options() options.experimental_deterministic = self.config.deterministic_train options.experimental_slack = self.config.use_slack options.experimental_optimization.parallel_batch = True options.experimental_optimization.map_fusion = True options.experimental_optimization.map_vectorization.enabled = True options.experimental_optimization.map_parallelization = True dataset = dataset.with_options(options) # Prefetch overlaps in-feed with training dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset
def hydra_predict_and_write_to_csv(all_eval_data: tf.data.Dataset, hydra_model: tf.keras.Model, missing_outcomes: bool, prediction_file: str, num_shards=100): """ Make predictions from eval_data using hydra_model, and write them to a tsv at prediction file :param all_eval_data: data to make predictions for :param hydra_model: model to use for predictions :param missing_outcomes: whether there are missing outcomes (changes signature of data labels and model output) :param prediction_file: file to write tsv :param num_shards: running predictions on large data can take a long time, so we shard the data into num_shards as a form of checkpointing. Each shard will have an intermediate TSV written out :return: """ # running evaluation on very large data can take an obscenely long time, so we shard the data as checkpointing for idx in range(num_shards): eval_data = all_eval_data.shard(num_shards=num_shards, index=idx) outputs = hydra_model.predict(x=eval_data) out_dict = {} if missing_outcomes: for t, g0 in enumerate(tf.unstack(outputs[0], axis=-1)): out_dict['g0_' + str(t)] = g0.numpy() for t, g1 in enumerate(tf.unstack(outputs[1], axis=-1)): out_dict['g1_' + str(t)] = g1.numpy() out_dict['prob_y_obs'] = np.squeeze(outputs[2]) for out, q in enumerate(outputs[3:]): out_dict['q' + str(out)] = np.squeeze(q) else: for t, g in enumerate(tf.unstack(outputs[0], axis=-1)): out_dict['g_' + str(t)] = g.numpy() for out, q in enumerate(outputs[1:]): out_dict['q' + str(out)] = np.squeeze(q) predictions = pd.DataFrame(out_dict) label_dataset = eval_data.map( lambda f, l: l, num_parallel_calls=tf.data.experimental.AUTOTUNE) data_df = dataset_to_pandas_df(label_dataset) outs = data_df.join(predictions) with tf.io.gfile.GFile(prediction_file + f'_index{idx}', "w") as writer: writer.write(outs.to_csv(sep="\t")) # merge shards and write final predictions csv dfs = [] for idx in range(num_shards): dfs.append(pd.read_csv(prediction_file + f'_index{idx}', sep='\t')) full_df = pd.concat(dfs) with tf.io.gfile.GFile(prediction_file, "w") as writer: writer.write(full_df.to_csv(sep="\t"))
def compute_predictions_jax( model: PredictionModel, dataset: tf.data.Dataset, batch_size: int )-> Iterator[Tuple[types.ModelPredictions, types.Features]]: """Yield the predictions of the given JAX model on the given dataset. Note that this also works in multi-host configurations. You have to make sure that this function gets called on all hosts. The results will be yielded only to the host with a jax.host_id() equal to 0. Args: model: A function that takes tensor-valued features and returns a vector of predictions. dataset: The dataset that the function consumes to produce the predictions. batch_size: The batch size that should be used. Yields: The predictions of the model on the dataset. """ def _gather(inputs): return jax.lax.all_gather(inputs, "i") gather = jax.pmap(_gather, axis_name="i") def infer(features): probabilities = model(features) return_vals = (probabilities, features["metadata"], features["mask"]) return_vals_reshaped = jax.tree_map( lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), return_vals ) return jax.tree_map(lambda x: x[0], gather(return_vals_reshaped)) if dataset.cardinality() < 0: raise ValueError( "The cardinality must be known when running JAX multi-host models.") total_batches = math.ceil(dataset.cardinality() / batch_size) lcm = lambda x, y: (x * y) // math.gcd(x, y) # We want each shard (host) to get an equal number of batches. total_batches_padded = lcm(jax.host_count(), total_batches) logging.info("Total batches %d, rounded up to %d", total_batches, total_batches_padded) def pad_strings(array): if array.dtype != tf.string: return array array_bytes = tf.strings.unicode_decode(array, "UTF-8") # The return type is either Tensor or RaggedTensor. try: # When a RaggedTensor, which we need to convert it. # to_tensor() adds a leading dimension of size 1, which we drop. array_bytes = array_bytes.to_tensor()[0] except AttributeError: pass array_size = tf.size(array_bytes) with tf.control_dependencies([ tf.compat.v1.assert_less_equal(array_size, 1024)]): packed = tf.pad(array_bytes, [[0, 1024 - array_size]]) return {"__packed": tf.ensure_shape(packed, [1024])} def unpad_strings(array): if isinstance(array, dict): with_trailing_zeros = bytes(tf.strings.unicode_encode( np.asarray(array["__packed"]).reshape(-1), "UTF-8").numpy()) return with_trailing_zeros.rstrip(b"\x00") else: return np.asarray(array) def pad_strings_in_metadata(features): """Only padding of the strings subject to a gather operation.""" features["metadata"] = tf.nest.map_structure(pad_strings, features["metadata"]) return features dataset = clu_dd.pad_dataset( dataset.map(pad_strings_in_metadata), batch_dims=[batch_size], pad_up_to_batches=total_batches_padded, cardinality=None, # It will be inferred from the datset. ).batch(batch_size) # The shard for the current host. dataset_shard = dataset.shard(jax.host_count(), jax.host_id()) logging.info("Batches per host: %d", dataset_shard.cardinality()) for features in dataset_shard.as_numpy_iterator(): time_start = time.time() # There is a bug in XLA, the following fails for int8s. features["mask"] = features["mask"].astype(np.int32) flatten = lambda array: array.reshape((-1,) + array.shape[2:]) predictions, metadatas, masks = jax.tree_map(flatten, infer(features)) time_end = time.time() time_delta_per_example = (time_end - time_start) / predictions.shape[0] predictions = np.asarray(predictions) # Materialize. if jax.host_id() == 0: for i in range(predictions.shape[0]): if masks[i]: predictions_i = types.ModelPredictions( predictions=[predictions[i]], time_in_s=time_delta_per_example) metadata_i = _slice_dictionary(metadatas, i) is_leaf_fn = lambda x: isinstance(x, dict) and "__packed" in x metadata_i_unpadded = jax.tree_map( unpad_strings, metadata_i, is_leaf=is_leaf_fn) yield predictions_i, metadata_i_unpadded
def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset: """Build a pipeline fetching, shuffling, and preprocessing the dataset. Args: dataset: A `tf.data.Dataset` that loads raw files. Returns: A TensorFlow dataset outputting batched images and labels. """ # This can help resolve OOM issues when using only 1 GPU for training options = tf.data.Options() options.experimental_optimization.map_parallelization = ( not self.disable_map_parallelization) dataset = dataset.with_options(options) if self._num_gpus > 1: # For multi-host training, we want each hosts to always process the same # subset of files. Each host only sees a subset of the entire dataset, # allowing us to cache larger datasets in memory. dataset = dataset.shard(self._num_gpus, hvd.rank()) if self.is_training: # Shuffle the input files. dataset.shuffle(buffer_size=self._file_shuffle_buffer_size) if self.is_training and not self._cache: dataset = dataset.repeat() # Read the data from disk in parallel dataset = dataset.interleave( tf.data.TFRecordDataset, cycle_length=10, block_length=1, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self._cache: dataset = dataset.cache() if self.is_training: dataset = dataset.shuffle(self._shuffle_buffer_size) dataset = dataset.repeat() # Parse, pre-process, and batch the data in parallel preprocess = self.parse_record dataset = dataset.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self._num_gpus > 1: # The batch size of the dataset will be multiplied by the number of # replicas automatically when strategy.distribute_datasets_from_function # is called, so we use local batch size here. dataset = dataset.batch(self.local_batch_size, drop_remainder=self.is_training) else: dataset = dataset.batch(self.global_batch_size, drop_remainder=self.is_training) # apply Mixup/CutMix only during training, if requested in the data pipeline, # otherwise they will be applied in the model module on device mixup_alpha = self.mixup_alpha if self.is_training else 0.0 cutmix_alpha = self.cutmix_alpha if self.is_training else 0.0 dataset = dataset.map(functools.partial(mixing, self.local_batch_size, mixup_alpha, cutmix_alpha, self.defer_img_mixing), num_parallel_calls=64) # Assign static batch size dimension # dataset = dataset.map( # functools.partial(self.set_shapes, batch_size), # num_parallel_calls=tf.data.experimental.AUTOTUNE) # Prefetch overlaps in-feed with training dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset