def model_fn(features, labels, mode): """The model function for creating an Estimtator.""" del labels input_count = tf.reduce_sum( tf.to_int32( tf.greater(features["input_refs"][:, :, 1], features["input_refs"][:, :, 0]))) tf.summary.scalar("input_count", input_count) loss_dict, pred_dict, areas = seq2act_model.core_graph( features, hparams, mode, compute_additional_loss_fn) if mode == tf.estimator.ModeKeys.PREDICT: pred_dict["sequences"] = decode_sequence( features, areas, hparams, decode_length, post_processing=FLAGS.post_processing) return tf.estimator.EstimatorSpec(mode, predictions=pred_dict) elif mode == tf.estimator.ModeKeys.EVAL: metrics = {} _eval(metrics, pred_dict, loss_dict, features, areas, compute_seq_accuracy, hparams, metric_types=FLAGS.metric_types.split(","), decode_length=decode_length) if compute_additional_metric_fn: compute_additional_metric_fn(metrics, pred_dict, features) return tf.estimator.EstimatorSpec(mode, loss=loss_dict["total_loss"], eval_metric_ops=metrics) else: assert mode == tf.estimator.ModeKeys.TRAIN loss = loss_dict["total_loss"] for loss_name in loss_dict: if loss_name == "total_loss": continue if loss_name.endswith("losses"): continue tf.summary.scalar(loss_name, loss_dict[loss_name]) step_num = tf.to_float(tf.train.get_global_step()) schedule_string = hparams.learning_rate_schedule names = schedule_string.split("*") names = [name.strip() for name in names if name.strip()] ret = tf.constant(1.0) for name in names: ret *= learning_rate.learning_rate_factor( name, step_num, hparams) train_op = optimize.optimize(loss, ret, hparams) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
def learning_rate_schedule(hparams): """Learning rate schedule based on hparams.""" mlperf_log.transformer_print(key=mlperf_log.OPT_LR, deferred=True) mlperf_log.transformer_print(key=mlperf_log.OPT_LR_WARMUP_STEPS, value=hparams.learning_rate_warmup_steps) step_num = _global_step(hparams) # Simulate pretraining the encoder, decoder and posterior with the same # learning rate schedule, and then restoring the parameters. # using `warm_start_from` is not compatible with actnorm DDI on TPUs. step_num = tf.where(step_num < hparams.kl_startup_steps, step_num, step_num - hparams.kl_startup_steps) schedule_string = hparams.learning_rate_schedule names = schedule_string.split("*") names = [name.strip() for name in names if name.strip()] ret = tf.constant(1.0) for name in names: ret *= lr.learning_rate_factor(name, step_num, hparams) return ret