def test_gradient_unbalanced_weight_and_position_sym(solv, div, entropy, reach,
                                                     p, m):
    entropy.reach = reach
    cost = euclidean_cost(p)
    a, x = generate_measure(1, 5, 2)
    a = m * a
    a.requires_grad = True
    x.requires_grad = True
    _, f = solv.sinkhorn_asym(a, x, a, x, cost, entropy)
    func = entropy.output_regularized(a, x, a, x, cost, f, f)
    [grad_num_x, grad_num_a] = torch.autograd.grad(func, [x, a])
    grad_th_a = (
        -2 * entropy.legendre_entropy(-f) + 2 * entropy.blur * m -
        2 * entropy.blur *
        ((f[:, :, None] + f[:, None, :] - dist_matrix(x, x, p)).exp() *
         a[:, None, :]).sum(dim=2))
    pi = (a[:, :, None] * a[:, None, :] *
          ((f[:, :, None] + f[:, None, :] - dist_matrix(x, x, p)) /
           entropy.blur).exp())
    grad_th_x = 4 * x * pi.sum(dim=2)[:, :, None] - 4 * torch.einsum(
        "ijk, ikl->ijl", pi, x)
    print(f"Symmetric potential = {f}")
    print(f"gradient ratio = {grad_num_a / grad_th_a}")
    print(f"gradient ratio = {grad_num_x / grad_th_x}")
    assert torch.allclose(grad_th_a, grad_num_a, rtol=1e-5)
    assert torch.allclose(grad_th_x, grad_num_x, rtol=1e-5)
Пример #2
0
def test_divergence_positivity(div, entropy, reach, p, m, n):
    entropy.reach = reach
    cost = euclidean_cost(p)
    a, x = generate_measure(1, 5, 2)
    b, y = generate_measure(1, 6, 2)
    func = div(m * a, x, n * b, y, cost, entropy, solver=solver)
    assert torch.ge(func, 0.0).all()
Пример #3
0
def test_sinkhorn_consistency_exp_log_asym(entropy, rtol, p, m, reach):
    """Test if the exp sinkhorn is consistent with its log form"""
    entropy.reach = reach
    cost = euclidean_cost(p)
    a, x = generate_measure(2, 5, 3)
    b, y = generate_measure(2, 6, 3)
    solver1 = BatchVanillaSinkhorn(nits=10000,
                                   nits_grad=10,
                                   tol=1e-12,
                                   assume_convergence=True)
    solver2 = BatchExpSinkhorn(nits=10000,
                               nits_grad=10,
                               tol=1e-12,
                               assume_convergence=True)
    f_a, g_a = solver1.sinkhorn_asym(m * a,
                                     x,
                                     m * b,
                                     y,
                                     cost=cost,
                                     entropy=entropy)
    u_a, v_a = solver2.sinkhorn_asym(m * a,
                                     x,
                                     m * b,
                                     y,
                                     cost=cost,
                                     entropy=entropy)
    assert torch.allclose(f_a, u_a, rtol=rtol)
    assert torch.allclose(g_a, v_a, rtol=rtol)
Пример #4
0
def test_sanity_control_exp_sinkhorn_small(entropy):
    a, x = generate_measure(2, 5, 3)
    b, y = generate_measure(2, 6, 3)
    solver = BatchExpSinkhorn(nits=10000,
                              nits_grad=10,
                              tol=1e-12,
                              assume_convergence=True)
    f, g = solver.sinkhorn_asym(a,
                                x,
                                b,
                                y,
                                cost=euclidean_cost(1),
                                entropy=entropy)
    _, h = solver.sinkhorn_sym(a, x, cost=euclidean_cost(1), entropy=entropy)
    assert f is None
    assert g is None
    assert h is None
Пример #5
0
def test_consistency_infinite_blur_sinkhorn_div(entropy, reach, p):
    entropy.reach = reach
    cost = euclidean_cost(p)
    a, x = generate_measure(1, 5, 2)
    b, y = generate_measure(1, 6, 2)
    control = energyDistance(a, x, b, y, p)
    func = sinkhorn_divergence(a, x, b, y, cost, entropy, solver)
    assert torch.allclose(func, control, rtol=1e-0)
