コード例 #1
0
def test_algorithm(save_trace, algorithm, device):
    if device == 'cuda' and not torch.cuda.is_available():
        pytest.skip('No cuda')
    x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 100)
    x, la, y, lb, x_sampler, y_sampler = (x.to(device), la.to(device),
                                          y.to(device), lb.to(device), x_sampler.to(device), y_sampler.to(device))
    if save_trace == 'ref':
        F, G = sinkhorn(x=x, la=la, y=y, lb=lb, n_iter=10)
        ref = {'train': (F(x), x, G(y), y), 'test': (F(x), x, G(y), y)}
        save_trace = True
    else:
        if save_trace and algorithm == 'random_sinkhorn':
            pytest.skip('Unsupported configuration')
        ref = None
    funcs = {'sinkhorn': sinkhorn, 'subsampled_sinkhorn': subsampled_sinkhorn,
             'random_sinkhorn': random_sinkhorn}
    func = funcs[algorithm]
    res = func(x=x, la=la, y=y, lb=lb, save_trace=save_trace, ref=ref, n_iter=10)
    if save_trace:
        F, G, trace = res
        assert (len(trace) == 10)
    else:
        F, G = res
    assert not np.isnan(F(x).sum().item())
    assert not np.isnan(G(y).sum().item())
コード例 #2
0
def test_gmm_sampler(device):
    if device == 'cuda' and not torch.cuda.is_available():
        pytest.skip('No cuda')
    x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 100)
    x_sampler = x_sampler.to(device)
    x, la, xidx = x_sampler(10)
    assert (x.device.type == device)
    assert (la.device.type == device)
コード例 #3
0
def test_gaussian_algorithm(device):
    if device == 'cuda' and not torch.cuda.is_available():
        pytest.skip('No cuda')
    x, la, y, lb, x_sampler, y_sampler = make_data('gaussian_2d', 100)
    x, la, y, lb, x_sampler, y_sampler = (x.to(device), la.to(device),
                                          y.to(device), lb.to(device), x_sampler.to(device), y_sampler.to(device))
    F, G = sinkhorn_gaussian(x_sampler=x_sampler, y_sampler=y_sampler)
    assert not np.isnan(F(x).sum().item())
    assert not np.isnan(G(y).sum().item())
コード例 #4
0
def test_online_sinkhorn(save_trace, refit, force_full, use_finite, lr,
                         batch_size, device):
    if device == 'cuda' and not torch.cuda.is_available():
        pytest.skip('No cuda')
    if not use_finite and force_full:
        pytest.skip('Invalid configuration')
    x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 10)
    x, la, y, lb, x_sampler, y_sampler = (x.to(device), la.to(device),
                                          y.to(device), lb.to(device),
                                          x_sampler.to(device),
                                          y_sampler.to(device))
    if save_trace:
        F, G = sinkhorn(x=x, la=la, y=y, lb=lb, n_iter=10)
        ref = {'train': (F(x), x, G(y), y), 'test': (F(x), x, G(y), y)}
    else:
        ref = None
    if use_finite:
        input = dict(x=x, la=la, y=y, lb=lb)
    else:
        input = dict(x_sampler=x_sampler, y_sampler=y_sampler)
    n_iter = 15
    if batch_size == 'constant':
        batch_sizes = 10
    else:
        batch_sizes = [10 * (i + 1) for i in range(n_iter)]

    if lr == 'constant':
        lrs = 1
    else:
        lrs = [1 / np.sqrt(i + 1) for i in range(n_iter)]
    res = online_sinkhorn(**input,
                          save_trace=save_trace,
                          force_full=force_full,
                          use_finite=use_finite,
                          ref=ref,
                          n_iter=n_iter,
                          batch_sizes=batch_sizes,
                          refit=refit,
                          lrs=lrs)
    if save_trace:
        F, G, trace = res
        assert (len(trace) == 15)
    else:
        F, G = res
    assert not np.isnan(F(x).sum().item())
    assert not np.isnan(G(y).sum().item())
コード例 #5
0
def test_precompute_C():
    x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 100)
    F, G = sinkhorn(x=x, la=la, y=y, lb=lb, n_iter=10, precompute_C=False)
    Fp, Gp = sinkhorn(x=x, la=la, y=y, lb=lb, n_iter=10, precompute_C=True)
    assert_allclose(Fp(x), F(x))
    assert_allclose(Gp(y), G(y))
