コード例 #1
0
ファイル: training.py プロジェクト: nuhame/mlpug
    def set_model_components_state(self, state):
        """

        :param state:
        :return: success (True or False)
        """
        if not _.is_callable(getattr(state, 'items', None)):
            self._log.error(
                "State is invalid, unable to set model components state")
            return False

        success = True
        for name, model_state in state.items():
            model = self.get_model_component(name)
            if model is None:
                self._log.error(
                    f"No {name} model not found, unable to set state")
                success = False
                continue

            try:
                self._set_model_state(model, model_state, name)
            except Exception as e:
                _.log_exception(self._log,
                                f"Unable to set state for model {name}", e)
                success = False

        return success
コード例 #2
0
ファイル: training.py プロジェクト: nuhame/mlpug
    def set_optimizers_state(self, state):
        """

        :param state:
        :return: success (True, False)
        """
        if not _.is_callable(getattr(state, 'items', None)):
            self._log.error("State is invalid, unable to set optimizers state")
            return False

        success = True
        for name, optimizer_state in state.items():
            optimizer = self.get_optimizer(name)
            if optimizer is None:
                self._log.error(
                    f"No {name} optmizer not found, unable to set state")
                success = False
                continue

            try:
                self._set_optimizer_state(optimizer, optimizer_state, name)
            except Exception as e:
                _.log_exception(self._log,
                                f"Unable to set state for optimizer {name}", e)
                success = False

        return success
コード例 #3
0
    def _inspect_data(self, generated_batch_data, predicted_batch_data,
                      batch_metrics_data):
        if not (hasattr(self._inspector, "inspect")
                and _.is_callable(self._inspector.inspect)):
            self._log.error(
                "No valid inspector given, unable to inspect random samples")
            return

        self._log.info("Current batch performance metrics : \n%s")
        for metric, value in batch_metrics_data.items():
            self._log.info("%s : %f" % (metric, value))

        self._inspector.inspect(generated_batch_data, predicted_batch_data,
                                self._num_samples_to_inspect)
コード例 #4
0
ファイル: training.py プロジェクト: nuhame/mlpug
    def set_optimizers_state(self, state):
        """

        :param state:
        :return: success (True, False)
        """
        if not _.is_callable(getattr(state, 'items', None)):
            self._log.error("State is invalid, unable to set optimizers state")
            return False

        self._deferred_optimizers_state = state

        self._log.debug(
            "Optimizers checkpoint state received; "
            "deferred setting the state until training has started")

        return True
コード例 #5
0
    def _check_settings(self):
        if not (hasattr(self._model, 'save_weights')
                and _.is_callable(self._model.save_weights)):
            self._log.error(
                "No valid model provided, creating checkpoints will fail ...")

        if (self._create_checkpoint_every <
                0) and (self._archive_last_checkpoint_every > 0):
            self._log.error(
                "archive_last_checkpoint_every can't be > 0 while _create_checkpoint_every < 0, "
                "disabling archiving ... ")
            self._archive_last_checkpoint_every = -1

        if (self._create_checkpoint_every > 0) and \
           (self._archive_last_checkpoint_every > 0) and \
           (self._archive_last_checkpoint_every % self._create_checkpoint_every != 0):

            self._archive_last_checkpoint_every = 10 * self._create_checkpoint_every
            self._log.error(
                "archive_last_checkpoint_every must be exact multiple of _create_checkpoint_every, "
                "changing archive_last_checkpoint_every to [%d]" %
                self._archive_last_checkpoint_every)