Ejemplo n.º 1
0
def test_barycenter_stabilized_vs_sinkhorn():
    # test generalized sinkhorn for unbalanced OT barycenter
    n = 100
    rng = np.random.RandomState(42)

    x = rng.randn(n, 2)
    A = rng.rand(n, 2)

    # make dists unbalanced
    A = A * np.array([1, 4])[None, :]
    M = ot.dist(x, x)
    epsilon = 0.5
    reg_m = 10

    qstable, log = barycenter_unbalanced(
        A,
        M,
        reg=epsilon,
        reg_m=reg_m,
        log=True,
        tau=100,
        method="sinkhorn_stabilized",
    )
    q, log = barycenter_unbalanced(A,
                                   M,
                                   reg=epsilon,
                                   reg_m=reg_m,
                                   method="sinkhorn",
                                   log=True)

    np.testing.assert_allclose(q, qstable, atol=1e-05)
Ejemplo n.º 2
0
def test_unbalanced_barycenter(method):
    # test generalized sinkhorn for unbalanced OT barycenter
    n = 100
    rng = np.random.RandomState(42)

    x = rng.randn(n, 2)
    A = rng.rand(n, 2)

    # make dists unbalanced
    A = A * np.array([1, 2])[None, :]
    M = ot.dist(x, x)
    epsilon = 1.
    reg_m = 1.

    q, log = barycenter_unbalanced(A,
                                   M,
                                   reg=epsilon,
                                   reg_m=reg_m,
                                   method=method,
                                   log=True)
    # check fixed point equations
    fi = reg_m / (reg_m + epsilon)
    logA = np.log(A + 1e-16)
    logq = np.log(q + 1e-16)[:, None]
    logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
                       axis=0)
    logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
    v_final = fi * (logq - logKtu)
    u_final = fi * (logA - logKv)

    np.testing.assert_allclose(u_final, log["logu"], atol=1e-05)
    np.testing.assert_allclose(v_final, log["logv"], atol=1e-05)
Ejemplo n.º 3
0
def test_implemented_methods(nx):
    IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
    TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
    NOT_VALID_TOKENS = ['foo']
    # test generalized sinkhorn for unbalanced OT barycenter
    n = 3
    rng = np.random.RandomState(42)

    x = rng.randn(n, 2)
    a = ot.utils.unif(n)

    # make dists unbalanced
    b = ot.utils.unif(n) * 1.5
    A = rng.rand(n, 2)
    M = ot.dist(x, x)
    epsilon = 1.
    reg_m = 1.

    a, b, M, A = nx.from_numpy(a, b, M, A)

    for method in IMPLEMENTED_METHODS:
        ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
                                          method=method)
        ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
                                           method=method)
        barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
                              method=method)
    with pytest.warns(UserWarning, match='not implemented'):
        for method in set(TO_BE_IMPLEMENTED_METHODS):
            ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
                                              method=method)
            ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
                                               method=method)
            barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
                                  method=method)
    with pytest.raises(ValueError):
        for method in set(NOT_VALID_TOKENS):
            ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
                                              method=method)
            ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
                                               method=method)
            barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
                                  method=method)