Beispiel #1
0
    def forward(pred: ProbabilityMeasure, targets: ProbabilityMeasure):

        with torch.no_grad():
            P = compute_ot_matrix_par(pred.centered().coord.cpu().numpy(),
                                      targets.centered().coord.cpu().numpy())
            P = torch.from_numpy(P).type_as(pred.coord).cuda()

        xs = pred.centered().coord
        xsT = xs.transpose(1, 2)
        xt = targets.centered().coord

        a: Tensor = pred.probability + 1e-8
        a /= a.sum(dim=1, keepdim=True)
        a = a.reshape(a.shape[0], -1, 1)

        A = torch.inverse(xsT.bmm(a * xs)).bmm(xsT.bmm(P.bmm(xt)))

        T = targets.mean() - pred.mean()

        return A.type_as(pred.coord), T.detach()
Beispiel #2
0
    def forward(pred: ProbabilityMeasure,
                targets: ProbabilityMeasure,
                iters: int = 200):
        lambd = 0.002

        with torch.no_grad():
            P = SOT(iters, lambd).forward(pred.centered(), targets.centered())

        xs = pred.centered().coord
        xsT = xs.transpose(1, 2)
        xt = targets.centered().coord

        a = pred.probability + 1e-8
        a /= a.sum(dim=1, keepdim=True)
        a = a.reshape(a.shape[0], -1, 1)

        A = torch.inverse(xsT.bmm(a * xs)).bmm(xsT.bmm(P.bmm(xt)))

        T = targets.mean() - pred.mean()

        return A.type_as(pred.coord), T.detach()