def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if len(self._predictions) < 1 or len(self._targets) < 1: raise NotComputableError("PrecisionRecallCurve must have at least one example before it can be computed.") _prediction_tensor = torch.cat(self._predictions, dim=0) _target_tensor = torch.cat(self._targets, dim=0) ws = idist.get_world_size() if ws > 1 and not self._is_reduced: # All gather across all processes _prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor)) _target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor)) self._is_reduced = True if idist.get_rank() == 0: # Run compute_fn on zero rank only precision, recall, thresholds = self.compute_fn(_prediction_tensor, _target_tensor) precision = torch.tensor(precision) recall = torch.tensor(recall) # thresholds can have negative strides, not compatible with torch tensors # https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2 thresholds = torch.tensor(thresholds.copy()) else: precision, recall, thresholds = None, None, None if ws > 1: # broadcast result to all processes precision = idist.broadcast(precision, src=0, safe_mode=True) recall = idist.broadcast(recall, src=0, safe_mode=True) thresholds = idist.broadcast(thresholds, src=0, safe_mode=True) return precision, recall, thresholds
def compute(self) -> Union[torch.Tensor, float]: is_scalar = not isinstance(self._positives, torch.Tensor) or self._positives.ndim == 0 if is_scalar and self._positives == 0: raise NotComputableError( f"{self.__class__.__name__} must have at least one example before it can be computed." ) if not self._is_reduced: if not (self._type == "multilabel" and not self._average): self._true_positives = idist.all_reduce( self._true_positives) # type: ignore[assignment] self._positives = idist.all_reduce( self._positives) # type: ignore[assignment] else: self._true_positives = cast( torch.Tensor, idist.all_gather(self._true_positives)) self._positives = cast(torch.Tensor, idist.all_gather(self._positives)) self._is_reduced = True # type: bool result = self._true_positives / (self._positives + self.eps) if self._average: return cast(torch.Tensor, result).mean().item() else: return result
def _test(y_pred, y, batch_size, metric_device): metric_device = torch.device(metric_device) ap = AveragePrecision(device=metric_device) torch.manual_seed(10 + rank) ap.reset() ap.update((y_pred, y)) if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size ap.update( (y_pred[idx:idx + batch_size], y[idx:idx + batch_size])) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() res = ap.compute() assert isinstance(res, float) assert average_precision_score(np_y, np_y_pred) == pytest.approx(res)
def _test(metric_device): criterion = nn.NLLLoss().to(device) loss = Loss(criterion, device=metric_device) y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], device=device).log() y = torch.tensor([2, 2], device=device).long() loss.update((y_pred, y)) n = loss._num_examples assert n == len(y) res = loss.compute() assert n * idist.get_world_size() == loss._num_examples y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) true_loss_value = criterion(y_pred, y) assert_almost_equal(res, true_loss_value.item()) loss.reset() y_pred = torch.tensor([[0.1, 0.3, 0.6], [0.6, 0.2, 0.2], [0.2, 0.7, 0.1]], device=device).log() y = torch.tensor([2, 0, 2], device=device).long() loss.update((y_pred, y)) n = loss._num_examples res = loss.compute() assert n * idist.get_world_size() == loss._num_examples y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) true_loss_value = criterion(y_pred, y) if tol is None: assert_almost_equal(res, true_loss_value.item()) else: assert pytest.approx(res, rel=tol) == true_loss_value.item()
def _test(y_pred, y, batch_size, metric_device): metric_device = torch.device(metric_device) roc_auc = ROC_AUC(device=metric_device) torch.manual_seed(10 + rank) roc_auc.reset() if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size roc_auc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) else: roc_auc.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() res = roc_auc.compute() assert isinstance(res, float) assert roc_auc_score(np_y, np_y_pred) == pytest.approx(res)
def _test(metric_device): metric_device = torch.device(metric_device) m = MedianAbsoluteError(device=metric_device) torch.manual_seed(10 + rank) size = 105 y_pred = torch.randint(1, 10, size=(size, 1), dtype=torch.double, device=device) y = torch.randint(1, 10, size=(size, 1), dtype=torch.double, device=device) m.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = y_pred.cpu().numpy().ravel() np_y = y.cpu().numpy().ravel() res = m.compute() np_median_absolute_error = np.median(np.abs(np_y - np_y_pred)) assert np_median_absolute_error == pytest.approx(res)
def _test(y_pred, y, batch_size, metric_device): metric_device = torch.device(metric_device) prc = PrecisionRecallCurve(device=metric_device) torch.manual_seed(10 + rank) prc.reset() if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size prc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) else: prc.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() res = prc.compute() assert isinstance(res, Tuple) assert precision_recall_curve(np_y, np_y_pred)[0] == pytest.approx(res[0].cpu().numpy()) assert precision_recall_curve(np_y, np_y_pred)[1] == pytest.approx(res[1].cpu().numpy()) assert precision_recall_curve(np_y, np_y_pred)[2] == pytest.approx(res[2].cpu().numpy())
def compute(self) -> float: if len(self._predictions) < 1 or len(self._targets) < 1: raise NotComputableError( "EpochMetric must have at least one example before it can be computed." ) _prediction_tensor = torch.cat(self._predictions, dim=0) _target_tensor = torch.cat(self._targets, dim=0) ws = idist.get_world_size() if ws > 1 and not self._is_reduced: # All gather across all processes _prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor)) _target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor)) self._is_reduced = True result = 0.0 if idist.get_rank() == 0: # Run compute_fn on zero rank only result = self.compute_fn(_prediction_tensor, _target_tensor) if ws > 1: # broadcast result to all processes result = cast(float, idist.broadcast(result, src=0)) return result
def _test(metric_device, y_test_1, y_test_2): criterion = nn.NLLLoss().to(device) loss = Loss(criterion, device=metric_device) y_pred, y, _ = y_test_1 loss.update((y_pred, y)) n = loss._num_examples assert n == len(y) res = loss.compute() assert n * idist.get_world_size() == loss._num_examples y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) true_loss_value = criterion(y_pred, y) assert_almost_equal(res, true_loss_value.item()) loss.reset() y_pred, y, _ = y_test_2 loss.update((y_pred, y)) n = loss._num_examples res = loss.compute() assert n * idist.get_world_size() == loss._num_examples y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) true_loss_value = criterion(y_pred, y) if tol is None: assert_almost_equal(res, true_loss_value.item()) else: assert pytest.approx(res, rel=tol) == true_loss_value.item()
def _test(y_pred, y, n_iters, metric_device): metric_device = torch.device(metric_device) ck = CohenKappa(device=metric_device) torch.manual_seed(10 + rank) ck.reset() ck.update((y_pred, y)) if n_iters > 1: batch_size = y.shape[0] // n_iters + 1 for i in range(n_iters): idx = i * batch_size ck.update( (y_pred[idx:idx + batch_size], y[idx:idx + batch_size])) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() res = ck.compute() assert isinstance(res, float) assert cohen_kappa_score(np_y, np_y_pred) == pytest.approx(res)
def _test(metric_device): metric_device = torch.device(metric_device) m = MedianAbsolutePercentageError(device=metric_device) torch.manual_seed(10 + rank) size = 105 y_pred = torch.randint(1, 10, size=(size, 1), dtype=torch.double, device=device) y = torch.randint(1, 10, size=(size, 1), dtype=torch.double, device=device) m.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = y_pred.cpu().numpy().ravel() np_y = y.cpu().numpy().ravel() res = m.compute() e = np.abs(np_y - np_y_pred) / np.abs(np_y) # The results between numpy.median() and torch.median() are Inconsistant # when the length of the array/tensor is even. So this is a hack to avoid that. # issue: https://github.com/pytorch/pytorch/issues/1837 if np_y_pred.shape[0] % 2 == 0: e_prepend = np.insert(e, 0, e[0], axis=0) np_res_prepend = 100.0 * np.median(e_prepend) assert pytest.approx(res) == np_res_prepend else: np_res = 100.0 * np.median(e) assert pytest.approx(res) == np_res
def _test(metric_device): num_classes = 3 cm = MultiLabelConfusionMatrix(num_classes=num_classes, device=metric_device) y_true, y_pred = get_y_true_y_pred() # Compute confusion matrix with sklearn sklearn_CM = multilabel_confusion_matrix( y_true.transpose((0, 2, 3, 1)).reshape(-1, 3), y_pred.transpose((0, 2, 3, 1)).reshape(-1, 3)) # Update metric output = (torch.tensor(y_pred).to(device), torch.tensor(y_true).to(device)) cm.update(output) ignite_CM = cm.compute().cpu().numpy() assert np.all(ignite_CM == sklearn_CM) # Another test on batch of 2 images num_classes = 3 cm = MultiLabelConfusionMatrix(num_classes=num_classes, device=metric_device) # Create a batch of two images: th_y_true1 = torch.tensor(y_true) th_y_true2 = torch.tensor(y_true.transpose(0, 1, 3, 2)) th_y_true = torch.cat([th_y_true1, th_y_true2], dim=0) th_y_true = th_y_true.to(device) th_y_pred1 = torch.tensor(y_pred) th_y_pred2 = torch.tensor(y_pred.transpose(0, 1, 3, 2)) th_y_pred = torch.cat([th_y_pred1, th_y_pred2], dim=0) th_y_pred = th_y_pred.to(device) # Update metric & compute output = (th_y_pred, th_y_true) cm.update(output) ignite_CM = cm.compute().cpu().numpy() # Compute confusion matrix with sklearn th_y_true = idist.all_gather(th_y_true) th_y_pred = idist.all_gather(th_y_pred) np_y_true = th_y_true.cpu().numpy().transpose( (0, 2, 3, 1)).reshape(-1, 3) np_y_pred = th_y_pred.cpu().numpy().transpose( (0, 2, 3, 1)).reshape(-1, 3) sklearn_CM = multilabel_confusion_matrix(np_y_true, np_y_pred) assert np.all(ignite_CM == sklearn_CM)
def _test_distrib_all_gather(device): res = torch.tensor(idist.all_gather(10), device=device) true_res = torch.tensor([ 10, ] * idist.get_world_size(), device=device) assert (res == true_res).all() t = torch.tensor(idist.get_rank(), device=device) res = idist.all_gather(t) true_res = torch.tensor([i for i in range(idist.get_world_size())], device=device) assert (res == true_res).all() x = "test-test" if idist.get_rank() == 0: x = "abc" res = idist.all_gather(x) true_res = [ "abc", ] + ["test-test"] * (idist.get_world_size() - 1) assert res == true_res base_x = "tests/ignite/distributed/utils/test_native.py" * 2000 x = base_x if idist.get_rank() == 0: x = "abc" res = idist.all_gather(x) true_res = [ "abc", ] + [base_x] * (idist.get_world_size() - 1) assert res == true_res t = torch.arange(100, device=device).reshape(4, 25) * (idist.get_rank() + 1) in_dtype = t.dtype res = idist.all_gather(t) assert res.shape == (idist.get_world_size() * 4, 25) assert res.dtype == in_dtype true_res = torch.zeros(idist.get_world_size() * 4, 25, device=device) for i in range(idist.get_world_size()): true_res[i * 4:(i + 1) * 4, ...] = torch.arange( 100, device=device).reshape(4, 25) * (i + 1) assert (res == true_res).all() if idist.get_world_size() > 1: with pytest.raises(TypeError, match=r"Unhandled input type"): idist.all_reduce([0, 1, 2])
def __init__(self, logger: TrainsLogger = None, output_uri: str = None, dirname: str = None, *args, **kwargs): self._setup_check_trains(logger, output_uri) if not dirname: dirname = "" if idist.get_rank() == 0: dirname = tempfile.mkdtemp( prefix="ignite_checkpoints_{}".format( datetime.now().strftime("%Y_%m_%d_%H_%M_%S_"))) if idist.get_world_size() > 1: dirname = idist.all_gather(dirname)[0] warnings.warn( "TrainsSaver created a temporary checkpoints directory: {}". format(dirname)) idist.barrier() # Let's set non-atomic tmp dir saving behaviour if "atomic" not in kwargs: kwargs["atomic"] = False super(TrainsSaver, self).__init__(dirname=dirname, *args, **kwargs)
def compute(self) -> float: if len(self._probs) < 1: raise NotComputableError( "Inception score must have at least one example before it can be computed." ) ws = idist.get_world_size() _probs_tensor = torch.cat(self._probs, dim=0) if ws > 1 and not self._is_reduced: _probs_tensor = cast(torch.Tensor, idist.all_gather(_probs_tensor)) self._is_reduced = True result = 0.0 if idist.get_rank() == 0: N = _probs_tensor.shape[0] scores = torch.zeros((self.n_splits, )) for i in range(self.n_splits): part = _probs_tensor[i * (N // self.n_splits):(i + 1) * (N // self.n_splits)] kl = part * (torch.log(part) - torch.log(torch.mean(part, dim=0))) kl = torch.mean(torch.sum(kl, dim=1)) scores[i] = torch.exp(kl) result = torch.mean(scores).item() if ws > 1: result = cast(float, idist.broadcast(result, src=0)) return result
def __init__( self, logger: Optional[ClearMLLogger] = None, output_uri: Optional[str] = None, dirname: Optional[str] = None, *args: Any, **kwargs: Any, ) -> None: self._setup_check_clearml(logger, output_uri) if not dirname: dirname = "" if idist.get_rank() == 0: dirname = tempfile.mkdtemp(prefix=f"ignite_checkpoints_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S_')}") if idist.get_world_size() > 1: dirname = idist.all_gather(dirname)[0] # type: ignore[index, assignment] warnings.warn(f"ClearMLSaver created a temporary checkpoints directory: {dirname}") idist.barrier() # Let's set non-atomic tmp dir saving behaviour if "atomic" not in kwargs: kwargs["atomic"] = False self._checkpoint_slots = defaultdict(list) # type: DefaultDict[Union[str, Tuple[str, str]], List[Any]] super(ClearMLSaver, self).__init__(dirname=dirname, *args, **kwargs) # type: ignore[misc]
def update_cta_rates(): batch = trainer.state.batch x, y = batch["cta_probe_batch"]["image"], batch["cta_probe_batch"]["target"] if x.device != device: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) policies = batch["cta_probe_batch"]["policy"] ema_model.eval() with torch.no_grad(): y_pred = ema_model(x) y_probas = torch.softmax(y_pred, dim=1) # (N, C) if distributed: for y_proba, t, policy in zip(y_probas, y, policies): error = y_proba error[t] -= 1 error = torch.abs(error).sum() cta.update_rates(policy, 1.0 - 0.5 * error.item()) else: error_per_op = [] for y_proba, t, policy in zip(y_probas, y, policies): error = y_proba error[t] -= 1 error = torch.abs(error).sum() for k, bins in policy: error_per_op.append(pack_as_tensor(k, bins, error)) error_per_op = torch.stack(error_per_op) # all gather tensor_list = idist.all_gather(error_per_op) # update cta rates for t in tensor_list: k, bins, error = unpack_from_tensor(t) cta.update_rates([(k, bins),], 1.0 - 0.5 * error)
def _test(metric_device): metric_device = torch.device(metric_device) m = CanberraMetric(device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.randint(0, 10, size=(10, ), device=device).float() y = torch.randint(0, 10, size=(10, ), device=device).float() m.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = y_pred.cpu().numpy() np_y = y.cpu().numpy() res = m.compute() assert canberra.pairwise([np_y_pred, np_y])[0][1] == pytest.approx(res)
def _test(metric_device): metric_device = torch.device(metric_device) m = R2Score(device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.randint(0, 10, size=(10, ), device=device).float() y = torch.randint(0, 10, size=(10, ), device=device).float() m.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = y_pred.cpu().numpy() np_y = y.cpu().numpy() res = m.compute() assert r2_score(np_y, np_y_pred) == pytest.approx(res, abs=tol)
def _test(metric_device): metric_device = torch.device(metric_device) m = GeometricMeanRelativeAbsoluteError(device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.rand(size=(100,), device=device) y = torch.rand(size=(100,), device=device) m.update((y_pred, y)) y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() np_gmrae = np.exp(np.log(np.abs(np_y - np_y_pred) / np.abs(np_y - np_y.mean())).mean()) assert m.compute() == pytest.approx(np_gmrae, rel=1e-4)
def test_distrib_cpu(local_rank, distributed_context_single_node_gloo): # Perform some ops otherwise, next tests fail if idist.get_world_size() > 1: _test_warning() device = "cpu" y = torch.rand(10, 12, device=device) y = idist.all_gather(y) assert isinstance(y, torch.Tensor)
def _test(metric_device): metric_device = torch.device(metric_device) m = MeanError(device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.rand(size=(100,), device=device) y = torch.rand(size=(100,), device=device) m.update((y_pred, y)) y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() np_sum = (np_y - np_y_pred).sum() np_len = len(np_y_pred) np_ans = np_sum / np_len assert m.compute() == pytest.approx(np_ans)
def compute(self) -> float: if len(self._predictions) < 1 or len(self._targets) < 1: raise NotComputableError( "GeometricMeanRelativeAbsoluteError must have at least one example before it can be computed." ) _prediction_tensor = torch.cat(self._predictions, dim=0) _target_tensor = torch.cat(self._targets, dim=0) # All gather across all processes _prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor)) _target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor)) result = torch.exp( torch.log( torch.abs(_target_tensor - _prediction_tensor) / torch.abs(_target_tensor - _target_tensor.mean())).mean()).item() return result
def _test(metric_device): metric_device = torch.device(metric_device) m = WaveHedgesDistance(device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.randint(0, 10, size=(10,), device=device).float() y = torch.randint(0, 10, size=(10,), device=device).float() m.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = y_pred.cpu().numpy() np_y = y.cpu().numpy() res = m.compute() np_sum = (np.abs(np_y - np_y_pred) / (np.maximum.reduce([np_y_pred, np_y]) + 1e-30)).sum() assert np_sum == pytest.approx(res)
def _test(metric_device): metric_device = torch.device(metric_device) ck_metric = CohenKappa(device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.randint(0, 2, size=(100, 1), device=device) y = torch.randint(0, 2, size=(100, 1), device=device) ck_metric.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = y_pred.cpu().numpy() np_y = y.cpu().numpy() np_ck = cohen_kappa_score(np_y, np_y_pred) res = ck_metric.compute() assert res == pytest.approx(np_ck)
def _test(metric_device): metric_device = torch.device(metric_device) m = MaximumAbsoluteError(device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.randint(0, 10, size=(10,), device=device).float() y = torch.randint(0, 10, size=(10,), device=device).float() m.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = y_pred.cpu().numpy() np_y = y.cpu().numpy() res = m.compute() np_max = np.max(np.abs((np_y_pred - np_y))) assert np_max == pytest.approx(res)
def _test(metric_device): metric_device = torch.device(metric_device) m = FractionalAbsoluteError(device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.rand(size=(100,), device=device) y = torch.rand(size=(100,), device=device) m.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() np_sum = (2 * np.abs((np_y_pred - np_y)) / (np.abs(np_y_pred) + np.abs(np_y))).sum() np_len = len(np_y_pred) np_ans = np_sum / np_len assert m.compute() == pytest.approx(np_ans)
def _test(metric_device): metric_device = torch.device(metric_device) m = MeanNormalizedBias(device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.randint(1, 11, size=(10, ), device=device).float() y = torch.randint(1, 11, size=(10, ), device=device).float() m.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = y_pred.cpu().numpy() np_y = y.cpu().numpy() res = m.compute() np_sum = ((np_y - np_y_pred) / np_y).sum() np_len = len(np_y_pred) np_ans = np_sum / np_len assert np_ans == pytest.approx(res)
def _test(metric_device): metric_device = torch.device(metric_device) m = FractionalBias(device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.randint(0, 10, size=(10,), device=device).float() y = torch.randint(0, 10, size=(10,), device=device).float() m.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = y_pred.cpu().numpy() np_y = y.cpu().numpy() res = m.compute() np_sum = (2 * (np_y - np_y_pred) / (np_y_pred + np_y + 1e-30)).sum() np_len = len(y_pred) np_ans = np_sum / np_len assert np_ans == pytest.approx(res, rel=tol)
def _test(metric_device): metric_device = torch.device(metric_device) m = GeometricMeanAbsoluteError(device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.randint(0, 10, size=(10, ), device=device).float() y = torch.randint(0, 10, size=(10, ), device=device).float() m.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = y_pred.cpu().numpy() np_y = y.cpu().numpy() res = m.compute() sum_errors = (np.log(np.abs(np_y - np_y_pred))).sum() np_len = len(y_pred) np_ans = np.exp(sum_errors / np_len) assert np_ans == pytest.approx(res)