def load_metrics(self): """ AUTHORS: -------- Samuel Westlake, Alix Leroy DESCRIPTION: ------------ Load metrics into the deeplodocus Metrics class PARAMETERS: ----------- None RETURN: ------- return: None """ self.loading_message("Metrics") self.metrics = Metrics(self.config.metrics.get()) # Update metrics for trainer, validator, tester and predictor for item in (self.trainer, self.validator, self.tester, self.predictor): if item is not None: item.metrics = self.metrics Notification(DEEP_NOTIF_INFO, "%s : Metrics updated" % item.name)
def __init__(self, dataset: Dataset, model, transform_manager, losses: Losses, metrics: Union[Metrics, None] = None, batch_size: int = 32, num_workers: int = 1, shuffle: Flag = DEEP_SHUFFLE_NONE, name: str = "Inferer"): self.dataset = dataset self.model = model self.transform_manager = transform_manager self.losses = losses self.metrics = Metrics() if metrics is None else metrics self.batch_size = batch_size self.num_workers = num_workers self.name = name self.shuffle = get_corresponding_flag(DEEP_LIST_SHUFFLE, shuffle, fatal=False, default=DEEP_SHUFFLE_NONE) self.dataloader = DataLoader(dataset=self.dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
def load_metrics(self): """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Load the metrics PARAMETERS: ----------- None RETURN: ------- :return loss_functions->dict: The metrics """ metrics = {} if self.config.metrics.get(): for key, config in self.config.metrics.get().items(): # Get the expected loss path (for notification purposes) if config.module is None: metric_path = "%s : %s from default modules" % (key, config.name) else: metric_path = "%s : %s from %s" % (key, config.name, config.module) # Notify the user which loss is being collected and from where Notification(DEEP_NOTIF_INFO, DEEP_MSG_METRIC_LOADING % metric_path) # Get the metric object metric, module = get_module( name=config.name, module=config.module, browse={**DEEP_MODULE_METRICS, **DEEP_MODULE_LOSSES}, silence=True ) # If metric is not found by get_module if metric is None: Notification(DEEP_NOTIF_FATAL, DEEP_MSG_METRIC_NOT_FOUND % config.name) # Check if the metric is a class or a stand-alone function if inspect.isclass(metric): method = metric(**config.kwargs.get()) else: method = metric # Add to the dictionary of metrics and notify of success metrics[str(key)] = Metric(name=str(key), method=method) metrics[str(key)] = Metric(name=str(key), method=method) Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_METRIC_LOADED % (key, config.name, module)) else: Notification(DEEP_NOTIF_INFO, DEEP_MSG_METRIC_NONE) self.metrics = Metrics(metrics)