def test_accumulator_detached(): mpd = MeanPairwiseDistance() y_pred = torch.tensor([[3.0, 4.0], [-3.0, -4.0]], requires_grad=True) y = torch.zeros(2, 2) mpd.update((y_pred, y)) assert not mpd._sum_of_distances.requires_grad
def test_compute(): mpd = MeanPairwiseDistance() y_pred = torch.Tensor([[3.0, 4.0], [-3.0, -4.0]]) y = torch.zeros(2, 2) mpd.update((y_pred, y)) assert mpd.compute() == approx(5.0) mpd.reset() y_pred = torch.Tensor([[4.0, 4.0, 4.0, 4.0], [-4.0, -4.0, -4.0, -4.0]]) y = torch.zeros(2, 4) mpd.update((y_pred, y)) assert mpd.compute() == approx(8.0)
def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": metric_devices.append(idist.device()) for metric_device in metric_devices: mpd = MeanPairwiseDistance(device=metric_device) for dev in [mpd._device, mpd._sum_of_distances.device]: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" y_pred = torch.Tensor([[3.0, 4.0], [-3.0, -4.0]]) y = torch.zeros(2, 2) mpd.update((y_pred, y)) for dev in [mpd._device, mpd._sum_of_distances.device]: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"