Exemplo n.º 1
0
Arquivo: train.py Projeto: Gs-001/quad
def _define_loop(graph, logdir, train_steps, eval_steps):
    """Create and configure a training loop with training and evaluation phases.

  Args:
    graph: Object providing graph elements via attributes.
    logdir: Log directory for storing checkpoints and summaries.
    train_steps: Number of training steps per epoch.
    eval_steps: Number of evaluation steps per epoch.

  Returns:
    Loop object.
  """
    loop = tools.Loop(logdir, graph.step, graph.should_log, graph.do_report,
                      graph.force_reset)
    loop.add_phase('train',
                   graph.done,
                   graph.score,
                   graph.summary,
                   train_steps,
                   report_every=None,
                   log_every=train_steps // 2,
                   checkpoint_every=None,
                   feed={graph.is_training: True})
    loop.add_phase('eval',
                   graph.done,
                   graph.score,
                   graph.summary,
                   eval_steps,
                   report_every=eval_steps,
                   log_every=eval_steps // 2,
                   checkpoint_every=10 * eval_steps,
                   feed={graph.is_training: False})
    return loop
Exemplo n.º 2
0
def _define_loop(graph, eval_steps):
  """Create and configure an evaluation loop.

  Args:
    graph: Object providing graph elements via attributes.
    eval_steps: Number of evaluation steps per epoch.

  Returns:
    Loop object.
  """
  loop = tools.Loop(
      None, graph.step, graph.should_log, graph.do_report, graph.force_reset)
  loop.add_phase(
      'eval', graph.done, graph.score, graph.summary, eval_steps,
      report_every=eval_steps,
      log_every=None,
      checkpoint_every=None,
      feed={graph.is_training: False})
  return loop