def test_algorithm(save_trace, algorithm, device): if device == 'cuda' and not torch.cuda.is_available(): pytest.skip('No cuda') x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 100) x, la, y, lb, x_sampler, y_sampler = (x.to(device), la.to(device), y.to(device), lb.to(device), x_sampler.to(device), y_sampler.to(device)) if save_trace == 'ref': F, G = sinkhorn(x=x, la=la, y=y, lb=lb, n_iter=10) ref = {'train': (F(x), x, G(y), y), 'test': (F(x), x, G(y), y)} save_trace = True else: if save_trace and algorithm == 'random_sinkhorn': pytest.skip('Unsupported configuration') ref = None funcs = {'sinkhorn': sinkhorn, 'subsampled_sinkhorn': subsampled_sinkhorn, 'random_sinkhorn': random_sinkhorn} func = funcs[algorithm] res = func(x=x, la=la, y=y, lb=lb, save_trace=save_trace, ref=ref, n_iter=10) if save_trace: F, G, trace = res assert (len(trace) == 10) else: F, G = res assert not np.isnan(F(x).sum().item()) assert not np.isnan(G(y).sum().item())
def test_gmm_sampler(device): if device == 'cuda' and not torch.cuda.is_available(): pytest.skip('No cuda') x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 100) x_sampler = x_sampler.to(device) x, la, xidx = x_sampler(10) assert (x.device.type == device) assert (la.device.type == device)
def test_gaussian_algorithm(device): if device == 'cuda' and not torch.cuda.is_available(): pytest.skip('No cuda') x, la, y, lb, x_sampler, y_sampler = make_data('gaussian_2d', 100) x, la, y, lb, x_sampler, y_sampler = (x.to(device), la.to(device), y.to(device), lb.to(device), x_sampler.to(device), y_sampler.to(device)) F, G = sinkhorn_gaussian(x_sampler=x_sampler, y_sampler=y_sampler) assert not np.isnan(F(x).sum().item()) assert not np.isnan(G(y).sum().item())
def test_online_sinkhorn(save_trace, refit, force_full, use_finite, lr, batch_size, device): if device == 'cuda' and not torch.cuda.is_available(): pytest.skip('No cuda') if not use_finite and force_full: pytest.skip('Invalid configuration') x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 10) x, la, y, lb, x_sampler, y_sampler = (x.to(device), la.to(device), y.to(device), lb.to(device), x_sampler.to(device), y_sampler.to(device)) if save_trace: F, G = sinkhorn(x=x, la=la, y=y, lb=lb, n_iter=10) ref = {'train': (F(x), x, G(y), y), 'test': (F(x), x, G(y), y)} else: ref = None if use_finite: input = dict(x=x, la=la, y=y, lb=lb) else: input = dict(x_sampler=x_sampler, y_sampler=y_sampler) n_iter = 15 if batch_size == 'constant': batch_sizes = 10 else: batch_sizes = [10 * (i + 1) for i in range(n_iter)] if lr == 'constant': lrs = 1 else: lrs = [1 / np.sqrt(i + 1) for i in range(n_iter)] res = online_sinkhorn(**input, save_trace=save_trace, force_full=force_full, use_finite=use_finite, ref=ref, n_iter=n_iter, batch_sizes=batch_sizes, refit=refit, lrs=lrs) if save_trace: F, G, trace = res assert (len(trace) == 15) else: F, G = res assert not np.isnan(F(x).sum().item()) assert not np.isnan(G(y).sum().item())
def test_precompute_C(): x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 100) F, G = sinkhorn(x=x, la=la, y=y, lb=lb, n_iter=10, precompute_C=False) Fp, Gp = sinkhorn(x=x, la=la, y=y, lb=lb, n_iter=10, precompute_C=True) assert_allclose(Fp(x), F(x)) assert_allclose(Gp(y), G(y))
def run(data_source, n_samples, epsilon, n_iter, device, method, max_calls, batch_exp, batch_size, lr, lr_exp, max_length, refit, _seed, _run): if refit: max_length = 20000 np.random.seed(_seed) torch.manual_seed(_seed) output_dir = join(exp.observers[0].dir, 'artifacts') if not os.path.exists(output_dir): os.makedirs(output_dir) x, la, y, lb, x_sampler, y_sampler = make_data(data_source, n_samples) x, y, la, lb = x.to(device), y.to(device), la.to(device), lb.to(device) x_sampler.to(device) y_sampler.to(device) if 'gaussian' in data_source: F, G = sinkhorn_gaussian(x_sampler, y_sampler, epsilon=epsilon) xr, lar, yr, lbr, _, _ = make_data(data_source, n_samples) xr, yr, lar, lbr = xr.to(device), yr.to(device), lar.to(device), lbr.to(device) ref = {'test': (F(xr), xr, G(yr), yr)} else: F, G, trace = torch_cached(sinkhorn)(x, la, y, lb, n_iter=n_iter, epsilon=epsilon, save_trace=True, verbose=False, max_calls=max_calls * 4, count_recompute=True) xr, lar, yr, lbr, _, _ = make_data(data_source, n_samples) xr, yr, lar, lbr = xr.to(device), yr.to(device), lar.to(device), lbr.to(device) Fr, Gr, tracer = torch_cached(sinkhorn)(xr, lar, yr, lbr, n_iter=n_iter, epsilon=epsilon, save_trace=True, verbose=False, max_calls=max_calls * 4, count_recompute=True) ref = {'train': (F(x), x, G(y), y), 'test': (Fr(xr), xr, G(yr), yr)} if max_calls is None: max_calls = tracer[-1]['n_calls'] if method == 'sinkhorn_precompute': F, G, trace = sinkhorn(x, la, y, lb, n_iter=n_iter, epsilon=epsilon, save_trace=True, ref=ref, max_calls=max_calls) elif method == 'sinkhorn': F, G, trace = sinkhorn(x, la, y, lb, n_iter=n_iter, epsilon=epsilon, save_trace=True, ref=ref, max_calls=max_calls, count_recompute=True) elif method == 'subsampled': F, G, trace = subsampled_sinkhorn(x, la, y, lb, n_iter=n_iter, batch_size=batch_size, max_calls=max_calls, epsilon=epsilon, save_trace=True, ref=ref, count_recompute=True) elif method == 'random': F, G, trace = random_sinkhorn(x_sampler=x_sampler, y_sampler=y_sampler, n_iter=n_iter, epsilon=epsilon, save_trace=True, ref=ref, use_finite=False, batch_sizes=batch_size, max_calls=max_calls) elif method in ['online', 'online_on_finite', 'online_as_warmup']: if n_iter is None: n_iter = int(1e6) batch_sizes, lrs, lr_exp = schedule(batch_exp, batch_size, lr, lr_exp, max_length, n_iter, refit) print(f'Using lr_exp={lr_exp}') _run.info['lr_exp'] = lr_exp if method == 'online': F, G, trace = online_sinkhorn(x_sampler=x_sampler, y_sampler=y_sampler, batch_sizes=batch_sizes, refit=refit, force_full=False, lrs=lrs, n_iter=n_iter, use_finite=False, max_length=max_length, epsilon=epsilon, save_trace=True, ref=ref, max_calls=max_calls) elif method == 'online_on_finite': F, G, trace = online_sinkhorn(x=x, la=la, y=y, lb=lb, batch_sizes=batch_sizes, refit=refit, force_full=True, lrs=lrs, n_iter=n_iter, use_finite=False, max_length=max_length, max_calls=max_calls, epsilon=epsilon, save_trace=True, ref=ref) elif method == 'online_as_warmup': F, G, trace = online_sinkhorn(x=x, la=la, y=y, lb=lb, batch_sizes=batch_sizes, refit=refit, force_full=True, lrs=lrs, n_iter=n_iter, use_finite=True, max_length=max_length, max_calls=max_calls, epsilon=epsilon, save_trace=True, ref=ref) else: raise ValueError else: raise ValueError torch.save(dict(x=x, la=la, y=y, lb=lb, F=F, G=G, trace=trace), join(output_dir, 'results.pkl'))
def run(data_source, n_samples, epsilon, n_iter, device, method, max_calls, compare_with_ref, use_test, n_eval, precompute_C, force_full, batch_exp, batch_size, lr, lr_exp, max_length, refit, _seed, _run): np.random.seed(_seed) torch.manual_seed(_seed) output_dir = join(exp.observers[0].dir, 'artifacts') if not os.path.exists(output_dir): os.makedirs(output_dir) x, la, y, lb, x_sampler, y_sampler = make_data(data_source, n_samples) x, y, la, lb = x.to(device), y.to(device), la.to(device), lb.to(device) x_sampler.to(device) y_sampler.to(device) if n_iter is None: n_iter = int(1e4) if 'dragon' in data_source: use_test = False # Useless if use_test: xr, lar, yr, lbr, _, _ = make_data(data_source, n_samples) xr, yr, lar, lbr = xr.to(device), yr.to(device), lar.to(device), lbr.to(device) if compare_with_ref: ref_calls = max_calls * 4 F, G = torch_cached(sinkhorn)(x, la, y, lb, n_iter=n_iter, epsilon=epsilon, save_trace=False, verbose=False, max_calls=ref_calls, count_recompute=True) ref = {'train': (F(x), x, G(y), y)} if use_test: if 'gaussian' in data_source: Fr, Gr = sinkhorn_gaussian(x_sampler, y_sampler, epsilon=epsilon) # Exact reference else: Fr, Gr = torch_cached(sinkhorn)(xr, lar, yr, lbr, n_iter=n_iter, epsilon=epsilon, save_trace=False, verbose=False, max_calls=ref_calls, count_recompute=True) if use_test: ref['test'] = (Fr(xr), xr, Gr(yr), yr) else: ref = {'train': (None, x, None, y)} if use_test: ref['test'] = (None, xr, None, yr) if method == 'sinkhorn': n_iter = min(n_iter, int(2e3)) # Faster F, G, trace = subsampled_sinkhorn(x, la, y, lb, n_iter=n_iter, batch_size=batch_size, max_calls=max_calls, precompute_C=precompute_C, trace_every=max_calls // n_eval, epsilon=epsilon, save_trace=True, ref=ref) elif method == 'random': F, G, trace = random_sinkhorn(x_sampler=x_sampler, y_sampler=y_sampler, n_iter=n_iter, epsilon=epsilon, save_trace=True, ref=ref, use_finite=False, batch_sizes=batch_size, trace_every=max_calls // n_eval, max_calls=max_calls) elif method == 'online': batch_sizes, lrs, lr_exp = schedule(batch_exp, batch_size, lr, lr_exp, max_length, n_iter, refit) print(f'Using lr_exp={lr_exp}') _run.info['lr_exp'] = lr_exp if method == 'online': F, G, trace = online_sinkhorn(x=x, la=la, y=y, lb=lb, x_sampler=x_sampler, y_sampler=y_sampler, batch_sizes=batch_sizes, refit=refit, force_full=force_full, precompute_C=precompute_C, trace_every=max_calls // n_eval, lrs=lrs, n_iter=n_iter, use_finite=force_full, max_length=max_length, epsilon=epsilon, save_trace=True, ref=ref, max_calls=max_calls) else: raise ValueError torch.save(dict(x=x, la=la, y=y, lb=lb, F=F, G=G, trace=trace), join(output_dir, 'results.pkl'))