def _init_model(self) -> None: self._init_paths() self.estimator = tf.estimator.Estimator( model_fn=self.estimator._model_fn, config=self._init_run_config(self.estimator.config), params=self.estimator.params if self.estimator.params != {} else None, warm_start_from=self.estimator._warm_start_settings, ) check.is_instance( self.estimator, tf.estimator.Estimator, "Please modify your model definition's build_estimator() implementation to return " "an instance of `tf.estimator.Estimator`.", ) check.is_instance( self.user_train_spec, tf.estimator.TrainSpec, "Please modify your model definition's build_train_spec() implementation to return " "an instance of `tf.estimator.TrainSpec`.", ) check.is_instance( self.val_spec, tf.estimator.EvalSpec, "Please modify your model definition's build_validation_spec() implementation " "to return an instance of `tf.estimator.EvalSpec`.", ) all_hooks = [*self.user_train_spec.hooks] if self.hvd_config.use: all_hooks.append(hvd.BroadcastGlobalVariablesHook(0)) # It is important that this hook is the final in the list so that if # any other hooks need to run _before_ the training step ends they have # their chance. all_hooks.append(DeterminedControlHook(self)) # TODO(DET-834): Separate step ID from data loader state. # # During warm start, we initialize model weights, optimizer state # and input state from the checkpoint, and we set the step ID to # 1. Trials typically use the step ID as an index into the data # sequence, which means there is an inconsistency between the # step ID (as data index) and the optimizer state and input state. # # In the short term, behave like other trials and reset input # state if we are warm started. This will create an inconsistency # wrt saved optimizer state. # Repeat training dataset so we never run out of data. repeating_train_fn = self._check_and_repeat_train_input_fn( self.user_train_spec.input_fn) self.train_spec = tf.estimator.TrainSpec(input_fn=repeating_train_fn, hooks=all_hooks) self.eval_spec = tf.estimator.EvalSpec(input_fn=self.val_spec.input_fn, steps=None)
def _init_train_hooks(self) -> None: self.train_hooks = [*self.user_train_spec.hooks] if self.hvd_config.use: self.train_hooks.append(hvd.BroadcastGlobalVariablesHook(0)) # It is important that this hook is the final in the list so that if # any other hooks need to run _before_ the training step ends they have # their chance. self.train_hooks.append(DeterminedControlHook(self))
def _init_train_hooks(self) -> None: self.train_hooks = [*self.user_train_spec.hooks] self.train_hooks.append(DeterminedEarlyStoppingHook(self.context)) if self.context.distributed.size > 1: self.train_hooks.append(hvd.BroadcastGlobalVariablesHook(0)) # It is important that this hook is the final in the list so that if # any other hooks need to run _before_ the training step ends they have # their chance. self.train_hooks.append(DeterminedControlHook(self))