def test_accumulator_detached():
    mpd = MeanPairwiseDistance()

    y_pred = torch.tensor([[3.0, 4.0], [-3.0, -4.0]], requires_grad=True)
    y = torch.zeros(2, 2)
    mpd.update((y_pred, y))

    assert not mpd._sum_of_distances.requires_grad
def test_compute():
    mpd = MeanPairwiseDistance()

    y_pred = torch.Tensor([[3.0, 4.0], [-3.0, -4.0]])
    y = torch.zeros(2, 2)
    mpd.update((y_pred, y))
    assert mpd.compute() == approx(5.0)

    mpd.reset()
    y_pred = torch.Tensor([[4.0, 4.0, 4.0, 4.0], [-4.0, -4.0, -4.0, -4.0]])
    y = torch.zeros(2, 4)
    mpd.update((y_pred, y))
    assert mpd.compute() == approx(8.0)
Пример #3
0
def _test_distrib_accumulator_device(device):

    metric_devices = [torch.device("cpu")]
    if device.type != "xla":
        metric_devices.append(idist.device())
    for metric_device in metric_devices:

        mpd = MeanPairwiseDistance(device=metric_device)
        for dev in [mpd._device, mpd._sum_of_distances.device]:
            assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

        y_pred = torch.Tensor([[3.0, 4.0], [-3.0, -4.0]])
        y = torch.zeros(2, 2)
        mpd.update((y_pred, y))

        for dev in [mpd._device, mpd._sum_of_distances.device]:
            assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"