def create_hooks(use_tfdbg=False, use_dbgprofile=False, dbgprofile_kwargs=None, use_validation_monitor=False, validation_monitor_kwargs=None, use_early_stopping=False, early_stopping_kwargs=None): """Create train and eval hooks for Experiment.""" train_monitors = [] eval_hooks = [] if use_tfdbg: hook = debug.LocalCLIDebugHook() train_monitors.append(hook) eval_hooks.append(hook) if use_dbgprofile: # Recorded traces can be visualized with chrome://tracing/ # The memory/tensor lifetime is also profiled tf.logging.info("Using ProfilerHook") defaults = dict(save_steps=10, show_dataflow=True, show_memory=True) defaults.update(dbgprofile_kwargs) train_monitors.append(tf.contrib.hooks.ProfilerHook(**defaults)) if use_validation_monitor: tf.logging.info("Using ValidationMonitor") train_monitors.append( tf.contrib.learn.monitors.ValidationMonitor( hooks=eval_hooks, **validation_monitor_kwargs)) if use_early_stopping: tf.logging.info("Using EarlyStoppingHook") hook = metrics_hook.EarlyStoppingHook(**early_stopping_kwargs) # Adding to both training and eval so that eval aborts as well train_monitors.append(hook) eval_hooks.append(hook) return train_monitors, eval_hooks
def create_hooks(use_tfdbg=False, use_dbgprofile=False, dbgprofile_kwargs=None, use_validation_monitor=False, validation_monitor_kwargs=None, use_early_stopping=False, early_stopping_kwargs=None): """Create train and eval hooks for Experiment.""" train_hooks = [] eval_hooks = [] if use_tfdbg: hook = debug.LocalCLIDebugHook() train_hooks.append(hook) eval_hooks.append(hook) if use_dbgprofile: # Recorded traces can be visualized with chrome://tracing/ # The memory/tensor lifetime is also profiled tf.logging.info("Using ProfilerHook") defaults = dict(save_steps=10, show_dataflow=True, show_memory=True) defaults.update(dbgprofile_kwargs) train_hooks.append(tf.train.ProfilerHook(**defaults)) if use_validation_monitor: tf.logging.info("Using ValidationMonitor") # Fathom # continuous_train_and_eval breaks early stopping flags = tf.flags FLAGS = flags.FLAGS assert FLAGS.schedule != 'continuous_train_and_eval' train_hooks.append( FathomValidationMonitor( restart_after_eval=FLAGS.restart_after_eval, hooks=eval_hooks, **validation_monitor_kwargs)) if use_early_stopping: tf.logging.info("Using EarlyStoppingHook") hook = metrics_hook.EarlyStoppingHook(**early_stopping_kwargs) # Adding to both training and eval so that eval aborts as well train_hooks.append(hook) eval_hooks.append(hook) # NOTE: # Attempt at adding better OOM feedback--although doesn't seem to work. # (See MemoryReportingHook) # Commenting this out for now because it doens't seem to actually work... #train_monitors.append(MemoryReportingHook()) #eval_hooks.append(MemoryReportingHook()) return train_hooks, eval_hooks
def testEarlyStoppingHook(self): global_step = tf.train.create_global_step() counter = tf.get_variable("count", initializer=0, dtype=tf.int32) tf.summary.scalar("count", counter) incr_global_step = tf.assign_add(global_step, 1) incr_counter = tf.assign_add(counter, 1) # Stop if the global step has not gone up by more than 1 in 20 steps. ckpt_dir = self.ckpt_dir("early") stop_hook = metrics_hook.EarlyStoppingHook(ckpt_dir, "count_1", num_plateau_steps=20, plateau_delta=1., plateau_decrease=False, every_n_steps=10) with self.sess(stop_hook, ckpt_dir) as sess: for _ in range(20): sess.run((incr_global_step, incr_counter)) # Summary files should now have 2 values in them self.flush() # Run for more steps so that the hook gets triggered and we verify that we # don't stop. for _ in range(30): sess.run((incr_global_step, incr_counter)) self.flush() # Run without incrementing the counter for _ in range(40): sess.run(incr_global_step) # Metrics should be written such that now the counter has gone >20 steps # without being incremented. self.flush() # Check that we ask for stop with self.assertRaisesRegexp(RuntimeError, "after should_stop requested"): for _ in range(30): sess.run(incr_global_step)
def run(target, unused_is_chief, device_fn, use_tpu): """Run training. Args: target: The target of the TensorFlow standard server to use. Can be the empty string to run locally using an inprocess server. device_fn: Device function used to assign ops to devices. use_tpu: turn on tpu code path. """ if not FLAGS.dataset_config_pbtxt: logging.error('Need to specify --dataset_config_pbtxt') return g = tf.Graph() with g.as_default(): with tf.device(device_fn): # If ps_tasks is zero, the local device is used. When using multiple # (non-local) replicas, the ReplicaDeviceSetter distributes the variables # across the different devices. tf_dataset = data_providers.get_input_fn_from_dataset( dataset_config_filename=FLAGS.dataset_config_pbtxt, mode=tf.estimator.ModeKeys.TRAIN, max_examples=FLAGS.max_examples, use_tpu=use_tpu) model = modeling.get_model(FLAGS.model_name) logging.info('Running training on %s with model %s and tpu %s', tf_dataset, FLAGS.model_name, use_tpu) batches_per_epoch = tf_dataset.num_examples // FLAGS.batch_size logging.info('Batches per epoch %s', batches_per_epoch) params = dict(batches_per_epoch=batches_per_epoch,) estimator = model.make_estimator( batch_size=FLAGS.batch_size, model_dir=FLAGS.train_dir, params=params, use_tpu=use_tpu, master=target, start_from_checkpoint=FLAGS.start_from_checkpoint, ) training_hooks = None if FLAGS.use_early_stopping: # Early stopping hook depends on existence of events directory. eval_dir = os.path.join(FLAGS.train_dir, FLAGS.early_stopping_directory) tf.gfile.MakeDirs(eval_dir) plateau_decrease = True if FLAGS.early_stopping_metric_direction == 'increase': plateau_decrease = False early_stopping_hook = metrics_hook.EarlyStoppingHook( events_dir=eval_dir, tag=FLAGS.early_stopping_tag, num_plateau_steps=FLAGS.early_stopping_num_plateau_steps, plateau_delta=FLAGS.early_stopping_plateau_delta, plateau_decrease=plateau_decrease, every_n_steps=FLAGS.early_stopping_every_n_steps) training_hooks = [early_stopping_hook] estimator.train( input_fn=tf_dataset, max_steps=FLAGS.number_of_steps, hooks=training_hooks)