def loss(image: Tensor, mask: ProbabilityMeasure): with torch.no_grad(): A, T = LinearTransformOT.forward(mask, barycenter) t_loss = Samples_Loss(scaling=0.8, border=0.0001)(mask, mask.detach() + T) a_loss = Samples_Loss(scaling=0.8, border=0.0001)( mask.centered(), mask.centered().multiply(A).detach()) w_loss = Samples_Loss(scaling=0.85, border=0.00001)( mask.centered().multiply(A), barycenter.centered().detach()) # print(time.time() - t1) return a_loss * ca + w_loss * cw + t_loss * ct
def forward(pred: ProbabilityMeasure, targets: ProbabilityMeasure): with torch.no_grad(): P = compute_ot_matrix_par(pred.centered().coord.cpu().numpy(), targets.centered().coord.cpu().numpy()) P = torch.from_numpy(P).type_as(pred.coord).cuda() xs = pred.centered().coord xsT = xs.transpose(1, 2) xt = targets.centered().coord a: Tensor = pred.probability + 1e-8 a /= a.sum(dim=1, keepdim=True) a = a.reshape(a.shape[0], -1, 1) A = torch.inverse(xsT.bmm(a * xs)).bmm(xsT.bmm(P.bmm(xt))) T = targets.mean() - pred.mean() return A.type_as(pred.coord), T.detach()
def forward(pred: ProbabilityMeasure, targets: ProbabilityMeasure, iters: int = 200): lambd = 0.002 with torch.no_grad(): P = SOT(iters, lambd).forward(pred.centered(), targets.centered()) xs = pred.centered().coord xsT = xs.transpose(1, 2) xt = targets.centered().coord a = pred.probability + 1e-8 a /= a.sum(dim=1, keepdim=True) a = a.reshape(a.shape[0], -1, 1) A = torch.inverse(xsT.bmm(a * xs)).bmm(xsT.bmm(P.bmm(xt))) T = targets.mean() - pred.mean() return A.type_as(pred.coord), T.detach()