예제 #1
0
def test_accumulator_detached():
    mae = MeanAbsoluteError()

    y_pred = torch.tensor([[2.0], [-2.0]], requires_grad=True)
    y = torch.zeros(2)
    mae.update((y_pred, y))

    assert not mae._sum_of_absolute_errors.requires_grad
예제 #2
0
def test_compute():
    mae = MeanAbsoluteError()

    y_pred = torch.Tensor([[2.0], [-2.0]])
    y = torch.zeros(2)
    mae.update((y_pred, y))
    assert mae.compute() == 2.0

    mae.reset()
    y_pred = torch.Tensor([[3.0], [-3.0]])
    y = torch.zeros(2)
    mae.update((y_pred, y))
    assert mae.compute() == 3.0
예제 #3
0
def _test_distrib_accumulator_device(device):

    metric_devices = [torch.device("cpu")]
    if device.type != "xla":
        metric_devices.append(idist.device())
    for metric_device in metric_devices:
        mae = MeanAbsoluteError(device=metric_device)

        for dev in [mae._device, mae._sum_of_absolute_errors.device]:
            assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

        y_pred = torch.tensor([[2.0], [-2.0]])
        y = torch.zeros(2)
        mae.update((y_pred, y))

        for dev in [mae._device, mae._sum_of_absolute_errors.device]:
            assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"
def _test_distrib_accumulator_device(device):

    metric_devices = [torch.device("cpu")]
    if device.type != "xla":
        metric_devices.append(idist.device())
    for metric_device in metric_devices:
        mae = MeanAbsoluteError(device=metric_device)
        assert mae._device == metric_device
        assert mae._sum_of_absolute_errors.device == metric_device, "{}:{} vs {}:{}".format(
            type(mae._sum_of_absolute_errors.device),
            mae._sum_of_absolute_errors.device,
            type(metric_device),
            metric_device,
        )

        y_pred = torch.tensor([[2.0], [-2.0]])
        y = torch.zeros(2)
        mae.update((y_pred, y))
        assert mae._sum_of_absolute_errors.device == metric_device, "{}:{} vs {}:{}".format(
            type(mae._sum_of_absolute_errors.device),
            mae._sum_of_absolute_errors.device,
            type(metric_device),
            metric_device,
        )