Beispiel #1
0
        def loss(image: Tensor, mask: ProbabilityMeasure):

            with torch.no_grad():
                A, T = LinearTransformOT.forward(mask, barycenter)

            t_loss = Samples_Loss(scaling=0.8,
                                  border=0.0001)(mask, mask.detach() + T)
            a_loss = Samples_Loss(scaling=0.8, border=0.0001)(
                mask.centered(), mask.centered().multiply(A).detach())
            w_loss = Samples_Loss(scaling=0.85, border=0.00001)(
                mask.centered().multiply(A), barycenter.centered().detach())

            # print(time.time() - t1)

            return a_loss * ca + w_loss * cw + t_loss * ct
Beispiel #2
0
    def forward(pred: ProbabilityMeasure, targets: ProbabilityMeasure):

        with torch.no_grad():
            P = compute_ot_matrix_par(pred.centered().coord.cpu().numpy(),
                                      targets.centered().coord.cpu().numpy())
            P = torch.from_numpy(P).type_as(pred.coord).cuda()

        xs = pred.centered().coord
        xsT = xs.transpose(1, 2)
        xt = targets.centered().coord

        a: Tensor = pred.probability + 1e-8
        a /= a.sum(dim=1, keepdim=True)
        a = a.reshape(a.shape[0], -1, 1)

        A = torch.inverse(xsT.bmm(a * xs)).bmm(xsT.bmm(P.bmm(xt)))

        T = targets.mean() - pred.mean()

        return A.type_as(pred.coord), T.detach()
Beispiel #3
0
    def forward(pred: ProbabilityMeasure,
                targets: ProbabilityMeasure,
                iters: int = 200):
        lambd = 0.002

        with torch.no_grad():
            P = SOT(iters, lambd).forward(pred.centered(), targets.centered())

        xs = pred.centered().coord
        xsT = xs.transpose(1, 2)
        xt = targets.centered().coord

        a = pred.probability + 1e-8
        a /= a.sum(dim=1, keepdim=True)
        a = a.reshape(a.shape[0], -1, 1)

        A = torch.inverse(xsT.bmm(a * xs)).bmm(xsT.bmm(P.bmm(xt)))

        T = targets.mean() - pred.mean()

        return A.type_as(pred.coord), T.detach()