def test_gradient_unbalanced_weight_and_position_sym(solv, div, entropy, reach, p, m): entropy.reach = reach cost = euclidean_cost(p) a, x = generate_measure(1, 5, 2) a = m * a a.requires_grad = True x.requires_grad = True _, f = solv.sinkhorn_asym(a, x, a, x, cost, entropy) func = entropy.output_regularized(a, x, a, x, cost, f, f) [grad_num_x, grad_num_a] = torch.autograd.grad(func, [x, a]) grad_th_a = ( -2 * entropy.legendre_entropy(-f) + 2 * entropy.blur * m - 2 * entropy.blur * ((f[:, :, None] + f[:, None, :] - dist_matrix(x, x, p)).exp() * a[:, None, :]).sum(dim=2)) pi = (a[:, :, None] * a[:, None, :] * ((f[:, :, None] + f[:, None, :] - dist_matrix(x, x, p)) / entropy.blur).exp()) grad_th_x = 4 * x * pi.sum(dim=2)[:, :, None] - 4 * torch.einsum( "ijk, ikl->ijl", pi, x) print(f"Symmetric potential = {f}") print(f"gradient ratio = {grad_num_a / grad_th_a}") print(f"gradient ratio = {grad_num_x / grad_th_x}") assert torch.allclose(grad_th_a, grad_num_a, rtol=1e-5) assert torch.allclose(grad_th_x, grad_num_x, rtol=1e-5)
def test_divergence_positivity(div, entropy, reach, p, m, n): entropy.reach = reach cost = euclidean_cost(p) a, x = generate_measure(1, 5, 2) b, y = generate_measure(1, 6, 2) func = div(m * a, x, n * b, y, cost, entropy, solver=solver) assert torch.ge(func, 0.0).all()
def test_sinkhorn_consistency_exp_log_asym(entropy, rtol, p, m, reach): """Test if the exp sinkhorn is consistent with its log form""" entropy.reach = reach cost = euclidean_cost(p) a, x = generate_measure(2, 5, 3) b, y = generate_measure(2, 6, 3) solver1 = BatchVanillaSinkhorn(nits=10000, nits_grad=10, tol=1e-12, assume_convergence=True) solver2 = BatchExpSinkhorn(nits=10000, nits_grad=10, tol=1e-12, assume_convergence=True) f_a, g_a = solver1.sinkhorn_asym(m * a, x, m * b, y, cost=cost, entropy=entropy) u_a, v_a = solver2.sinkhorn_asym(m * a, x, m * b, y, cost=cost, entropy=entropy) assert torch.allclose(f_a, u_a, rtol=rtol) assert torch.allclose(g_a, v_a, rtol=rtol)
def test_sanity_control_exp_sinkhorn_small(entropy): a, x = generate_measure(2, 5, 3) b, y = generate_measure(2, 6, 3) solver = BatchExpSinkhorn(nits=10000, nits_grad=10, tol=1e-12, assume_convergence=True) f, g = solver.sinkhorn_asym(a, x, b, y, cost=euclidean_cost(1), entropy=entropy) _, h = solver.sinkhorn_sym(a, x, cost=euclidean_cost(1), entropy=entropy) assert f is None assert g is None assert h is None
def test_consistency_infinite_blur_sinkhorn_div(entropy, reach, p): entropy.reach = reach cost = euclidean_cost(p) a, x = generate_measure(1, 5, 2) b, y = generate_measure(1, 6, 2) control = energyDistance(a, x, b, y, p) func = sinkhorn_divergence(a, x, b, y, cost, entropy, solver) assert torch.allclose(func, control, rtol=1e-0)
def test_sinkhorn_sym_infinite_blur(solv, entropy, atol, p, m, reach): entropy.reach = reach cost = euclidean_cost(p) a, x = generate_measure(2, 5, 3) f_c, _ = entropy.init_potential(m * a, x, m * a, x, cost=cost) _, f = solv.sinkhorn_sym(m * a, x, cost=cost, entropy=entropy) assert torch.allclose(entropy.error_sink(f, f_c), torch.tensor([0.0]), atol=atol)
def test_consistency_regularized_sym_asym(entropy, reach, p, m): entropy.reach = reach cost = euclidean_cost(p) a, x = generate_measure(1, 5, 2) f_xy, g_xy = solver.sinkhorn_asym(a, x, a, x, cost, entropy) _, f_xx = solver.sinkhorn_sym(a, x, cost, entropy) func_asym = entropy.output_regularized(a, x, a, x, cost, f_xy, g_xy) func_sym = entropy.output_regularized(a, x, a, x, cost, f_xx, f_xx) assert torch.allclose(func_asym, func_sym, rtol=1e-6)
def test_consistency_infinite_blur_regularized_ot_balanced(entropy, p): """Control consistency in OT_eps when eps goes to infinity, especially for balanced OT""" cost = euclidean_cost(p) a, x = generate_measure(1, 5, 2) b, y = generate_measure(1, 6, 2) f, g = convolution(a, x, b, y, cost) control = scal(a, f) func = regularized_ot(a, x, b, y, cost, entropy, solver) assert torch.allclose(func, control, rtol=1e-0)
def test_gradient_balanced_zero_grad_sinkhorn(solv, div, entropy, p, m): cost = euclidean_cost(p) a, x = generate_measure(1, 5, 2) a = m * a b, y = torch.zeros_like(a), torch.zeros_like(x) b.copy_(a) y.copy_(x) a.requires_grad = True x.requires_grad = True func = sinkhorn_divergence(a, x, b, y, cost, entropy, solver=solv) [grad_num_x, grad_num_a] = torch.autograd.grad(func, [x, a]) assert torch.allclose(torch.zeros_like(grad_num_a), grad_num_a, atol=1e-5) assert torch.allclose(torch.zeros_like(grad_num_x), grad_num_x, atol=1e-5)
def test_consistency_infinite_blur_regularized_ot_unbalanced( entropy, reach, p, m, n): entropy.reach = reach cost = euclidean_cost(p) torch.set_default_dtype(torch.float64) a, x = generate_measure(1, 5, 2) b, y = generate_measure(1, 6, 2) phi = entropy.entropy f, g = convolution(a, x, b, y, cost) control = (scal(a, f) + m * phi(torch.Tensor([n])) + n * phi(torch.Tensor([m]))) func = regularized_ot(m * a, x, n * b, y, cost, entropy, solver=solver) assert torch.allclose(func, control, rtol=1e-0)
def test_sinkhorn_asym_infinite_blur_unbalanced(solv, entropy, atol, p, m, n, reach): entropy.reach = reach cost = euclidean_cost(p) a, x = generate_measure(2, 5, 3) b, y = generate_measure(2, 6, 3) f_c, g_c = entropy.init_potential(m * a, x, n * b, y, p) f, g = solv.sinkhorn_asym(m * a, x, n * b, y, cost=cost, entropy=entropy) assert torch.allclose(torch.tensor([0.0]), entropy.error_sink(f, f_c), atol=atol) assert torch.allclose(torch.tensor([0.0]), entropy.error_sink(g, g_c), atol=atol)
def test_gradient_balanced_weight_and_position_sym(solv, div, entropy, p, m): cost = euclidean_cost(p) a, x = generate_measure(1, 5, 2) a = m * a a.requires_grad = True x.requires_grad = True _, f = solv.sinkhorn_sym(a, x, cost, entropy) func = entropy.output_regularized(a, x, a, x, cost, f, f) [grad_num_x, grad_num_a] = torch.autograd.grad(func, [x, a]) grad_th_a = 2 * f pi = (a[:, :, None] * a[:, None, :] * ((f[:, :, None] + f[:, None, :] - dist_matrix(x, x, p)) / entropy.blur).exp()) grad_th_x = 2 * x * pi.sum(dim=2)[:, :, None] - 2 * torch.einsum( "ijk, ikl->ijl", pi, x) assert torch.allclose(grad_th_a, grad_num_a, rtol=1e-5) assert torch.allclose(grad_th_x, grad_num_x, rtol=1e-5)
def test_sinkhorn_consistency_sym_asym(solv, entropy, atol, p, m, reach): """Test if the symmetric and assymetric Sinkhorn output the same results when (a,x)=(b,y)""" entropy.reach = reach cost = euclidean_cost(p) a, x = generate_measure(2, 5, 3) f_a, g_a = solv.sinkhorn_asym(m * a, x, m * a, x, cost=cost, entropy=entropy) _, f_s = solv.sinkhorn_sym(m * a, x, cost=cost, entropy=entropy) assert torch.allclose(entropy.error_sink(f_a, f_s), torch.tensor([0.0]), atol=atol) assert torch.allclose(entropy.error_sink(g_a, f_s), torch.tensor([0.0]), atol=atol)
def test_gradient_unbalanced_weight_and_position_asym(solv, div, entropy, reach, p, m, n): entropy.reach = reach cost = euclidean_cost(p) a, x = generate_measure(1, 5, 2) a = m * a a.requires_grad = True x.requires_grad = True b, y = generate_measure(1, 6, 2) f, g = solv.sinkhorn_asym(a, x, n * b, y, cost, entropy) func = entropy.output_regularized(a, x, n * b, y, cost, f, g) [grad_num_x, grad_num_a] = torch.autograd.grad(func, [x, a]) grad_th_a = ( -entropy.legendre_entropy(-f) + entropy.blur * n - entropy.blur * ((f[:, :, None] + g[:, None, :] - dist_matrix(x, y, p)).exp() * n * b[:, None, :]).sum(dim=2)) pi = (n * a[:, :, None] * b[:, None, :] * ((f[:, :, None] + g[:, None, :] - dist_matrix(x, y, p)) / entropy.blur).exp()) grad_th_x = 2 * x * pi.sum(dim=2)[:, :, None] - 2 * torch.einsum( "ijk, ikl->ijl", pi, y) assert torch.allclose(grad_th_a, grad_num_a, rtol=1e-5) assert torch.allclose(grad_th_x, grad_num_x, rtol=1e-5)
b = b / np.sum(b) return a, x, b, y # Init of measures and solvers a, x, b, y = template_measure(500) A, X, B, Y = ( torch.from_numpy(a)[None, :], torch.from_numpy(x)[None, :, None], torch.from_numpy(b)[None, :], torch.from_numpy(y)[None, :, None], ) blur = 1e-3 reach = np.array([10 ** x for x in np.linspace(-2, np.log10(0.5), 4)]) cost = euclidean_cost(2) solver = BatchVanillaSinkhorn( nits=10000, nits_grad=2, tol=1e-8, assume_convergence=True ) list_entropy = [KullbackLeibler(blur, reach[0]), TotalVariation(blur, reach[0])] # Init of plot blue = (0.55, 0.55, 0.95) red = (0.95, 0.55, 0.55) # Plotting transport marginals for each entropy for i in range(len(list_entropy)): for j in range(len(reach)): fig = plt.figure(figsize=(8, 4)) entropy = list_entropy[i]
def test_sinkhorn_no_bug(entropy, solv): a, x = generate_measure(2, 5, 3) b, y = generate_measure(2, 6, 3) solv.sinkhorn_asym(a, x, b, y, cost=euclidean_cost(1), entropy=entropy) solv.sinkhorn_sym(a, x, cost=euclidean_cost(1), entropy=entropy, y_j=y)
if isinstance(entropy, Balanced): a_i.data *= (-(2 * lr_a * m)).exp() a_i.data /= a_i.data.sum(1) else: a_i.data *= (-(2 * lr_a * m)).exp() print(f"At step {i} the total mass is {a_i.sum().item()}") fname = path + f"/unbalanced_flow_{func.__name__}_p{p}_" \ f"{entropy.__dict__}_lrx{lr_x}_lra{lr_a}_" \ f"steps{Nsteps}_frame{i}.eps" plt.savefig(fname, format='eps') plt.cla() if __name__ == '__main__': setting = 0 p, cost = 2, euclidean_cost(2) solver = BatchVanillaSinkhorn(nits=5000, nits_grad=15, tol=1e-8, assume_convergence=True) if setting == 0: # Compare KL for various mass steps gradient_flow(sinkhorn_divergence, entropy=KullbackLeibler(1e-2, 0.3), solver=solver, cost=cost, p=p, lr_x=60., lr_a=0., Nsteps=300) gradient_flow(sinkhorn_divergence,
def test_divergence_zero(div, entropy, reach, p, m): entropy.reach = reach cost = euclidean_cost(p) a, x = generate_measure(1, 5, 2) func = div(m * a, x, m * a, x, cost, entropy, solver=solver) assert torch.allclose(func, torch.Tensor([0.0]), rtol=1e-6)
if isinstance(entropy, Balanced): a_i.data *= (- (2 * lr_a * m)).exp() a_i.data /= a_i.data.sum(1) else: a_i.data *= (- (2 * lr_a * m)).exp() print(f"At step {i} the total mass is {a_i.sum().item()}") plt.title(f"t = {i / (Nsteps - 1):1.2f}, elapsed time: " f"{(time.time() - t_0) / Nsteps} s/it", fontsize=30) plt.savefig(fname) plt.show() if __name__ == '__main__': setting = 0 p, cost = 2, euclidean_cost(2) solver = BatchVanillaSinkhorn(nits=5000, nits_grad=15, tol=1e-8, assume_convergence=True) if setting == 0: # Compare KL for various mass steps gradient_flow(sinkhorn_divergence, entropy=KullbackLeibler(1e-2, 0.3), solver=solver, cost=cost, p=p, lr_x=60., lr_a=0., Nsteps=300) gradient_flow(sinkhorn_divergence, entropy=KullbackLeibler(1e-2, 0.3), solver=solver, cost=cost, p=p, lr_x=60., lr_a=0.3, Nsteps=300) gradient_flow(sinkhorn_divergence, entropy=KullbackLeibler(1e-2, 0.3), solver=solver, cost=cost, p=p, lr_x=60., lr_a=0.5, Nsteps=300) gradient_flow(sinkhorn_divergence, entropy=KullbackLeibler(1e-2, 0.3), solver=solver, cost=cost, p=p,