def _train_model(self, input_fn, hooks): all_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = training.create_global_step(g) with ops.device('/cpu:0'): features, labels = input_fn() estimator_spec = self._call_model_fn(features, labels, model_fn_lib.ModeKeys.FIT) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) all_hooks.extend([ training.NanTensorHook(estimator_spec.loss), training.LoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step_tensor }, every_n_iter=100) ]) all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) scaffold = estimator_spec.scaffold or training.Scaffold() if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection(ops.GraphKeys.SAVERS, training.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, defer_build=True)) chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): saver_hook_exists = any([ isinstance(h, training.CheckpointSaverHook) for h in (all_hooks + chief_hooks + estimator_spec.training_chief_hooks) ]) if not saver_hook_exists: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=scaffold) ] with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=scaffold, hooks=all_hooks, chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, config=config_pb2.ConfigProto(allow_soft_placement=True)) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) return loss
def _train_model(self, input_fn, hooks, saving_listeners): worker_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) global_step_read_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access features, labels = self._get_features_and_labels_from_input_fn( input_fn, model_fn_lib.ModeKeys.TRAIN) with ops.control_dependencies([global_step_read_tensor]): estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) # Check if the user created a loss summary, and add one if they didn't. # We assume here that the summary is called 'loss'. If it is not, we will # make another one with the name 'loss' to ensure it shows up in the right # graph in TensorBoard. if not any([x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]): summary.scalar('loss', estimator_spec.loss) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) worker_hooks.extend(hooks) worker_hooks.extend([ training.NanTensorHook(estimator_spec.loss), training.LoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step_tensor }, every_n_iter=100) ]) worker_hooks.extend(estimator_spec.training_hooks) if not (estimator_spec.scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, training.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=( self._config.keep_checkpoint_every_n_hours), defer_build=True, save_relative_paths=True)) chief_hooks = [] all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks) saver_hooks = [ h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): if not saver_hooks: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=estimator_spec.scaffold) ] saver_hooks = [chief_hooks[0]] if saving_listeners: if not saver_hooks: raise ValueError( 'There should be a CheckpointSaverHook to use saving_listeners. ' 'Please set one of the RunConfig.save_checkpoints_steps or ' 'RunConfig.save_checkpoints_secs.') else: # It is expected to have one CheckpointSaverHook. If multiple, we pick # up the first one to add listener. saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=estimator_spec.scaffold, hooks=worker_hooks, chief_only_hooks=( tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, config=self._session_config, log_step_count_steps=self._config.log_step_count_steps) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) return loss
def _train_model(self, input_fn, hooks): all_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) features, labels = self._get_features_and_labels_from_input_fn( input_fn, model_fn_lib.ModeKeys.TRAIN) estimator_spec = self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) all_hooks.extend(hooks) all_hooks.extend([ training.NanTensorHook(estimator_spec.loss), training.LoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step_tensor }, every_n_iter=100) ]) all_hooks.extend(estimator_spec.training_hooks) if not (estimator_spec.scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, training.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=( self._config.keep_checkpoint_every_n_hours), defer_build=True, save_relative_paths=True)) chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): saver_hook_exists = any([ isinstance(h, training.CheckpointSaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks)) ]) if not saver_hook_exists: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=estimator_spec.scaffold) ] with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=estimator_spec.scaffold, hooks=all_hooks, chief_only_hooks=( tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, config=self._session_config, log_step_count_steps=self._config.log_step_count_steps ) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run( [estimator_spec.train_op, estimator_spec.loss]) return loss
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners, save_best_ckpt): """Train a model with the given Estimator Spec.""" if self._warm_start_settings: logging.info('Warm-starting with WarmStartSettings: %s' % (self._warm_start_settings, )) warm_starting_util.warm_start(*self._warm_start_settings) worker_hooks.extend(hooks) worker_hooks.append(training.NanTensorHook(estimator_spec.loss)) if self._config.log_step_count_steps is not None: tensors = {"loss": estimator_spec.loss, "step": global_step_tensor} tensors.update({ key.replace("/", ""): val for key, val in estimator_spec.predictions.items() if "/" in key }) worker_hooks.append( training.LoggingTensorHook( tensors, every_n_iter=self._config.log_step_count_steps)) worker_hooks.extend(estimator_spec.training_hooks) # Create Saver object if not (estimator_spec.scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, training.Saver(sharded=True, max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=( self._config.keep_checkpoint_every_n_hours), defer_build=True, save_relative_paths=True)) chief_hooks = [] all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks) saver_hooks = [ h for h in all_hooks if isinstance(h, training.CheckpointSaverHook) ] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): if not saver_hooks: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=estimator_spec.scaffold) ] saver_hooks = [chief_hooks[0]] if saving_listeners: if not saver_hooks: raise ValueError( 'There should be a CheckpointSaverHook to use saving_listeners. ' 'Please set one of the RunConfig.save_checkpoints_steps or ' 'RunConfig.save_checkpoints_secs.') else: # It is expected to have one CheckpointSaverHook. If multiple, we pick # up the first one to add listener. saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access if self._train_with_eval: self.dataset_handle_hook = IteratorStringHandleHook( self.train_iterator, self.eval_iterator) worker_hooks.append(self.dataset_handle_hook) self._predict_keys = estimator_spec.predictions if save_best_ckpt: EvaluatorCls = self._params.get("evaluator", None) if not issubclass(EvaluatorCls, EvaluateBase): raise TypeError( "Parameter `evaluator` must be a EvaluateBase instance, but got {}" .format(type(EvaluatorCls))) eval_kwargs = self._params.get("eval_kwargs", {}) eval_steps = self._params.get("eval_steps", 2500) primary_metric = self._params.get("primary_metric", None) secondary_metric = self._params.get("secondary_metric", None) # We must construct Evaluator inside a graph scope evaluator = EvaluatorCls(self, **eval_kwargs) worker_hooks.append( BestCheckpointSaverHook(evaluator=evaluator, checkpoint_dir=self._model_dir, compare_fn=partial( evaluator.compare, primary_metric=primary_metric, secondary_metric=secondary_metric), tag=self._params["args"].tag, save_steps=eval_steps)) # Training session monitor with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=estimator_spec.scaffold, hooks=worker_hooks, chief_only_hooks=(tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, save_summaries_steps=self._config.save_summary_steps, config=self._session_config, log_step_count_steps=self._config.log_step_count_steps ) as mon_sess: loss = None # Make sure that use self.dataset_handle_hook.xxx_handle after create MonitoredSession() self._feed_dict = _add_key_value( self._feed_dict, self.handler, self.dataset_handle_hook.train_handle) while not mon_sess.should_stop(): _, loss = mon_sess.run( [estimator_spec.train_op, estimator_spec.loss], self._feed_dict) return loss
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners): """Train a model with the given Estimator Spec.""" if self._warm_start_settings: logging.info('Warm-starting with WarmStartSettings: %s' % (self._warm_start_settings, )) warm_starting_util.warm_start(*self._warm_start_settings) # Check if the user created a loss summary, and add one if they didn't. # We assume here that the summary is called 'loss'. If it is not, we will # make another one with the name 'loss' to ensure it shows up in the right # graph in TensorBoard. # if not any([x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]): # summary.scalar('loss', estimator_spec.loss) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) worker_hooks.extend(hooks) # worker_hooks.extend([ # training.NanTensorHook(estimator_spec.loss) # ]) worker_hooks.extend(estimator_spec.training_hooks) if not (estimator_spec.scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, training.Saver(sharded=True, max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=( self._config.keep_checkpoint_every_n_hours), defer_build=True, save_relative_paths=True)) chief_hooks = [] all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks) saver_hooks = [ h for h in all_hooks if isinstance(h, training.CheckpointSaverHook) ] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): if not saver_hooks: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=estimator_spec.scaffold) ] saver_hooks = [chief_hooks[0]] if saving_listeners: if not saver_hooks: raise ValueError( 'There should be a CheckpointSaverHook to use saving_listeners. ' 'Please set one of the RunConfig.save_checkpoints_steps or ' 'RunConfig.save_checkpoints_secs.') else: # It is expected to have one CheckpointSaverHook. If multiple, we pick # up the first one to add listener. saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access if is_rank0(): log_step_count_steps = self._config.log_step_count_steps checkpoint_dir = self.model_dir chief_only_hooks = (tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)) else: log_step_count_steps = None checkpoint_dir = None chief_only_hooks = None with MonitoredTrainingSession( master=self._config.master, is_chief=is_rank0(), checkpoint_dir=checkpoint_dir, scaffold=estimator_spec.scaffold, hooks=worker_hooks, chief_only_hooks=chief_only_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, config=self._session_config, log_step_count_steps=log_step_count_steps) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run( [estimator_spec.train_op, estimator_spec.loss]) return loss