def rmse(gt, approx): """ Computes the RMSE between two tensors (torch or tntorch). :param gt: a torch or tntorch tensor :param approx: a torch or tntorch tensor :return: a scalar :math:`\ge 0` """ gt, approx = _process(gt, approx) if isinstance(gt, torch.Tensor) and isinstance(approx, torch.Tensor): return torch.dist(gt, approx) / np.sqrt(gt.numel()) return tn.dist(gt, approx) / torch.sqrt(gt.numel())
def r_squared(gt, approx): """ Computes the :math:`R^2` score between two tensors (torch or tntorch). :param gt: a torch or tntorch tensor :param approx: a torch or tntorch tensor :return: a scalar <= 1 """ gt, approx = _process(gt, approx) if isinstance(gt, torch.Tensor) and isinstance(approx, torch.Tensor): return 1 - torch.dist(gt, approx)**2 / torch.dist(gt, torch.mean(gt))**2 return 1 - tn.dist(gt, approx)**2 / tn.normsq(gt-tn.mean(gt))
def __eq__(self, other): return tn.dist(self, other) <= 1e-14