def logconv(self, C, dtype): if len(C) == 2: D = C[0].shape[1] log_conv = generic_logsumexp("( B - (P * " + self.cost.formula + " ) )", "A = Vi(1)", "X = Vi({})".format(D), "Y = Vj({})".format(D), "B = Vj(1)", "P = Pm(1)", dtype=dtype) else: D = C[0].shape[1] C = C[2].shape[1] log_conv = generic_logsumexp( "( B - (P * " + f"Sqrt(Pow({self.cost.formula},2) + Pow(Z|L, 2))" + " ) )", "A = Vi(1)", "X = Vi({})".format(D), "Y = Vj({})".format(D), "Z = Vi({})".format(C), "L = Vj({})".format(C), "B = Vj(1)", "P = Pm(1)", dtype=dtype) return log_conv
def keops_lse(cost, D, dtype="float32"): log_conv = generic_logsumexp("( B - (P * " + cost + " ) )", "A = Vi(1)", "X = Vi({})".format(D), "Y = Vj({})".format(D), "B = Vj(1)", "P = Pm(1)", dtype=dtype) return log_conv
def Sinkhorn_ops(p, ε, x_i, y_j): """ Given: - an exponent p = 1 or 2 - a regularization strength ε > 0 - point clouds x_i and y_j, encoded as N-by-D and M-by-D torch arrays, Returns a pair of routines S_x, S_y such that [S_x(f_i)]_j = -log sum_i exp( f_i - |x_i-y_j|^p / ε ) [S_y(f_j)]_i = -log sum_j exp( f_j - |x_i-y_j|^p / ε ) This may look like a strange level of abstraction, but it is the most convenient way of working with KeOps and Vanilla pytorch (with a pre-computed cost matrix) at the same time. """ if backend == "keops": # Memory-efficient GPU implementation : ONline logsumexp # We create a KeOps GPU routine... if p == 1: formula = "Fj - (Sqrt(SqDist(Xi,Yj)) / E)" elif p == 2: formula = "Fj - (SqDist(Xi,Yj) / E)" else: formula = "Fj - (Powf(SqDist(Xi,Yj),R)/ E)" raise (NotImplementedError( "I should fix the derivative at 0 of Powf, in KeOps's core.")) D = x_i.shape[1] # Dimension of the ambient space (typically 2 or 3) routine = generic_logsumexp( formula, "outi = Vx(1)", # Formula, output... # and input variables : ε, x_i, y_j, f_j, p/2 given with their respective dimensions "E = Pm(1)", "Xi = Vx({})".format(D), "Yj = Vy({})".format(D), "Fj = Vy(1)", "R=Pm(1)") # Before wrapping it up in a simple pair of operators - don't forget the minus! ε, r = torch.Tensor([ε]).type_as(x_i), torch.Tensor([p / 2 ]).type_as(x_i) S_x = lambda f_i: -routine(ε, y_j, x_i, f_i, r) S_y = lambda f_j: -routine(ε, x_i, y_j, f_j, r) return S_x, S_y elif backend == "pytorch": # Naive matrix-vector implementation : OFFline logsumexp # We precompute the |x_i-y_j|^p matrix once and for all... x_y = x_i.unsqueeze(1) - y_j.unsqueeze(0) if p == 1: C_e = x_y.norm(dim=2) / ε elif p == 2: C_e = (x_y**2).sum(2) / ε else: C_e = x_y.norm(dim=2)**(p / 2) / ε CT_e = C_e.t() # Before wrapping it up in a simple pair of operators - don't forget the minus! S_x = lambda f_i: -lse(f_i.view(1, -1) - CT_e) S_y = lambda f_j: -lse(f_j.view(1, -1) - C_e) return S_x, S_y
def lse_genred(cost, D, dtype="float32"): """Legacy "Genred" implementation, with low-level KeOps formulas.""" log_conv = generic_logsumexp( "( B - (P * " + cost + " ) )", "A = Vi(1)", "X = Vi({})".format(D), "Y = Vj({})".format(D), "B = Vj(1)", "P = Pm(1)", dtype=dtype, ) return log_conv
def benchmark(bench_name, N, dev, backend, loops=10, enable_GC=True, fidelity=None): importlib.reload(torch) device = torch.device(dev) x_i = torch.randn(N, D, dtype=torch.float32, device=device, requires_grad=True) y_j = torch.randn(N, D, dtype=torch.float32, device=device) α_i = torch.randn(N, 1, dtype=torch.float32, device=device) β_j = torch.randn(N, 1, dtype=torch.float32, device=device) α_i = α_i.abs() β_j = β_j.abs() α_i = α_i / α_i.sum() β_j = β_j / β_j.sum() s2v = lambda x: torch.tensor([x], dtype=torch.float32, device=device) def scal(α, f): return torch.dot(α.view(-1), f.view(-1)) if bench_name == "energy_distance": keops_conv = generic_sum( "Sqrt(SqDist(Xi,Yj))* Bj", "out_i = Vx(1)", # Formula, output... # and input variables : x_i, y_j, β_j, given with their respective dimensions "Xi = Vx({})".format(D), "Yj = Vy({})".format(D), "Bj = Vy(1)") def vanilla_conv(x, y, β): XmY2 = ((x.unsqueeze(1) - y.unsqueeze(0))**2).sum(2) K = XmY2.sqrt() return K @ β def bench(α, x, β, y): if backend == "GPU_1D": conv = keops_conv elif backend == "pytorch": conv = vanilla_conv cost = scal(α, conv(x, y, β) - .5 * conv(x, x, α)) - .5 * scal(β, conv(y, y, β)) cost.backward() return cost code = '_ = bench(α_i,x_i,β_j,y_j)' task = "Energy Distances" if bench_name == "LogSumExp": keops_lse = generic_logsumexp( "Sqrt(SqDist(Xi,Yj))", "out_i = Vx(1)", # Formula, output... # and input variables : x_i, y_j, β_j, given with their respective dimensions "Xi = Vx({})".format(D), "Yj = Vy({})".format(D)) def lse(v_ij): """[lse(v_ij)]_i = log sum_j exp(v_ij), with numerical accuracy.""" V_i = torch.max(v_ij, 1)[0].view(-1, 1) return V_i + (v_ij - V_i).exp().sum(1).log().view(-1, 1) def vanilla_lse(x, y): XmY2 = ((x.unsqueeze(1) - y.unsqueeze(0))**2).sum(2) K = XmY2.sqrt() return lse(K) def bench(x, y): if backend == "GPU_1D": return keops_lse(x, y) elif backend == "pytorch": return vanilla_lse(x, y) else: raise NotImplementedError() code = '_ = bench(x_i,y_j)' task = "LSEs" elif bench_name == "fidelities": from divergences import kernel_divergence, regularized_ot, hausdorff_divergence, sinkhorn_divergence if fidelity == "energy_distance": params = ("energy", None) code = "c = kernel_divergence(α_i,x_i, β_j,y_j, k=params ) ; c.backward()" elif fidelity == "hausdorff": params = { "p": 1, "eps": .1, "nits": 3, "tol": 0., } code = "c = hausdorff_divergence(α_i,x_i, β_j,y_j, **params ) ; c.backward()" elif fidelity == "sinkhorn": params = { "p": 1, "eps": .1, "nits": (20, 3), "assume_convergence": True, # This is true in practice, and lets us win a x2 factor "tol": 0., } code = "c = sinkhorn_divergence(α_i,x_i, β_j,y_j, **params ) ; c.backward()" elif fidelity == "sinkhorn_nocv": params = { "p": 1, "eps": .1, "nits": (20, 3), "assume_convergence": False, "tol": 0., } code = "c = sinkhorn_divergence(α_i,x_i, β_j,y_j, **params ) ; c.backward()" task = "fidelities" exec(code, locals()) import gc GC = 'gc.enable();' if enable_GC else 'pass;' print("{:3} NxN {}, with N ={:7}: {:3}x".format(loops, task, N, loops), end="") exec(code, locals()) # Warmup run elapsed = timeit.Timer(code, GC, globals=locals(), timer=time.time).timeit(loops) print("{:3.6f}s".format(elapsed / loops)) return elapsed / loops
def keops_lse(cost, D): log_conv = generic_logsumexp("( B - (P * " + cost + " ) )", "A = Vx(1)", "X = Vx({})".format(D), "Y = Vy({})".format(D), "B = Vy(1)", "P = Pm(1)") return log_conv