def test_sliced_backend(nx): n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 2) y = rng.randn(2 * n, 2) P = rng.randn(2, 20) P = P / np.sqrt((P**2).sum(0, keepdims=True)) n_projections = 20 xb, yb, Pb = nx.from_numpy(x, y, P) val0 = ot.sliced_wasserstein_distance(x, y, projections=P) val = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) val2 = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) assert val > 0 assert val == val2 valb = nx.to_numpy(ot.sliced_wasserstein_distance(xb, yb, projections=Pb)) assert np.allclose(val0, valb)
def test_sliced_same_dist(): n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 2) u = ot.utils.unif(n) res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) np.testing.assert_almost_equal(res, 0.)
def test_sliced_different_dists(): n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 2) u = ot.utils.unif(n) y = rng.randn(n, 2) res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) assert res > 0.
def test_sliced_bad_shapes(): n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 2) y = rng.randn(n, 4) u = ot.utils.unif(n) with pytest.raises(ValueError): _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
def test_sliced_backend_device_tf(): nx = ot.backend.TensorflowBackend() n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 2) y = rng.randn(2 * n, 2) P = rng.randn(2, 20) P = P / np.sqrt((P**2).sum(0, keepdims=True)) # Check that everything stays on the CPU with tf.device("/CPU:0"): xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU")
def test_1d_sliced_equals_emd(): n = 100 m = 120 rng = np.random.RandomState(0) x = rng.randn(n, 1) a = rng.uniform(0, 1, n) a /= a.sum() y = rng.randn(m, 1) u = ot.utils.unif(m) res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) np.testing.assert_almost_equal(res**2, expected)
def test_sliced_backend_type_devices(nx): n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 2) y = rng.randn(2 * n, 2) P = rng.randn(2, 20) P = P / np.sqrt((P**2).sum(0, keepdims=True)) for tp in nx.__type_list__: print(nx.dtype_device(tp)) xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb)
def test_sliced_log(): n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 4) y = rng.randn(n, 4) u = ot.utils.unif(n) res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) assert len(log) == 2 projections = log["projections"] projected_emds = log["projected_emds"] assert len(projections) == len(projected_emds) == 10 for emd in projected_emds: assert emd > 0
def empirical_dist(Xs,Xt): M = ot.dist(Xs, Xt, metric='sqeuclidean') M /= M.max() # EMD Transport # ot_emd = ot.da.EMDTransport() # ot_emd.fit(Xs=Xs, Xt=Xt) # # # Sinkhorn Transport # ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1) # ot_sinkhorn.fit(Xs=Xs, Xt=Xt) # # # # transport source samples onto target samples # transp_Xs_emd = ot_emd.transform(Xs=Xs) # transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs) n = Xs.shape[0] n_seed = 50 res = ot.sliced_wasserstein_distance(Xs, Xt) return res
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.legend(loc=0) pl.title('Source and target distributions') ################################################################################### # Compute Sliced Wasserstein distance for different seeds and number of projections # ----------- n_seed = 50 n_projections_arr = np.logspace(0, 3, 25, dtype=int) res = np.empty((n_seed, 25)) # %% Compute statistics for seed in range(n_seed): for i, n_projections in enumerate(n_projections_arr): res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed) res_mean = np.mean(res, axis=0) res_std = np.std(res, axis=0) ################################################################################### # Plot Sliced Wasserstein Distance # ----------- pl.figure(2) pl.plot(n_projections_arr, res_mean, label="SWD") pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) pl.legend() pl.xscale('log')
lr = 1e3 nb_iter_max = 100 x_all = np.zeros((nb_iter_max, x1.shape[0], 2)) loss_iter = [] # generator for random permutations gen = torch.Generator() gen.manual_seed(42) for i in range(nb_iter_max): loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen) loss_iter.append(loss.clone().detach().cpu().numpy()) loss.backward() # performs a step of projected gradient descent with torch.no_grad(): grad = x1_torch.grad x1_torch -= grad * lr / (1 + i / 5e1) # step x1_torch.grad.zero_() x_all[i, :, :] = x1_torch.clone().detach().cpu().numpy() xb = x1_torch.clone().detach().cpu().numpy() pl.figure(2, (8, 4))