Exemple #1
0
 def update_loss(y, logits):
     loss_fn = modeling.SpanOrCrossEntropyLoss(
         reduction=tf.keras.losses.Reduction.NONE)
     return loss_metric(loss_fn(y, logits))
Exemple #2
0
def fit(model,
        strategy,
        train_dataset,
        model_dir,
        init_checkpoint_path=None,
        evaluate_fn=None,
        learning_rate=1e-5,
        learning_rate_polynomial_decay_rate=1.,
        weight_decay_rate=1e-1,
        num_warmup_steps=5000,
        num_decay_steps=51000,
        num_epochs=6):
    """Train and evaluate."""
    hparams = dict(learning_rate=learning_rate,
                   num_decay_steps=num_decay_steps,
                   num_warmup_steps=num_warmup_steps,
                   num_epochs=num_epochs,
                   weight_decay_rate=weight_decay_rate,
                   dropout_rate=FLAGS.dropout_rate,
                   attention_dropout_rate=FLAGS.attention_dropout_rate,
                   label_smoothing=FLAGS.label_smoothing)
    logging.info(hparams)
    learning_rate_schedule = nlp_optimization.WarmUp(
        learning_rate,
        tf.keras.optimizers.schedules.PolynomialDecay(
            learning_rate,
            num_decay_steps,
            end_learning_rate=0.,
            power=learning_rate_polynomial_decay_rate), num_warmup_steps)
    with strategy.scope():
        optimizer = nlp_optimization.AdamWeightDecay(
            learning_rate_schedule,
            weight_decay_rate=weight_decay_rate,
            epsilon=1e-6,
            exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
        model.compile(optimizer, loss=modeling.SpanOrCrossEntropyLoss())

    def init_fn(init_checkpoint_path):
        ckpt = tf.train.Checkpoint(encoder=model.encoder)
        ckpt.restore(init_checkpoint_path).assert_existing_objects_matched()

    with worker_context():
        ckpt_manager = tf.train.CheckpointManager(
            tf.train.Checkpoint(model=model, optimizer=optimizer),
            model_dir,
            max_to_keep=None,
            init_fn=(functools.partial(init_fn, init_checkpoint_path)
                     if init_checkpoint_path else None))
        with strategy.scope():
            ckpt_manager.restore_or_initialize()
        val_summary_writer = tf.summary.create_file_writer(
            os.path.join(model_dir, 'val'))
        best_exact_match = 0.
        for epoch in range(len(ckpt_manager.checkpoints), num_epochs):
            model.fit(train_dataset,
                      callbacks=[
                          tf.keras.callbacks.TensorBoard(model_dir,
                                                         write_graph=False),
                      ])
            ckpt_path = ckpt_manager.save()
            if evaluate_fn is None:
                continue
            metrics = evaluate_fn()
            logging.info('Epoch %d: %s', epoch + 1, metrics)
            if best_exact_match < metrics['exact_match']:
                best_exact_match = metrics['exact_match']
                model.save(os.path.join(model_dir, 'export'),
                           include_optimizer=False)
                logging.info('Exporting %s as SavedModel.', ckpt_path)
            with val_summary_writer.as_default():
                for name, data in metrics.items():
                    tf.summary.scalar(name, data, epoch + 1)