Пример #6
0
def test_sinkhorn_sym_infinite_blur(solv, entropy, atol, p, m, reach):
    entropy.reach = reach
    cost = euclidean_cost(p)
    a, x = generate_measure(2, 5, 3)
    f_c, _ = entropy.init_potential(m * a, x, m * a, x, cost=cost)
    _, f = solv.sinkhorn_sym(m * a, x, cost=cost, entropy=entropy)
    assert torch.allclose(entropy.error_sink(f, f_c),
                          torch.tensor([0.0]),
                          atol=atol)
Пример #7
0
def test_consistency_regularized_sym_asym(entropy, reach, p, m):
    entropy.reach = reach
    cost = euclidean_cost(p)
    a, x = generate_measure(1, 5, 2)
    f_xy, g_xy = solver.sinkhorn_asym(a, x, a, x, cost, entropy)
    _, f_xx = solver.sinkhorn_sym(a, x, cost, entropy)
    func_asym = entropy.output_regularized(a, x, a, x, cost, f_xy, g_xy)
    func_sym = entropy.output_regularized(a, x, a, x, cost, f_xx, f_xx)
    assert torch.allclose(func_asym, func_sym, rtol=1e-6)
Пример #8
0
def test_consistency_infinite_blur_regularized_ot_balanced(entropy, p):
    """Control consistency in OT_eps when eps goes to infinity,
    especially for balanced OT"""
    cost = euclidean_cost(p)
    a, x = generate_measure(1, 5, 2)
    b, y = generate_measure(1, 6, 2)
    f, g = convolution(a, x, b, y, cost)
    control = scal(a, f)
    func = regularized_ot(a, x, b, y, cost, entropy, solver)
    assert torch.allclose(func, control, rtol=1e-0)
def test_gradient_balanced_zero_grad_sinkhorn(solv, div, entropy, p, m):
    cost = euclidean_cost(p)
    a, x = generate_measure(1, 5, 2)
    a = m * a
    b, y = torch.zeros_like(a), torch.zeros_like(x)
    b.copy_(a)
    y.copy_(x)
    a.requires_grad = True
    x.requires_grad = True
    func = sinkhorn_divergence(a, x, b, y, cost, entropy, solver=solv)
    [grad_num_x, grad_num_a] = torch.autograd.grad(func, [x, a])
    assert torch.allclose(torch.zeros_like(grad_num_a), grad_num_a, atol=1e-5)
    assert torch.allclose(torch.zeros_like(grad_num_x), grad_num_x, atol=1e-5)
Пример #10
0
def test_consistency_infinite_blur_regularized_ot_unbalanced(
        entropy, reach, p, m, n):
    entropy.reach = reach
    cost = euclidean_cost(p)
    torch.set_default_dtype(torch.float64)
    a, x = generate_measure(1, 5, 2)
    b, y = generate_measure(1, 6, 2)
    phi = entropy.entropy
    f, g = convolution(a, x, b, y, cost)
    control = (scal(a, f) + m * phi(torch.Tensor([n])) +
               n * phi(torch.Tensor([m])))
    func = regularized_ot(m * a, x, n * b, y, cost, entropy, solver=solver)
    assert torch.allclose(func, control, rtol=1e-0)
Пример #11
0
def test_sinkhorn_asym_infinite_blur_unbalanced(solv, entropy, atol, p, m, n,
                                                reach):
    entropy.reach = reach
    cost = euclidean_cost(p)
    a, x = generate_measure(2, 5, 3)
    b, y = generate_measure(2, 6, 3)
    f_c, g_c = entropy.init_potential(m * a, x, n * b, y, p)
    f, g = solv.sinkhorn_asym(m * a, x, n * b, y, cost=cost, entropy=entropy)
    assert torch.allclose(torch.tensor([0.0]),
                          entropy.error_sink(f, f_c),
                          atol=atol)
    assert torch.allclose(torch.tensor([0.0]),
                          entropy.error_sink(g, g_c),
                          atol=atol)
