import pytest

import torch

from common.sinkhorn import BatchVanillaSinkhorn, BatchScalingSinkhorn, BatchExpSinkhorn
from common.entropy import KullbackLeibler, Balanced, TotalVariation, Range, PowerEntropy
from common.utils import generate_measure, euclidean_cost

torch.set_default_tensor_type(torch.DoubleTensor)



@pytest.mark.parametrize('entropy', [Balanced(1e1), KullbackLeibler(1e1, 1e0), TotalVariation(1e1, 1e0),
                                     Range(1e1, 0.3, 2), PowerEntropy(1e1, 1e0, 0), PowerEntropy(1e1, 1e0, -1)])
@pytest.mark.parametrize('solv', [BatchVanillaSinkhorn(nits=10, nits_grad=10, tol=1e-5, assume_convergence=True),
                                  BatchVanillaSinkhorn(nits=10, nits_grad=10, tol=1e-5, assume_convergence=False),
                                  BatchScalingSinkhorn(budget=10, nits_grad=10, assume_convergence=True),
                                  BatchScalingSinkhorn(budget=10, nits_grad=10, assume_convergence=False),
                                  BatchExpSinkhorn(nits=10, nits_grad=10, tol=1e-5, assume_convergence=True),
                                  BatchExpSinkhorn(nits=10, nits_grad=10, tol=1e-5, assume_convergence=False)])
def test_sinkhorn_no_bug(entropy, solv):
    a, x = generate_measure(2, 5, 3)
    b, y = generate_measure(2, 6, 3)
    solv.sinkhorn_asym(a, x, b, y, cost=euclidean_cost(1), entropy=entropy)
    solv.sinkhorn_sym(a, x, cost=euclidean_cost(1), entropy=entropy, y_j=y)


# TODO: Adapt the error function for TV due to translation invariance when masses are both 1
@pytest.mark.parametrize('p', [1, 1.5, 2])
@pytest.mark.parametrize('reach', [0.5, 1., 2.])
@pytest.mark.parametrize('m,n', [(1., 1.), (0.7, 2.), (0.5, 0.7), (1.5, 2.)])
Example #2
0
from common.utils import generate_measure, convolution, scal, euclidean_cost

torch.set_printoptions(precision=10)
torch.set_default_tensor_type(torch.DoubleTensor)
solver = BatchVanillaSinkhorn(nits=5000,
                              nits_grad=5,
                              tol=1e-15,
                              assume_convergence=True)


@pytest.mark.parametrize('p', [1, 1.5, 2])
@pytest.mark.parametrize('reach', [0.5, 1., 2.])
@pytest.mark.parametrize('m', [1., 0.7, 2.])
@pytest.mark.parametrize('entropy', [
    KullbackLeibler(1e0, 1e0),
    Balanced(1e0),
    TotalVariation(1e0, 1e0),
    Range(1e0, 0.3, 2),
    PowerEntropy(1e0, 1e0, 0),
    PowerEntropy(1e0, 1e0, -1)
])
@pytest.mark.parametrize('div', [sinkhorn_divergence, hausdorff_divergence])
def test_divergence_zero(div, entropy, reach, p, m):
    entropy.reach = reach
    cost = euclidean_cost(p)
    a, x = generate_measure(1, 5, 2)
    func = div(m * a, x, m * a, x, cost, entropy, solver=solver)
    assert torch.allclose(func, torch.Tensor([0.0]), rtol=1e-6)


@pytest.mark.parametrize('p', [1, 1.5, 2])
Example #3
0
                      lr_a=0.3,
                      Nsteps=300)

    if setting == 2:  # Compute Kl dynamic for an almost L1 metric
        gradient_flow(sinkhorn_divergence,
                      KullbackLeibler(1e-2, 0.3),
                      solver=solver,
                      cost=euclidean_cost(1.1),
                      p=1.1,
                      lr_x=10.,
                      lr_a=0.3,
                      Nsteps=300)

    if setting == 3:  # Compare Balanced OT with and without mass creation allowed
        gradient_flow(sinkhorn_divergence,
                      entropy=Balanced(1e-3),
                      solver=solver,
                      cost=cost,
                      p=p,
                      lr_x=60.,
                      lr_a=0.,
                      Nsteps=300)
        gradient_flow(sinkhorn_divergence,
                      entropy=Balanced(1e-3),
                      solver=solver,
                      cost=cost,
                      p=p,
                      lr_x=60.,
                      lr_a=0.3,
                      Nsteps=300)
import numpy as np
import torch
import matplotlib.pyplot as plt
from common.entropy import Balanced, KullbackLeibler, TotalVariation, Range, PowerEntropy

x = torch.linspace(-5, 5, 200)
L_entropy = [Balanced(1e0), KullbackLeibler(1e0, 1e0), TotalVariation(1e0, 1e0),Range(1e0, 0.5, 2),
             PowerEntropy(1e0, 1e0, 0), PowerEntropy(1e0, 1e0, -1)]
