Exemplo n.º 1
0
    ])

OptState = collections.namedtuple(
    '_OptState',
    [
        'weights',  # Model weights.
        'slots',  # Per-parameter optimizer state, e.g. gradient moments.
        'opt_params',  # Optimizer (hyper)parameters, e.g. learning rate, momentum.
    ])

_DEFAULT_METRICS = {
    'loss': tl.CrossEntropyLoss(),
    'accuracy': tl.Accuracy(),
    'sequence_accuracy': tl.SequenceAccuracy(),
    'neg_log_perplexity': tl.Serial(tl.CrossEntropyLoss(), tl.Negate()),
    'weights_per_batch_per_core': tl.SumOfWeights(),
}


class Trainer(object):
    """Trax trainer.

  A trainer allows to make training steps, train for full epochs,
  save the training state and access evaluation data.
  """
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
Exemplo n.º 2
0
def finetune(output_dir,
             model=gin.REQUIRED,
             dataset_name=gin.REQUIRED,
             batch_size=16,
             num_train_epochs=3.0):
    """Fine-tuning loop for GLUE, largely following the BERT recipe."""
    ds_info = tfds.builder(dataset_name).info
    is_regression_task = (ds_info.features.dtype['label'] == onp.float32)

    if is_regression_task:
        # Regression task
        loss_fn = tl.L2Loss()
        metrics = {
            'loss': tl.L2Loss(),
            'weights_per_batch_per_core': tl.SumOfWeights(),
        }
        model = functools.partial(model, head=trax.models.BERTRegressionHead)
    else:
        # Classification task
        loss_fn = tl.CrossEntropyLoss()
        metrics = {
            'loss': tl.CrossEntropyLoss(),
            'accuracy': tl.AccuracyScalar(),
            'weights_per_batch_per_core': tl.SumOfWeights(),
        }
        n_classes = ds_info.features['label'].num_classes
        with gin.unlock_config():
            gin.parse_config(f'BERTClassifierHead.n_classes = {n_classes}')
        model = functools.partial(model, head=trax.models.BERTClassifierHead)

    num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples
    total_steps = int(num_train_examples * num_train_epochs // batch_size)
    warmup_steps = int(0.1 * total_steps)
    cooldown_steps = total_steps - warmup_steps

    # TODO(kitaev): Re-think how configuration works for this setup.
    with gin.unlock_config():
        gin.parse_config(f"""
    # TODO(kitaev): Devlin et al. use linear decay, not cosine decay
    MultifactorSchedule.factors = 'constant * linear_warmup * cosine_decay'
    MultifactorSchedule.warmup_steps = {warmup_steps}
    MultifactorSchedule.steps_per_cycle = {cooldown_steps}

    # TODO(kitaev): Devlin et al. use 0.01, but exclude biases from weight decay
    Adam.weight_decay_rate=0.0
    Adam.b1 = 0.9
    Adam.b2 = 0.999
    Adam.eps = 1e-6

    glue_inputs.dataset_name = '{dataset_name}'
    glue_inputs.batch_size = {batch_size}
    glue_inputs.tokenizer = @bert_tokenizer
    """)

    trainer = Trainer(
        model,
        loss_fn,
        optimizer=trax.optimizers.Adam,
        lr_schedule=trax.lr_schedules.MultifactorSchedule,
        inputs=glue_inputs,
        output_dir=output_dir,
        random_seed=None,
        n_devices=None,  # Use all available.
        checkpoints_at=None,
        nontrainable_param_map=None,
        metrics=metrics,
        id_to_mask=None,
        checkpoint_lowest=None,
        checkpoint_highest=None,
    )

    trainer.log_step('Starting training using %d devices' % trainer.n_devices)
    trainer.print_n_weights()

    trainer.train_epoch(n_steps=1, n_eval_steps=10)
    trainer.save_gin()
    trainer.train_epoch(n_steps=warmup_steps - 1, n_eval_steps=10)
    trainer.train_epoch(n_steps=cooldown_steps, n_eval_steps=10)

    trainer.log_step('Training done')

    # Evaluation
    # pylint: disable=protected-access
    def my_jit(forward, n_devices):
        """Returns a JIT-compiled forward function running on n_devices."""
        model_predict = trax.layers.base._accelerate(forward, n_devices)
        if n_devices == 1:

            def predict1(x, weights, state):
                res, state = model_predict(x,
                                           weights,
                                           state,
                                           rng=jax.random.PRNGKey(0))
                return res

            return predict1

        def predict(x, weights, state):
            """Predict function jited and parallelized as requested."""
            res, state = trax.layers.base._combine_devices(
                model_predict(
                    trax.layers.base.reshape_by_device(x, n_devices), weights,
                    state,
                    np.broadcast_to(jax.random.PRNGKey(0)[None, :], (8, 2))))
            return res

        return predict

    fwd = functools.partial(my_jit(trainer._model_predict_eval.pure_fn,
                                   trainer._n_devices),
                            weights=trainer._opt_state[0][0],
                            state=trainer._model_state[0])

    def run_model(stream):
        """Run forward pass on a dataset."""
        all_out = []
        all_idx = []
        all_labels = []
        for input_ids, type_ids, idx, labels in stream:
            remainder = labels.shape[0] % trainer._n_devices
            if remainder:
                pad_amount = trainer._n_devices - remainder
                input_ids = onp.pad(input_ids, ((0, pad_amount), (0, 0)),
                                    mode='constant')
                type_ids = onp.pad(type_ids, ((0, pad_amount), (0, 0)),
                                   mode='constant')
                padded_idx = onp.pad(idx, ((0, pad_amount), ), mode='constant')
            else:
                padded_idx = idx
            out = onp.array(fwd((input_ids, type_ids, padded_idx)))
            if remainder:
                out = out[:-pad_amount]
            all_out.append(out)
            all_idx.append(idx)
            all_labels.append(labels)
        all_out = onp.concatenate(all_out, axis=0)
        all_idx = onp.concatenate(all_idx, axis=0)
        all_labels = onp.concatenate(all_labels, axis=0)

        return all_out, all_labels, all_idx

    eval_metrics = {}
    if is_regression_task:
        eval_metrics['pearsonr'] = get_pearsonr
    else:
        eval_metrics['accuracy'] = get_accuracy

    if dataset_name == 'glue/cola':
        eval_metrics['mcc'] = get_mcc
    elif dataset_name in ('glue/mrpc', 'glue/qqp'):
        eval_metrics['f1_accuracy_mean'] = get_f1_accuracy_mean

    preds_labels_idxs = [
        run_model(stream)
        for stream in trainer._inputs.extra_streams(trainer._n_devices)
    ]

    # Log results on development data
    eval_results_path = os.path.join(trainer._output_dir, 'eval_results.txt')
    with tf.io.gfile.GFile(eval_results_path, 'w') as f:
        guess, gold, _ = preds_labels_idxs[0]
        if is_regression_task:
            guess = guess[:, 0]
        else:
            guess = guess.argmax(-1)
        for name, fn in sorted(eval_metrics.items()):
            val = fn(guess, gold)
            f.write(f'eval_{name} = {val:.06f}\n')
            trainer.log_step(f'eval_{name} = {val:.06f}\n')

        if dataset_name == 'glue/mnli':
            guess, gold, _ = preds_labels_idxs[1]
            guess = guess.argmax(-1)
            for name, fn in sorted(eval_metrics.items()):
                val = fn(guess, gold)
                f.write(f'eval_mismatched_{name} = {val:.06f}\n')
                trainer.log_step(f'eval_mismatched_{name} = {val:.06f}\n')

        f.write(f'global_step = {trainer.step}\n')

    # Write predictions for test data
    path_map = {
        'glue/cola': 'CoLA.tsv',
        'glue/mrpc': 'MRPC.tsv',
        'glue/qqp': 'QQP.tsv',
        'glue/sst2': 'SST-2.tsv',
        'glue/mnli': 'MNLI-mm.tsv',
        'glue/qnli': 'QNLI.tsv',
        'glue/rte': 'RTE.tsv',
        # No eval on WNLI for now. BERT accuracy on WNLI is below baseline, unless
        # special training recipe is used.
        # 'glue/wnli': 'WNLI.tsv',
    }

    if dataset_name == 'glue/stsb':
        test_results_path = os.path.join(trainer._output_dir, 'STS-B.tsv')
        idxs = preds_labels_idxs[-1][2]
        guess = preds_labels_idxs[-1][0][:, 0]
        with tf.io.gfile.GFile(test_results_path, 'w') as f:
            f.write('index\tprediction\n')
            for idx, val in zip(idxs, guess):
                f.write(f'{idx}\t{val:.06f}\n')
    elif dataset_name in path_map:
        if dataset_name in ('glue/cola', 'glue/mrpc', 'glue/qqp', 'glue/sst2'):
            label_set = ['0', '1']
        elif dataset_name in ('glue/qnli', 'glue/rte'):
            label_set = ['entailment', 'not_entailment']
        elif dataset_name == 'glue/mnli':
            label_set = ['entailment', 'neutral', 'contradiction']
        else:
            assert False, f'Unexpected dataset_name {dataset_name}'

        test_results_path = os.path.join(trainer._output_dir,
                                         path_map[dataset_name])

        idxs = preds_labels_idxs[-1][2]
        guess = preds_labels_idxs[-1][0].argmax(-1)
        with tf.io.gfile.GFile(test_results_path, 'w') as f:
            f.write('index\tprediction\n')
            for idx, val in zip(idxs, guess):
                f.write(f'{idx}\t{label_set[val]}\n')

        trainer.log_step(f'Predictions written to {test_results_path}')

        if dataset_name == 'glue/mnli':
            test_results_path = os.path.join(trainer._output_dir, 'MNLI-m.tsv')
            idxs = preds_labels_idxs[-2][2]
            guess = preds_labels_idxs[-2][0].argmax(-1)
            with tf.io.gfile.GFile(test_results_path, 'w') as f:
                f.write('index\tprediction\n')
                for idx, val in zip(idxs, guess):
                    f.write(f'{idx}\t{label_set[val]}\n')
            trainer.log_step(f'Predictions written to {test_results_path}')

    return trainer, preds_labels_idxs