def save_best_models_under_current_metrics(
        self, model: BaseModel, metrics_holder: dict, metric_func: dict, **kwargs
    ):
        """[This function is responsible to save checkpoint under the current metrics and their associated DEFAULT_METRICS_FUNC]
        Arguments:
            model {[BaseModel]} -- [Model]
            metrics_holder {[Dict]} -- [Need to contain stage, epoch, current_metrics]
        """
        metrics = metrics_holder["current_metrics"]
        stage = metrics_holder["stage"]
        epoch = metrics_holder["epoch"]

        stats = self._checkpoint.stats
        state_dict = copy.deepcopy(model.state_dict())

        current_stat = {}
        current_stat["epoch"] = epoch

        models_to_save = self._checkpoint.models
        if stage not in stats:
            stats[stage] = []

        if stage == "train":
            models_to_save[Checkpoint._LATEST] = state_dict
        else:
            if len(stats[stage]) > 0:
                latest_stats = stats[stage][-1]

                msg = ""
                improved_metric = 0

                for metric_name, current_metric_value in metrics.items():
                    current_stat[metric_name] = current_metric_value

                    metric_func = self.find_func_from_metric_name(metric_name, metric_func)
                    best_metric_from_stats = latest_stats.get("best_{}".format(metric_name), current_metric_value)
                    best_value = metric_func(best_metric_from_stats, current_metric_value)
                    current_stat["best_{}".format(metric_name)] = best_value

                    # This new value seems to be better under metric_func
                    if (self._selection_stage == stage) and (
                        current_metric_value == best_value
                    ):  # Update the model weights
                        models_to_save["best_{}".format(metric_name)] = state_dict

                        msg += "{}: {} -> {}, ".format(metric_name, best_metric_from_stats, best_value)
                        improved_metric += 1

                if improved_metric > 0:
                    colored_print(COLORS.VAL_COLOR, msg[:-2])
            else:
                # stats[stage] is empty.
                for metric_name, metric_value in metrics.items():
                    current_stat[metric_name] = metric_value
                    current_stat["best_{}".format(metric_name)] = metric_value
                    models_to_save["best_{}".format(metric_name)] = state_dict

        self._checkpoint.stats[stage].append(current_stat)
        self._checkpoint.save_objects(models_to_save, stage, current_stat, model.optimizer, model.schedulers, **kwargs)
示例#2
0
 def log_optimizers(self):
     colored_print(COLORS.Green, "Optimizer: {}".format(self._optimizer))
     colored_print(COLORS.Green,
                   "Learning Rate Scheduler: {}".format(self._lr_scheduler))
     colored_print(COLORS.Green,
                   "BatchNorm Scheduler: {}".format(self._bn_scheduler))
     colored_print(
         COLORS.Green, "Accumulated gradients: {}".format(
             self._accumulated_gradient_step))