def test_scaling_match(): rnd = np.random.RandomState(42) width = 16 n_tasks = 1 n_features = width**2 coefs = rnd.rand(width, width, n_tasks) coefs_flat = coefs.reshape(n_features, -1) M = utils.groundmetric2d(n_features, p=2, normed=False) K = np.exp(-M) scaling = K.dot(coefs_flat).flatten() scaling_log = np.exp( utils.logsumexp(np.log(coefs_flat)[None, :] - M, axis=1)) M = utils.groundmetric(width, p=2, normed=False) K = np.exp(-M) scaling_conv = utils.klconv1d_list(coefs, K).flatten() scaling_conv_log = np.exp(utils.kls(np.log(coefs), -M)) assert_allclose(scaling, scaling_log.flatten()) assert_allclose(scaling, scaling_conv) assert_allclose(scaling_conv_log.flatten(), scaling_conv)
def test_barycenter_match(): rnd = np.random.RandomState(42) width = 16 n_tasks = 2 n_features = width**2 coefs_flat = rnd.rand(n_features, n_tasks) coefs_flat[n_features // 2] = 0. M = utils.groundmetric2d(n_features, p=2, normed=False) m = np.median(M) epsilon = 5. / n_features gamma = 1. maxiter = 50 K = -M / (m * epsilon) options = dict(P=coefs_flat, M=K, epsilon=epsilon, gamma=gamma, maxiter=maxiter, tol=0.) f, log, ms, b, q = otfunctions.barycenterkl(**options) fl, logl, msl, bl, ql = otfunctions.barycenterkl_log(**options) M = utils.groundmetric(width, p=2, normed=False) K = -M / (m * epsilon) options["M"] = K fc, logc, msc, bc, qc = otfunctions.barycenterkl_img(**options) fcl, logcl, mscl, bcl, qcl = otfunctions.barycenterkl_img_log(**options) assert_allclose(q, ql, rtol=1e-5, atol=1e-5) assert_allclose(ms, msl, rtol=1e-5, atol=1e-5) assert_allclose(b, np.exp(bl), rtol=1e-5, atol=1e-5) assert_allclose(qc, qcl, rtol=1e-5, atol=1e-5) assert_allclose(msc, mscl, rtol=1e-5, atol=1e-5) assert_allclose(bc, np.exp(bcl), rtol=1e-5, atol=1e-5) assert_allclose(q, qc.reshape(n_features), rtol=1e-5, atol=1e-5) assert_allclose(ms, msc.reshape(-1, n_features), rtol=1e-5, atol=1e-5)
def test_mtw_convolution(positive, alpha, epsilon): # Estimator params seed = 42 width, n_tasks = 12, 2 nnz = 2 overlap = 0. denoising = False binary = False corr = 0.9 # Gaussian Noise snr = 4 # Deduce supplementary params n_features = width**2 n_samples = n_features // 2 # ot params epsilon = epsilon / n_features stable = False Mbig = utils.groundmetric2d(n_features, p=2, normed=False) m = np.median(Mbig) M = utils.groundmetric(width, p=2, normed=True) M = utils.groundmetric(width, p=2, normed=False) M /= m # M = Mbig / m # Generate Coefs coefs = generate_dirac_images(width, n_tasks, nnz=nnz, seed=seed, overlap=overlap, binary=binary, positive=positive) coefs_flat = coefs.reshape(-1, n_tasks) # # Generate X, Y data std = 1 / snr X, Y = gaussian_design(n_samples, coefs_flat, corr=corr, sigma=std, denoising=denoising, scaled=True, seed=seed) betamax = np.array([abs(x.T.dot(y)) for x, y in zip(X, Y)]).max() beta_fr = 0.5 beta = beta_fr * betamax / n_samples alpha = alpha / n_samples callback_options = { 'callback': True, 'x_real': coefs_flat.reshape(-1, n_tasks), 'verbose': True, 'rate': 1, 'prc_only': False } gamma = utils.compute_gamma(0.8, M) """Fit mtw_model using convolutions to compute OT barycenters.""" mtw_model = MTW(M=M, alpha=alpha, beta=beta, epsilon=epsilon, gamma=gamma, stable=stable, tol_ot=1e-8, tol=1e-5, maxiter_ot=30, maxiter=200, **callback_options, positive=positive) # first fit mtw_model.fit(X, Y) assert mtw_model.log_['dloss'][-1] < 1e-4 M = Mbig / m # mtw_model using standard sinkhorn mtw_model2 = MTW(M=M, alpha=alpha, beta=beta, epsilon=epsilon, gamma=gamma, stable=stable, tol_ot=1e-8, tol=1e-5, maxiter_ot=30, maxiter=200, **callback_options, positive=positive) mtw_model2.fit(X, Y) assert mtw_model2.log_['dloss'][-1] < 1e-4 assert_allclose(mtw_model.coefs_, mtw_model2.coefs_, atol=1e-5, rtol=1e-5) # get positive / negative parts coefs1, coefs2 = utils.get_unsigned(mtw_model.coefs_)
n_features = width ** 2 n_samples = n_features // 2 """Generate Coefs and X, Y data...""" coefs = generate_dirac_images(width, n_tasks, nnz=nnz, positive=positive, seed=seed, overlap=overlap) coefs_flat = coefs.reshape(-1, n_tasks) std = 0.25 X, Y = gaussian_design(n_samples, coefs_flat, corr=0.95, sigma=std, scaled=True, seed=seed) ############################################################################### # set ot params epsilon = 2.5 / n_features M = utils.groundmetric2d(n_features, p=2, normed=True) gamma = utils.compute_gamma(0.8, M) ############################################################################### # set hyperparameters and fit MTW betamax = np.array([x.T.dot(y) for x, y in zip(X, Y)]).max() / n_samples alpha = 10. / n_samples beta_fr = 0.35 beta = beta_fr * betamax callback_options = {'callback': True, 'x_real': coefs.reshape(- 1, n_tasks), 'verbose': True, 'rate': 1} print("Fitting MTW model...")