コード例 #6
0
ファイル: online.py プロジェクト: arthurmensch/onlikorn
def run(data_source, n_samples, epsilon, n_iter, device, method, max_calls,
        batch_exp, batch_size, lr, lr_exp, max_length, refit, _seed, _run):

    if refit:
        max_length = 20000
    np.random.seed(_seed)
    torch.manual_seed(_seed)
    output_dir = join(exp.observers[0].dir, 'artifacts')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    x, la, y, lb, x_sampler, y_sampler = make_data(data_source, n_samples)

    x, y, la, lb = x.to(device), y.to(device), la.to(device), lb.to(device)
    x_sampler.to(device)
    y_sampler.to(device)

    if 'gaussian' in data_source:
        F, G = sinkhorn_gaussian(x_sampler, y_sampler, epsilon=epsilon)
        xr, lar, yr, lbr, _, _ = make_data(data_source, n_samples)
        xr, yr, lar, lbr = xr.to(device), yr.to(device), lar.to(device), lbr.to(device)
        ref = {'test': (F(xr), xr, G(yr), yr)}
    else:
        F, G, trace = torch_cached(sinkhorn)(x, la, y, lb, n_iter=n_iter, epsilon=epsilon, save_trace=True,
                                             verbose=False, max_calls=max_calls * 4,
                                             count_recompute=True)
        xr, lar, yr, lbr, _, _ = make_data(data_source, n_samples)
        xr, yr, lar, lbr = xr.to(device), yr.to(device), lar.to(device), lbr.to(device)
        Fr, Gr, tracer = torch_cached(sinkhorn)(xr, lar, yr, lbr, n_iter=n_iter, epsilon=epsilon, save_trace=True,
                                                verbose=False, max_calls=max_calls * 4,
                                                count_recompute=True)
        ref = {'train': (F(x), x, G(y), y), 'test': (Fr(xr), xr, G(yr), yr)}
        if max_calls is None:
            max_calls = tracer[-1]['n_calls']

    if method == 'sinkhorn_precompute':
        F, G, trace = sinkhorn(x, la, y, lb, n_iter=n_iter, epsilon=epsilon, save_trace=True, ref=ref,
                               max_calls=max_calls)
    elif method == 'sinkhorn':
        F, G, trace = sinkhorn(x, la, y, lb, n_iter=n_iter, epsilon=epsilon, save_trace=True, ref=ref,
                               max_calls=max_calls,
                               count_recompute=True)
    elif method == 'subsampled':
        F, G, trace = subsampled_sinkhorn(x, la, y, lb, n_iter=n_iter, batch_size=batch_size,
                                          max_calls=max_calls,
                                          epsilon=epsilon, save_trace=True, ref=ref, count_recompute=True)
    elif method == 'random':
        F, G, trace = random_sinkhorn(x_sampler=x_sampler, y_sampler=y_sampler, n_iter=n_iter,
                                      epsilon=epsilon, save_trace=True, ref=ref, use_finite=False,
                                      batch_sizes=batch_size,
                                      max_calls=max_calls)
    elif method in ['online', 'online_on_finite', 'online_as_warmup']:
        if n_iter is None:
            n_iter = int(1e6)
        batch_sizes, lrs, lr_exp = schedule(batch_exp, batch_size, lr, lr_exp, max_length, n_iter, refit)
        print(f'Using lr_exp={lr_exp}')
        _run.info['lr_exp'] = lr_exp
        if method == 'online':
            F, G, trace = online_sinkhorn(x_sampler=x_sampler, y_sampler=y_sampler, batch_sizes=batch_sizes,
                                          refit=refit, force_full=False,
                                          lrs=lrs, n_iter=n_iter, use_finite=False, max_length=max_length,
                                          epsilon=epsilon, save_trace=True, ref=ref, max_calls=max_calls)
        elif method == 'online_on_finite':
            F, G, trace = online_sinkhorn(x=x, la=la, y=y, lb=lb, batch_sizes=batch_sizes,
                                          refit=refit, force_full=True,
                                          lrs=lrs, n_iter=n_iter, use_finite=False, max_length=max_length,
                                          max_calls=max_calls,
                                          epsilon=epsilon, save_trace=True, ref=ref)
        elif method == 'online_as_warmup':
            F, G, trace = online_sinkhorn(x=x, la=la, y=y, lb=lb, batch_sizes=batch_sizes,
                                          refit=refit, force_full=True,
                                          lrs=lrs, n_iter=n_iter, use_finite=True, max_length=max_length,
                                          max_calls=max_calls,
                                          epsilon=epsilon, save_trace=True, ref=ref)
        else:
            raise ValueError
    else:
        raise ValueError

    torch.save(dict(x=x, la=la, y=y, lb=lb, F=F, G=G, trace=trace), join(output_dir, 'results.pkl'))