def test_gradient_balanced_weight_and_position_sym(solv, div, entropy, p, m):
    cost = euclidean_cost(p)
    a, x = generate_measure(1, 5, 2)
    a = m * a
    a.requires_grad = True
    x.requires_grad = True
    _, f = solv.sinkhorn_sym(a, x, cost, entropy)
    func = entropy.output_regularized(a, x, a, x, cost, f, f)
    [grad_num_x, grad_num_a] = torch.autograd.grad(func, [x, a])
    grad_th_a = 2 * f
    pi = (a[:, :, None] * a[:, None, :] *
          ((f[:, :, None] + f[:, None, :] - dist_matrix(x, x, p)) /
           entropy.blur).exp())
    grad_th_x = 2 * x * pi.sum(dim=2)[:, :, None] - 2 * torch.einsum(
        "ijk, ikl->ijl", pi, x)
    assert torch.allclose(grad_th_a, grad_num_a, rtol=1e-5)
    assert torch.allclose(grad_th_x, grad_num_x, rtol=1e-5)
Пример #13
0
def test_sinkhorn_consistency_sym_asym(solv, entropy, atol, p, m, reach):
    """Test if the symmetric and assymetric Sinkhorn
    output the same results when (a,x)=(b,y)"""
    entropy.reach = reach
    cost = euclidean_cost(p)
    a, x = generate_measure(2, 5, 3)
    f_a, g_a = solv.sinkhorn_asym(m * a,
                                  x,
                                  m * a,
                                  x,
                                  cost=cost,
                                  entropy=entropy)
    _, f_s = solv.sinkhorn_sym(m * a, x, cost=cost, entropy=entropy)
    assert torch.allclose(entropy.error_sink(f_a, f_s),
                          torch.tensor([0.0]),
                          atol=atol)
    assert torch.allclose(entropy.error_sink(g_a, f_s),
                          torch.tensor([0.0]),
                          atol=atol)
def test_gradient_unbalanced_weight_and_position_asym(solv, div, entropy,
                                                      reach, p, m, n):
    entropy.reach = reach
    cost = euclidean_cost(p)
    a, x = generate_measure(1, 5, 2)
    a = m * a
    a.requires_grad = True
    x.requires_grad = True
    b, y = generate_measure(1, 6, 2)
    f, g = solv.sinkhorn_asym(a, x, n * b, y, cost, entropy)
    func = entropy.output_regularized(a, x, n * b, y, cost, f, g)
    [grad_num_x, grad_num_a] = torch.autograd.grad(func, [x, a])
    grad_th_a = (
        -entropy.legendre_entropy(-f) + entropy.blur * n - entropy.blur *
        ((f[:, :, None] + g[:, None, :] - dist_matrix(x, y, p)).exp() * n *
         b[:, None, :]).sum(dim=2))
    pi = (n * a[:, :, None] * b[:, None, :] *
          ((f[:, :, None] + g[:, None, :] - dist_matrix(x, y, p)) /
           entropy.blur).exp())
    grad_th_x = 2 * x * pi.sum(dim=2)[:, :, None] - 2 * torch.einsum(
        "ijk, ikl->ijl", pi, y)
    assert torch.allclose(grad_th_a, grad_num_a, rtol=1e-5)
    assert torch.allclose(grad_th_x, grad_num_x, rtol=1e-5)
    b = b / np.sum(b)

    return a, x, b, y


# Init of measures and solvers
a, x, b, y = template_measure(500)
A, X, B, Y = (
    torch.from_numpy(a)[None, :],
    torch.from_numpy(x)[None, :, None],
    torch.from_numpy(b)[None, :],
    torch.from_numpy(y)[None, :, None],
)
blur = 1e-3
reach = np.array([10 ** x for x in np.linspace(-2, np.log10(0.5), 4)])
cost = euclidean_cost(2)
solver = BatchVanillaSinkhorn(
    nits=10000, nits_grad=2, tol=1e-8, assume_convergence=True
)
list_entropy = [KullbackLeibler(blur, reach[0]),
                TotalVariation(blur, reach[0])]

