def get_metrics(config, metrics): if not config.debug_run: metrics = common_utils.get_metrics(metrics) else: metrics = common_utils.stack_forest(metrics) metrics = jax.device_get(metrics) return metrics
def maybe_eval_and_log(self, eval_summary, master, step, tick, train_metrics, train_summary): """Maybe evaluate and log based on the current step value.""" if (step % self.eval_frequency == 0) or (step == self.total_steps): del eval_summary del train_summary train_metrics = common_utils.get_metrics(train_metrics) train_summary = pipeline_utils.compute_global_mean_metrics( train_metrics) tock = time.time() steps_per_sec = self.eval_frequency / (tock - tick) tick = tock # log train summary if master: self.write_train_summary(step=step, metric_dict=train_metrics, summary=train_summary, steps_per_sec=steps_per_sec) # reset metric accumulation for next evaluation cycle del train_metrics train_metrics = [] # sync model state across replicas self.train_state = pipeline_utils.sync_model_state_across_replicas( self.train_state) # evaluate and log the results eval_summary, _ = self.eval(step, self.train_state) return eval_summary, train_metrics, train_summary, tick
def test(optimizer, state, p_eval_step, step, test_ds, summary_writer): """Test the flax module in optimizer on test_ds. Args: optimizer: flax optimizer (contains flax module). state: model state, e.g. batch statistics. p_eval_step: fn; Pmapped evaluation step function. step: int; Number of training steps passed so far. test_ds: tf.dataset; Test dataset. summary_writer: tensorflow summary writer. """ # Test Metrics test_metrics = [] test_iter = iter(test_ds) for _, test_batch in zip(itertools.repeat(1), test_iter): # pylint: disable=protected-access test_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), test_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, state, test_batch) test_metrics.append(metrics) test_metrics = common_utils.get_metrics(test_metrics) test_metrics_sums = jax.tree_map(jnp.sum, test_metrics) test_denominator = test_metrics_sums.pop('denominator') test_summary = jax.tree_map( lambda x: x / test_denominator, # pylint: disable=cell-var-from-loop test_metrics_sums) logging.info('test in step: %d, loss: %.4f, acc: %.4f', step, test_summary['loss'], test_summary['accuracy']) if jax.host_id() == 0: for key, val in test_summary.items(): summary_writer.scalar(f'test_{key}', val, step) summary_writer.flush()
def combine_metrics(step_metrics): """Given a list of metric dicts, combine to a single summary metrics dict. Args: step_metrics: A dict with (metric name, metric value) items. Contains summed metrics and the corresponding denominator (the number of next-token prediction instances). Each metric value have at least one dimension. Returns: A dict with (metric name, metric value) items containing combined metrics. """ metrics_all = common_utils.get_metrics(step_metrics) lr = None if 'learning_rate' in metrics_all: lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop if lr is not None: summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: if 'loss' in summary: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) return summary
def predict_once(run_configuration, optimizer=None): """Predict the result once for each element in the dataset.""" adapter = run_configuration.adapter checkpoint_path = run_configuration.original_checkpoint_path optimizer = optimizer or adapter.create_optimizer(run_configuration) dataset = run_configuration.dataset_info.dataset # Restore checkpoint optimizer = checkpoint_utils.restore_checkpoint(checkpoint_path, optimizer) # Replicate optimizer. optimizer = flax.jax_utils.replicate(optimizer) predict_step = adapter.make_predict_step() predict_step_parallel = jax.pmap(predict_step, axis_name='batch') # Perform inference dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw) metrics_all = [] for example in itertools.islice(dataset_iter, 200): train_inputs = adapter.get_train_inputs(example) metrics, logits, state = predict_step_parallel(optimizer.target, train_inputs) adapter.handle_predict(metrics, logits, state) metrics_all.append(metrics) metrics_all = common_utils.get_metrics(metrics_all) metrics = jax.tree_map(jnp.sum, metrics_all) return metrics
def train_for_one_epoch( dataset_source: dataset_source_lib.DatasetSource, optimizer: flax.optim.Optimizer, state: flax.nn.Collection, prng_key: jnp.ndarray, pmapped_train_step: _TrainStep, pmapped_update_ema: Optional[_EMAUpdateStep], moving_averages: Optional[efficientnet_optim.ExponentialMovingAverage], summary_writer: tensorboard.SummaryWriter ) -> Tuple[flax.optim.Optimizer, flax.nn.Collection, Optional[efficientnet_optim.ExponentialMovingAverage]]: """Trains the model for one epoch. Args: dataset_source: Container for the training dataset. optimizer: The optimizer targeting the model to train. state: Current state associated with the model (contains the batch norm MA). prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual dropout mask). Is not used for shuffling the dataset. pmapped_train_step: A pmapped version of the `train_step` function (see its documentation for more details). pmapped_update_ema: Function to update the parameter moving average. Can be None if we don't use EMA. moving_averages: Parameters moving average if used. summary_writer: A Tensorboard SummaryWriter to use to log metrics. Returns: The updated optimizer (with the associated updated model), state and PRNG key. """ start_time = time.time() cnt = 0 train_metrics = [] for batch in dataset_source.get_train(use_augmentations=True): # Generate a PRNG key that will be rolled into the batch. step_key = jax.random.fold_in(prng_key, optimizer.state.step[0]) # Load and shard the TF batch. batch = tensorflow_to_numpy(batch) batch = shard_batch(batch) # Shard the step PRNG key. sharded_keys = common_utils.shard_prng_key(step_key) optimizer, state, metrics, lr = pmapped_train_step( optimizer, state, batch, sharded_keys) cnt += 1 if moving_averages is not None: moving_averages = pmapped_update_ema(optimizer, state, moving_averages) train_metrics.append(metrics) train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging. train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) train_summary['learning_rate'] = lr[0] current_step = int(optimizer.state.step[0]) info = 'Whole training step done in {} ({} steps)'.format( time.time()-start_time, cnt) logging.info(info) for metric_name, metric_value in train_summary.items(): summary_writer.scalar(metric_name, metric_value, current_step) summary_writer.flush() return optimizer, state, moving_averages
def eval_split(self, train_state, split_name, eval_env_ids=None): """Evaluation loop on the specified split. Args: train_state: TrainState; Object containing training state. split_name: str; Name of the data split we want to evaluate the model on. eval_env_ids: list(int); Eval environments ids. Returns: eval_summary, train_state """ data_iters = self.task.dataset.data_iters[split_name] if eval_env_ids is None: eval_env_ids = list(map(int, data_iters.keys())) eval_metrics = {} if isinstance(self.steps_per_eval, dict): for env_id in eval_env_ids: env_id_str = str(env_id) env_eval_metrics = [] for _ in range(self.steps_per_eval[split_name][env_id_str]): env_eval_batches = self.get_next_batch( [data_iters[env_id_str]]) e_metrics = self.pmapped_eval_step(train_state, env_eval_batches, env_id) env_eval_metrics.append(e_metrics) env_eval_metrics = common_utils.get_metrics(env_eval_metrics) eval_metrics.update(env_eval_metrics) eval_summary = pipeline_utils.compute_global_mean_metrics( eval_metrics) else: _, data_iters = list(zip(*dict(data_iters).items())) eval_metrics = [] for _ in range(self.steps_per_eval): env_eval_batches = self.get_next_batch(data_iters) e_metrics = self.pmapped_eval_step(train_state, env_eval_batches, -1) eval_metrics.append(e_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_summary = pipeline_utils.compute_global_mean_metrics( eval_metrics) return eval_summary, eval_metrics
def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def eval_once(run_configuration, checkpoint_path, optimizer=None): """Evaluates a single checkpoint on a single epoch of data.""" config = run_configuration.config run_dir = run_configuration.run_dir adapter = run_configuration.adapter optimizer = optimizer or adapter.create_optimizer(run_configuration) dataset = run_configuration.dataset_info.dataset info = run_configuration.dataset_info.info eval_name = config.eval_name or 'eval' log_dir = os.path.join(run_dir, eval_name) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(log_dir) # Restore checkpoint optimizer = checkpoint_utils.restore_checkpoint(checkpoint_path, optimizer) step = int(optimizer.state.step) # Replicate optimizer. optimizer = flax.jax_utils.replicate(optimizer) eval_step = adapter.make_eval_step() eval_step_parallel = jax.pmap(eval_step, axis_name='batch') # Perform evaluation tick = time.time() metrics_all = [] example = None dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw) for unused_eval_step, example in zip(range(config.eval_steps), dataset_iter): train_inputs = adapter.get_train_inputs(example) metrics, logits, state = eval_step_parallel(optimizer.target, train_inputs) metrics_all.append(metrics) # Write results. metrics_all = common_utils.get_metrics(metrics_all) metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('eval @ train step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = len(metrics_all) / (tock - tick) examples_per_sec = denominator / (tock - tick) summary_writer.scalar('per-second/steps', steps_per_sec, step) summary_writer.scalar('per-second/examples', examples_per_sec, step) for key, val in summary.items(): summary_writer.scalar(key, val, step) adapter.write_summaries(example, logits, summary_writer, info, step, state) summary_writer.flush()
def write_metric(train_metrics, eval_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step)
def evaluate(p_eval_step, state, eval_ds, num_eval_steps=-1): """Evaluate on the given dataset.""" logging.info('Starting evaluating.') eval_metrics = [] for step, batch in enumerate(eval_ds): batch = jax.tree_map(np.asarray, batch) metrics = p_eval_step(batch=batch, state=state) eval_metrics.append(metrics) if num_eval_steps > 0 and step + 1 == num_eval_steps: break eval_metrics = common_utils.get_metrics(eval_metrics) summary = train_utils.metrics_summary(eval_metrics, 'eval') return summary
def evaluate(p_eval_step, params, eval_ds, rng): """Evaluate the target and return a dictionary with the metrics.""" logging.info('Gathering evaluation metrics.') eval_metrics = [] for eval_batch in eval_ds: eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics, rng = p_eval_step(rng, params, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_summary = jax.tree_map(np.mean, eval_metrics) return eval_summary, rng
def eval_policy(policy, rng, state, model, test_ds): """Evaluate the target with policy and return a dictionary with the metrics.""" eval_metrics = [] for eval_batch in test_ds: eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics, rng = eval_step_policy(rng, eval_batch, state, model, policy) # Better to leave metrics on device, and off-load after finishing epoch. eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_summary = jax.tree_map(np.mean, eval_metrics) return eval_summary
def write_train_metric(train_metrics, train_time, step): train_metrics = get_metrics(train_metrics) num_steps = len(list(train_metrics.values())[0]) for key, vals in train_metrics.items(): assert len(vals) == num_steps train_metrics_by_step = [{} for _ in range(num_steps)] for key, vals in train_metrics.items(): for i, val in enumerate(vals): train_metrics_by_step[i][f"train_{key}"] = val for i in range(num_steps): wandb.log(train_metrics_by_step[i], step=step - num_steps + i + 1) wandb.log({"train_time": train_time}, step=step)
def evaluate(*, p_eval_step, target, eval_ds): """Evaluate the target an return a dictionary with the metrics.""" eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs, _ = common_utils.shard(batches) metrics = p_eval_step(target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) return eval_summary
def evaluate(*, p_eval_step, target, eval_ds, num_eval_steps): """Evaluate the target an return a dictionary with the metrics.""" logging.info('Gathering evaluation metrics.') eval_metrics = [] eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types for _, eval_batch in zip(range(num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) return eval_summary
def train_for_one_epoch( dataset_source, optimizer, state, prng_key, pmapped_train_step, summary_writer ): """Trains the model for one epoch. Args: dataset_source: Container for the training dataset. optimizer: The optimizer targeting the model to train. state: Current state associated with the model (contains the batch norm MA). prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual dropout mask). Is not used for shuffling the dataset. pmapped_train_step: A pmapped version of the `train_step` function (see its documentation for more details). summary_writer: A Tensorboard SummaryWriter to use to log metrics. Returns: The updated optimizer (with the associated updated model), state and PRNG key. """ train_metrics = [] for batch in dataset_source.get_train(use_augmentations=True): # Generate a PRNG key that will be rolled into the batch. step_key, prng_key = jax.random.split(prng_key) # Load and shard the TF batch. batch = tensorflow_to_numpy(batch) batch = shard_batch(batch) # Shard the step PRNG key. sharded_keys = common_utils.shard_prng_key(step_key) optimizer, state, metrics, lr = pmapped_train_step( optimizer, state, batch, sharded_keys) train_metrics.append(metrics) train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging. train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) train_summary['learning_rate'] = lr[0] current_step = int(optimizer.state.step[0]) for metric_name, metric_value in train_summary.items(): summary_writer.scalar(metric_name, metric_value, current_step) summary_writer.flush() return optimizer, state, prng_key
def eval_on_dataset( model: flax.nn.Model, state: flax.nn.Collection, dataset: tf.data.Dataset, pmapped_eval_step: _EvalStep): """Evaluates the model on the whole dataset. Args: model: The model to evaluate. state: Current state associated with the model (contains the batch norm MA). dataset: Dataset on which the model should be evaluated. Should already being batched. pmapped_eval_step: A pmapped version of the `eval_step` function (see its documentation for more details). Returns: A dictionary containing the loss and error rate on the batch. These metrics are averaged over the samples. """ eval_metrics = [] total_num_samples = 0 all_host_psum = jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i') for eval_batch in dataset: # Load and shard the TF batch. eval_batch = load_and_shard_tf_batch(eval_batch) # Compute metrics and sum over all observations in the batch. metrics = pmapped_eval_step(model, state, eval_batch) eval_metrics.append(metrics) if 'mask' not in eval_batch: # Number of samples seen in num_replicas * per_replica_batch_size. total_num_samples += ( eval_batch['label'].shape[0] * eval_batch['label'].shape[1] * jax.host_count()) else: total_num_samples += all_host_psum(eval_batch['mask'])[0].sum() # Metrics are all the same across all replicas (since we applied psum in the # eval_step). The next line will fetch the metrics on one of them. eval_metrics = common_utils.get_metrics(eval_metrics) # Finally, we divide by the number of samples to get the mean error rate and # cross entropy. eval_summary = jax.tree_map(lambda x: x.sum() / total_num_samples, eval_metrics) return eval_summary
def test_reproduce_paper_evals(self, num_chunks): """Reproduce results from https://www.aclweb.org/anthology/P18-1009.pdf.""" num_samples = self.labels.shape[0] chunk_size = num_samples // num_chunks metrics = [] for chunk_start in range(0, num_samples, chunk_size): chunk_end = min(chunk_start + chunk_size, num_samples) labels = self.labels[chunk_start:chunk_end] predictions = self.predictions[chunk_start:chunk_end] current_metrics = ultra_fine_entity_typing_task.get_prediction_recall_metrics( labels, predictions) current_metrics = jax.tree_map(lambda x: jnp.expand_dims(x, 0), current_metrics) metrics.append(current_metrics) metrics = common_utils.get_metrics(metrics) metrics_sum = jax.tree_map(jnp.sum, metrics) processed_metrics = metric_utils.process_metrics(metrics_sum) self.assertAlmostEqual(processed_metrics['total_precision_value'], 0.481, places=3) self.assertAlmostEqual(processed_metrics['total_recall_value'], 0.232, places=3) self.assertAlmostEqual( processed_metrics['coarse_grained_precision_value'], 0.603, places=3) self.assertAlmostEqual( processed_metrics['coarse_grained_recall_value'], 0.616, places=3) self.assertAlmostEqual( processed_metrics['fine_grained_precision_value'], 0.404, places=3) self.assertAlmostEqual(processed_metrics['fine_grained_recall_value'], 0.384, places=3) self.assertAlmostEqual( processed_metrics['ultra_fine_grained_precision_value'], 0.428, places=3) self.assertAlmostEqual( processed_metrics['ultra_fine_grained_recall_value'], 0.088, places=3)
def eval_split(self, train_state, split_name): """Evaluation loop on the specified split. Args: train_state: TrainState; Object containing training state. split_name: str; Name of the data split we want to evaluate the model on. Returns: eval_summary, train_state """ data_iters = self.task.dataset.data_iters[split_name] eval_metrics = [] for _ in range(self.steps_per_eval): env_eval_batches = self.get_next_batch(data_iters) e_metrics = self.pmapped_eval_step(train_state, env_eval_batches) eval_metrics.append(e_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_summary = pipeline_utils.compute_global_mean_metrics(eval_metrics) return eval_summary, eval_metrics
def evaluate( eval_step_fn, train_state: ts.TrainState, model_vars: Dict[str, Any], eval_data: Sequence[Dict[str, Any]], ) -> Tuple[Dict[str, Any], Sequence[Tuple[Dict[str, Any], Optional[Dict[ str, Any]]]]]: """Evaluate current parameters and return a dictionary with metrics. Args: eval_step_fn: partial eval step that takes in model params and inputs only train_state: contains model params, loss fn, grad update fn. model_vars: model variables that are not optimized. eval_data: sequence of evaluation data. Returns: Dictionary of metrics aggregated over all evaluation steps and the info for the very first batch (batch itself and corresponding auxiliary output). """ logging.info('Performing evaluation.') eval_metrics = [] eval_auxiliary = [] for batch in eval_data: batch = jax.tree_map(jnp.asarray, batch) metrics, auxiliary_output = eval_step_fn( train_state, model_vars, batch, ) eval_metrics.append(metrics) batch_auxiliary = (jax.device_get(batch), jax.device_get(auxiliary_output)) eval_auxiliary.append(batch_auxiliary) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_summary = metric_utils.process_metrics(eval_metrics_sums, prefix='eval') return eval_summary, eval_auxiliary
def run_eval(eval_ds, num_eval_steps=-1): eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.count() else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): # pylint: disable=protected-access eval_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), eval_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip(jnp.exp(eval_summary['loss']), a_max=1.0e4) return eval_summary
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() batch_size = FLAGS.batch_size learning_rate = FLAGS.learning_rate num_train_steps = FLAGS.num_train_steps num_eval_steps = FLAGS.num_eval_steps eval_freq = FLAGS.eval_frequency max_length = FLAGS.max_length random_seed = FLAGS.random_seed if not FLAGS.dev: raise app.UsageError('Please provide path to dev set.') if not FLAGS.train: raise app.UsageError('Please provide path to training set.') parameter_path = os.path.join(FLAGS.model_dir, FLAGS.experiment + '.params') if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval')) if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') device_batch_size = batch_size // jax.device_count() # create the training and development dataset vocabs = input_pipeline.create_vocabs(FLAGS.train) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] train_ds = input_pipeline.sentence_dataset_dict( FLAGS.train, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=max_length) eval_ds = input_pipeline.sentence_dataset_dict( FLAGS.dev, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=max_length, repeat=1) train_iter = iter(train_ds) bs = device_batch_size * jax.device_count() rng = random.PRNGKey(random_seed) rng, init_rng = random.split(rng) input_shape = (bs, max_length) transformer_kwargs = { 'vocab_size': len(vocabs['forms']), 'output_vocab_size': len(vocabs['xpos']), 'emb_dim': 512, 'num_heads': 8, 'num_layers': 6, 'qkv_dim': 512, 'mlp_dim': 2048, 'max_len': max_length, } model = create_model(init_rng, tuple(input_shape), transformer_kwargs) optimizer = create_optimizer(model, learning_rate) del model # don't keep a copy of the initial model learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate) p_train_step = jax.pmap( functools.partial(train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) metrics_all = [] tick = time.time() best_dev_score = 0 for step, batch in zip(range(num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) if (step + 1) % eval_freq == 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # reset metric accumulation for next evaluation cycle. metrics_all = [] eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.repeat(1) else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size != batch_size: logging.info('Uneven batch size %d.', cur_pred_batch_size) eval_batch = jax.tree_map( lambda x: pad_examples(x, batch_size), eval_batch) eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip( jnp.exp(eval_summary['loss']), a_max=1.0e4) logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if best_dev_score < eval_summary['accuracy']: best_dev_score = eval_summary['accuracy'] # TODO: save model. eval_summary['best_dev_score'] = best_dev_score logging.info('best development model score %.4f', best_dev_score) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()
def run_train(run_configuration): """Runs the training workflow.""" config = run_configuration.config run_dir = run_configuration.run_dir adapter = run_configuration.adapter log_dir = os.path.join(run_dir, 'train') checkpoint_path = run_configuration.original_checkpoint_path dataset = run_configuration.dataset_info.dataset info = run_configuration.dataset_info.info random_seed = 0 rng = jax.random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rngs = jax.random.split(rng, jax.local_device_count()) # Set up optimizer. optimizer = adapter.create_optimizer(run_configuration, rng=init_rng) # Set up train step. train_step = adapter.make_train_step() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(log_dir) # Set up checkpointing. # TODO(dbieber): Set up phoenix. checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir) if checkpoint_path is None: checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir) optimizer = checkpoint_utils.handle_restart_behavior( checkpoint_path, optimizer, config) start_step = int(optimizer.state.step) num_train_steps = config.train.total_steps # Replicate optimizer. optimizer = flax.jax_utils.replicate(optimizer) # Begin training loop. dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw) summary_freq = config.logging.summary_freq metrics_all = [] tick = time.time() for step, example in zip(range(start_step, num_train_steps), dataset_iter): train_inputs = adapter.get_train_inputs(example) optimizer, metrics, dropout_rngs, logits, state = train_step( optimizer, train_inputs, dropout_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % config.logging.save_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.logging.save_freq: # Save unreplicated optimizer + model state. checkpoint_utils.save_checkpoint( checkpoint_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if summary_freq and step % summary_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = summary_freq / (tock - tick) examples_per_sec = denominator / (tock - tick) tick = tock summary_writer.scalar('per-second/steps', steps_per_sec, step) summary_writer.scalar('per-second/examples', examples_per_sec, step) for key, val in summary.items(): summary_writer.scalar(key, val, step) adapter.write_summaries(example, logits, summary_writer, info, step, state) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = []
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.model_dir) rng = random.PRNGKey(0) image_size = 224 batch_size = FLAGS.batch_size if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = batch_size // jax.host_count() device_batch_size = batch_size // jax.device_count() platform = jax.local_devices()[0].platform if FLAGS.half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 input_dtype = tf.bfloat16 else: model_dtype = jnp.float16 input_dtype = tf.float16 else: model_dtype = jnp.float32 input_dtype = tf.float32 train_iter = create_input_iter(local_batch_size, image_size, input_dtype, train=True, cache=FLAGS.cache) eval_iter = create_input_iter(local_batch_size, image_size, input_dtype, train=False, cache=FLAGS.cache) num_epochs = FLAGS.num_epochs steps_per_epoch = input_pipeline.TRAIN_IMAGES // batch_size steps_per_eval = input_pipeline.EVAL_IMAGES // batch_size steps_per_checkpoint = steps_per_epoch * 10 num_steps = steps_per_epoch * num_epochs base_learning_rate = FLAGS.learning_rate * batch_size / 256. base_learning_rate = base_learning_rate / FLAGS.loss_scaling model, model_state = create_model(rng, device_batch_size, image_size, model_dtype) optimizer = optim.Momentum(beta=FLAGS.momentum, nesterov=True).create(model) state = TrainState(step=0, optimizer=optimizer, model_state=model_state) del model, model_state # do not keep a copy of the initial model state = restore_checkpoint(state) step_offset = int( state.step) # step_offset > 0 if restarting from checkpoint state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_fn(base_learning_rate, steps_per_epoch, num_epochs) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') epoch_metrics = [] t_loop_start = time.time() for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) epoch_metrics.append(metrics) if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch epoch_metrics = common_utils.get_metrics(epoch_metrics) summary = jax.tree_map(lambda x: x.mean(), epoch_metrics) logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) steps_per_sec = steps_per_epoch / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: for key, vals in epoch_metrics.items(): tag = 'train_%s' % key for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) summary_writer.scalar('steps per second', steps_per_sec, step) epoch_metrics = [] eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) if jax.host_id() == 0: for key, val in eval_metrics.items(): tag = 'eval_%s' % key summary_writer.scalar(tag, val.mean(), step) summary_writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = "tpu_driver" jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # This seems to be necessary even when importing TF2? tf.enable_v2_behavior() # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') # Load Dataset logging.info('Initializing dataset.') train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_token = 2 # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_token) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer transformer_kwargs = { 'vocab_size': vocab_size, 'output_vocab_size': vocab_size, 'emb_dim': 1024, 'num_heads': 16, 'num_layers': 6, 'qkv_dim': 1024, 'mlp_dim': 4096, 'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length), 'share_embeddings': FLAGS.share_embeddings, 'logits_via_embedding': FLAGS.logits_via_embedding, } start_step = 0 rng = random.PRNGKey(FLAGS.random_seed) rng, init_rng = random.split(rng) input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) model, cache_def = create_model(init_rng, input_shape, target_shape, transformer_kwargs) optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.weight_decay) # We access model only from optimizer below via optimizer.target. del model if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_eval_step = jax.pmap(functools.partial( eval_step, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_pred_step = jax.pmap( functools.partial(predict_step, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info('Starting training loop.') metrics_all = [] t_loop_start = time.time() for step, batch in zip(range(start_step, FLAGS.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a checkpoint on one host after every checkpoint_freq steps. if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % FLAGS.eval_frequency != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr steps_per_eval = FLAGS.eval_frequency if step != 0 else 1 steps_per_sec = steps_per_eval / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] logging.info('train in step: %d, loss: %.4f', step, summary['loss']) # Eval Metrics logging.info('Gathering evaluation metrics.') t_eval_start = time.time() eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) logging.info('eval time: %.4f s step %d', time.time() - t_eval_start, step) # Translation and BLEU Score. logging.info('Translating evaluation dataset.') t_inference_start = time.time() predict_iter = iter(predict_ds) sources, references, predictions = [], [], [] for _, pred_batch in enumerate(predict_iter): pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) per_device_batchsize = pred_batch['inputs'].shape[1] cache_dtype = jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32 cache = jax_utils.replicate( cache_def.initialize_cache( (per_device_batchsize, FLAGS.max_predict_length), dtype=cache_dtype)) predicted = p_pred_step(pred_batch['inputs'], optimizer.target, cache, eos_token, FLAGS.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch['inputs']) targets = tohost(pred_batch['targets']) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info('Translation: %d predictions %d references %d sources.', len(predictions), len(references), len(sources)) logging.info('Translation time: %.4f s step %d.', time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n' if jax.host_id() == 0: eval_summary_writer.scalar('bleu', bleu_score, step) eval_summary_writer.text('samples', exemplars, step) eval_summary_writer.flush() logging.info('Translation BLEU Score %.4f', bleu_score)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_t5_mlm", model_args, data_args, framework="flax") if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO, datefmt="[%X]", ) # Log on each process the small summary: logger = logging.getLogger(__name__) # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" datasets = load_dataset( extension, data_files=data_files, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) datasets["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.config_name: config = T5Config.from_pretrained( model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer), use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = T5Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = datasets["train"].column_names else: column_names = datasets["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. # Since we make sure that all sequences are of the same length, no attention_mask is needed. def tokenize_function(examples): return tokenizer(examples[text_column_name], return_attention_mask=False) tokenized_datasets = datasets.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token. # To ensure that the input length is `max_seq_length`, we need to increase the maximum length # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly. expanded_inputs_length, targets_length = compute_input_and_target_lengths( inputs_length=max_seq_length, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, ) # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length. def group_texts(examples): # Concatenate all texts. concatenated_examples = { k: list(chain(*examples[k])) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= expanded_inputs_length: total_length = (total_length // expanded_inputs_length) * expanded_inputs_length # Split by chunks of max_len. result = { k: [ t[i:i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length) ] for k, t in concatenated_examples.items() } return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value # might be slower to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map tokenized_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) if model_args.model_name_or_path: model = FlaxT5ForConditionalGeneration.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) else: config.vocab_size = len(tokenizer) model = FlaxT5ForConditionalGeneration( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) # Data collator # This one will take care of randomly masking the tokens. data_collator = FlaxDataCollatorForT5MLM( tokenizer=tokenizer, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, input_length=max_seq_length, target_length=targets_length, pad_token_id=model.config.pad_token_id, decoder_start_token_id=model.config.decoder_start_token_id, ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = per_device_eval_batch_size * jax.device_count() num_train_steps = len( tokenized_datasets["train"]) // train_batch_size * num_epochs num_of_hosts = jax.process_count() current_host_idx = jax.process_index() # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps) decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) linear_decay_lr_schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = set([ layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() ]) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer) # Define gradient update step fn def train_step(state, batch, dropout_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def loss_fn(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] # compute loss loss = optax.softmax_cross_entropy( logits, onehot(labels, logits.shape[-1])).mean() return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng # Create parallel version of the train step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] # compute loss loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) # compute accuracy accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) # summarize metrics metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, )) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(tokenized_datasets["train"]) # Avoid using jax.numpy here in case of TPU training train_samples_idx = np.random.permutation(np.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate( tqdm(train_batch_idx, desc="Training...", position=1)): samples = [ tokenized_datasets["train"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) local_host_model_inputs = { key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx] for key, value in model_inputs.data.items() } # Model forward model_inputs = shard(local_host_model_inputs) state, train_metric, dropout_rngs = p_train_step( state, model_inputs, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:" f" {train_metric['learning_rate'].mean()})") train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write( f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" ) # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get( jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) # Eval after training if training_args.do_eval: num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics) if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
Args: model: The model the evaluate. state: Model state containing state for stateful flax.nn functions, such as batch normalization. eval_dataset: Dataset to evaluate the model over. Returns: Dictionary containing the average loss and accuracy of the model on the given dataset. """ p_eval_step = jax.pmap(_eval_step, axis_name='batch') batch_sizes = [] metrics = [] for batch in eval_dataset: batch_size = len(batch[LABELKEY]) # These are required for pmap call. batch = _shard_batch(batch) batch_metrics = p_eval_step(model, state, batch) batch_sizes.append(batch_size) metrics.append(batch_metrics) # Note: use weighted mean, since we do mean of means with potentially # different batch sizes otherwise. batch_sizes = jnp.array(batch_sizes) weights = batch_sizes / jnp.sum(batch_sizes) eval_metrics = common_utils.get_metrics(metrics) return jax.tree_map(lambda x: (weights * x).sum(), eval_metrics)
num_eval_samples = len(tokenized_datasets["validation"]) eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward model_inputs = shard(model_inputs.data) metrics = p_eval_step(state.params, model_inputs) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.sum, eval_metrics) eval_normalizer = eval_metrics.pop("normalizer") eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) # Update progress bar epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
def train_loop(config, dropout_rngs, eval_ds, eval_freq, num_eval_steps, num_train_steps, optimizer, state, p_eval_step, p_train_step, start_step, train_iter, summary_writer): """Training loop. Args: config: experiment config. dropout_rngs: float array; Jax PRNG key. eval_ds: tf.dataset; Evaluation dataset. eval_freq: int; Evaluation frequency; num_eval_steps: int; Number of evaluation steps. num_train_steps: int; Number of training steps. optimizer: flax optimizer. state: model state, e.g. batch statistics. p_eval_step: fn; Pmapped evaluation step function. p_train_step: fn; Pmapped train step function. start_step: int; global training step. train_iter: iter(tf.dataset); Training data iterator. summary_writer: tensorflow summary writer. Returns: optimizer, global training step """ metrics_all = [] tick = time.time() logging.info('Starting training') logging.info('====================') step = 0 for step, batch in zip(range(start_step, num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, state, metrics, dropout_rngs = p_train_step( optimizer, state, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % config.checkpoint_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.save_checkpoints: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint(FLAGS.model_dir, (jax_utils.unreplicate(optimizer), jax_utils.unreplicate(state)), step) # Periodic metric handling. if step % eval_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: logging.info('train in step: %d, loss: %.4f, acc: %.4f', step, summary['loss'], summary['accuracy']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock summary_writer.scalar('examples_per_second', steps_per_sec * config.batch_size, step) for key, val in summary.items(): summary_writer.scalar(f'train_{key}', val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Eval Metrics eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.repeat(1) else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): # pylint: disable=protected-access eval_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), eval_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if jax.host_id() == 0: for key, val in eval_summary.items(): summary_writer.scalar(f'val_{key}', val, step) summary_writer.flush() return optimizer, state, step