コード例 #7
0
ファイル: online.py プロジェクト: hedgefair/online_sinkhorn
def run(data_source, n_samples, epsilon, n_iter, device, method, max_calls, compare_with_ref, use_test,
        n_eval, precompute_C, force_full, batch_exp, batch_size, lr, lr_exp, max_length, refit, _seed, _run):
    np.random.seed(_seed)
    torch.manual_seed(_seed)
    output_dir = join(exp.observers[0].dir, 'artifacts')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    x, la, y, lb, x_sampler, y_sampler = make_data(data_source, n_samples)

    x, y, la, lb = x.to(device), y.to(device), la.to(device), lb.to(device)
    x_sampler.to(device)
    y_sampler.to(device)

    if n_iter is None:
        n_iter = int(1e4)

    if 'dragon' in data_source:
        use_test = False  # Useless

    if use_test:
        xr, lar, yr, lbr, _, _ = make_data(data_source, n_samples)
        xr, yr, lar, lbr = xr.to(device), yr.to(device), lar.to(device), lbr.to(device)

    if compare_with_ref:
        ref_calls = max_calls * 4
        F, G = torch_cached(sinkhorn)(x, la, y, lb, n_iter=n_iter, epsilon=epsilon, save_trace=False,
                                      verbose=False, max_calls=ref_calls,
                                      count_recompute=True)
        ref = {'train': (F(x), x, G(y), y)}
        if use_test:
            if 'gaussian' in data_source:
                Fr, Gr = sinkhorn_gaussian(x_sampler, y_sampler, epsilon=epsilon)  # Exact reference
            else:
                Fr, Gr = torch_cached(sinkhorn)(xr, lar, yr, lbr, n_iter=n_iter, epsilon=epsilon, save_trace=False,
                                                verbose=False, max_calls=ref_calls,
                                                count_recompute=True)
            if use_test:
                ref['test'] = (Fr(xr), xr, Gr(yr), yr)
    else:
        ref = {'train': (None, x, None, y)}
        if use_test:
            ref['test'] = (None, xr, None, yr)

    if method == 'sinkhorn':
        n_iter = min(n_iter, int(2e3))  # Faster
        F, G, trace = subsampled_sinkhorn(x, la, y, lb, n_iter=n_iter, batch_size=batch_size,
                                          max_calls=max_calls, precompute_C=precompute_C,
                                          trace_every=max_calls // n_eval,
                                          epsilon=epsilon, save_trace=True, ref=ref)
    elif method == 'random':
        F, G, trace = random_sinkhorn(x_sampler=x_sampler, y_sampler=y_sampler, n_iter=n_iter,
                                      epsilon=epsilon, save_trace=True, ref=ref, use_finite=False,
                                      batch_sizes=batch_size,
                                      trace_every=max_calls // n_eval,
                                      max_calls=max_calls)
    elif method == 'online':
        batch_sizes, lrs, lr_exp = schedule(batch_exp, batch_size, lr, lr_exp, max_length, n_iter, refit)
        print(f'Using lr_exp={lr_exp}')
        _run.info['lr_exp'] = lr_exp
        if method == 'online':
            F, G, trace = online_sinkhorn(x=x, la=la, y=y, lb=lb, x_sampler=x_sampler, y_sampler=y_sampler, batch_sizes=batch_sizes,
                                          refit=refit, force_full=force_full, precompute_C=precompute_C,
                                          trace_every=max_calls // n_eval,
                                          lrs=lrs, n_iter=n_iter, use_finite=force_full, max_length=max_length,
                                          epsilon=epsilon, save_trace=True, ref=ref, max_calls=max_calls)
    else:
        raise ValueError

    torch.save(dict(x=x, la=la, y=y, lb=lb, F=F, G=G, trace=trace), join(output_dir, 'results.pkl'))