def save_best_models_under_current_metrics( self, model: BaseModel, metrics_holder: dict, metric_func: dict, **kwargs ): """[This function is responsible to save checkpoint under the current metrics and their associated DEFAULT_METRICS_FUNC] Arguments: model {[BaseModel]} -- [Model] metrics_holder {[Dict]} -- [Need to contain stage, epoch, current_metrics] """ metrics = metrics_holder["current_metrics"] stage = metrics_holder["stage"] epoch = metrics_holder["epoch"] stats = self._checkpoint.stats state_dict = copy.deepcopy(model.state_dict()) current_stat = {} current_stat["epoch"] = epoch models_to_save = self._checkpoint.models if stage not in stats: stats[stage] = [] if stage == "train": models_to_save[Checkpoint._LATEST] = state_dict else: if len(stats[stage]) > 0: latest_stats = stats[stage][-1] msg = "" improved_metric = 0 for metric_name, current_metric_value in metrics.items(): current_stat[metric_name] = current_metric_value metric_func = self.find_func_from_metric_name(metric_name, metric_func) best_metric_from_stats = latest_stats.get("best_{}".format(metric_name), current_metric_value) best_value = metric_func(best_metric_from_stats, current_metric_value) current_stat["best_{}".format(metric_name)] = best_value # This new value seems to be better under metric_func if (self._selection_stage == stage) and ( current_metric_value == best_value ): # Update the model weights models_to_save["best_{}".format(metric_name)] = state_dict msg += "{}: {} -> {}, ".format(metric_name, best_metric_from_stats, best_value) improved_metric += 1 if improved_metric > 0: colored_print(COLORS.VAL_COLOR, msg[:-2]) else: # stats[stage] is empty. for metric_name, metric_value in metrics.items(): current_stat[metric_name] = metric_value current_stat["best_{}".format(metric_name)] = metric_value models_to_save["best_{}".format(metric_name)] = state_dict self._checkpoint.stats[stage].append(current_stat) self._checkpoint.save_objects(models_to_save, stage, current_stat, model.optimizer, model.schedulers, **kwargs)
def log_optimizers(self): colored_print(COLORS.Green, "Optimizer: {}".format(self._optimizer)) colored_print(COLORS.Green, "Learning Rate Scheduler: {}".format(self._lr_scheduler)) colored_print(COLORS.Green, "BatchNorm Scheduler: {}".format(self._bn_scheduler)) colored_print( COLORS.Green, "Accumulated gradients: {}".format( self._accumulated_gradient_step))