L_name = ['Balanced', 'KL', '$RG_{[0.5,2]}$', 'TV', 'Berg', 'Hellinger']
for entropy, name in zip(L_entropy, L_name):
    aprox = entropy.aprox
    x_, y_ = x.data.numpy(), (- aprox( -x )).squeeze().data.numpy()
    plt.plot(x_, y_, label=name)

plt.xlabel('p', fontsize=16)
plt.ylabel('-aprox(-p)', fontsize=16)
plt.legend(fontsize=13)
plt.tight_layout()
plt.savefig('output/fig_aprox.eps', format='eps', transparent=True)
plt.show()
Example #5
0
import pytest

import torch
from common.functional import regularized_ot, hausdorff_divergence, sinkhorn_divergence, energyDistance
from common.sinkhorn import BatchVanillaSinkhorn
from common.entropy import KullbackLeibler, Balanced, TotalVariation, Range, PowerEntropy
from common.utils import generate_measure, euclidean_cost

torch.set_default_tensor_type(torch.cuda.FloatTensor)
solver = BatchVanillaSinkhorn(nits=10, tol=0, assume_convergence=True)

@pytest.mark.parametrize('entropy', [KullbackLeibler(1e0, 1e0), Balanced(1e0), TotalVariation(1e0, 1e0),
                                     Range(1e0, 0.3, 2), PowerEntropy(1e0, 1e0, 0), PowerEntropy(1e0, 1e0, -1)])
def test_divergence_zero(entropy):
    a, x = generate_measure(1, 5, 2)
    a, x = a.float().cuda(), x.float().cuda()
    b, y = generate_measure(1, 6, 2)
    b, y = b.float().cuda(), y.float().cuda()
    sinkhorn_divergence(a, x, b, y, cost=euclidean_cost(2), entropy=entropy, solver=solver)
Example #6
0
    func = entropy.output_regularized(a, x, n * b, y, cost, f, g)
    [grad_num_x, grad_num_a] = torch.autograd.grad(func, [x, a])
    grad_th_a = - entropy.legendre_entropy(-f) + entropy.blur * n - entropy.blur * \
              ((f[:,:,None ] + g[:,None,:] - dist_matrix(x, y, p)).exp() * n * b[:,None,:]).sum(dim=2)
    pi = n * a[:, :, None] * b[:, None, :] * (
        (f[:, :, None] + g[:, None, :] - dist_matrix(x, y, p)) /
        entropy.blur).exp()
    grad_th_x = 2 * x * pi.sum(dim=2)[:, :, None] - 2 * torch.einsum(
        'ijk, ikl->ijl', pi, y)
    assert torch.allclose(grad_th_a, grad_num_a, rtol=1e-5)
    assert torch.allclose(grad_th_x, grad_num_x, rtol=1e-5)


@pytest.mark.parametrize('p', [2])
@pytest.mark.parametrize('m', [1., 0.7, 1.5])
@pytest.mark.parametrize('entropy', [Balanced(1e0)])
@pytest.mark.parametrize('div', [regularized_ot])
@pytest.mark.parametrize('solv', [
    BatchVanillaSinkhorn(
        nits=5000, nits_grad=1, tol=1e-14, assume_convergence=True),
    BatchExpSinkhorn(
        nits=5000, nits_grad=1, tol=1e-14, assume_convergence=True)
])
def test_gradient_balanced_weight_and_position_asym(solv, div, entropy, p, m):
    cost = euclidean_cost(p)
    a, x = generate_measure(1, 5, 2)
    b, y = generate_measure(1, 6, 2)
    a, b = m * a, m * b
    a.requires_grad = True
    x.requires_grad = True
    f, g = solv.sinkhorn_asym(a, x, b, y, cost, entropy)
    b2 = b2 / np.sum(b2)

    y = np.concatenate((y1, y2))
    b = np.concatenate((0.45 * b1, 0.55 * b2))
    b = b / np.sum(b)

    return a, x, b, y

# Init of measures and solvers
a, x, b, y = template_measure(250)
A, X, B, Y = torch.from_numpy(a)[None, :], torch.from_numpy(x)[None, :, None], torch.from_numpy(b)[None, :], \
             torch.from_numpy(y)[None, :, None]
p, blur, reach = 2, 1e-3, 0.1
cost = euclidean_cost(p)
solver = BatchVanillaSinkhorn(nits=10000, nits_grad=1, tol=1e-5, assume_convergence=True)
list_entropy = [Balanced(blur), KullbackLeibler(blur, reach), TotalVariation(blur, reach), Range(blur, 0.7, 1.3),
                PowerEntropy(blur, reach, 0.)]

# Init of plot
blue = (.55,.55,.95)
red = (.95,.55,.55)
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(40,12))
ax[0, 0].fill_between(x, 0, a, color='b')
ax[0, 0].fill_between(y, 0, b, color='r')
ax[0, 0].set_title('Input Marginals', fontsize=50)
ax[0, 0].set_yticklabels([])
ax[0, 0].set_xticklabels([])

# Plotting transport marginals for each entropy
k = 1
for entropy in list_entropy: