Beispiel #1
0
def rmse(gt, approx):
    """
    Computes the RMSE between two tensors (torch or tntorch).

    :param gt: a torch or tntorch tensor
    :param approx: a torch or tntorch tensor

    :return: a scalar :math:`\ge 0`
    """

    gt, approx = _process(gt, approx)
    if isinstance(gt, torch.Tensor) and isinstance(approx, torch.Tensor):
        return torch.dist(gt, approx) / np.sqrt(gt.numel())
    return tn.dist(gt, approx) / torch.sqrt(gt.numel())
Beispiel #2
0
def r_squared(gt, approx):
    """
    Computes the :math:`R^2` score between two tensors (torch or tntorch).

    :param gt: a torch or tntorch tensor
    :param approx: a torch or tntorch tensor

    :return: a scalar <= 1
    """

    gt, approx = _process(gt, approx)
    if isinstance(gt, torch.Tensor) and isinstance(approx, torch.Tensor):
        return 1 - torch.dist(gt, approx)**2 / torch.dist(gt, torch.mean(gt))**2
    return 1 - tn.dist(gt, approx)**2 / tn.normsq(gt-tn.mean(gt))
Beispiel #3
0
 def __eq__(self, other):
     return tn.dist(self, other) <= 1e-14