コード例 #1
0
ファイル: fixed.py プロジェクト: zqhfpjlswsqy/google-research
def eval_driver(params, checkpoint_dir, model_dir, input_fn, model_fn):
    """Eval.

  Args:
    params: hyper-parameters.
    checkpoint_dir: where the checkpoints live and where `done` is found.
    model_dir: where to dump eval TensorBoard logs.
    input_fn: for `Estimator`.
    model_fn: for `Estimator`.

  """

    estimator = utils.create_estimator(params, model_dir, model_fn)
    eval_hooks = []

    prev_checkpoint = None
    num_mins_waited = 0
    while True:
        curr_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
        if curr_checkpoint is not None and curr_checkpoint != prev_checkpoint:
            tf.logging.info('Eval at {0}'.format(curr_checkpoint))
            checkpoint_number = int(
                curr_checkpoint.split('/')[-1].split('-')[-1])
            if (checkpoint_number >= params.start_moving_average
                    and not params.moving_average):
                tf.logging.info('From now on, use moving average for eval')
                del estimator
                params.set_hparam('moving_average', True)
                estimator = utils.create_estimator(params, model_dir, model_fn)
            try:
                results = estimator.evaluate(input_fn=input_fn,
                                             steps=params.num_eval_steps,
                                             hooks=eval_hooks,
                                             checkpoint_path=curr_checkpoint)
                log_ppl = results['log_ppl/{0}'.format(params.task_mode)]
                ppl = np.exp(log_ppl)
                tf.logging.info('Eval step={0} {1}_ppl={2:<.2f}'.format(
                    results['global_step'], params.task_mode, ppl))
            except Warning as w:
                tf.logging.info(w)
            except Exception:  # pylint: disable=broad-except
                traceback.print_exc()
                tf.logging.info('Eval failed. Retrying...')
                continue
            prev_checkpoint = curr_checkpoint
            num_mins_waited = 0
        elif gfile.Exists(os.path.join(checkpoint_dir, 'done')):
            tf.logging.info('Finished')
            sys.exit(0)
        else:
            time.sleep(30)
            num_mins_waited += 0.5
            tf.logging.info('Waited {0:<.1f} mins'.format(num_mins_waited))
            if num_mins_waited >= 120:
                sys.exit(0)
コード例 #2
0
ファイル: fixed.py プロジェクト: zqhfpjlswsqy/google-research
def train_driver(params, model_dir, input_fn, model_fn):
    """What we will do for training."""

    estimator = utils.create_estimator(params, model_dir, model_fn)

    tf.logging.info('Train for {0} steps.'.format(params.num_train_steps))
    for trial_id in range(MAX_RETRIES):
        try:
            estimator.train(input_fn=input_fn,
                            max_steps=params.num_train_steps)
            break
        except Warning as w:
            tf.logging.info(w)
        except Exception as e:  # pylint: disable=broad-except
            tf.logging.info(e)
            traceback.print_exc()
            tf.logging.info('Failed {0} times. Retry!'.format(trial_id + 1))
            continue
    else:
        with gfile.GFile(os.path.join(model_dir, 'done'), 'w') as fout:
            fout.write('Job failed after {0} retries'.format(MAX_RETRIES))
            fout.flush()
        sys.exit(1)

    with gfile.GFile(os.path.join(model_dir, 'done'), 'w') as fout:
        fout.write('Job finished')
        fout.flush()