Esempio n. 1
0
def test_sliced_backend(nx):

    n = 100
    rng = np.random.RandomState(0)

    x = rng.randn(n, 2)
    y = rng.randn(2 * n, 2)

    P = rng.randn(2, 20)
    P = P / np.sqrt((P**2).sum(0, keepdims=True))

    n_projections = 20

    xb, yb, Pb = nx.from_numpy(x, y, P)

    val0 = ot.sliced_wasserstein_distance(x, y, projections=P)

    val = ot.sliced_wasserstein_distance(xb,
                                         yb,
                                         n_projections=n_projections,
                                         seed=0)
    val2 = ot.sliced_wasserstein_distance(xb,
                                          yb,
                                          n_projections=n_projections,
                                          seed=0)

    assert val > 0
    assert val == val2

    valb = nx.to_numpy(ot.sliced_wasserstein_distance(xb, yb, projections=Pb))

    assert np.allclose(val0, valb)
Esempio n. 2
0
def test_sliced_same_dist():
    n = 100
    rng = np.random.RandomState(0)

    x = rng.randn(n, 2)
    u = ot.utils.unif(n)

    res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng)
    np.testing.assert_almost_equal(res, 0.)
Esempio n. 3
0
def test_sliced_different_dists():
    n = 100
    rng = np.random.RandomState(0)

    x = rng.randn(n, 2)
    u = ot.utils.unif(n)
    y = rng.randn(n, 2)

    res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
    assert res > 0.
Esempio n. 4
0
def test_sliced_bad_shapes():
    n = 100
    rng = np.random.RandomState(0)

    x = rng.randn(n, 2)
    y = rng.randn(n, 4)
    u = ot.utils.unif(n)

    with pytest.raises(ValueError):
        _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
Esempio n. 5
0
def test_sliced_backend_device_tf():
    nx = ot.backend.TensorflowBackend()
    n = 100
    rng = np.random.RandomState(0)
    x = rng.randn(n, 2)
    y = rng.randn(2 * n, 2)
    P = rng.randn(2, 20)
    P = P / np.sqrt((P**2).sum(0, keepdims=True))

    # Check that everything stays on the CPU
    with tf.device("/CPU:0"):
        xb, yb, Pb = nx.from_numpy(x, y, P)
        valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
        nx.assert_same_dtype_device(xb, valb)

    if len(tf.config.list_physical_devices('GPU')) > 0:
        # Check that everything happens on the GPU
        xb, yb, Pb = nx.from_numpy(x, y, P)
        valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
        nx.assert_same_dtype_device(xb, valb)
        assert nx.dtype_device(valb)[1].startswith("GPU")
Esempio n. 6
0
def test_1d_sliced_equals_emd():
    n = 100
    m = 120
    rng = np.random.RandomState(0)

    x = rng.randn(n, 1)
    a = rng.uniform(0, 1, n)
    a /= a.sum()
    y = rng.randn(m, 1)
    u = ot.utils.unif(m)
    res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42)
    expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u)
    np.testing.assert_almost_equal(res**2, expected)
Esempio n. 7
0
def test_sliced_backend_type_devices(nx):
    n = 100
    rng = np.random.RandomState(0)

    x = rng.randn(n, 2)
    y = rng.randn(2 * n, 2)

    P = rng.randn(2, 20)
    P = P / np.sqrt((P**2).sum(0, keepdims=True))

    for tp in nx.__type_list__:
        print(nx.dtype_device(tp))

        xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp)

        valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)

        nx.assert_same_dtype_device(xb, valb)
Esempio n. 8
0
def test_sliced_log():
    n = 100
    rng = np.random.RandomState(0)

    x = rng.randn(n, 4)
    y = rng.randn(n, 4)
    u = ot.utils.unif(n)

    res, log = ot.sliced_wasserstein_distance(x,
                                              y,
                                              u,
                                              u,
                                              10,
                                              seed=rng,
                                              log=True)
    assert len(log) == 2
    projections = log["projections"]
    projected_emds = log["projected_emds"]

    assert len(projections) == len(projected_emds) == 10
    for emd in projected_emds:
        assert emd > 0
Esempio n. 9
0
def empirical_dist(Xs,Xt):
    M = ot.dist(Xs, Xt, metric='sqeuclidean')
    M /= M.max()

    # EMD Transport
    # ot_emd = ot.da.EMDTransport()
    # ot_emd.fit(Xs=Xs, Xt=Xt)
    #
    # # Sinkhorn Transport
    # ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
    # ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
    #
    #
    # # transport source samples onto target samples
    # transp_Xs_emd = ot_emd.transform(Xs=Xs)
    # transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
    n = Xs.shape[0]

    n_seed = 50


    res = ot.sliced_wasserstein_distance(Xs, Xt)

    return res
Esempio n. 10
0
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')

###################################################################################
# Compute Sliced Wasserstein distance for different seeds and number of projections
# -----------

n_seed = 50
n_projections_arr = np.logspace(0, 3, 25, dtype=int)
res = np.empty((n_seed, 25))

# %% Compute statistics
for seed in range(n_seed):
    for i, n_projections in enumerate(n_projections_arr):
        res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed)

res_mean = np.mean(res, axis=0)
res_std = np.std(res, axis=0)

###################################################################################
# Plot Sliced Wasserstein Distance
# -----------

pl.figure(2)
pl.plot(n_projections_arr, res_mean, label="SWD")
pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5)

pl.legend()
pl.xscale('log')
lr = 1e3
nb_iter_max = 100

x_all = np.zeros((nb_iter_max, x1.shape[0], 2))

loss_iter = []

# generator for random permutations
gen = torch.Generator()
gen.manual_seed(42)

for i in range(nb_iter_max):

    loss = ot.sliced_wasserstein_distance(x1_torch,
                                          x2_torch,
                                          n_projections=20,
                                          seed=gen)

    loss_iter.append(loss.clone().detach().cpu().numpy())
    loss.backward()

    # performs a step of projected gradient descent
    with torch.no_grad():
        grad = x1_torch.grad
        x1_torch -= grad * lr / (1 + i / 5e1)  # step
        x1_torch.grad.zero_()
        x_all[i, :, :] = x1_torch.clone().detach().cpu().numpy()

xb = x1_torch.clone().detach().cpu().numpy()

pl.figure(2, (8, 4))