def test_sinkhorn_consistency_exp_log_asym(entropy, rtol, p, m, reach): """Test if the exp sinkhorn is consistent with its log form""" entropy.reach = reach cost = euclidean_cost(p) a, x = generate_measure(2, 5, 3) b, y = generate_measure(2, 6, 3) solver1 = BatchVanillaSinkhorn(nits=10000, nits_grad=10, tol=1e-12, assume_convergence=True) solver2 = BatchExpSinkhorn(nits=10000, nits_grad=10, tol=1e-12, assume_convergence=True) f_a, g_a = solver1.sinkhorn_asym(m * a, x, m * b, y, cost=cost, entropy=entropy) u_a, v_a = solver2.sinkhorn_asym(m * a, x, m * b, y, cost=cost, entropy=entropy) assert torch.allclose(f_a, u_a, rtol=rtol) assert torch.allclose(g_a, v_a, rtol=rtol)
return a, x, b, y # Init of measures and solvers a, x, b, y = template_measure(500) A, X, B, Y = ( torch.from_numpy(a)[None, :], torch.from_numpy(x)[None, :, None], torch.from_numpy(b)[None, :], torch.from_numpy(y)[None, :, None], ) blur = 1e-3 reach = np.array([10 ** x for x in np.linspace(-2, np.log10(0.5), 4)]) cost = euclidean_cost(2) solver = BatchVanillaSinkhorn( nits=10000, nits_grad=2, tol=1e-8, assume_convergence=True ) list_entropy = [KullbackLeibler(blur, reach[0]), TotalVariation(blur, reach[0])] # Init of plot blue = (0.55, 0.55, 0.95) red = (0.95, 0.55, 0.55) # Plotting transport marginals for each entropy for i in range(len(list_entropy)): for j in range(len(reach)): fig = plt.figure(figsize=(8, 4)) entropy = list_entropy[i] entropy.reach = reach[j] f, g = solver.sinkhorn_asym(A, X, B, Y, cost, entropy)
(1.5, 2.0)]) @pytest.mark.parametrize( "entropy", [ KullbackLeibler(1e0, 1e0), TotalVariation(1e0, 1e0), Range(1e0, 0.3, 2), PowerEntropy(1e0, 1e0, 0), PowerEntropy(1e0, 1e0, -1), ], ) @pytest.mark.parametrize("div", [regularized_ot]) @pytest.mark.parametrize( "solv", [ BatchVanillaSinkhorn( nits=5000, nits_grad=20, tol=1e-14, assume_convergence=True), BatchExpSinkhorn( nits=5000, nits_grad=20, tol=1e-14, assume_convergence=True), ], ) def test_gradient_unbalanced_weight_and_position_asym(solv, div, entropy, reach, p, m, n): entropy.reach = reach cost = euclidean_cost(p) a, x = generate_measure(1, 5, 2) a = m * a a.requires_grad = True x.requires_grad = True b, y = generate_measure(1, 6, 2) f, g = solv.sinkhorn_asym(a, x, n * b, y, cost, entropy) func = entropy.output_regularized(a, x, n * b, y, cost, f, g)
a_i.data /= a_i.data.sum(1) else: a_i.data *= (-(2 * lr_a * m)).exp() print(f"At step {i} the total mass is {a_i.sum().item()}") fname = path + f"/unbalanced_flow_{func.__name__}_p{p}_" \ f"{entropy.__dict__}_lrx{lr_x}_lra{lr_a}_" \ f"steps{Nsteps}_frame{i}.eps" plt.savefig(fname, format='eps') plt.cla() if __name__ == '__main__': setting = 0 p, cost = 2, euclidean_cost(2) solver = BatchVanillaSinkhorn(nits=5000, nits_grad=15, tol=1e-8, assume_convergence=True) if setting == 0: # Compare KL for various mass steps gradient_flow(sinkhorn_divergence, entropy=KullbackLeibler(1e-2, 0.3), solver=solver, cost=cost, p=p, lr_x=60., lr_a=0., Nsteps=300) gradient_flow(sinkhorn_divergence, entropy=KullbackLeibler(1e-2, 0.3), solver=solver, cost=cost,
@pytest.mark.parametrize( "entropy", [ Balanced(1e1), KullbackLeibler(1e1, 1e0), TotalVariation(1e1, 1e0), Range(1e1, 0.3, 2), PowerEntropy(1e1, 1e0, 0), PowerEntropy(1e1, 1e0, -1), ], ) @pytest.mark.parametrize( "solv", [ BatchVanillaSinkhorn( nits=10, nits_grad=10, tol=1e-5, assume_convergence=True), BatchVanillaSinkhorn( nits=10, nits_grad=10, tol=1e-5, assume_convergence=False), BatchScalingSinkhorn(budget=10, nits_grad=10, assume_convergence=True), BatchScalingSinkhorn(budget=10, nits_grad=10, assume_convergence=False), BatchExpSinkhorn( nits=10, nits_grad=10, tol=1e-5, assume_convergence=True), BatchExpSinkhorn( nits=10, nits_grad=10, tol=1e-5, assume_convergence=False), ], ) def test_sinkhorn_no_bug(entropy, solv): a, x = generate_measure(2, 5, 3) b, y = generate_measure(2, 6, 3) solv.sinkhorn_asym(a, x, b, y, cost=euclidean_cost(1), entropy=entropy)