def test_construct_metric_value_holders_one_metric_for_all_labels(self): hvd_mock = mock.MagicMock() hvd_mock.allreduce = lambda tensor, name: 2 * tensor hvd_mock.local_size = lambda: 2 metric_class = remote._metric_cls() def torch_dummy_metric(outputs, labels): count = torch.tensor(0.) for output, label in zip(outputs, labels): count += 1 return count metric_fn_groups = [[torch_dummy_metric], [torch_dummy_metric]] label_columns = ['l1', 'l2'] construct_metric_value_holders = remote._construct_metric_value_holders_fn( ) metric_values = construct_metric_value_holders(metric_class, metric_fn_groups, label_columns, hvd_mock) assert metric_values[0][0].name == 'group_0_l1' assert metric_values[0][1].name == 'group_0_l2' assert metric_values[1][0].name == 'group_1_l1' assert metric_values[1][1].name == 'group_1_l2'
def test_metric_class(self): hvd_mock = mock.MagicMock() hvd_mock.allreduce = lambda tensor, name: 2 * tensor hvd_mock.local_size = lambda: 2 metric_class = remote._metric_cls() metric = metric_class('dummy_metric', hvd_mock) metric.update(torch.tensor(1.0)) metric.update(torch.tensor(2.0)) assert metric.sum.item() == 6.0 assert metric.n.item() == 2.0 assert metric.avg.item() == 6.0 / 2.0