def test_train_mnist(self): """Train MNIST model (almost) fully, to compare to other implementations. Evals for cross-entropy loss and accuracy are run every 50 steps; their values are visible in the test log. """ gin.parse_config([ 'batch_fn.batch_size_per_device = 256', 'batch_fn.eval_batch_size = 256', ]) mnist_model = tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(10), tl.LogSoftmax(), ) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), tl.CrossEntropyLoss(), adafactor.Adafactor(.02)) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [tl.CrossEntropyLoss(), tl.AccuracyScalar()], names=['CrossEntropyLoss', 'AccuracyScalar'], eval_at=lambda step_n: step_n % 50 == 0, eval_N=10) training_session = training.Loop(mnist_model, task, eval_task=eval_task) training_session.run(n_steps=1000) self.assertEqual(training_session.current_step(), 1000)
def finetune(output_dir, model=gin.REQUIRED, dataset_name=gin.REQUIRED, batch_size=16, num_train_epochs=3.0): """Fine-tuning loop for GLUE, largely following the BERT recipe.""" ds_info = tfds.builder(dataset_name).info is_regression_task = (ds_info.features.dtype['label'] == onp.float32) if is_regression_task: # Regression task loss_fn = tl.L2Loss() metrics = { 'loss': tl.L2Loss(), 'weights_per_batch_per_core': tl.SumOfWeights(), } model = functools.partial(model, head=trax.models.BERTRegressionHead) else: # Classification task loss_fn = tl.CrossEntropyLoss() metrics = { 'loss': tl.CrossEntropyLoss(), 'accuracy': tl.AccuracyScalar(), 'weights_per_batch_per_core': tl.SumOfWeights(), } n_classes = ds_info.features['label'].num_classes with gin.unlock_config(): gin.parse_config(f'BERTClassifierHead.n_classes = {n_classes}') model = functools.partial(model, head=trax.models.BERTClassifierHead) num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples total_steps = int(num_train_examples * num_train_epochs // batch_size) warmup_steps = int(0.1 * total_steps) cooldown_steps = total_steps - warmup_steps # TODO(kitaev): Re-think how configuration works for this setup. with gin.unlock_config(): gin.parse_config(f""" # TODO(kitaev): Devlin et al. use linear decay, not cosine decay MultifactorSchedule.factors = 'constant * linear_warmup * cosine_decay' MultifactorSchedule.warmup_steps = {warmup_steps} MultifactorSchedule.steps_per_cycle = {cooldown_steps} # TODO(kitaev): Devlin et al. use 0.01, but exclude biases from weight decay Adam.weight_decay_rate=0.0 Adam.b1 = 0.9 Adam.b2 = 0.999 Adam.eps = 1e-6 glue_inputs.dataset_name = '{dataset_name}' glue_inputs.batch_size = {batch_size} glue_inputs.tokenizer = @bert_tokenizer """) trainer = Trainer( model, loss_fn, optimizer=trax.optimizers.Adam, lr_schedule=trax.lr_schedules.MultifactorSchedule, inputs=glue_inputs, output_dir=output_dir, random_seed=None, n_devices=None, # Use all available. checkpoints_at=None, nontrainable_param_map=None, metrics=metrics, id_to_mask=None, checkpoint_lowest=None, checkpoint_highest=None, ) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() trainer.train_epoch(n_steps=1, n_eval_steps=10) trainer.save_gin() trainer.train_epoch(n_steps=warmup_steps - 1, n_eval_steps=10) trainer.train_epoch(n_steps=cooldown_steps, n_eval_steps=10) trainer.log_step('Training done') # Evaluation # pylint: disable=protected-access def my_jit(forward, n_devices): """Returns a JIT-compiled forward function running on n_devices.""" model_predict = trax.layers.base._accelerate(forward, n_devices) if n_devices == 1: def predict1(x, weights, state): res, state = model_predict(x, weights, state, rng=jax.random.PRNGKey(0)) return res return predict1 def predict(x, weights, state): """Predict function jited and parallelized as requested.""" res, state = trax.layers.base._combine_devices( model_predict( trax.layers.base.reshape_by_device(x, n_devices), weights, state, np.broadcast_to(jax.random.PRNGKey(0)[None, :], (8, 2)))) return res return predict fwd = functools.partial(my_jit(trainer._model_predict_eval.pure_fn, trainer._n_devices), weights=trainer._opt_state[0][0], state=trainer._model_state[0]) def run_model(stream): """Run forward pass on a dataset.""" all_out = [] all_idx = [] all_labels = [] for input_ids, type_ids, idx, labels in stream: remainder = labels.shape[0] % trainer._n_devices if remainder: pad_amount = trainer._n_devices - remainder input_ids = onp.pad(input_ids, ((0, pad_amount), (0, 0)), mode='constant') type_ids = onp.pad(type_ids, ((0, pad_amount), (0, 0)), mode='constant') padded_idx = onp.pad(idx, ((0, pad_amount), ), mode='constant') else: padded_idx = idx out = onp.array(fwd((input_ids, type_ids, padded_idx))) if remainder: out = out[:-pad_amount] all_out.append(out) all_idx.append(idx) all_labels.append(labels) all_out = onp.concatenate(all_out, axis=0) all_idx = onp.concatenate(all_idx, axis=0) all_labels = onp.concatenate(all_labels, axis=0) return all_out, all_labels, all_idx eval_metrics = {} if is_regression_task: eval_metrics['pearsonr'] = get_pearsonr else: eval_metrics['accuracy'] = get_accuracy if dataset_name == 'glue/cola': eval_metrics['mcc'] = get_mcc elif dataset_name in ('glue/mrpc', 'glue/qqp'): eval_metrics['f1_accuracy_mean'] = get_f1_accuracy_mean preds_labels_idxs = [ run_model(stream) for stream in trainer._inputs.extra_streams(trainer._n_devices) ] # Log results on development data eval_results_path = os.path.join(trainer._output_dir, 'eval_results.txt') with tf.io.gfile.GFile(eval_results_path, 'w') as f: guess, gold, _ = preds_labels_idxs[0] if is_regression_task: guess = guess[:, 0] else: guess = guess.argmax(-1) for name, fn in sorted(eval_metrics.items()): val = fn(guess, gold) f.write(f'eval_{name} = {val:.06f}\n') trainer.log_step(f'eval_{name} = {val:.06f}\n') if dataset_name == 'glue/mnli': guess, gold, _ = preds_labels_idxs[1] guess = guess.argmax(-1) for name, fn in sorted(eval_metrics.items()): val = fn(guess, gold) f.write(f'eval_mismatched_{name} = {val:.06f}\n') trainer.log_step(f'eval_mismatched_{name} = {val:.06f}\n') f.write(f'global_step = {trainer.step}\n') # Write predictions for test data path_map = { 'glue/cola': 'CoLA.tsv', 'glue/mrpc': 'MRPC.tsv', 'glue/qqp': 'QQP.tsv', 'glue/sst2': 'SST-2.tsv', 'glue/mnli': 'MNLI-mm.tsv', 'glue/qnli': 'QNLI.tsv', 'glue/rte': 'RTE.tsv', # No eval on WNLI for now. BERT accuracy on WNLI is below baseline, unless # special training recipe is used. # 'glue/wnli': 'WNLI.tsv', } if dataset_name == 'glue/stsb': test_results_path = os.path.join(trainer._output_dir, 'STS-B.tsv') idxs = preds_labels_idxs[-1][2] guess = preds_labels_idxs[-1][0][:, 0] with tf.io.gfile.GFile(test_results_path, 'w') as f: f.write('index\tprediction\n') for idx, val in zip(idxs, guess): f.write(f'{idx}\t{val:.06f}\n') elif dataset_name in path_map: if dataset_name in ('glue/cola', 'glue/mrpc', 'glue/qqp', 'glue/sst2'): label_set = ['0', '1'] elif dataset_name in ('glue/qnli', 'glue/rte'): label_set = ['entailment', 'not_entailment'] elif dataset_name == 'glue/mnli': label_set = ['entailment', 'neutral', 'contradiction'] else: assert False, f'Unexpected dataset_name {dataset_name}' test_results_path = os.path.join(trainer._output_dir, path_map[dataset_name]) idxs = preds_labels_idxs[-1][2] guess = preds_labels_idxs[-1][0].argmax(-1) with tf.io.gfile.GFile(test_results_path, 'w') as f: f.write('index\tprediction\n') for idx, val in zip(idxs, guess): f.write(f'{idx}\t{label_set[val]}\n') trainer.log_step(f'Predictions written to {test_results_path}') if dataset_name == 'glue/mnli': test_results_path = os.path.join(trainer._output_dir, 'MNLI-m.tsv') idxs = preds_labels_idxs[-2][2] guess = preds_labels_idxs[-2][0].argmax(-1) with tf.io.gfile.GFile(test_results_path, 'w') as f: f.write('index\tprediction\n') for idx, val in zip(idxs, guess): f.write(f'{idx}\t{label_set[val]}\n') trainer.log_step(f'Predictions written to {test_results_path}') return trainer, preds_labels_idxs
'opt_state', # OptState. 'history', # trax.history.History. 'model_state', # Auxilliary state of the model. ]) OptState = collections.namedtuple( '_OptState', [ 'weights', # Model weights. 'slots', # Per-parameter optimizer state, e.g. gradient moments. 'opt_params', # Optimizer (hyper)parameters, e.g. learning rate, momentum. ]) _DEFAULT_METRICS = { 'loss': tl.CrossEntropyLoss(), 'accuracy': tl.AccuracyScalar(), 'sequence_accuracy': tl.SequenceAccuracyScalar(), 'neg_log_perplexity': tl.Serial(tl.CrossEntropyLoss(), tl.Negate()), 'weights_per_batch_per_core': tl.SumOfWeights(), } class Trainer(object): """Trax trainer. A trainer allows to make training steps, train for full epochs, save the training state and access evaluation data. """ def __init__(self, model, loss_fn,
def test_accuracy_scalar(self): layer = tl.AccuracyScalar() xs = [np.ones((9, 4, 4, 20)), np.ones((9, 4, 4)), np.ones((9, 4, 4))] y = layer(xs) self.assertEqual(y.shape, ())