Beispiel #1
0
def run_eval_loop(sess=None,
                  fetches_to_collect=None,
                  other_ops=(),
                  hooks=(),
                  checkpoint_dir=None,
                  load_path=None,
                  max_steps=None,
                  max_seconds=None,
                  init_fn=None):
    if isinstance(fetches_to_collect, dict):
        keys, values = zip(*fetches_to_collect.items())
        results = run_eval_loop(sess, list(values), other_ops, hooks,
                                checkpoint_dir, load_path, max_steps,
                                max_seconds)
        return attrdict.AttrDict(dict(zip(keys, results)))

    sess_creator = None if sess else make_session_creator(
        checkpoint_dir, load_path, init_fn)
    collect_hook = session_hooks.collect_hook(fetches_to_collect)
    hooks = [collect_hook, *hooks]
    if max_seconds or max_steps:
        stop_hook = session_hooks.stop_after_steps_or_seconds_hook(
            max_seconds, max_steps)
        hooks.append(stop_hook)

    tfasync.main_loop(sess=sess,
                      sess_creator=sess_creator,
                      ops=other_ops,
                      hooks=hooks)
    return collect_hook.result
Beispiel #2
0
def run_train_loop(
        train_op, sess=None, checkpoint_dir=None, load_path=None, max_steps=None, hooks=(),
        init_fn=None):
    sess_creator = None if sess else make_session_creator(checkpoint_dir, load_path, init_fn)
    if max_steps:
        stop_hook = session_hooks.stop_after_steps_or_seconds_hook(steps_limit=max_steps)
        hooks = [stop_hook, *hooks]

    hooks = [*hooks, session_hooks.stop_on_signal_hook()]

    tfasync.main_loop(sess=sess, sess_creator=sess_creator, ops=train_op, hooks=hooks)
Beispiel #3
0
async def validation_hook(
        metrics, checkpoint_path_prefix=None, summary_output_dir=None, max_steps=None,
        max_seconds=None):
    """Evaluates the model and logs the resulting average evaluation metrics.

    Furthermore, if `checkpoint_path_prefix` is given, it also saves a checkpoint whenever there
    is record-low loss. In this case the loss tensor is assumed to be the first metric.
    To reiterate, if checkpointing is desired, the first metric must be a loss, for which lower
    values are better.

    This hook runs a separate validation loop within the existing session.

    Args:
        metric_tensors: The tensors representing the evaluation metrics on the validation set.
        metric_names: The names of the metrics in the same order.
        metric_format_specs: The format specifiers for str.format, for pretty printing in the log.
        checkpoint_path_prefix: Prefix of the path where the best model's checkpoint should be
            saved. If None, no checkpointing is done.
        summary_output_dir: TensorBoard summary directory where the metrics should be written.
        max_steps: The maximum number of steps to run in the validation loop.
        max_seconds: The maximum time to spend on validation. After this the loop is stopped and the
            average metrics are calculated based on the steps that were run.

    Returns:
        A hook that can be used in a training loop.
    """

    saver = tf.compat.v1.train.Saver(max_to_keep=1, save_relative_paths=True)
    summary_writer = (tf.compat.v1.summary.FileWriterCache.get(summary_output_dir)
                      if summary_output_dir else None)

    # Read the best loss from filesystem. The loss value is encoded into the filename of the
    # checkpoint, such as "something-val-0.003348-158585.index"
    if checkpoint_path_prefix:
        paths = glob.glob(checkpoint_path_prefix + '*')
        filename_matches = [re.match(r'.+?-val-(.+?)-.+?\.index', path) for path in paths]
        losses = [float(match[1]) for match in filename_matches if match]
        best_main_metric_value = min(losses) if losses else np.inf
    else:
        best_main_metric_value = np.inf

    # We create some hooks for the internal, nested validation loop. (Hooks within a hook!)
    # First a counter to count steps.
    counter = tfu.get_or_create_counter('validation')
    counter_h = counter_hook(counter)
    eta_h = eta_hook(max_steps, step_tensor=counter.var, init_phase_seconds=30, every_n_secs=10)
    collect_h = collect_hook([m.tensor for m in metrics])
    stop_h = stop_after_steps_or_seconds_hook(seconds_limit=max_seconds, steps_limit=max_steps)
    sigint_h = stop_on_signal_hook(sig=signal.SIGINT)
    sigterm_h = stop_on_signal_hook(sig=signal.SIGTERM, is_additional=True)
    inner_hooks = [counter_h, eta_h, collect_h, stop_h, sigint_h, sigterm_h]
    global_step_tensor = tf.compat.v1.train.get_global_step()

    async for run_context, run_values in tfasync.run_detailed_iter(global_step_tensor):
        global_step_value = run_values.results

        # don't validate after the very first step, not much to validate yet
        if global_step_value == 0:
            continue

        logging.info('Running validation')

        # Run the evaluation loop in the existing session that this validation hook operates in
        tfasync.main_loop(sess=run_context.session, hooks=inner_hooks)

        aggregated_values = [
            metric.get_aggregated_value(result)
            for metric, result in zip(metrics, collect_h.result)]

        # Write to log
        for metric, aggregated_value in zip(metrics, aggregated_values):
            logging.info(metric.format(aggregated_value))

        # Write summaries for TensorBoard
        if summary_writer:
            summary = tfu.scalar_dict_to_summary({
                f'validation/{metric.name}': np.nanmean(value)
                for metric, value in zip(metrics, aggregated_values)})
            summary_writer.add_summary(summary, global_step=global_step_value)
            summary_writer.flush()

        # Save checkpoint if the loss improved
        aggregated_main_value = np.nanmean(aggregated_values[0])
        if metrics[0].is_first_better(aggregated_main_value, best_main_metric_value):
            best_main_metric_value = aggregated_main_value
            logging.info(f'Main metric: {metrics[0].format(aggregated_main_value):} (new record!)')
            if checkpoint_path_prefix:
                saver.save(
                    run_context.session,
                    f'{checkpoint_path_prefix}-val-{aggregated_main_value:.6f}',
                    global_step=global_step_value)
        else:
            logging.info(f'Main metric: {aggregated_main_value:.6f}')