def update_loss(y, logits): loss_fn = modeling.SpanOrCrossEntropyLoss( reduction=tf.keras.losses.Reduction.NONE) return loss_metric(loss_fn(y, logits))
def fit(model, strategy, train_dataset, model_dir, init_checkpoint_path=None, evaluate_fn=None, learning_rate=1e-5, learning_rate_polynomial_decay_rate=1., weight_decay_rate=1e-1, num_warmup_steps=5000, num_decay_steps=51000, num_epochs=6): """Train and evaluate.""" hparams = dict(learning_rate=learning_rate, num_decay_steps=num_decay_steps, num_warmup_steps=num_warmup_steps, num_epochs=num_epochs, weight_decay_rate=weight_decay_rate, dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, label_smoothing=FLAGS.label_smoothing) logging.info(hparams) learning_rate_schedule = nlp_optimization.WarmUp( learning_rate, tf.keras.optimizers.schedules.PolynomialDecay( learning_rate, num_decay_steps, end_learning_rate=0., power=learning_rate_polynomial_decay_rate), num_warmup_steps) with strategy.scope(): optimizer = nlp_optimization.AdamWeightDecay( learning_rate_schedule, weight_decay_rate=weight_decay_rate, epsilon=1e-6, exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias']) model.compile(optimizer, loss=modeling.SpanOrCrossEntropyLoss()) def init_fn(init_checkpoint_path): ckpt = tf.train.Checkpoint(encoder=model.encoder) ckpt.restore(init_checkpoint_path).assert_existing_objects_matched() with worker_context(): ckpt_manager = tf.train.CheckpointManager( tf.train.Checkpoint(model=model, optimizer=optimizer), model_dir, max_to_keep=None, init_fn=(functools.partial(init_fn, init_checkpoint_path) if init_checkpoint_path else None)) with strategy.scope(): ckpt_manager.restore_or_initialize() val_summary_writer = tf.summary.create_file_writer( os.path.join(model_dir, 'val')) best_exact_match = 0. for epoch in range(len(ckpt_manager.checkpoints), num_epochs): model.fit(train_dataset, callbacks=[ tf.keras.callbacks.TensorBoard(model_dir, write_graph=False), ]) ckpt_path = ckpt_manager.save() if evaluate_fn is None: continue metrics = evaluate_fn() logging.info('Epoch %d: %s', epoch + 1, metrics) if best_exact_match < metrics['exact_match']: best_exact_match = metrics['exact_match'] model.save(os.path.join(model_dir, 'export'), include_optimizer=False) logging.info('Exporting %s as SavedModel.', ckpt_path) with val_summary_writer.as_default(): for name, data in metrics.items(): tf.summary.scalar(name, data, epoch + 1)