def test_accumulator_detached(): mae = MeanAbsoluteError() y_pred = torch.tensor([[2.0], [-2.0]], requires_grad=True) y = torch.zeros(2) mae.update((y_pred, y)) assert not mae._sum_of_absolute_errors.requires_grad
def test_compute(): mae = MeanAbsoluteError() y_pred = torch.Tensor([[2.0], [-2.0]]) y = torch.zeros(2) mae.update((y_pred, y)) assert mae.compute() == 2.0 mae.reset() y_pred = torch.Tensor([[3.0], [-3.0]]) y = torch.zeros(2) mae.update((y_pred, y)) assert mae.compute() == 3.0
def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": metric_devices.append(idist.device()) for metric_device in metric_devices: mae = MeanAbsoluteError(device=metric_device) for dev in [mae._device, mae._sum_of_absolute_errors.device]: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" y_pred = torch.tensor([[2.0], [-2.0]]) y = torch.zeros(2) mae.update((y_pred, y)) for dev in [mae._device, mae._sum_of_absolute_errors.device]: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"
def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": metric_devices.append(idist.device()) for metric_device in metric_devices: mae = MeanAbsoluteError(device=metric_device) assert mae._device == metric_device assert mae._sum_of_absolute_errors.device == metric_device, "{}:{} vs {}:{}".format( type(mae._sum_of_absolute_errors.device), mae._sum_of_absolute_errors.device, type(metric_device), metric_device, ) y_pred = torch.tensor([[2.0], [-2.0]]) y = torch.zeros(2) mae.update((y_pred, y)) assert mae._sum_of_absolute_errors.device == metric_device, "{}:{} vs {}:{}".format( type(mae._sum_of_absolute_errors.device), mae._sum_of_absolute_errors.device, type(metric_device), metric_device, )