# Init of plot
blue = (0.55, 0.55, 0.95)
red = (0.95, 0.55, 0.55)

# Plotting transport marginals for each entropy
for i in range(len(list_entropy)):
    for j in range(len(reach)):
        fig = plt.figure(figsize=(8, 4))
        entropy = list_entropy[i]
Пример #16
0
def test_sinkhorn_no_bug(entropy, solv):
    a, x = generate_measure(2, 5, 3)
    b, y = generate_measure(2, 6, 3)
    solv.sinkhorn_asym(a, x, b, y, cost=euclidean_cost(1), entropy=entropy)
    solv.sinkhorn_sym(a, x, cost=euclidean_cost(1), entropy=entropy, y_j=y)
        if isinstance(entropy, Balanced):
            a_i.data *= (-(2 * lr_a * m)).exp()
            a_i.data /= a_i.data.sum(1)
        else:
            a_i.data *= (-(2 * lr_a * m)).exp()
        print(f"At step {i} the total mass is {a_i.sum().item()}")
        fname = path + f"/unbalanced_flow_{func.__name__}_p{p}_" \
                       f"{entropy.__dict__}_lrx{lr_x}_lra{lr_a}_" \
                       f"steps{Nsteps}_frame{i}.eps"
        plt.savefig(fname, format='eps')
        plt.cla()


if __name__ == '__main__':
    setting = 0
    p, cost = 2, euclidean_cost(2)
    solver = BatchVanillaSinkhorn(nits=5000,
                                  nits_grad=15,
                                  tol=1e-8,
                                  assume_convergence=True)

    if setting == 0:  # Compare KL for various mass steps
        gradient_flow(sinkhorn_divergence,
                      entropy=KullbackLeibler(1e-2, 0.3),
                      solver=solver,
                      cost=cost,
                      p=p,
                      lr_x=60.,
                      lr_a=0.,
                      Nsteps=300)
        gradient_flow(sinkhorn_divergence,
Пример #18
0
def test_divergence_zero(div, entropy, reach, p, m):
    entropy.reach = reach
    cost = euclidean_cost(p)
    a, x = generate_measure(1, 5, 2)
    func = div(m * a, x, m * a, x, cost, entropy, solver=solver)
    assert torch.allclose(func, torch.Tensor([0.0]), rtol=1e-6)
        if isinstance(entropy, Balanced):
            a_i.data *= (- (2 * lr_a * m)).exp()
            a_i.data /= a_i.data.sum(1)
        else:
            a_i.data *= (- (2 * lr_a * m)).exp()
        print(f"At step {i} the total mass is {a_i.sum().item()}")
    plt.title(f"t = {i / (Nsteps - 1):1.2f}, elapsed time: "
              f"{(time.time() - t_0) / Nsteps} s/it",
              fontsize=30)
    plt.savefig(fname)
    plt.show()


if __name__ == '__main__':
    setting = 0
    p, cost = 2, euclidean_cost(2)
    solver = BatchVanillaSinkhorn(nits=5000, nits_grad=15, tol=1e-8,
                                  assume_convergence=True)

    if setting == 0:  # Compare KL for various mass steps
        gradient_flow(sinkhorn_divergence, entropy=KullbackLeibler(1e-2, 0.3),
                      solver=solver, cost=cost, p=p,
                      lr_x=60., lr_a=0., Nsteps=300)
        gradient_flow(sinkhorn_divergence, entropy=KullbackLeibler(1e-2, 0.3),
                      solver=solver, cost=cost, p=p,
                      lr_x=60., lr_a=0.3, Nsteps=300)
        gradient_flow(sinkhorn_divergence, entropy=KullbackLeibler(1e-2, 0.3),
                      solver=solver, cost=cost, p=p,
                      lr_x=60., lr_a=0.5, Nsteps=300)
        gradient_flow(sinkhorn_divergence, entropy=KullbackLeibler(1e-2, 0.3),
                      solver=solver, cost=cost, p=p,