class Metrics(object): """Metrics class of all metrics defined in cfg. :param metric_cfg: metric part of config :type metric_cfg: dict or Config """ config = MetricsConfig() def __init__(self, metric_cfg=None): """Init Metrics.""" self.mdict = {} metric_config = self.config.to_dict() if not isinstance(metric_config, list): metric_config = [metric_config] for metric_item in metric_config: ClassFactory.get_cls(ClassType.METRIC, self.config.type) metric_name = metric_item.pop('type') metric_class = ClassFactory.get_cls(ClassType.METRIC, metric_name) if isfunction(metric_class): metric_class = partial(metric_class, **metric_item.get("params", {})) else: metric_class = metric_class(**metric_item.get("params", {})) self.mdict[metric_name] = metric_class self.mdict = Config(self.mdict) self.metric_results = dict() def __call__(self, output=None, target=None, *args, **kwargs): """Calculate all supported metrics by using output and target. :param output: predicted output by networks :type output: torch tensor :param target: target label data :type target: torch tensor :return: performance of metrics :rtype: list """ pfms = {} for key in self.mdict: metric = self.mdict[key] pfms.update(metric(output, target, *args, **kwargs)) for key in pfms: self.metric_results[key] = None return pfms def reset(self): """Reset states for new evaluation after each epoch.""" self.metric_results = dict() @property def results(self): """Return metrics results.""" return deepcopy(self.metric_results) @property def objectives(self): """Return objectives results.""" return {name: self.mdict.get(name).objective for name in self.mdict} def update(self, metrics): """Update the metrics results. :param metrics: outside metrics :type metrics: dict """ for key in metrics: # if key in self.metric_results: self.metric_results[key] = metrics[key]
class Metrics(object): """Metrics class of all metrics defined in cfg. :param metric_cfg: metric part of config :type metric_cfg: dict or Config """ config = MetricsConfig() def __init__(self, metric_cfg=None): """Init Metrics.""" self.mdict = {} metric_config = self.config.to_dict() if not metric_cfg else deepcopy( metric_cfg) if not isinstance(metric_config, list): metric_config = [metric_config] for metric_item in metric_config: ClassFactory.get_cls(ClassType.METRIC, self.config.type) metric_name = metric_item.pop('type') metric_class = ClassFactory.get_cls(ClassType.METRIC, metric_name) if isfunction(metric_class): metric_class = partial(metric_class, **metric_item.get("params", {})) else: metric_class = metric_class(**metric_item.get("params", {})) self.mdict[metric_name] = metric_class self.mdict = Config(self.mdict) def __call__(self, output=None, target=None, *args, **kwargs): """Calculate all supported metrics by using output and target. :param output: predicted output by networks :type output: torch tensor :param target: target label data :type target: torch tensor :return: performance of metrics :rtype: list """ pfms = [] for key in self.mdict: metric = self.mdict[key] pfms.append(metric(output, target, *args, **kwargs)) return pfms def reset(self): """Reset states for new evaluation after each epoch.""" for val in self.mdict.values(): val.reset() @property def results(self): """Return metrics results.""" res = {} for name, metric in self.mdict.items(): res.update(metric.result) return res @property def objectives(self): """Return objectives results.""" _objs = {} for name in self.mdict: objective = self.mdict.get(name).objective if isinstance(objective, dict): _objs = dict(_objs, **objective) else: _objs[name] = objective return _objs def __getattr__(self, key): """Get a metric by key name. :param key: metric name :type key: str """ return self.mdict[key]