def main(_): runner_config = load_runner_config() if FLAGS.output_dir: tf.gfile.MakeDirs(FLAGS.output_dir) is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.estimator.tpu.RunConfig( master=FLAGS.master, model_dir=FLAGS.output_dir, save_checkpoints_steps=runner_config["save_checkpoints_steps"], keep_checkpoint_max=20, tpu_config=tf.estimator.tpu.TPUConfig( iterations_per_loop=runner_config["iterations_per_loop"], num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host)) model_fn = model_fn_builder(runner_config) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. batch_size = runner_config["batch_size"] estimator = tf.estimator.tpu.TPUEstimator( use_tpu=FLAGS.use_tpu, model_fn=model_fn, config=run_config, train_batch_size=batch_size, eval_batch_size=batch_size, predict_batch_size=batch_size) if FLAGS.runner_mode == "train": train_input_fn = input_fn_reader.create_input_fn( runner_config=runner_config, mode=tf.estimator.ModeKeys.TRAIN, drop_remainder=True) estimator.train( input_fn=train_input_fn, max_steps=runner_config["train_steps"]) elif FLAGS.runner_mode == "eval": # TPU needs fixed shapes, so if the last batch is smaller, we drop it. eval_input_fn = input_fn_reader.create_input_fn( runner_config=runner_config, mode=tf.estimator.ModeKeys.EVAL, drop_remainder=True) for _ in tf.train.checkpoints_iterator(FLAGS.output_dir, timeout=600): result = estimator.evaluate(input_fn=eval_input_fn) for key in sorted(result): logging.info(" %s = %s", key, str(result[key]))
def main(_): runner_config = load_runner_config() if FLAGS.output_dir: tf.io.gfile.makedirs(FLAGS.output_dir) train_model = model_fn_builder(runner_config, tf_estimator.ModeKeys.TRAIN) optimizer = tf.keras.optimizers.Adam() train_input_fn = input_fn_reader.create_input_fn( runner_config=runner_config, mode=tf_estimator.ModeKeys.TRAIN, drop_remainder=True) params = {"batch_size": runner_config["batch_size"]} train_ds = train_input_fn(params) train_loss = tf.keras.metrics.Mean(name="train_loss") @tf.function def train_step(features): with tf.GradientTape() as tape: logits = train_model(features["projection"], features["seq_length"]) loss = compute_loss(logits, features["label"], runner_config["model_config"], tf_estimator.ModeKeys.TRAIN) gradients = tape.gradient(loss, train_model.trainable_variables) optimizer.apply_gradients(zip(gradients, train_model.trainable_variables)) train_loss(loss) for epoch in range(1): train_loss.reset_states() for features in train_ds: train_step(features) step = optimizer.iterations.numpy() if step % 100 == 0: logging.info("Running step %s in epoch %s", step, epoch) logging.info("Training loss: %s, epoch: %s, step: %s", round(train_loss.result().numpy(), 4), epoch, step)