def train_fn(hparams, num_workers): """Copy of train function from estimator.py.""" # TODO: Merge improvements into the original. # pylint: disable=protected-access hparams.tgt_sos_id, hparams.tgt_eos_id = nmt_estimator._get_tgt_sos_eos_id( hparams) model_fn = nmt_estimator.make_model_fn(hparams) def print_log(): mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=0) mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE, value=hparams.num_examples_per_epoch) if hparams.use_tpu_low_level_api: runner = create_train_runner(hparams, num_workers) mlperf_log.gnmt_print(key=mlperf_log.RUN_START) input_fn = DistributedPipeline(hparams, num_workers) runner.initialize(input_fn, {}) runner.build_model(model_fn, {}) print_log() runner.train(0, hparams.num_train_steps) return 0.0 # cluster = tf.contrib.cluster_resolver.TPUClusterResolver(hparams.tpu_name) # cluster_spec = cluster.cluster_spec() # print('cluster_spec: %s' % cluster_spec) # num_workers = cluster_spec.num_tasks('tpu_worker') # print('num_workers: %s' % num_workers) pipeline = DistributedPipeline(hparams, num_workers) print_log() if hparams.use_tpu: run_config = nmt_estimator._get_tpu_run_config(hparams, True) estimator = tf.contrib.tpu.TPUEstimator( model_fn=model_fn, config=run_config, use_tpu=hparams.use_tpu, train_batch_size=hparams.batch_size, eval_batch_size=hparams.batch_size, predict_batch_size=hparams.infer_batch_size) else: raise ValueError("Distributed input pipeline only supported on TPUs.") hooks = [pipeline] if hparams.use_async_checkpoint: hooks.append( async_checkpoint.AsyncCheckpointSaverHook( checkpoint_dir=hparams.out_dir, save_steps=int(hparams.num_examples_per_epoch / hparams.batch_size))) estimator.train(input_fn=pipeline, max_steps=hparams.num_train_steps, hooks=hooks) # Return value is not used return 0.0
def train_and_eval_with_low_level_api(hparams, num_workers): """Train and evaluation function.""" # pylint: disable=protected-access hparams.tgt_sos_id, hparams.tgt_eos_id = 1, 2 model_fn = nmt_estimator.make_model_fn(hparams) train_runner = create_train_runner(hparams, num_workers) eval_runner = nmt_estimator.create_eval_runner(hparams, model_fn) mlperf_log.gnmt_print(key=mlperf_log.RUN_START) train_input_fn = DistributedPipeline(hparams, num_workers) train_runner.initialize(train_input_fn, {}) train_runner.build_model(model_fn, {}) eval_input_fn = nmt_estimator.make_input_fn( hparams, tf.contrib.learn.ModeKeys.INFER) params = { "infer_batch_size": int(hparams.infer_batch_size / hparams.num_shards) } eval_runner.initialize(eval_input_fn, params) eval_runner.build_model(model_fn, params) score = 0.0 mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP) mlperf_log.gnmt_print(key=mlperf_log.EVAL_TARGET, value=hparams.target_bleu) current_step = 0 for i in range(hparams.max_train_epochs): mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=i) tf.logging.info("Start training epoch %d", i) mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE, value=hparams.num_examples_per_epoch) steps_per_epoch = int(hparams.num_examples_per_epoch / hparams.batch_size) train_runner.train(current_step, steps_per_epoch // 2) current_step += steps_per_epoch // 2 train_runner.train(current_step, steps_per_epoch // 2) current_step += steps_per_epoch // 2 mlperf_log.gnmt_print(key=mlperf_log.TRAIN_CHECKPOINT, value=("Under " + hparams.out_dir)) tf.logging.info("End training epoch %d", i) mlperf_log.gnmt_print(key=mlperf_log.EVAL_START) predictions = list(eval_runner.predict()) score = nmt_estimator.get_metric(hparams, predictions, current_step) tf.logging.info("Score after epoch %d: %f", i, score) mlperf_log.gnmt_print(key=mlperf_log.EVAL_ACCURACY, value={ "value": score, "epoch": i }) mlperf_log.gnmt_print(key=mlperf_log.EVAL_STOP, value=i) # if score >= hparams.target_bleu: # mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": True}) # return score mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": False}) return score
def train_and_eval_fn(hparams, num_workers): """Train and evaluation function.""" # pylint: disable=protected-access mlperf_log.gnmt_print(key=mlperf_log.RUN_START) hparams.tgt_sos_id, hparams.tgt_eos_id = 1, 2 model_fn = nmt_estimator.make_model_fn(hparams) pipeline = DistributedPipeline(hparams, num_workers) run_config = nmt_estimator._get_tpu_run_config(hparams) estimator = tf.contrib.tpu.TPUEstimator( model_fn=model_fn, config=run_config, use_tpu=hparams.use_tpu, train_batch_size=hparams.batch_size, eval_batch_size=hparams.batch_size, predict_batch_size=hparams.infer_batch_size) score = 0.0 mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP) mlperf_log.gnmt_print(key=mlperf_log.EVAL_TARGET, value=hparams.target_bleu) for i in range(hparams.max_train_epochs): mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=i) tf.logging.info("Start training epoch %d", i) mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE, value=hparams.num_examples_per_epoch) steps_per_epoch = int(hparams.num_examples_per_epoch / hparams.batch_size) max_steps = steps_per_epoch * (i + 1) estimator.train(input_fn=pipeline, max_steps=max_steps, hooks=[pipeline]) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_CHECKPOINT, value=("Under " + hparams.out_dir)) tf.logging.info("End training epoch %d", i) mlperf_log.gnmt_print(key=mlperf_log.EVAL_START) score = nmt_estimator.get_metric_from_estimator(hparams, estimator) tf.logging.info("Score after epoch %d: %f", i, score) mlperf_log.gnmt_print(key=mlperf_log.EVAL_ACCURACY, value={ "value": score, "epoch": i }) mlperf_log.gnmt_print(key=mlperf_log.EVAL_STOP, value=i) if score >= hparams.target_bleu: mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": True}) return score mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": False}) return score