def test_aggregate_metrics(self): metrics_dict = [ collections.OrderedDict(loss=metrics.MeanMetric(total=2, count=11), num_examples=metrics.CountMetric(count=3)), collections.OrderedDict(loss=metrics.MeanMetric(total=6, count=9), num_examples=metrics.CountMetric(count=7)) ] aggregated = evaluation_util.aggregate_metrics(metrics_dict) self.assertAlmostEqual(aggregated['loss'], 0.4) self.assertEqual(aggregated['num_examples'], 10.)
def evaluate(self, params: Params, batch: Batch) -> Dict[str, metrics.Metric]: """Evaluates model on input batch.""" rng = None preds = self.apply_fn(params, rng, batch, **self.test_kwargs) loss_metric = self.loss_fn(batch, preds) num_examples = loss_metric.count metrics_dict = collections.OrderedDict( loss=loss_metric, regularizer=metrics.MeanMetric(total=self.reg_fn(params), count=num_examples), num_examples=metrics.CountMetric(count=num_examples)) for metric_name, metric_fn in self.metrics_fn_map.items(): metrics_dict[metric_name] = metric_fn(batch, preds) return metrics_dict
def test_raises_on_non_scalar_value(self): with self.assertRaises(TypeError): metrics.MeanMetric(total=jnp.array([1]), count=jnp.array([2])) with self.assertRaises(TypeError): metrics.CountMetric(count=jnp.array([3]))
def test_str_version_of_metric(self): metric = metrics.MeanMetric(total=2, count=4) self.assertEqual('MeanMetric(total=2, count=4) => 0.5', str(metric))