def test_mean_absolute_error_computation(self, device: str): mae = MeanAbsoluteError() predictions = torch.tensor( [[1.0, 1.5, 1.0], [2.0, 3.0, 3.5], [4.0, 5.0, 5.5], [6.0, 7.0, 7.5]], device=device ) targets = torch.tensor( [[0.0, 1.0, 0.0], [2.0, 2.0, 0.0], [4.0, 5.0, 0.0], [7.0, 7.0, 0.0]], device=device ) mae(predictions, targets) assert mae.get_metric() == 21.0 / 12.0 mask = torch.tensor( [[True, True, False], [True, True, False], [True, True, False], [True, True, False]], device=device, ) mae(predictions, targets, mask) assert mae.get_metric() == (21.0 + 3.5) / (12.0 + 8.0) new_targets = torch.tensor( [[2.0, 2.0, 0.0], [0.0, 1.0, 0.0], [7.0, 7.0, 0.0], [4.0, 5.0, 0.0]], device=device ) mae(predictions, new_targets) assert mae.get_metric() == (21.0 + 3.5 + 32.0) / (12.0 + 8.0 + 12.0) mae.reset() mae(predictions, new_targets) assert mae.get_metric() == 32.0 / 12.0
def test_mean_absolute_error_computation(self): mae = MeanAbsoluteError() predictions = torch.Tensor([[1.0, 1.5, 1.0], [2.0, 3.0, 3.5], [4.0, 5.0, 5.5], [6.0, 7.0, 7.5]]) targets = torch.Tensor([[0.0, 1.0, 0.0], [2.0, 2.0, 0.0], [4.0, 5.0, 0.0], [7.0, 7.0, 0.0]]) mae(predictions, targets) assert mae.get_metric() == 21.0 / 12.0 mask = torch.Tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) mae(predictions, targets, mask) assert mae.get_metric() == (21.0 + 3.5) / (12.0 + 8.0) new_targets = torch.Tensor([[2.0, 2.0, 0.0], [0.0, 1.0, 0.0], [7.0, 7.0, 0.0], [4.0, 5.0, 0.0]]) mae(predictions, new_targets) assert mae.get_metric() == (21.0 + 3.5 + 32.0) / (12.0 + 8.0 + 12.0) mae.reset() mae(predictions, new_targets) assert mae.get_metric() == 32.0 / 12.0
def multiple_runs( global_rank: int, world_size: int, gpu_id: Union[int, torch.device], metric: MeanAbsoluteError, metric_kwargs: Dict[str, List[Any]], desired_values: Dict[str, Any], exact: Union[bool, Tuple[float, float]] = True, ): kwargs = {} # Use the arguments meant for the process with rank `global_rank`. for argname in metric_kwargs: kwargs[argname] = metric_kwargs[argname][global_rank] for i in range(200): metric(**kwargs) assert desired_values["mae"] == metric.get_metric()["mae"]