def test_algorithm(save_trace, algorithm, device): if device == 'cuda' and not torch.cuda.is_available(): pytest.skip('No cuda') x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 100) x, la, y, lb, x_sampler, y_sampler = (x.to(device), la.to(device), y.to(device), lb.to(device), x_sampler.to(device), y_sampler.to(device)) if save_trace == 'ref': F, G = sinkhorn(x=x, la=la, y=y, lb=lb, n_iter=10) ref = {'train': (F(x), x, G(y), y), 'test': (F(x), x, G(y), y)} save_trace = True else: if save_trace and algorithm == 'random_sinkhorn': pytest.skip('Unsupported configuration') ref = None funcs = {'sinkhorn': sinkhorn, 'subsampled_sinkhorn': subsampled_sinkhorn, 'random_sinkhorn': random_sinkhorn} func = funcs[algorithm] res = func(x=x, la=la, y=y, lb=lb, save_trace=save_trace, ref=ref, n_iter=10) if save_trace: F, G, trace = res assert (len(trace) == 10) else: F, G = res assert not np.isnan(F(x).sum().item()) assert not np.isnan(G(y).sum().item())
def test_online_sinkhorn(save_trace, refit, force_full, use_finite, lr, batch_size, device): if device == 'cuda' and not torch.cuda.is_available(): pytest.skip('No cuda') if not use_finite and force_full: pytest.skip('Invalid configuration') x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 10) x, la, y, lb, x_sampler, y_sampler = (x.to(device), la.to(device), y.to(device), lb.to(device), x_sampler.to(device), y_sampler.to(device)) if save_trace: F, G = sinkhorn(x=x, la=la, y=y, lb=lb, n_iter=10) ref = {'train': (F(x), x, G(y), y), 'test': (F(x), x, G(y), y)} else: ref = None if use_finite: input = dict(x=x, la=la, y=y, lb=lb) else: input = dict(x_sampler=x_sampler, y_sampler=y_sampler) n_iter = 15 if batch_size == 'constant': batch_sizes = 10 else: batch_sizes = [10 * (i + 1) for i in range(n_iter)] if lr == 'constant': lrs = 1 else: lrs = [1 / np.sqrt(i + 1) for i in range(n_iter)] res = online_sinkhorn(**input, save_trace=save_trace, force_full=force_full, use_finite=use_finite, ref=ref, n_iter=n_iter, batch_sizes=batch_sizes, refit=refit, lrs=lrs) if save_trace: F, G, trace = res assert (len(trace) == 15) else: F, G = res assert not np.isnan(F(x).sum().item()) assert not np.isnan(G(y).sum().item())
def test_precompute_C(): x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 100) F, G = sinkhorn(x=x, la=la, y=y, lb=lb, n_iter=10, precompute_C=False) Fp, Gp = sinkhorn(x=x, la=la, y=y, lb=lb, n_iter=10, precompute_C=True) assert_allclose(Fp(x), F(x)) assert_allclose(Gp(y), G(y))
def run(data_source, n_samples, epsilon, n_iter, device, method, max_calls, batch_exp, batch_size, lr, lr_exp, max_length, refit, _seed, _run): if refit: max_length = 20000 np.random.seed(_seed) torch.manual_seed(_seed) output_dir = join(exp.observers[0].dir, 'artifacts') if not os.path.exists(output_dir): os.makedirs(output_dir) x, la, y, lb, x_sampler, y_sampler = make_data(data_source, n_samples) x, y, la, lb = x.to(device), y.to(device), la.to(device), lb.to(device) x_sampler.to(device) y_sampler.to(device) if 'gaussian' in data_source: F, G = sinkhorn_gaussian(x_sampler, y_sampler, epsilon=epsilon) xr, lar, yr, lbr, _, _ = make_data(data_source, n_samples) xr, yr, lar, lbr = xr.to(device), yr.to(device), lar.to(device), lbr.to(device) ref = {'test': (F(xr), xr, G(yr), yr)} else: F, G, trace = torch_cached(sinkhorn)(x, la, y, lb, n_iter=n_iter, epsilon=epsilon, save_trace=True, verbose=False, max_calls=max_calls * 4, count_recompute=True) xr, lar, yr, lbr, _, _ = make_data(data_source, n_samples) xr, yr, lar, lbr = xr.to(device), yr.to(device), lar.to(device), lbr.to(device) Fr, Gr, tracer = torch_cached(sinkhorn)(xr, lar, yr, lbr, n_iter=n_iter, epsilon=epsilon, save_trace=True, verbose=False, max_calls=max_calls * 4, count_recompute=True) ref = {'train': (F(x), x, G(y), y), 'test': (Fr(xr), xr, G(yr), yr)} if max_calls is None: max_calls = tracer[-1]['n_calls'] if method == 'sinkhorn_precompute': F, G, trace = sinkhorn(x, la, y, lb, n_iter=n_iter, epsilon=epsilon, save_trace=True, ref=ref, max_calls=max_calls) elif method == 'sinkhorn': F, G, trace = sinkhorn(x, la, y, lb, n_iter=n_iter, epsilon=epsilon, save_trace=True, ref=ref, max_calls=max_calls, count_recompute=True) elif method == 'subsampled': F, G, trace = subsampled_sinkhorn(x, la, y, lb, n_iter=n_iter, batch_size=batch_size, max_calls=max_calls, epsilon=epsilon, save_trace=True, ref=ref, count_recompute=True) elif method == 'random': F, G, trace = random_sinkhorn(x_sampler=x_sampler, y_sampler=y_sampler, n_iter=n_iter, epsilon=epsilon, save_trace=True, ref=ref, use_finite=False, batch_sizes=batch_size, max_calls=max_calls) elif method in ['online', 'online_on_finite', 'online_as_warmup']: if n_iter is None: n_iter = int(1e6) batch_sizes, lrs, lr_exp = schedule(batch_exp, batch_size, lr, lr_exp, max_length, n_iter, refit) print(f'Using lr_exp={lr_exp}') _run.info['lr_exp'] = lr_exp if method == 'online': F, G, trace = online_sinkhorn(x_sampler=x_sampler, y_sampler=y_sampler, batch_sizes=batch_sizes, refit=refit, force_full=False, lrs=lrs, n_iter=n_iter, use_finite=False, max_length=max_length, epsilon=epsilon, save_trace=True, ref=ref, max_calls=max_calls) elif method == 'online_on_finite': F, G, trace = online_sinkhorn(x=x, la=la, y=y, lb=lb, batch_sizes=batch_sizes, refit=refit, force_full=True, lrs=lrs, n_iter=n_iter, use_finite=False, max_length=max_length, max_calls=max_calls, epsilon=epsilon, save_trace=True, ref=ref) elif method == 'online_as_warmup': F, G, trace = online_sinkhorn(x=x, la=la, y=y, lb=lb, batch_sizes=batch_sizes, refit=refit, force_full=True, lrs=lrs, n_iter=n_iter, use_finite=True, max_length=max_length, max_calls=max_calls, epsilon=epsilon, save_trace=True, ref=ref) else: raise ValueError else: raise ValueError torch.save(dict(x=x, la=la, y=y, lb=lb, F=F, G=G, trace=trace), join(output_dir, 'results.pkl'))