def _train_model(self, input_fn, hooks): all_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = training.create_global_step(g) with ops.device('/cpu:0'): features, labels = input_fn() estimator_spec = self._call_model_fn(features, labels, model_fn_lib.ModeKeys.FIT) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) all_hooks.extend([ training.NanTensorHook(estimator_spec.loss), training.LoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step_tensor }, every_n_iter=100) ]) all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) scaffold = estimator_spec.scaffold or training.Scaffold() if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection(ops.GraphKeys.SAVERS, training.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, defer_build=True)) chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): saver_hook_exists = any([ isinstance(h, training.CheckpointSaverHook) for h in (all_hooks + chief_hooks + estimator_spec.training_chief_hooks) ]) if not saver_hook_exists: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=scaffold) ] with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=scaffold, hooks=all_hooks, chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, config=config_pb2.ConfigProto(allow_soft_placement=True)) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) return loss
def _restore(self, path): """Restores this estimator from given path. Note: will rebuild the graph and initialize all parameters, and will ignore provided model. Args: path: Path to checkpoints and other information. """ # Currently Saver requires absolute path to work correctly. path = os.path.abspath(path) self._graph = ops.Graph() with self._graph.as_default(): endpoints_filename = os.path.join(path, 'endpoints') if not os.path.exists(endpoints_filename): raise ValueError("Restore folder doesn't contain endpoints.") with gfile.Open(endpoints_filename) as foutputs: endpoints = foutputs.read().split('\n') graph_filename = os.path.join(path, 'graph.pbtxt') if not os.path.exists(graph_filename): raise ValueError("Restore folder doesn't contain graph definition.") with gfile.Open(graph_filename) as fgraph: graph_def = graph_pb2.GraphDef() text_format.Merge(fgraph.read(), graph_def) (self._inp, self._out, self._model_predictions, self._model_loss) = importer.import_graph_def( graph_def, name='', return_elements=endpoints) saver_filename = os.path.join(path, 'saver.pbtxt') if not os.path.exists(saver_filename): raise ValueError("Restore folder doesn't contain saver definition.") with gfile.Open(saver_filename) as fsaver: saver_def = train.SaverDef() text_format.Merge(fsaver.read(), saver_def) self._saver = train.Saver(saver_def=saver_def) # Restore trainer self._global_step = self._graph.get_tensor_by_name('global_step:0') self._train = self._graph.get_operation_by_name('OptimizeLoss/train') # Restore summaries. self._summaries = self._graph.get_operation_by_name( 'MergeSummary/MergeSummary') # Restore session. if not isinstance(self._config, RunConfig): self._config = RunConfig(verbose=self.verbose) self._session = session.Session(self._config.master, config=self._config.tf_config) checkpoint_path = train.latest_checkpoint(path) if checkpoint_path is None: raise ValueError( 'Missing checkpoint files in the %s. Please ' 'make sure you are you have checkpoint file that describes ' 'latest checkpoints and appropriate checkpoints are there. ' 'If you have moved the folder, you at this point need to ' 'update manually update the paths in the checkpoint file.' % path) self._saver.restore(self._session, checkpoint_path) # Set to be initialized. self._initialized = True
def model_fn(features, labels, mode): _, _ = features, labels return estimator_lib.EstimatorSpec( mode, loss=constant_op.constant(3.), scaffold=training.Scaffold(saver=training.Saver()), train_op=constant_op.constant(5.), eval_metric_ops={ 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) })
def model_fn(features, labels, mode): _, _ = features, labels mean = metrics_module.Mean() mean.update_state(constant_op.constant(2.)) return estimator_lib.EstimatorSpec( mode, loss=constant_op.constant(3.), scaffold=training.Scaffold(saver=training.Saver()), train_op=constant_op.constant(5.), eval_metric_ops={ 'mean_of_features': mean, })
def test_stop_if_checkpoint_step_is_laststep(self): model_dir = tempfile.mkdtemp() with ops.Graph().as_default(): step = training.create_global_step() assign_ten = step.assign(10) no_op = control_flow_ops.no_op() hook = hooks_lib._StopAtCheckpointStepHook( model_dir=model_dir, last_step=10) with tf_session.Session() as sess: sess.run(assign_ten) training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: mon_sess.raw_session().run(assign_ten) with test.mock.patch.object(time, 'sleep') as mock_sleep: mon_sess.run(no_op) self.assertFalse(mock_sleep.called) self.assertTrue(mon_sess.should_stop())
def _train_model(self, input_fn, hooks, saving_listeners): worker_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) global_step_read_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access features, labels = self._get_features_and_labels_from_input_fn( input_fn, model_fn_lib.ModeKeys.TRAIN) with ops.control_dependencies([global_step_read_tensor]): estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) # Check if the user created a loss summary, and add one if they didn't. # We assume here that the summary is called 'loss'. If it is not, we will # make another one with the name 'loss' to ensure it shows up in the right # graph in TensorBoard. if not any([x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]): summary.scalar('loss', estimator_spec.loss) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) worker_hooks.extend(hooks) worker_hooks.extend([ training.NanTensorHook(estimator_spec.loss), training.LoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step_tensor }, every_n_iter=100) ]) worker_hooks.extend(estimator_spec.training_hooks) if not (estimator_spec.scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, training.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=( self._config.keep_checkpoint_every_n_hours), defer_build=True, save_relative_paths=True)) chief_hooks = [] all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks) saver_hooks = [ h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): if not saver_hooks: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=estimator_spec.scaffold) ] saver_hooks = [chief_hooks[0]] if saving_listeners: if not saver_hooks: raise ValueError( 'There should be a CheckpointSaverHook to use saving_listeners. ' 'Please set one of the RunConfig.save_checkpoints_steps or ' 'RunConfig.save_checkpoints_secs.') else: # It is expected to have one CheckpointSaverHook. If multiple, we pick # up the first one to add listener. saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=estimator_spec.scaffold, hooks=worker_hooks, chief_only_hooks=( tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, config=self._session_config, log_step_count_steps=self._config.log_step_count_steps) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) return loss
def _setup_training(self): """Sets up graph, model and trainer.""" # Create config if not given. if self._config is None: self._config = RunConfig(verbose=self.verbose) # Create new graph. self._graph = ops.Graph() self._graph.add_to_collection("IS_TRAINING", True) with self._graph.as_default(): random_seed.set_random_seed(self._config.tf_random_seed) self._global_step = variables.Variable(0, name="global_step", trainable=False) # Setting up inputs and outputs. self._inp, self._out = self._data_feeder.input_builder() # If class weights are provided, add them to the graph. # Different loss functions can use this tensor by name. if self.class_weight: self._class_weight_node = constant_op.constant( self.class_weight, name='class_weight') # Add histograms for X and y if they are floats. if self._data_feeder.input_dtype in (np.float32, np.float64): logging_ops.histogram_summary("X", self._inp) if self._data_feeder.output_dtype in (np.float32, np.float64): logging_ops.histogram_summary("y", self._out) # Create model's graph. self._model_predictions, self._model_loss = self.model_fn( self._inp, self._out) # Set up a single operator to merge all the summaries self._summaries = logging_ops.merge_all_summaries() # Create trainer and augment graph with gradients and optimizer. # Additionally creates initialization ops. learning_rate = self.learning_rate optimizer = self.optimizer if callable(learning_rate): learning_rate = learning_rate(self._global_step) if callable(optimizer): optimizer = optimizer(learning_rate) self._train = optimizers.optimize_loss( self._model_loss, self._global_step, learning_rate=learning_rate, optimizer=optimizer, clip_gradients=self.clip_gradients) # Update ops during training, e.g. batch_norm_ops self._train = control_flow_ops.group( self._train, *ops.get_collection('update_ops')) # Get all initializers for all trainable variables. self._initializers = variables.initialize_all_variables() # Create model's saver capturing all the nodes created up until now. self._saver = train.Saver( max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=self._config. keep_checkpoint_every_n_hours) # Enable monitor to create validation data dict with appropriate tf placeholders self._monitor.create_val_feed_dict(self._inp, self._out) # Create session to run model with. self._session = session.Session(self._config.tf_master, config=self._config.tf_config) # Run parameter initializers. self._session.run(self._initializers)
def _train_model(self, input_fn, hooks): all_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) features, labels = self._get_features_and_labels_from_input_fn( input_fn, model_fn_lib.ModeKeys.TRAIN) estimator_spec = self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) all_hooks.extend(hooks) all_hooks.extend([ training.NanTensorHook(estimator_spec.loss), training.LoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step_tensor }, every_n_iter=100) ]) all_hooks.extend(estimator_spec.training_hooks) if not (estimator_spec.scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, training.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=( self._config.keep_checkpoint_every_n_hours), defer_build=True, save_relative_paths=True)) chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): saver_hook_exists = any([ isinstance(h, training.CheckpointSaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks)) ]) if not saver_hook_exists: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=estimator_spec.scaffold) ] with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=estimator_spec.scaffold, hooks=all_hooks, chief_only_hooks=( tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, config=self._session_config, log_step_count_steps=self._config.log_step_count_steps ) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run( [estimator_spec.train_op, estimator_spec.loss]) return loss
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners, save_best_ckpt): """Train a model with the given Estimator Spec.""" if self._warm_start_settings: logging.info('Warm-starting with WarmStartSettings: %s' % (self._warm_start_settings, )) warm_starting_util.warm_start(*self._warm_start_settings) worker_hooks.extend(hooks) worker_hooks.append(training.NanTensorHook(estimator_spec.loss)) if self._config.log_step_count_steps is not None: tensors = {"loss": estimator_spec.loss, "step": global_step_tensor} tensors.update({ key.replace("/", ""): val for key, val in estimator_spec.predictions.items() if "/" in key }) worker_hooks.append( training.LoggingTensorHook( tensors, every_n_iter=self._config.log_step_count_steps)) worker_hooks.extend(estimator_spec.training_hooks) # Create Saver object if not (estimator_spec.scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, training.Saver(sharded=True, max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=( self._config.keep_checkpoint_every_n_hours), defer_build=True, save_relative_paths=True)) chief_hooks = [] all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks) saver_hooks = [ h for h in all_hooks if isinstance(h, training.CheckpointSaverHook) ] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): if not saver_hooks: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=estimator_spec.scaffold) ] saver_hooks = [chief_hooks[0]] if saving_listeners: if not saver_hooks: raise ValueError( 'There should be a CheckpointSaverHook to use saving_listeners. ' 'Please set one of the RunConfig.save_checkpoints_steps or ' 'RunConfig.save_checkpoints_secs.') else: # It is expected to have one CheckpointSaverHook. If multiple, we pick # up the first one to add listener. saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access if self._train_with_eval: self.dataset_handle_hook = IteratorStringHandleHook( self.train_iterator, self.eval_iterator) worker_hooks.append(self.dataset_handle_hook) self._predict_keys = estimator_spec.predictions if save_best_ckpt: EvaluatorCls = self._params.get("evaluator", None) if not issubclass(EvaluatorCls, EvaluateBase): raise TypeError( "Parameter `evaluator` must be a EvaluateBase instance, but got {}" .format(type(EvaluatorCls))) eval_kwargs = self._params.get("eval_kwargs", {}) eval_steps = self._params.get("eval_steps", 2500) primary_metric = self._params.get("primary_metric", None) secondary_metric = self._params.get("secondary_metric", None) # We must construct Evaluator inside a graph scope evaluator = EvaluatorCls(self, **eval_kwargs) worker_hooks.append( BestCheckpointSaverHook(evaluator=evaluator, checkpoint_dir=self._model_dir, compare_fn=partial( evaluator.compare, primary_metric=primary_metric, secondary_metric=secondary_metric), tag=self._params["args"].tag, save_steps=eval_steps)) # Training session monitor with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=estimator_spec.scaffold, hooks=worker_hooks, chief_only_hooks=(tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, save_summaries_steps=self._config.save_summary_steps, config=self._session_config, log_step_count_steps=self._config.log_step_count_steps ) as mon_sess: loss = None # Make sure that use self.dataset_handle_hook.xxx_handle after create MonitoredSession() self._feed_dict = _add_key_value( self._feed_dict, self.handler, self.dataset_handle_hook.train_handle) while not mon_sess.should_stop(): _, loss = mon_sess.run( [estimator_spec.train_op, estimator_spec.loss], self._feed_dict) return loss
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners): """Train a model with the given Estimator Spec.""" if self._warm_start_settings: logging.info('Warm-starting with WarmStartSettings: %s' % (self._warm_start_settings, )) warm_starting_util.warm_start(*self._warm_start_settings) # Check if the user created a loss summary, and add one if they didn't. # We assume here that the summary is called 'loss'. If it is not, we will # make another one with the name 'loss' to ensure it shows up in the right # graph in TensorBoard. # if not any([x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]): # summary.scalar('loss', estimator_spec.loss) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) worker_hooks.extend(hooks) # worker_hooks.extend([ # training.NanTensorHook(estimator_spec.loss) # ]) worker_hooks.extend(estimator_spec.training_hooks) if not (estimator_spec.scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, training.Saver(sharded=True, max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=( self._config.keep_checkpoint_every_n_hours), defer_build=True, save_relative_paths=True)) chief_hooks = [] all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks) saver_hooks = [ h for h in all_hooks if isinstance(h, training.CheckpointSaverHook) ] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): if not saver_hooks: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=estimator_spec.scaffold) ] saver_hooks = [chief_hooks[0]] if saving_listeners: if not saver_hooks: raise ValueError( 'There should be a CheckpointSaverHook to use saving_listeners. ' 'Please set one of the RunConfig.save_checkpoints_steps or ' 'RunConfig.save_checkpoints_secs.') else: # It is expected to have one CheckpointSaverHook. If multiple, we pick # up the first one to add listener. saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access if is_rank0(): log_step_count_steps = self._config.log_step_count_steps checkpoint_dir = self.model_dir chief_only_hooks = (tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)) else: log_step_count_steps = None checkpoint_dir = None chief_only_hooks = None with MonitoredTrainingSession( master=self._config.master, is_chief=is_rank0(), checkpoint_dir=checkpoint_dir, scaffold=estimator_spec.scaffold, hooks=worker_hooks, chief_only_hooks=chief_only_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, config=self._session_config, log_step_count_steps=log_step_count_steps) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run( [estimator_spec.train_op, estimator_spec.loss]) return loss