Ejemplo n.º 1
0
    def _init_model(self) -> None:
        self._init_paths()

        self.estimator = tf.estimator.Estimator(
            model_fn=self.estimator._model_fn,
            config=self._init_run_config(self.estimator.config),
            params=self.estimator.params
            if self.estimator.params != {} else None,
            warm_start_from=self.estimator._warm_start_settings,
        )

        check.is_instance(
            self.estimator,
            tf.estimator.Estimator,
            "Please modify your model definition's build_estimator() implementation to return "
            "an instance of `tf.estimator.Estimator`.",
        )
        check.is_instance(
            self.user_train_spec,
            tf.estimator.TrainSpec,
            "Please modify your model definition's build_train_spec() implementation to return "
            "an instance of `tf.estimator.TrainSpec`.",
        )
        check.is_instance(
            self.val_spec,
            tf.estimator.EvalSpec,
            "Please modify your model definition's build_validation_spec() implementation "
            "to return an instance of `tf.estimator.EvalSpec`.",
        )

        all_hooks = [*self.user_train_spec.hooks]

        if self.hvd_config.use:
            all_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        # It is important that this hook is the final in the list so that if
        # any other hooks need to run _before_ the training step ends they have
        # their chance.
        all_hooks.append(DeterminedControlHook(self))

        # TODO(DET-834): Separate step ID from data loader state.
        #
        # During warm start, we initialize model weights, optimizer state
        # and input state from the checkpoint, and we set the step ID to
        # 1. Trials typically use the step ID as an index into the data
        # sequence, which means there is an inconsistency between the
        # step ID (as data index) and the optimizer state and input state.
        #
        # In the short term, behave like other trials and reset input
        # state if we are warm started. This will create an inconsistency
        # wrt saved optimizer state.

        # Repeat training dataset so we never run out of data.
        repeating_train_fn = self._check_and_repeat_train_input_fn(
            self.user_train_spec.input_fn)

        self.train_spec = tf.estimator.TrainSpec(input_fn=repeating_train_fn,
                                                 hooks=all_hooks)
        self.eval_spec = tf.estimator.EvalSpec(input_fn=self.val_spec.input_fn,
                                               steps=None)
Ejemplo n.º 2
0
    def _init_train_hooks(self) -> None:
        self.train_hooks = [*self.user_train_spec.hooks]

        if self.hvd_config.use:
            self.train_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        # It is important that this hook is the final in the list so that if
        # any other hooks need to run _before_ the training step ends they have
        # their chance.
        self.train_hooks.append(DeterminedControlHook(self))
Ejemplo n.º 3
0
    def _init_train_hooks(self) -> None:
        self.train_hooks = [*self.user_train_spec.hooks]

        self.train_hooks.append(DeterminedEarlyStoppingHook(self.context))

        if self.context.distributed.size > 1:
            self.train_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        # It is important that this hook is the final in the list so that if
        # any other hooks need to run _before_ the training step ends they have
        # their chance.
        self.train_hooks.append(DeterminedControlHook(self))