def eval_driver(params, checkpoint_dir, model_dir, input_fn, model_fn): """Eval. Args: params: hyper-parameters. checkpoint_dir: where the checkpoints live and where `done` is found. model_dir: where to dump eval TensorBoard logs. input_fn: for `Estimator`. model_fn: for `Estimator`. """ estimator = utils.create_estimator(params, model_dir, model_fn) eval_hooks = [] prev_checkpoint = None num_mins_waited = 0 while True: curr_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) if curr_checkpoint is not None and curr_checkpoint != prev_checkpoint: tf.logging.info('Eval at {0}'.format(curr_checkpoint)) checkpoint_number = int( curr_checkpoint.split('/')[-1].split('-')[-1]) if (checkpoint_number >= params.start_moving_average and not params.moving_average): tf.logging.info('From now on, use moving average for eval') del estimator params.set_hparam('moving_average', True) estimator = utils.create_estimator(params, model_dir, model_fn) try: results = estimator.evaluate(input_fn=input_fn, steps=params.num_eval_steps, hooks=eval_hooks, checkpoint_path=curr_checkpoint) log_ppl = results['log_ppl/{0}'.format(params.task_mode)] ppl = np.exp(log_ppl) tf.logging.info('Eval step={0} {1}_ppl={2:<.2f}'.format( results['global_step'], params.task_mode, ppl)) except Warning as w: tf.logging.info(w) except Exception: # pylint: disable=broad-except traceback.print_exc() tf.logging.info('Eval failed. Retrying...') continue prev_checkpoint = curr_checkpoint num_mins_waited = 0 elif gfile.Exists(os.path.join(checkpoint_dir, 'done')): tf.logging.info('Finished') sys.exit(0) else: time.sleep(30) num_mins_waited += 0.5 tf.logging.info('Waited {0:<.1f} mins'.format(num_mins_waited)) if num_mins_waited >= 120: sys.exit(0)
def train_driver(params, model_dir, input_fn, model_fn): """What we will do for training.""" estimator = utils.create_estimator(params, model_dir, model_fn) tf.logging.info('Train for {0} steps.'.format(params.num_train_steps)) for trial_id in range(MAX_RETRIES): try: estimator.train(input_fn=input_fn, max_steps=params.num_train_steps) break except Warning as w: tf.logging.info(w) except Exception as e: # pylint: disable=broad-except tf.logging.info(e) traceback.print_exc() tf.logging.info('Failed {0} times. Retry!'.format(trial_id + 1)) continue else: with gfile.GFile(os.path.join(model_dir, 'done'), 'w') as fout: fout.write('Job failed after {0} retries'.format(MAX_RETRIES)) fout.flush() sys.exit(1) with gfile.GFile(os.path.join(model_dir, 'done'), 'w') as fout: fout.write('Job finished') fout.flush()