class Metrics(object): """Metrics class of all metrics defined in cfg. :param metric_cfg: metric part of config :type metric_cfg: dict or Config """ __supported_call__ = [ 'accuracy', 'DetMetric', 'IoUMetric', 'SRMetric', 'JDDTrainerPSNRMetric' ] def __init__(self, metric_cfg): """Init Metrics.""" metric_config = deepcopy(metric_cfg) self.mdict = {} if not isinstance(metric_config, list): metric_config = [metric_config] for metric_item in metric_config: metric_name = metric_item.pop('type') if ClassFactory.is_exists(ClassType.METRIC, metric_name): metric_class = ClassFactory.get_cls(ClassType.METRIC, metric_name) else: metric_class = getattr( importlib.import_module('vega.core.metrics'), metric_name) if isfunction(metric_class): metric_class = partial(metric_class, **metric_item) else: metric_class = metric_class(**metric_item) self.mdict[metric_name] = metric_class self.mdict = Config(self.mdict) def __call__(self, output=None, target=None, *args, **kwargs): """Calculate all supported metrics by using output and target. :param output: predicted output by networks :type output: torch tensor :param target: target label data :type target: torch tensor :return: performance of metrics :rtype: list """ pfms = [] for key in self.mdict: metric = self.mdict[key] if key in self.__supported_call__: pfms.append(metric(output, target, *args, **kwargs)) return pfms # if len(pfms) == 1: # return pfms[0] # else: # return pfms def reset(self): """Reset states for new evaluation after each epoch.""" for val in self.mdict.values(): val.reset() @property def names(self): """Return metrics names.""" names = [ name for name in self.mdict if name in self.__supported_call__ ] return names # if len(names) == 1: # return names[0] # else: # return names @property def results(self): """Return metrics results.""" results = [ self.mdict[name].summary() for name in self.mdict if name in self.__supported_call__ ] return deepcopy(results) @property def results_dict(self): """Return metrics results dict.""" rdict = {} for key in self.mdict: rdict[key] = self.mdict[key].summary() return deepcopy(rdict) def __getattr__(self, key): """Get a metric by key name. :param key: metric name :type key: str """ return self.mdict[key]
class Metrics(object): """Metrics class of all metrics defined in cfg. :param metric_cfg: metric part of config :type metric_cfg: dict or Config """ config = MetricsConfig() def __init__(self, metric_cfg=None): """Init Metrics.""" self.mdict = {} metric_config = obj2config( self.config) if not metric_cfg else deepcopy(metric_cfg) if not isinstance(metric_config, list): metric_config = [metric_config] for metric_item in metric_config: ClassFactory.get_cls(ClassType.METRIC, self.config.type) metric_name = metric_item.pop('type') metric_class = ClassFactory.get_cls(ClassType.METRIC, metric_name) if isfunction(metric_class): metric_class = partial(metric_class, **metric_item.get("params", {})) else: metric_class = metric_class(**metric_item.get("params", {})) self.mdict[metric_name] = metric_class self.mdict = Config(self.mdict) def __call__(self, output=None, target=None, *args, **kwargs): """Calculate all supported metrics by using output and target. :param output: predicted output by networks :type output: torch tensor :param target: target label data :type target: torch tensor :return: performance of metrics :rtype: list """ pfms = [] for key in self.mdict: metric = self.mdict[key] pfms.append(metric(output, target, *args, **kwargs)) return pfms def reset(self): """Reset states for new evaluation after each epoch.""" for val in self.mdict.values(): val.reset() @property def results(self): """Return metrics results.""" res = {} for name, metric in self.mdict.items(): res.update(metric.result) return res @property def objectives(self): """Return objectives results.""" return {name: self.mdict.get(name).objective for name in self.mdict} def __getattr__(self, key): """Get a metric by key name. :param key: metric name :type key: str """ return self.mdict[key]