Esempio n. 1
0
    def load_metrics(self):
        """
        AUTHORS:
        --------
        Samuel Westlake, Alix Leroy

        DESCRIPTION:
        ------------
        Load metrics into the deeplodocus Metrics class

        PARAMETERS:
        -----------
        None

        RETURN:
        -------
        return: None
        """
        self.loading_message("Metrics")
        self.metrics = Metrics(self.config.metrics.get())
        # Update metrics for trainer, validator, tester and predictor
        for item in (self.trainer, self.validator, self.tester,
                     self.predictor):
            if item is not None:
                item.metrics = self.metrics
                Notification(DEEP_NOTIF_INFO,
                             "%s : Metrics updated" % item.name)
Esempio n. 2
0
 def __init__(self,
              dataset: Dataset,
              model,
              transform_manager,
              losses: Losses,
              metrics: Union[Metrics, None] = None,
              batch_size: int = 32,
              num_workers: int = 1,
              shuffle: Flag = DEEP_SHUFFLE_NONE,
              name: str = "Inferer"):
     self.dataset = dataset
     self.model = model
     self.transform_manager = transform_manager
     self.losses = losses
     self.metrics = Metrics() if metrics is None else metrics
     self.batch_size = batch_size
     self.num_workers = num_workers
     self.name = name
     self.shuffle = get_corresponding_flag(DEEP_LIST_SHUFFLE,
                                           shuffle,
                                           fatal=False,
                                           default=DEEP_SHUFFLE_NONE)
     self.dataloader = DataLoader(dataset=self.dataset,
                                  batch_size=self.batch_size,
                                  shuffle=False,
                                  num_workers=self.num_workers)
Esempio n. 3
0
    def load_metrics(self):
        """
        AUTHORS:
        --------
        :author: Alix Leroy

        DESCRIPTION:
        ------------
        Load the metrics

        PARAMETERS:
        -----------
        None

        RETURN:
        -------
        :return loss_functions->dict: The metrics
        """
        metrics = {}
        if self.config.metrics.get():
            for key, config in self.config.metrics.get().items():

                # Get the expected loss path (for notification purposes)
                if config.module is None:
                    metric_path = "%s : %s from default modules" % (key, config.name)
                else:
                    metric_path = "%s : %s from %s" % (key, config.name, config.module)

                # Notify the user which loss is being collected and from where
                Notification(DEEP_NOTIF_INFO, DEEP_MSG_METRIC_LOADING % metric_path)

                # Get the metric object
                metric, module = get_module(
                    name=config.name,
                    module=config.module,
                    browse={**DEEP_MODULE_METRICS, **DEEP_MODULE_LOSSES},
                    silence=True
                )

                # If metric is not found by get_module
                if metric is None:
                    Notification(DEEP_NOTIF_FATAL, DEEP_MSG_METRIC_NOT_FOUND % config.name)

                # Check if the metric is a class or a stand-alone function
                if inspect.isclass(metric):
                    method = metric(**config.kwargs.get())
                else:
                    method = metric

                # Add to the dictionary of metrics and notify of success
                metrics[str(key)] = Metric(name=str(key), method=method)
                metrics[str(key)] = Metric(name=str(key), method=method)
                Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_METRIC_LOADED % (key, config.name, module))
        else:
            Notification(DEEP_NOTIF_INFO, DEEP_MSG_METRIC_NONE)

        self.metrics = Metrics(metrics)