Пример #1
0
def create_hooks(use_tfdbg=False, use_dbgprofile=False, dbgprofile_kwargs=None,
                 use_validation_monitor=False, validation_monitor_kwargs=None,
                 use_early_stopping=False, early_stopping_kwargs=None):
  """Create train and eval hooks for Experiment."""
  train_monitors = []
  eval_hooks = []

  if use_tfdbg:
    hook = debug.LocalCLIDebugHook()
    train_monitors.append(hook)
    eval_hooks.append(hook)

  if use_dbgprofile:
    # Recorded traces can be visualized with chrome://tracing/
    # The memory/tensor lifetime is also profiled
    tf.logging.info("Using ProfilerHook")
    defaults = dict(save_steps=10, show_dataflow=True, show_memory=True)
    defaults.update(dbgprofile_kwargs)
    train_monitors.append(tf.contrib.hooks.ProfilerHook(**defaults))

  if use_validation_monitor:
    tf.logging.info("Using ValidationMonitor")
    train_monitors.append(
        tf.contrib.learn.monitors.ValidationMonitor(
            hooks=eval_hooks, **validation_monitor_kwargs))

  if use_early_stopping:
    tf.logging.info("Using EarlyStoppingHook")
    hook = metrics_hook.EarlyStoppingHook(**early_stopping_kwargs)
    # Adding to both training and eval so that eval aborts as well
    train_monitors.append(hook)
    eval_hooks.append(hook)

  return train_monitors, eval_hooks
Пример #2
0
def create_hooks(use_tfdbg=False,
                 use_dbgprofile=False,
                 dbgprofile_kwargs=None,
                 use_validation_monitor=False,
                 validation_monitor_kwargs=None,
                 use_early_stopping=False,
                 early_stopping_kwargs=None):
    """Create train and eval hooks for Experiment."""
    train_hooks = []
    eval_hooks = []

    if use_tfdbg:
        hook = debug.LocalCLIDebugHook()
        train_hooks.append(hook)
        eval_hooks.append(hook)

    if use_dbgprofile:
        # Recorded traces can be visualized with chrome://tracing/
        # The memory/tensor lifetime is also profiled
        tf.logging.info("Using ProfilerHook")
        defaults = dict(save_steps=10, show_dataflow=True, show_memory=True)
        defaults.update(dbgprofile_kwargs)
        train_hooks.append(tf.train.ProfilerHook(**defaults))

    if use_validation_monitor:
        tf.logging.info("Using ValidationMonitor")
        # Fathom
        # continuous_train_and_eval breaks early stopping
        flags = tf.flags
        FLAGS = flags.FLAGS
        assert FLAGS.schedule != 'continuous_train_and_eval'

        train_hooks.append(
            FathomValidationMonitor(
                restart_after_eval=FLAGS.restart_after_eval,
                hooks=eval_hooks,
                **validation_monitor_kwargs))

    if use_early_stopping:
        tf.logging.info("Using EarlyStoppingHook")
        hook = metrics_hook.EarlyStoppingHook(**early_stopping_kwargs)
        # Adding to both training and eval so that eval aborts as well
        train_hooks.append(hook)
        eval_hooks.append(hook)

    # NOTE:
    # Attempt at adding better OOM feedback--although doesn't seem to work.
    # (See MemoryReportingHook)
    # Commenting this out for now because it doens't seem to actually work...
    #train_monitors.append(MemoryReportingHook())
    #eval_hooks.append(MemoryReportingHook())

    return train_hooks, eval_hooks
Пример #3
0
    def testEarlyStoppingHook(self):
        global_step = tf.train.create_global_step()
        counter = tf.get_variable("count", initializer=0, dtype=tf.int32)
        tf.summary.scalar("count", counter)
        incr_global_step = tf.assign_add(global_step, 1)
        incr_counter = tf.assign_add(counter, 1)

        # Stop if the global step has not gone up by more than 1 in 20 steps.

        ckpt_dir = self.ckpt_dir("early")
        stop_hook = metrics_hook.EarlyStoppingHook(ckpt_dir,
                                                   "count_1",
                                                   num_plateau_steps=20,
                                                   plateau_delta=1.,
                                                   plateau_decrease=False,
                                                   every_n_steps=10)
        with self.sess(stop_hook, ckpt_dir) as sess:
            for _ in range(20):
                sess.run((incr_global_step, incr_counter))

            # Summary files should now have 2 values in them
            self.flush()

            # Run for more steps so that the hook gets triggered and we verify that we
            # don't stop.
            for _ in range(30):
                sess.run((incr_global_step, incr_counter))

            self.flush()

            # Run without incrementing the counter
            for _ in range(40):
                sess.run(incr_global_step)

            # Metrics should be written such that now the counter has gone >20 steps
            # without being incremented.
            self.flush()

            # Check that we ask for stop
            with self.assertRaisesRegexp(RuntimeError,
                                         "after should_stop requested"):
                for _ in range(30):
                    sess.run(incr_global_step)
Пример #4
0
def run(target, unused_is_chief, device_fn, use_tpu):
  """Run training.

  Args:
     target: The target of the TensorFlow standard server to use. Can be the
       empty string to run locally using an inprocess server.
     device_fn: Device function used to assign ops to devices.
     use_tpu: turn on tpu code path.
  """
  if not FLAGS.dataset_config_pbtxt:
    logging.error('Need to specify --dataset_config_pbtxt')
    return

  g = tf.Graph()
  with g.as_default():
    with tf.device(device_fn):
      # If ps_tasks is zero, the local device is used. When using multiple
      # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
      # across the different devices.

      tf_dataset = data_providers.get_input_fn_from_dataset(
          dataset_config_filename=FLAGS.dataset_config_pbtxt,
          mode=tf.estimator.ModeKeys.TRAIN,
          max_examples=FLAGS.max_examples,
          use_tpu=use_tpu)
      model = modeling.get_model(FLAGS.model_name)
      logging.info('Running training on %s with model %s and tpu %s',
                   tf_dataset, FLAGS.model_name, use_tpu)

      batches_per_epoch = tf_dataset.num_examples // FLAGS.batch_size
      logging.info('Batches per epoch %s', batches_per_epoch)
      params = dict(batches_per_epoch=batches_per_epoch,)
      estimator = model.make_estimator(
          batch_size=FLAGS.batch_size,
          model_dir=FLAGS.train_dir,
          params=params,
          use_tpu=use_tpu,
          master=target,
          start_from_checkpoint=FLAGS.start_from_checkpoint,
      )

      training_hooks = None
      if FLAGS.use_early_stopping:
        # Early stopping hook depends on existence of events directory.
        eval_dir = os.path.join(FLAGS.train_dir, FLAGS.early_stopping_directory)
        tf.gfile.MakeDirs(eval_dir)

        plateau_decrease = True
        if FLAGS.early_stopping_metric_direction == 'increase':
          plateau_decrease = False

        early_stopping_hook = metrics_hook.EarlyStoppingHook(
            events_dir=eval_dir,
            tag=FLAGS.early_stopping_tag,
            num_plateau_steps=FLAGS.early_stopping_num_plateau_steps,
            plateau_delta=FLAGS.early_stopping_plateau_delta,
            plateau_decrease=plateau_decrease,
            every_n_steps=FLAGS.early_stopping_every_n_steps)

        training_hooks = [early_stopping_hook]

      estimator.train(
          input_fn=tf_dataset,
          max_steps=FLAGS.number_of_steps,
          hooks=training_hooks)