def test_two_listeners_with_default_saver(self): with ops.Graph().as_default(): global_step = variables.get_or_create_global_step() train_op = state_ops.assign_add(global_step, 1) listener1 = MockCheckpointSaverListener() listener2 = MockCheckpointSaverListener() hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=1, listeners=[listener1, listener2]) with monitored_session.SingularMonitoredSession( hooks=[hook], checkpoint_dir=self.model_dir) as sess: sess.run(train_op) sess.run(train_op) global_step_val = sess.run(global_step) listener1_counts = listener1.get_counts() listener2_counts = listener2.get_counts() self.assertEqual(2, global_step_val) self.assertEqual({ 'begin': 1, 'before_save': 2, 'after_save': 2, 'end': 1 }, listener1_counts) self.assertEqual(listener1_counts, listener2_counts) with ops.Graph().as_default(): global_step = variables.get_or_create_global_step() with monitored_session.SingularMonitoredSession( checkpoint_dir=self.model_dir) as sess2: global_step_saved_val = sess2.run(global_step) self.assertEqual(2, global_step_saved_val)
def test_listener_with_monitored_session(self): with ops.Graph().as_default(): scaffold = monitored_session.Scaffold() global_step = variables.get_or_create_global_step() train_op = state_ops.assign_add(global_step, 1) listener = MockCheckpointSaverListener() hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=1, scaffold=scaffold, listeners=[listener]) with monitored_session.SingularMonitoredSession( hooks=[hook], scaffold=scaffold, checkpoint_dir=self.model_dir) as sess: sess.run(train_op) sess.run(train_op) global_step_val = sess.run(global_step) listener_counts = listener.get_counts() self.assertEqual(2, global_step_val) self.assertEqual({ 'begin': 1, 'before_save': 2, 'after_save': 2, 'end': 1 }, listener_counts)
def test_save_secs_calls_listeners_periodically(self): with self.graph.as_default(): listener = MockCheckpointSaverListener() hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_secs=2, scaffold=self.scaffold, listeners=[listener]) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) # hook runs here mon_sess.run(self.train_op) time.sleep(2.5) mon_sess.run(self.train_op) # hook runs here mon_sess.run(self.train_op) mon_sess.run(self.train_op) time.sleep(2.5) mon_sess.run(self.train_op) # hook runs here mon_sess.run( self.train_op) # hook won't run here, so it does at end hook.end(sess) # hook runs here self.assertEqual( { 'begin': 1, 'before_save': 4, 'after_save': 4, 'end': 1 }, listener.get_counts())
def test_save_steps_saves_periodically(self): with self.graph.as_default(): hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=2, scaffold=self.scaffold) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) mon_sess.run(self.train_op) # Not saved self.assertEqual( 1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # saved self.assertEqual( 3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # Not saved self.assertEqual( 3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # saved self.assertEqual( 5, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def testResumeTrainAchievesRoughlyTheSameLoss(self): number_of_steps = [300, 1, 5] logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss') for i in range(len(number_of_steps)): with ops.Graph().as_default(): random_seed.set_random_seed(i) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = logistic_classifier(tf_inputs) loss_ops.log_loss(tf_predictions, tf_labels) total_loss = loss_ops.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(total_loss, optimizer) saver = saver_lib.Saver() loss = training.train( train_op, logdir, hooks=[ basic_session_run_hooks.StopAtStepHook( num_steps=number_of_steps[i]), basic_session_run_hooks.CheckpointSaverHook( logdir, save_steps=50, saver=saver), ]) self.assertIsNotNone(loss) self.assertLess(loss, .015)
def train_mnist(): with tf.get_default_graph().as_default(): input_pipe = MnistData(BATCH_SIZE) train_features, train_labels = input_pipe.build_train_data_tensor() global_step = 0 train_op, loss = train_graph(train_features, train_labels) assign_ops = get_assign_ops() checkpoint_hook = basic_session_run_hooks.CheckpointSaverHook( MODEL_DIR, save_steps=TEST_FREQ) class _LoggerHook(tf.train.SessionRunHook): """logs loss and runtime.""" def begin(self): self._step = global_step def before_run(self, run_context): self._start_time = time.time() return tf.train.SessionRunArgs(loss) # asks for loss value. def after_run(self, run_context, run_values): duration = time.time() - self._start_time loss_value = run_values.results if self._step > 0 and self._step % 100 == 0: num_examples_per_step = BATCH_SIZE examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec' + '; %.3f sec/batch)') tf.logging.info( format_str % (datetime.datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) with tf.train.MonitoredTrainingSession( checkpoint_dir=MODEL_DIR, hooks=[_LoggerHook(), checkpoint_hook], save_checkpoint_secs=None, config=tf.ConfigProto(log_device_placement=False)) as sess: if INIT_DICT is not None: sess.run(assign_ops) dict_str = '; '.join( map( lambda scope: scope + '/' + ', '.join(INIT_DICT[ scope].keys()) + '', INIT_DICT.keys())) tf.logging.info('Instantiated tensors with assign ops: ' + dict_str) for i in range(TRAIN_STEPS): sess.run(train_op)
def test_saves_when_saver_and_scaffold_both_missing(self): with self.graph.as_default(): hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=1) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) self.assertEqual(1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def MonitoredTrainingSession( master='', # pylint: disable=invalid-name is_chief=True, checkpoint_dir=None, hooks=None, scaffold=None, config=None): """Creates a `MonitoredSession` for training. For a chief, this utility sets proper session initializer/restorer. It also creates hooks related to checkpoint and summary saving. For workers, this utility sets proper session creator which waits for the chief to inialize/restore. Args: master: `String` the TensorFlow master to use. is_chief: If `True`, it will take care of initialization and recovery the underlying TensorFlow session. If `False`, it will wait on a chief to initialize or recover the TensorFlow session. checkpoint_dir: A string. Optional path to a directory where to restore variables. hooks: Optional list of `SessionRunHook` objects. scaffold: A `Scaffold` used for gathering or building supportive ops. If not specified, a default one is created. It's used to finalize the graph. config: `ConfigProto` proto used to configure the session. Returns: A `MonitoredSession` object. """ hooks = hooks or [] scaffold = scaffold or Scaffold() if not is_chief: session_creator = WorkerSessionCreator(scaffold=scaffold, master=master, config=config) else: session_creator = ChiefSessionCreator(scaffold=scaffold, checkpoint_dir=checkpoint_dir, master=master, config=config) hooks.extend([ basic_session_run_hooks.StepCounterHook(output_dir=checkpoint_dir), basic_session_run_hooks.SummarySaverHook( scaffold=scaffold, output_dir=checkpoint_dir), basic_session_run_hooks.CheckpointSaverHook(checkpoint_dir, save_secs=600, scaffold=scaffold), ]) return MonitoredSession(session_creator=session_creator, hooks=hooks)
def test_save_secs_saves_periodically(self, mock_time): # Let's have a realistic start time current_time = 1484695987.209386 with self.graph.as_default(): mock_time.return_value = current_time hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_secs=2, scaffold=self.scaffold) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mock_time.return_value = current_time mon_sess.run(self.train_op) # Saved. mock_time.return_value = current_time + 0.5 mon_sess.run(self.train_op) # Not saved. self.assertEqual( 1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) # Simulate 2.5 seconds of sleep. mock_time.return_value = current_time + 2.5 mon_sess.run(self.train_op) # Saved. mock_time.return_value = current_time + 2.6 mon_sess.run(self.train_op) # Not saved. mock_time.return_value = current_time + 2.7 mon_sess.run(self.train_op) # Not saved. self.assertEqual( 3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) # Simulate 7.5 more seconds of sleep (10 seconds from start. mock_time.return_value = current_time + 10 mon_sess.run(self.train_op) # Saved. self.assertEqual( 6, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def test_save_secs_calls_listeners_periodically(self, mock_time): # Let's have a realistic start time current_time = 1484695987.209386 with self.graph.as_default(): mock_time.return_value = current_time listener = MockCheckpointSaverListener() hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_secs=2, scaffold=self.scaffold, listeners=[listener]) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mock_time.return_value = current_time + 0.5 mon_sess.run(self.train_op) # hook runs here mock_time.return_value = current_time + 0.5 mon_sess.run(self.train_op) mock_time.return_value = current_time + 3.0 mon_sess.run(self.train_op) # hook runs here mock_time.return_value = current_time + 3.5 mon_sess.run(self.train_op) mock_time.return_value = current_time + 4.0 mon_sess.run(self.train_op) mock_time.return_value = current_time + 6.5 mon_sess.run(self.train_op) # hook runs here mock_time.return_value = current_time + 7.0 mon_sess.run(self.train_op) # hook won't run here, so it does at end mock_time.return_value = current_time + 7.5 hook.end(sess) # hook runs here self.assertEqual({ 'begin': 1, 'before_save': 4, 'after_save': 4, 'end': 1 }, listener.get_counts())
def __init__(self, estimator): """Initializes a `CheckpointInputPipelineHook`. Args: estimator: Estimator. Raises: ValueError: One of `save_steps` or `save_secs` should be set. ValueError: At most one of saver or scaffold should be set. """ # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines. # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix # to be different to avoid conflicts with the model checkpoint. # pylint: disable=protected-access checkpoint_prefix = "input" if estimator._config.num_worker_replicas > 1: # Distributed setting. suffix = "_{}_{}".format(estimator._config.task_type, estimator._config.task_id) checkpoint_prefix += suffix # pylint: enable=protected-access # We use a composition paradigm instead of inheriting from # `CheckpointSaverHook` because `Estimator` does an `isinstance` check # to check whether a `CheckpointSaverHook` is already present in the list # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` # would thwart this behavior. This hook checkpoints *only the iterators* # and not the graph variables. self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( estimator.model_dir, save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access checkpoint_basename=checkpoint_prefix + ".ckpt") # Name for the protocol buffer file that will contain the list of most # recent checkpoints stored as a `CheckpointState` protocol buffer. # This file, kept in the same directory as the checkpoint files, is # automatically managed by the `Saver` to keep track of recent checkpoints. # The default name used by the `Saver` for this file is "checkpoint". Here # we use the name "checkpoint_<checkpoint_prefix>" so that in case the # `checkpoint_dir` is the same as the model checkpoint directory, there are # no conflicts during restore. self._latest_filename = "checkpoint_" + checkpoint_prefix self._first_run = True
def test_summary_writer_defs(self): fake_summary_writer.FakeSummaryWriter.install() writer_cache.FileWriterCache.clear() summary_writer = writer_cache.FileWriterCache.get(self.model_dir) with self.graph.as_default(): hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=2, scaffold=self.scaffold) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) summary_writer.assert_summaries( test_case=self, expected_logdir=self.model_dir, expected_added_meta_graphs=[ meta_graph.create_meta_graph_def( graph_def=self.graph.as_graph_def(add_shapes=True), saver_def=self.scaffold.saver.saver_def) ]) fake_summary_writer.FakeSummaryWriter.uninstall()
def train_and_evaluate(self): """Interleaves training and evaluation. The frequency of evaluation is controlled by the constructor arg `min_eval_frequency`. When this parameter is 0, evaluation happens only after training has completed. Note that evaluation cannot happen more frequently than checkpoints are taken. If no new snapshots are available when evaluation is supposed to occur, then evaluation doesn't happen for another `min_eval_frequency` steps (assuming a checkpoint is available at that point). Thus, settings `min_eval_frequency` to 1 means that the model will be evaluated everytime there is a new checkpoint. This is particular useful for a "Master" task in the cloud, whose responsibility it is to take checkpoints, evaluate those checkpoints, and write out summaries. Participating in training as the supervisor allows such a task to accomplish the first and last items, while performing evaluation allows for the second. Returns: The result of the `evaluate` call to the `Estimator` as well as the export results using the specified `ExportStrategy`. """ # The directory to which evaluation summaries are written are determined # by adding a suffix to 'eval'; that suffix is the 'name' parameter to # the various evaluate(...) methods. By setting it to None, we force # the directory name to simply be 'eval'. eval_dir_suffix = None # We set every_n_steps to 1, but evaluation only occurs when a new # snapshot is available. If, by the time we finish evaluation # there is a new snapshot, then we just evaluate again. Otherwise, # we keep training until one becomes available. with _new_attr_context(self, "_train_monitors"): self._train_monitors = self._train_monitors or [] config = self._estimator.config intermediate_export = self._checkpoint_and_export and ( config.save_checkpoints_secs or config.save_checkpoints_steps) if intermediate_export: # Create a partially specified evaluate function with the desired # arguments. This will be executed by the _EvalAndExportListener, # which will specify the latest checkpoint path. eval_fn = functools.partial(self._call_evaluate, input_fn=self._eval_input_fn, steps=self._eval_steps, metrics=self._eval_metrics, hooks=self._eval_hooks) export_listener = _EvalAndExportListener( eval_fn=eval_fn, export_fn=self._maybe_export, model_dir=self._estimator.model_dir) saver_hook = basic_session_run_hooks.CheckpointSaverHook( checkpoint_dir=self._estimator.model_dir, save_secs=config.save_checkpoints_secs, save_steps=config.save_checkpoints_steps, listeners=[export_listener]) self._train_monitors += [saver_hook] else: if self._min_eval_frequency: self._train_monitors += [ monitors.ValidationMonitor( input_fn=self._eval_input_fn, eval_steps=self._eval_steps, metrics=self._eval_metrics, every_n_steps=self._min_eval_frequency, name=eval_dir_suffix, hooks=self._eval_hooks) ] self.train(delay_secs=0) # If the checkpoint_and_export flag and appropriate estimator configuration # parameters are set, then model evaluations and exports are done during the # training process. In particular, this will always occur at the end of # training, so we return the most recent results to avoid performing a # duplicate evaluation and model export. if intermediate_export: return export_listener.eval_result, export_listener.export_results else: eval_result = self._call_evaluate(input_fn=self._eval_input_fn, steps=self._eval_steps, metrics=self._eval_metrics, name=eval_dir_suffix, hooks=self._eval_hooks) export_results = self._maybe_export(eval_result) return eval_result, export_results
def MonitoredTrainingSession(master='', is_chief=True, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=USE_DEFAULT, save_summaries_steps=USE_DEFAULT, save_summaries_secs=USE_DEFAULT, config=None, stop_grace_period_secs=120, log_step_count_steps=100, save_checkpoint_steps=USE_DEFAULT, summary_dir=None): if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT: save_summaries_steps = 100 save_summaries_secs = None elif save_summaries_secs == USE_DEFAULT: save_summaries_secs = None elif save_summaries_steps == USE_DEFAULT: save_summaries_steps = None if (save_checkpoint_steps == USE_DEFAULT and save_checkpoint_secs == USE_DEFAULT): save_checkpoint_steps = None save_checkpoint_secs = 600 elif save_checkpoint_secs == USE_DEFAULT: save_checkpoint_secs = None elif save_checkpoint_steps == USE_DEFAULT: save_checkpoint_steps = None scaffold = scaffold or Scaffold() all_hooks = [] if is_chief and chief_only_hooks: all_hooks.extend(chief_only_hooks) session_creator = ChiefSessionCreator(scaffold=scaffold, checkpoint_dir=checkpoint_dir, master=master, config=config) summary_dir = summary_dir or checkpoint_dir if summary_dir: if (save_summaries_steps and save_summaries_steps > 0) or (save_summaries_secs and save_summaries_secs > 0): all_hooks.append( basic_session_run_hooks.SummarySaverHook( scaffold=scaffold, save_steps=save_summaries_steps, save_secs=save_summaries_secs, output_dir=summary_dir)) if checkpoint_dir: if (save_checkpoint_secs and save_checkpoint_secs > 0) or ( save_checkpoint_steps and save_checkpoint_steps > 0): all_hooks.append( basic_session_run_hooks.CheckpointSaverHook( checkpoint_dir, save_steps=save_checkpoint_steps, save_secs=save_checkpoint_secs, scaffold=scaffold)) if hooks: all_hooks.extend(hooks) hvd_info_rank0('all hooks {}'.format(all_hooks)) return MonitoredSession(session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
def train(args): """Train CIFAR-10 for a number of steps. Args: args: The command line arguments. """ with tf.Graph().as_default(), tf.device('/cpu:0'): # Create the global step. global_step = tf.contrib.framework.create_global_step() # Calculate the learning rate schedule. num_batches_per_epoch = (cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / args.batch_size) decay_steps = int(num_batches_per_epoch * cifar10.NUM_EPOCHS_PER_DECAY) # Decay the learning rate exponentially based on the number of steps. lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE, global_step, decay_steps, cifar10.LEARNING_RATE_DECAY_FACTOR, staircase=True) # Create an optimizer that performs gradient descent. opt = tf.train.GradientDescentOptimizer(lr) # Calculate the gradients for each model tower. tower_grads = [] for i in xrange(args.num_gpus): with tf.device('/gpu:%d' % i): with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope: # Calculate the loss for one tower of the CIFAR model. This function # constructs the entire CIFAR model but shares the variables across # all towers. loss = tower_loss(scope, args) # Reuse variables for the next tower. tf.get_variable_scope().reuse_variables() # Retain the summaries from the final tower. summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) # Calculate the gradients for the batch of data on this CIFAR tower. grads = opt.compute_gradients(loss) # Keep track of the gradients across all towers. tower_grads.append(grads) # We must calculate the mean of each gradient. Note that this is the # synchronization point across all towers. grads = average_gradients(tower_grads) # Add a summary to track the learning rate. summaries.append(tf.summary.scalar('learning_rate', lr)) # Add histograms for gradients. for grad, var in grads: if grad is not None: summaries.append( tf.summary.histogram(var.op.name + '/gradients', grad)) # Apply the gradients to adjust the shared variables. apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) # Add histograms for trainable variables. for var in tf.trainable_variables(): summaries.append(tf.summary.histogram(var.op.name, var)) # Track the moving averages of all trainable variables. # To understand why the following line is necessary, see: # https://github.com/carpedm20/DCGAN-tensorflow/issues/59 with tf.variable_scope(tf.get_variable_scope(), reuse=False): variable_averages = tf.train.ExponentialMovingAverage( cifar10.MOVING_AVERAGE_DECAY, global_step) variables_averages_op = variable_averages.apply( tf.trainable_variables()) # Group all updates to into a single train op. train_op = tf.group(apply_gradient_op, variables_averages_op) # Build the summary operation from the last tower summaries. summary_op = tf.summary.merge(summaries) scaffold = monitored_session.Scaffold(summary_op=summary_op) # allow_soft_placement must be set to True to build towers on GPU, as some # of the ops do not have GPU implementations. session_creator = monitored_session.ChiefSessionCreator( scaffold, checkpoint_dir=args.train_dir, config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=args.log_device_placement)) hooks = [ # Hook to save the model every N steps and at the end. basic_session_run_hooks.CheckpointSaverHook( args.train_dir, checkpoint_basename=CHECKPOINT_BASENAME, save_steps=args.checkpoint_interval_steps, scaffold=scaffold), # Hook to save a summary every N steps. basic_session_run_hooks.SummarySaverHook( save_steps=args.summary_interval_steps, output_dir=args.train_dir, scaffold=scaffold), # Hook to stop at step N. basic_session_run_hooks.StopAtStepHook( last_step=args.train_max_steps) ] # Start a new monitored session. This will automatically restart the # sessions if the parameter servers are preempted. with monitored_session.MonitoredSession( session_creator=session_creator, hooks=hooks) as sess: while not sess.should_stop(): start_time = time.time() _, loss_value, global_step_value = sess.run( [train_op, loss, global_step]) duration = time.time() - start_time assert not np.isnan( loss_value), 'Model diverged with loss = NaN' if global_step_value % 10 == 0: num_examples_per_step = args.batch_size * args.num_gpus examples_per_sec = num_examples_per_step / duration sec_per_batch = duration / args.num_gpus format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') logging.info(format_str % (datetime.now(), global_step_value, loss_value, examples_per_sec, sec_per_batch))
def test_raise_in_none_secs_and_steps(self): with self.assertRaises(ValueError): basic_session_run_hooks.CheckpointSaverHook(self.model_dir)
def test_raise_when_saver_and_scaffold_both_present(self): with self.assertRaises(ValueError): basic_session_run_hooks.CheckpointSaverHook( self.model_dir, saver=self.scaffold.saver, scaffold=self.scaffold)
def testTrainWithAlteredGradients(self): # Use the same learning rate but different gradient multipliers # to train two models. Model with equivalently larger learning # rate (i.e., learning_rate * gradient_multiplier) has smaller # training loss. logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs6/') logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs7/') if gfile.Exists(logdir1): gfile.DeleteRecursively(logdir1) if gfile.Exists(logdir2): gfile.DeleteRecursively(logdir2) multipliers = [1., 1000.] number_of_steps = 10 losses = [] learning_rate = 0.001 # First, train the model with equivalently smaller learning rate. with ops.Graph().as_default(): random_seed.set_random_seed(0) train_op = self.create_train_op( learning_rate=learning_rate, gradient_multiplier=multipliers[0]) saver = saver_lib.Saver() loss = training.train( train_op, logdir1, hooks=[ basic_session_run_hooks.StopAtStepHook(num_steps=number_of_steps), basic_session_run_hooks.CheckpointSaverHook( logdir1, save_steps=50, saver=saver), ]) losses.append(loss) self.assertGreater(loss, .5) # Second, train the model with equivalently larger learning rate. with ops.Graph().as_default(): random_seed.set_random_seed(0) train_op = self.create_train_op( learning_rate=learning_rate, gradient_multiplier=multipliers[1]) saver = saver_lib.Saver() loss = training.train( train_op, logdir2, hooks=[ basic_session_run_hooks.StopAtStepHook(num_steps=number_of_steps), basic_session_run_hooks.CheckpointSaverHook( logdir2, save_steps=50, saver=saver), ]) losses.append(loss) self.assertIsNotNone(loss) self.assertLess(loss, .5) # The loss of the model trained with larger learning rate should # be smaller. self.assertGreater(losses[0], losses[1])
def testTrainWithInitFromCheckpoint(self): logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/') logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/') if gfile.Exists(logdir1): # For running on jenkins. gfile.DeleteRecursively(logdir1) if gfile.Exists(logdir2): # For running on jenkins. gfile.DeleteRecursively(logdir2) # First, train the model one step (make sure the error is high). with ops.Graph().as_default(): random_seed.set_random_seed(0) train_op = self.create_train_op() saver = saver_lib.Saver() loss = training.train( train_op, logdir1, hooks=[ basic_session_run_hooks.CheckpointSaverHook( logdir1, save_steps=1, saver=saver), basic_session_run_hooks.StopAtStepHook(num_steps=1), ], save_checkpoint_secs=None) self.assertGreater(loss, .5) # Next, train the model to convergence. with ops.Graph().as_default(): random_seed.set_random_seed(1) train_op = self.create_train_op() saver = saver_lib.Saver() loss = training.train( train_op, logdir1, hooks=[ basic_session_run_hooks.CheckpointSaverHook( logdir1, save_steps=1, saver=saver), basic_session_run_hooks.StopAtStepHook(num_steps=300), ], save_checkpoint_secs=None) self.assertIsNotNone(loss) self.assertLess(loss, .02) # Finally, advance the model a single step and validate that the loss is # still low. with ops.Graph().as_default(): random_seed.set_random_seed(2) train_op = self.create_train_op() model_variables = variables_lib2.global_variables() model_path = os.path.join(logdir1, 'model.ckpt-300') assign_fn = variables_lib.assign_from_checkpoint_fn(model_path, model_variables) def init_fn(_, session): assign_fn(session) loss = training.train( train_op, logdir2, scaffold=monitored_session.Scaffold(init_fn=init_fn), hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)]) self.assertIsNotNone(loss) self.assertLess(loss, .02)
def _monitored_train(graph, output_dir, train_op, loss_op, global_step_tensor=None, init_op=None, init_feed_dict=None, init_fn=None, log_every_steps=10, supervisor_is_chief=True, supervisor_master='', supervisor_save_model_secs=600, supervisor_save_model_steps=None, keep_checkpoint_max=5, supervisor_save_summaries_secs=None, supervisor_save_summaries_steps=100, feed_fn=None, steps=None, fail_on_nan_loss=True, hooks=None, max_steps=None): """Train a model via monitored_session. Given `graph`, a directory to write outputs to (`output_dir`), and some ops, run a training loop. The given `train_op` performs one step of training on the model. The `loss_op` represents the objective function of the training. It is expected to increment the `global_step_tensor`, a scalar integer tensor counting training steps. This function uses `Supervisor` to initialize the graph (from a checkpoint if one is available in `output_dir`), write summaries defined in the graph, and write regular checkpoints as defined by `supervisor_save_model_secs`. Training continues until `global_step_tensor` evaluates to `max_steps`, or, if `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the program is terminated with exit code 1. Args: graph: A graph to train. It is expected that this graph is not in use elsewhere. output_dir: A directory to write outputs to. train_op: An op that performs one training step when run. loss_op: A scalar loss tensor. global_step_tensor: A tensor representing the global step. If none is given, one is extracted from the graph using the same logic as in `Supervisor`. init_op: An op that initializes the graph. If `None`, use `Supervisor`'s default. init_feed_dict: A dictionary that maps `Tensor` objects to feed values. This feed dictionary will be used when `init_op` is evaluated. init_fn: Optional callable passed to Supervisor to initialize the model. log_every_steps: Output logs regularly. The logs contain timing data and the current loss. A `0` or negative value disables logging. supervisor_is_chief: Whether the current process is the chief supervisor in charge of restoring the model and running standard services. supervisor_master: The master string to use when preparing the session. supervisor_save_model_secs: Save checkpoints every this many seconds. Can not be specified with `supervisor_save_model_steps`. supervisor_save_model_steps: Save checkpoints every this many steps. Can not be specified with `supervisor_save_model_secs`. keep_checkpoint_max: The maximum number of recent checkpoint files to keep. As new files are created, older files are deleted. If None or 0, all checkpoint files are kept. This is simply passed as the max_to_keep arg to `tf.Saver` constructor. supervisor_save_summaries_secs: Save summaries every `supervisor_save_summaries_secs` seconds when training. supervisor_save_summaries_steps: Save summaries every `supervisor_save_summaries_steps` steps when training. Exactly one of `supervisor_save_model_steps` and `supervisor_save_model_secs` should be specified, and the other should be None. feed_fn: A function that is called every iteration to produce a `feed_dict` passed to `session.run` calls. Optional. steps: Trains for this many steps (e.g. current global step + `steps`). fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op` evaluates to `NaN`. If false, continue training as if nothing happened. hooks: List of `SessionRunHook` subclass instances. Used for callbacks inside the training loop. max_steps: Number of total steps for which to train model. If `None`, train forever. Two calls fit(steps=100) means 200 training iterations. On the other hand two calls of fit(max_steps=100) means, second call will not do any iteration since first call did all 100 steps. Returns: The final loss value. Raises: ValueError: If `output_dir`, `train_op`, `loss_op`, or `global_step_tensor` is not provided. See `tf.contrib.framework.get_global_step` for how we look up the latter if not provided explicitly. NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever evaluates to `NaN`. ValueError: If both `steps` and `max_steps` are not `None`. """ if (steps is not None) and (max_steps is not None): raise ValueError('Can not provide both steps and max_steps.') if not output_dir: raise ValueError('Output directory should be non-empty %s.' % output_dir) if train_op is None: raise ValueError('Missing train_op.') if loss_op is None: raise ValueError('Missing loss_op.') if hooks is None: hooks = [] if not isinstance(hooks, list): raise ValueError('Hooks should be a list.') with graph.as_default(): global_step_tensor = contrib_variables.assert_or_get_global_step( graph, global_step_tensor) if global_step_tensor is None: raise ValueError('No "global_step" was provided or found in the graph.') if max_steps is not None: try: start_step = load_variable(output_dir, global_step_tensor.name) if max_steps <= start_step: logging.info('Skipping training since max_steps has already saved.') return None except: # pylint: disable=bare-except pass # Adapted SessionRunHooks such as ExportMonitor depend on the # CheckpointSaverHook to be executed before they should be executed. # The `hooks` param comprises of deprecated monitor hooks # (such as ExportMonitor). Appending them after the basic_session_run_hooks. all_hooks = [] with graph.as_default(): all_hooks.append(basic_session_run_hooks.NanTensorHook( loss_op, fail_on_nan_loss=fail_on_nan_loss)) if log_every_steps > 0: all_hooks.append(basic_session_run_hooks.LoggingTensorHook({ 'loss': loss_op.name, 'step': global_step_tensor.name }, every_n_iter=log_every_steps)) def make_saver(): return tf_saver.Saver( sharded=True, max_to_keep=keep_checkpoint_max, defer_build=True, write_version=saver_pb2.SaverDef.V1) scaffold = monitored_session.Scaffold( init_op=init_op, init_feed_dict=init_feed_dict, init_fn=init_fn, saver=monitored_session.Scaffold.get_or_default('saver', ops.GraphKeys.SAVERS, make_saver)) if not supervisor_is_chief: session_creator = monitored_session.WorkerSessionCreator( scaffold=scaffold, master=supervisor_master) else: session_creator = monitored_session.ChiefSessionCreator( scaffold=scaffold, checkpoint_dir=output_dir, master=supervisor_master) summary_writer = summary_io.SummaryWriterCache.get(output_dir) all_hooks.append( basic_session_run_hooks.StepCounterHook( summary_writer=summary_writer)) all_hooks.append( basic_session_run_hooks.SummarySaverHook( save_secs=supervisor_save_summaries_secs, save_steps=supervisor_save_summaries_steps, summary_writer=summary_writer, scaffold=scaffold)) if (supervisor_save_model_secs is not None or supervisor_save_model_steps is not None): all_hooks.append( basic_session_run_hooks.CheckpointSaverHook( output_dir, save_secs=supervisor_save_model_secs, save_steps=supervisor_save_model_steps, scaffold=scaffold)) if steps is not None or max_steps is not None: all_hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps)) all_hooks.extend(hooks) with monitored_session.MonitoredSession( session_creator=session_creator, hooks=all_hooks) as super_sess: loss = None while not super_sess.should_stop(): _, loss = super_sess.run([train_op, loss_op], feed_fn() if feed_fn else None) summary_io.SummaryWriterCache.clear() return loss
def PartialRestoreSession( master='', # pylint: disable=invalid-name is_chief=True, checkpoint_dir=None, restore_var_list=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=monitored_session.USE_DEFAULT, save_summaries_secs=monitored_session.USE_DEFAULT, config=None, stop_grace_period_secs=120, log_step_count_steps=100): """Creates a `MonitoredSession` for training. Supports partial restoration from checkpoints with parameter `restore_var_list`, by adding `CheckpointRestorerHook`. For a chief, this utility sets proper session initializer/restorer. It also creates hooks related to checkpoint and summary saving. For workers, this utility sets proper session creator which waits for the chief to initialize/restore. Please check `tf.train.MonitoredSession` for more information. Args: master: `String` the TensorFlow master to use. is_chief: If `True`, it will take care of initialization and recovery the underlying TensorFlow session. If `False`, it will wait on a chief to initialize or recover the TensorFlow session. checkpoint_dir: A string. Optional path to a directory where to restore variables. restore_var_list: a list of variables, optional, if not all variables should be recovered from checkpoint. Useful when changing network structures during training, i.e., finetuning a pretrained model with new layers. scaffold: A `Scaffold` used for gathering or building supportive ops. If not specified, a default one is created. It's used to finalize the graph. hooks: Optional list of `SessionRunHook` objects. chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if `is_chief==True`, ignore otherwise. save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved using a default checkpoint saver. If `save_checkpoint_secs` is set to `None`, then the default checkpoint saver isn't used. save_summaries_steps: The frequency, in number of global steps, that the summaries are written to disk using a default summary saver. If both `save_summaries_steps` and `save_summaries_secs` are set to `None`, then the default summary saver isn't used. Default 100. save_summaries_secs: The frequency, in secs, that the summaries are written to disk using a default summary saver. If both `save_summaries_steps` and `save_summaries_secs` are set to `None`, then the default summary saver isn't used. Default not enabled. config: an instance of `tf.ConfigProto` proto used to configure the session. It's the `config` argument of constructor of `tf.Session`. stop_grace_period_secs: Number of seconds given to threads to stop after `close()` has been called. log_step_count_steps: The frequency, in number of global steps, that the global step/sec is logged. Returns: A `MonitoredSession` object. """ if save_summaries_steps == monitored_session.USE_DEFAULT \ and save_summaries_secs == monitored_session.USE_DEFAULT: save_summaries_steps = 100 save_summaries_secs = None elif save_summaries_secs == monitored_session.USE_DEFAULT: save_summaries_secs = None elif save_summaries_steps == monitored_session.USE_DEFAULT: save_summaries_steps = None scaffold = scaffold or monitored_session.Scaffold() if not is_chief: session_creator = monitored_session.WorkerSessionCreator( scaffold=scaffold, master=master, config=config) return monitored_session.MonitoredSession( session_creator=session_creator, hooks=hooks or [], stop_grace_period_secs=stop_grace_period_secs) all_hooks = [] if chief_only_hooks: all_hooks.extend(chief_only_hooks) if restore_var_list is None: restore_checkpoint_dir = checkpoint_dir else: restore_checkpoint_dir = None all_hooks.append( CheckpointRestorerHook(checkpoint_dir, var_list=restore_var_list)) all_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) missing_vars = filter(lambda v: not (v in restore_var_list), all_vars) logging.warning("MonitoredTrainingSession not restoring %s", missing_vars) session_creator = monitored_session.ChiefSessionCreator( scaffold=scaffold, checkpoint_dir=restore_checkpoint_dir, master=master, config=config) if checkpoint_dir: all_hooks.append( basic_session_run_hooks.StepCounterHook( output_dir=checkpoint_dir, every_n_steps=log_step_count_steps)) if (save_summaries_steps and save_summaries_steps > 0) or (save_summaries_secs and save_summaries_secs > 0): all_hooks.append( basic_session_run_hooks.SummarySaverHook( scaffold=scaffold, save_steps=save_summaries_steps, save_secs=save_summaries_secs, output_dir=checkpoint_dir)) if save_checkpoint_secs and save_checkpoint_secs > 0: all_hooks.append( basic_session_run_hooks.CheckpointSaverHook( checkpoint_dir, save_secs=save_checkpoint_secs, scaffold=scaffold)) if hooks: all_hooks.extend(hooks) return monitored_session.MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
def train(train_op, logdir, master='', is_chief=True, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=100, config=None): """Runs the training loop. Args: train_op: A `Tensor` that, when executed, will apply the gradients and return the loss value. logdir: The directory where the graph and checkpoints are saved. master: The URL of the master. is_chief: Specifies whether or not the training is being run by the primary replica during replica training. scaffold: An tf.train.Scaffold instance. hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the training loop. chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run inside the training loop for the chief trainer only. save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved using a default checkpoint saver. If `save_checkpoint_secs` is set to `None`, then the default checkpoint saver isn't used. save_summaries_steps: The frequency, in number of global steps, that the summaries are written to disk using a default summary saver. If `save_summaries_steps` is set to `None`, then the default summary saver isn't used. config: An instance of `tf.ConfigProto`. Returns: the value of the loss function after training. Raises: ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or `save_summaries_steps` are `None. """ # TODO(nsilberman): move this logic into monitored_session.py scaffold = scaffold or monitored_session.Scaffold() hooks = hooks or [] if is_chief: session_creator = monitored_session.ChiefSessionCreator( scaffold=scaffold, checkpoint_dir=logdir, master=master, config=config) if chief_only_hooks: hooks.extend(chief_only_hooks) hooks.append( basic_session_run_hooks.StepCounterHook(output_dir=logdir)) if save_summaries_steps: if logdir is None: raise ValueError( 'logdir cannot be None when save_summaries_steps is None') hooks.append( basic_session_run_hooks.SummarySaverHook( scaffold=scaffold, save_steps=save_summaries_steps, output_dir=logdir)) if save_checkpoint_secs: if logdir is None: raise ValueError( 'logdir cannot be None when save_checkpoint_secs is None') hooks.append( basic_session_run_hooks.CheckpointSaverHook( logdir, save_secs=save_checkpoint_secs, scaffold=scaffold)) else: session_creator = monitored_session.WorkerSessionCreator( scaffold=scaffold, master=master, config=config) with monitored_session.MonitoredSession(session_creator=session_creator, hooks=hooks) as session: loss = None while not session.should_stop(): loss = session.run(train_op) return loss
def _model_fn(features, labels, mode): """Function that returns predictions, training loss, and training op.""" weights = None if weights_name and weights_name in features: weights = features.pop(weights_name) keys = None if keys_name and keys_name in features: keys = features.pop(keys_name) # If we're doing eval, optionally ignore device_assigner. # Also ignore device assigner if we're exporting (mode == INFER) dev_assn = device_assigner if (mode == model_fn_lib.ModeKeys.INFER or (local_eval and mode == model_fn_lib.ModeKeys.EVAL)): dev_assn = None graph_builder = graph_builder_class(params, device_assigner=dev_assn) inference = {} output_alternatives = None if (mode == model_fn_lib.ModeKeys.EVAL or mode == model_fn_lib.ModeKeys.INFER): inference[eval_metrics.INFERENCE_PROB_NAME] = ( graph_builder.inference_graph(features)) if params.regression: predictions = { None: inference[eval_metrics.INFERENCE_PROB_NAME] } output_alternatives = { None: (constants.ProblemType.LINEAR_REGRESSION, predictions) } else: inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax( inference[eval_metrics.INFERENCE_PROB_NAME], 1) predictions = { prediction_key.PredictionKey.PROBABILITIES: inference[eval_metrics.INFERENCE_PROB_NAME], prediction_key.PredictionKey.CLASSES: inference[eval_metrics.INFERENCE_PRED_NAME] } output_alternatives = { None: (constants.ProblemType.CLASSIFICATION, predictions) } if keys is not None: inference[keys_name] = keys # labels might be None if we're doing prediction (which brings up the # question of why we force everything to adhere to a single model_fn). loss_deps = [] training_graph = None training_hooks = [] scaffold = None if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN: training_graph = control_flow_ops.group( graph_builder.training_graph(features, labels, input_weights=weights, num_trainers=num_trainers, trainer_id=trainer_id), state_ops.assign_add(contrib_framework.get_global_step(), 1)) loss_deps.append(training_graph) if hasattr(graph_builder, 'finalize_training'): finalize_listener = EveryCheckpointPreSaveListener( graph_builder.finalize_training()) scaffold = monitored_session.Scaffold() training_hooks.append( basic_session_run_hooks.CheckpointSaverHook( model_dir, save_secs=600, save_steps=None, scaffold=scaffold, listeners=[finalize_listener])) training_loss = None if (mode == model_fn_lib.ModeKeys.EVAL or mode == model_fn_lib.ModeKeys.TRAIN): with ops.control_dependencies(loss_deps): training_loss = graph_builder.training_loss(features, labels, name=LOSS_NAME) # Put weights back in if weights is not None: features[weights_name] = weights if early_stopping_rounds: training_hooks.append(TensorForestLossHook(early_stopping_rounds)) if report_feature_importances: training_hooks.append( TensorForestRunOpAtEndHook({ 'feature_importances': graph_builder.feature_importances() })) return model_fn_lib.ModelFnOps(mode=mode, predictions=inference, loss=training_loss, train_op=training_graph, training_hooks=training_hooks, scaffold=scaffold, output_alternatives=output_alternatives)
def random_forest_model_fn(features, labels, mode, params, config): """Function that returns predictions, training loss, and training op.""" labels_tensor = labels if isinstance(labels, dict) and len(labels) == 1: labels_tensor = labels.values()[0] weights_name = params["weights_name"] keys_name = params["keys_name"] num_classes = tf.identity(params['num_classes'], name='num_classes') params_toGraphs = tensor_forest.ForestHParams( num_classes=params['num_classes'], num_features=params['num_features'], num_trees=params['num_trees'], max_nodes=params['max_nodes'], regression=params['regression'], split_after_samples=params['split_after_samples']) # 注意第90行 fill() # https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/contrib # /tensor_forest/python/tensor_forest.py params_toGraphs = params_toGraphs.fill() graph_builder_class = tensor_forest.RandomForestGraphs early_stopping_rounds = params["early_stopping_rounds"] num_trainers = 1 trainer_id = 0 report_feature_importances = False model_dir = None local_eval = False device_assigner = None weights = None if weights_name and weights_name in features: weights = features.pop(weights_name) keys = None if keys_name and keys_name in features: keys = features.pop(keys_name) # If we're doing eval, optionally ignore device_assigner. # Also ignore device assigner if we're exporting (mode == INFER) dev_assn = device_assigner if (mode == model_fn_lib.ModeKeys.INFER or (local_eval and mode == model_fn_lib.ModeKeys.EVAL)): dev_assn = None graph_builder = graph_builder_class(params_toGraphs, device_assigner=dev_assn) inference = {} predictions = {} output_alternatives = None # if (mode == model_fn_lib.ModeKeys.EVAL or # mode == model_fn_lib.ModeKeys.INFER): if True: inference[eval_metrics.INFERENCE_PROB_NAME] = ( graph_builder.inference_graph(features)) if params_toGraphs.regression: predictions = {None: inference[eval_metrics.INFERENCE_PROB_NAME]} output_alternatives = { None: (constants.ProblemType.LINEAR_REGRESSION, predictions) } else: inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax( inference[eval_metrics.INFERENCE_PROB_NAME], 1) predictions = { prediction_key.PredictionKey.PROBABILITIES: inference[eval_metrics.INFERENCE_PROB_NAME], prediction_key.PredictionKey.CLASSES: inference[eval_metrics.INFERENCE_PRED_NAME] } output_alternatives = { None: (constants.ProblemType.CLASSIFICATION, predictions) } if report_feature_importances: inference[eval_metrics.FEATURE_IMPORTANCE_NAME] = ( graph_builder.feature_importances()) if keys is not None: inference[keys_name] = keys # labels might be None if we're doing prediction (which brings up the # question of why we force everything to adhere to a single model_fn). loss_deps = [] training_graph = None training_hooks = [] scaffold = None if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN: training_graph = control_flow_ops.group( graph_builder.training_graph(features, labels, input_weights=weights, num_trainers=num_trainers, trainer_id=trainer_id), state_ops.assign_add(contrib_framework.get_global_step(), 1)) loss_deps.append(training_graph) if hasattr(graph_builder, 'finalize_training'): finalize_listener = EveryCheckpointPreSaveListener( graph_builder.finalize_training()) scaffold = monitored_session.Scaffold() training_hooks.append( basic_session_run_hooks.CheckpointSaverHook( model_dir, save_secs=600, save_steps=None, scaffold=scaffold, listeners=[finalize_listener])) training_loss = None if (mode == model_fn_lib.ModeKeys.EVAL or mode == model_fn_lib.ModeKeys.TRAIN): with ops.control_dependencies(loss_deps): training_loss = graph_builder.training_loss( features, labels, name='rf_training_loss') # 命名以传到 hook 中 if not params['regression']: confusion_matrix_print = confusion_matrix( labels=labels_tensor, predictions=predictions['classes'], num_classes=num_classes, ) confusion_matrix_print = tf.identity(confusion_matrix_print, name='confusion_matrix_print') else: confusion_matrix_print = tf.identity(0, name='confusion_matrix_print') regression_ornot = tf.identity(params['regression'], name='regression_ornot') # Put weights back in if weights is not None: features[weights_name] = weights if early_stopping_rounds: training_hooks.append(TensorForestLossHook(early_stopping_rounds)) metrics = {} # metrics[metric_key.MetricKey.AUC] = metrics_lib.streaming_auc( # labels=labels_tensor, # predictions=inference[eval_metrics.INFERENCE_PRED_NAME] # ) if not params_toGraphs.regression: metrics['eval_confusion_matrix'] = confusion_matrix( labels=labels_tensor, predictions=predictions['classes'], num_classes=params['num_classes'], ) return model_fn_lib.ModelFnOps(mode=mode, predictions=inference, loss=training_loss, train_op=training_graph, training_hooks=training_hooks, scaffold=scaffold, eval_metric_ops=metrics, output_alternatives=output_alternatives)
def MonitoredTrainingSession( master='', # pylint: disable=invalid-name is_chief=True, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=100, save_summaries_secs=None, config=None, stop_grace_period_secs=120, log_step_count_steps=100): """Creates a `MonitoredSession` for training. For a chief, this utility sets proper session initializer/restorer. It also creates hooks related to checkpoint and summary saving. For workers, this utility sets proper session creator which waits for the chief to initialize/restore. Args: master: `String` the TensorFlow master to use. is_chief: If `True`, it will take care of initialization and recovery the underlying TensorFlow session. If `False`, it will wait on a chief to initialize or recover the TensorFlow session. checkpoint_dir: A string. Optional path to a directory where to restore variables. scaffold: A `Scaffold` used for gathering or building supportive ops. If not specified, a default one is created. It's used to finalize the graph. hooks: Optional list of `SessionRunHook` objects. chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if `is_chief==True`, ignore otherwise. save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved using a default checkpoint saver. If `save_checkpoint_secs` is set to `None`, then the default checkpoint saver isn't used. save_summaries_steps: The frequency, in number of global steps, that the summaries are written to disk using a default summary saver. If both `save_summaries_steps` and `save_summaries_secs` are set to `None`, then the default summary saver isn't used. save_summaries_secs: The frequency, in secs, that the summaries are written to disk using a default summary saver. If both `save_summaries_steps` and `save_summaries_secs` are set to `None`, then the default summary saver isn't used. config: an instance of `tf.ConfigProto` proto used to configure the session. It's the `config` argument of constructor of `tf.Session`. stop_grace_period_secs: Number of seconds given to threads to stop after `close()` has been called. log_step_count_steps: The frequency, in number of global steps, that the global step/sec is logged. Returns: A `MonitoredSession` object. """ scaffold = scaffold or Scaffold() if not is_chief: session_creator = WorkerSessionCreator(scaffold=scaffold, master=master, config=config) return MonitoredSession(session_creator=session_creator, hooks=hooks or [], stop_grace_period_secs=stop_grace_period_secs) all_hooks = [] if chief_only_hooks: all_hooks.extend(chief_only_hooks) session_creator = ChiefSessionCreator(scaffold=scaffold, checkpoint_dir=checkpoint_dir, master=master, config=config) if checkpoint_dir: all_hooks.append( basic_session_run_hooks.StepCounterHook( output_dir=checkpoint_dir, every_n_steps=log_step_count_steps)) if (save_summaries_steps and save_summaries_steps > 0) or (save_summaries_secs and save_summaries_secs > 0): all_hooks.append( basic_session_run_hooks.SummarySaverHook( scaffold=scaffold, save_steps=save_summaries_steps, save_secs=save_summaries_secs, output_dir=checkpoint_dir)) if save_checkpoint_secs and save_checkpoint_secs > 0: all_hooks.append( basic_session_run_hooks.CheckpointSaverHook( checkpoint_dir, save_secs=save_checkpoint_secs, scaffold=scaffold)) if hooks: all_hooks.extend(hooks) return MonitoredSession(session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
def train(args): """Train CIFAR-10 for a number of steps. Args: args: The command line arguments. """ with tf.Graph().as_default(): # Create the global step global_step = tf.contrib.framework.create_global_step() # Get images and labels for CIFAR-10. images, labels = cifar10.distorted_inputs(args.data_dir, args.batch_size, args.use_fp16) # Build a Graph that computes the logits predictions from the # inference model. logits = cifar10.inference(images, args.batch_size, args.use_fp16) # Calculate loss. loss = cifar10.loss(logits, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step, args.batch_size) scaffold = monitored_session.Scaffold() session_creator = monitored_session.ChiefSessionCreator( scaffold, checkpoint_dir=args.train_dir, config=tf.ConfigProto( log_device_placement=args.log_device_placement)) hooks = [ # Hook to save the model every N steps and at the end. basic_session_run_hooks.CheckpointSaverHook( args.train_dir, checkpoint_basename=CHECKPOINT_BASENAME, save_steps=args.checkpoint_interval_steps, scaffold=scaffold), # Hook to save a summary every N steps. basic_session_run_hooks.SummarySaverHook( save_steps=args.summary_interval_steps, output_dir=args.train_dir, scaffold=scaffold), # Hook to stop at step N. basic_session_run_hooks.StopAtStepHook( last_step=args.train_max_steps) ] # Start a new monitored session. This will automatically restart the # sessions if the parameter servers are preempted. with monitored_session.MonitoredSession( session_creator=session_creator, hooks=hooks) as sess: while not sess.should_stop(): start_time = time.time() _, loss_value, global_step_value = sess.run( [train_op, loss, global_step]) duration = time.time() - start_time assert not np.isnan( loss_value), 'Model diverged with loss = NaN' if global_step_value % 10 == 0: num_examples_per_step = args.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) logging.info( ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)'), datetime.now(), global_step_value, loss_value, examples_per_sec, sec_per_batch)
def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self): logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/') if gfile.Exists(logdir): # For running on jenkins. gfile.DeleteRecursively(logdir) # First, train only the weights of the model. with ops.Graph().as_default(): random_seed.set_random_seed(0) total_loss = self.ModelLoss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) weights = variables_lib.get_variables_by_name('weights') train_op = training.create_train_op( total_loss, optimizer, variables_to_train=weights) saver = saver_lib.Saver() loss = training.train( train_op, logdir, hooks=[ basic_session_run_hooks.CheckpointSaverHook( logdir, save_steps=1, saver=saver), basic_session_run_hooks.StopAtStepHook(num_steps=200), ]) self.assertGreater(loss, .015) self.assertLess(loss, .05) # Next, train the biases of the model. with ops.Graph().as_default(): random_seed.set_random_seed(1) total_loss = self.ModelLoss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) biases = variables_lib.get_variables_by_name('biases') train_op = training.create_train_op( total_loss, optimizer, variables_to_train=biases) saver = saver_lib.Saver() loss = training.train( train_op, logdir, hooks=[ basic_session_run_hooks.CheckpointSaverHook( logdir, save_steps=1, saver=saver), basic_session_run_hooks.StopAtStepHook(num_steps=300), ]) self.assertGreater(loss, .015) self.assertLess(loss, .05) # Finally, train both weights and bias to get lower loss. with ops.Graph().as_default(): random_seed.set_random_seed(2) total_loss = self.ModelLoss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(total_loss, optimizer) saver = saver_lib.Saver() loss = training.train( train_op, logdir, hooks=[ basic_session_run_hooks.CheckpointSaverHook( logdir, save_steps=1, saver=saver), basic_session_run_hooks.StopAtStepHook(num_steps=400), ]) self.assertIsNotNone(loss) self.assertLess(loss, .015)
def test_raise_when_saver_and_scaffold_both_missing(self): with self.assertRaises(ValueError): basic_session_run_hooks.CheckpointSaverHook(self.model_dir)
def __init__(self, estimator, external_state_policy=None): """Initializes a `CheckpointInputPipelineHook`. If the input pipeline depends on external state (e.g. seeds for RandomUniform) beyond the input pipeline, this hook would be unable to serialize and deserialize that state. If its acceptable to ignore that state change the external_state_policy argument to 'warn' or 'ignore'. For e.g. ```python est = tf.estimator.Estimator(model_fn) while True: est.train( train_input_fn, hooks=[tf.data.experimental.CheckpointInputPipelineHook( est, external_state_policy='warn')], steps=train_steps_per_eval) # Note: We do not pass the hook here. metrics = est.evaluate(eval_input_fn) if should_stop_the_training(metrics): break ``` Args: estimator: Estimator. external_state_policy: A string that identifies how to handle input pipelines that depend on external state. Possible values are 'ignore': The external state is silently ignored. 'warn': The external state is ignored, logging a warning. 'fail': The operation fails upon encountering external state. By default we set it to 'fail'. Raises: ValueError: One of `save_steps` or `save_secs` should be set. ValueError: At most one of saver or scaffold should be set. ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or 'fail'. """ if external_state_policy is None: external_state_policy = "fail" self._external_state_policy = _convert_external_state_policy_to_enum( external_state_policy) # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines. # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix # to be different to avoid conflicts with the model checkpoint. # pylint: disable=protected-access checkpoint_prefix = "input" if estimator._config.num_worker_replicas > 1: # Distributed setting. suffix = "_{}_{}".format(estimator._config.task_type, estimator._config.task_id) checkpoint_prefix += suffix # pylint: enable=protected-access # We use a composition paradigm instead of inheriting from # `CheckpointSaverHook` because `Estimator` does an `isinstance` check # to check whether a `CheckpointSaverHook` is already present in the list # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` # would thwart this behavior. This hook checkpoints *only the iterators* # and not the graph variables. self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( estimator.model_dir, save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access checkpoint_basename=checkpoint_prefix + ".ckpt") # Name for the protocol buffer file that will contain the list of most # recent checkpoints stored as a `CheckpointState` protocol buffer. # This file, kept in the same directory as the checkpoint files, is # automatically managed by the `Saver` to keep track of recent checkpoints. # The default name used by the `Saver` for this file is "checkpoint". Here # we use the name "checkpoint_<checkpoint_prefix>" so that in case the # `checkpoint_dir` is the same as the model checkpoint directory, there are # no conflicts during restore. self._latest_filename = "checkpoint_" + checkpoint_prefix
def train_mnist(): """ The main function which trains. Does basically everything in one function. """ with tf.get_default_graph().as_default(): if DATA is 'MNIST': input_pipe = MnistData(BATCH_SIZE) train_features, train_labels = input_pipe.build_train_data_tensor() elif DATA is 'CIFAR10': train_features, train_labels = cifar10_input.distorted_inputs( BATCH_SIZE) else: raise ValueError('DATA value not supported: ' + DATA) checkpoint = tf.train.get_checkpoint_state(MODEL_DIR) if checkpoint is None: tf.logging.info('No checkpoint found; training from scratch') global_step = 0 else: # Assuming model_checkpoint_path looks something like: # /my-favorite-path/cifar10_train/model.ckpt-0, # extract global_step from it. # needs to be before graph setup global_step = float( checkpoint.model_checkpoint_path.split('/')[-1].split('-')[-1]) tf.logging.info('Continuing training from step: ' + str(global_step)) train_op, mask_update_op, pruning_obj, loss = train_graph( train_features, train_labels) assign_ops = get_assign_ops() checkpoint_hook = basic_session_run_hooks.CheckpointSaverHook( MODEL_DIR, save_steps=TEST_FREQ) # needs to be after graph setup if checkpoint is None: if RESTORE: var_list = [] for var_str, shape in zip(RESTORE_VARS, RESTORE_SHAPES): i = var_str.index('/') scope = var_str[:i] with tf.variable_scope(scope, reuse=True): var = tf.get_variable(var_str[i + 1:], shape=shape) var_list.append(var) saver = tf.train.Saver(var_list) else: # cifar_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # cifar_vars = cifar_vars[:-1] # for var in cifar_vars: # tf.logging.info(str(var)) # saver2 = tf.train.Saver(cifar_vars) saver2 = tf.train.Saver() class _LoggerHook(tf.train.SessionRunHook): """logs loss and runtime.""" def begin(self): self._step = -0.5 + global_step def before_run(self, run_context): # this is a hack so that it correctly counts # (trainop, mask_update_op) as one step self._step += 0.5 self._start_time = time.time() return tf.train.SessionRunArgs(loss) # asks for loss value. def after_run(self, run_context, run_values): duration = time.time() - self._start_time loss_value = run_values.results if self._step > 0 and self._step % 100 == 0: num_examples_per_step = BATCH_SIZE examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec' + '; %.3f sec/batch)') tf.logging.info( format_str % (datetime.datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) with tf.train.MonitoredTrainingSession( checkpoint_dir=MODEL_DIR, hooks=[_LoggerHook(), checkpoint_hook], save_checkpoint_secs=None, config=tf.ConfigProto(log_device_placement=False)) as sess: if checkpoint is not None: saver2.restore(sess, checkpoint.model_checkpoint_path) elif RESTORE: ckpt_path = tf.train.latest_checkpoint(RESTORE_DIR) if not ckpt_path: raise ValueError( 'Restore dir cant find checkpoint fles: ' + RESTORE_DIR) saver.restore(sess, ckpt_path) if INIT_DICT is not None: sess.run(assign_ops) dict_str = '; '.join( map( lambda scope: scope + '/' + ', '.join(INIT_DICT[scope]. keys()) + '', INIT_DICT.keys())) tf.logging.info('Instantiated tensors with assign ops: ' + dict_str) if PRUNING_CKPTS: pruning_dir = MODEL_DIR + '/../' + MODEL_NAME + '_prune_ckpts/' if not os.path.isdir(pruning_dir): os.mkdir(pruning_dir) j = 0 max_j = len(PRUNING_CKPTS) for i in range(TRAIN_STEPS): sess.run(train_op) sess.run(mask_update_op) if (PRUNING_CKPTS and j < max_j and sess.run(pruning_obj._sparsity) > PRUNING_CKPTS[j]): tf.logging.info('Saving pruning ckpt: ' + str(PRUNING_CKPTS[j])) new_dir = pruning_dir + str(PRUNING_CKPTS[j]) + '/' utils.duplicate_saved(MODEL_DIR, new_dir) j += 1