Exemple #1
0
    def test_construct_metric_value_holders_one_metric_for_all_labels(self):
        hvd_mock = mock.MagicMock()
        hvd_mock.allreduce = lambda tensor, name: 2 * tensor
        hvd_mock.local_size = lambda: 2
        metric_class = remote._metric_cls()

        def torch_dummy_metric(outputs, labels):
            count = torch.tensor(0.)
            for output, label in zip(outputs, labels):
                count += 1
            return count

        metric_fn_groups = [[torch_dummy_metric], [torch_dummy_metric]]
        label_columns = ['l1', 'l2']

        construct_metric_value_holders = remote._construct_metric_value_holders_fn(
        )
        metric_values = construct_metric_value_holders(metric_class,
                                                       metric_fn_groups,
                                                       label_columns, hvd_mock)

        assert metric_values[0][0].name == 'group_0_l1'
        assert metric_values[0][1].name == 'group_0_l2'
        assert metric_values[1][0].name == 'group_1_l1'
        assert metric_values[1][1].name == 'group_1_l2'
Exemple #2
0
    def test_metric_class(self):
        hvd_mock = mock.MagicMock()
        hvd_mock.allreduce = lambda tensor, name: 2 * tensor
        hvd_mock.local_size = lambda: 2

        metric_class = remote._metric_cls()
        metric = metric_class('dummy_metric', hvd_mock)
        metric.update(torch.tensor(1.0))
        metric.update(torch.tensor(2.0))

        assert metric.sum.item() == 6.0
        assert metric.n.item() == 2.0
        assert metric.avg.item() == 6.0 / 2.0