예제 #1
0
    def test_aggregate_metrics(self):
        metrics_dict = [
            collections.OrderedDict(loss=metrics.MeanMetric(total=2, count=11),
                                    num_examples=metrics.CountMetric(count=3)),
            collections.OrderedDict(loss=metrics.MeanMetric(total=6, count=9),
                                    num_examples=metrics.CountMetric(count=7))
        ]

        aggregated = evaluation_util.aggregate_metrics(metrics_dict)

        self.assertAlmostEqual(aggregated['loss'], 0.4)
        self.assertEqual(aggregated['num_examples'], 10.)
예제 #2
0
파일: model.py 프로젝트: alabid/fedjax
 def evaluate(self, params: Params,
              batch: Batch) -> Dict[str, metrics.Metric]:
     """Evaluates model on input batch."""
     rng = None
     preds = self.apply_fn(params, rng, batch, **self.test_kwargs)
     loss_metric = self.loss_fn(batch, preds)
     num_examples = loss_metric.count
     metrics_dict = collections.OrderedDict(
         loss=loss_metric,
         regularizer=metrics.MeanMetric(total=self.reg_fn(params),
                                        count=num_examples),
         num_examples=metrics.CountMetric(count=num_examples))
     for metric_name, metric_fn in self.metrics_fn_map.items():
         metrics_dict[metric_name] = metric_fn(batch, preds)
     return metrics_dict
예제 #3
0
 def test_raises_on_non_scalar_value(self):
     with self.assertRaises(TypeError):
         metrics.MeanMetric(total=jnp.array([1]), count=jnp.array([2]))
     with self.assertRaises(TypeError):
         metrics.CountMetric(count=jnp.array([3]))
예제 #4
0
 def test_str_version_of_metric(self):
     metric = metrics.MeanMetric(total=2, count=4)
     self.assertEqual('MeanMetric(total=2, count=4) => 0.5', str(metric))