Example #1
0
def subsampled_sinkhorn(x,
                        la,
                        y,
                        lb,
                        n_iter=100,
                        batch_size: int = 10,
                        epsilon=1,
                        save_trace=False,
                        ref=None,
                        precompute_C=True,
                        max_calls=None,
                        trace_every=1):
    if batch_size is not None and (batch_size != len(x)
                                   or batch_size != len(y)):
        x_sampler = Subsampler(x, la)
        y_sampler = Subsampler(y, lb)
        x, la, xidx = x_sampler(batch_size)
        y, lb, yidx = y_sampler(batch_size)
    return sinkhorn(x,
                    la,
                    y,
                    lb,
                    n_iter,
                    epsilon,
                    save_trace=save_trace,
                    ref=ref,
                    precompute_C=precompute_C,
                    max_calls=max_calls,
                    trace_every=trace_every)
Example #2
0
def subsampled_sinkhorn(x, la, y, lb, n_iter=100, batch_size: int = 10, epsilon=1, save_trace=False, ref=None,
                        precompute_C=True, count_recompute=False, max_calls=None,):
    x_sampler = Subsampler(x, la)
    y_sampler = Subsampler(y, lb)
    x, la, xidx = x_sampler(batch_size)
    y, lb, yidx = y_sampler(batch_size)
    return sinkhorn(x, la, y, lb, n_iter, epsilon, save_trace=save_trace, ref=ref, precompute_C=precompute_C,
                    count_recompute=count_recompute, max_calls=max_calls)
Example #3
0
def test_subsampler(device):
    if device == 'cuda' and not torch.cuda.is_available():
        pytest.skip('No cuda')
    x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 100)
    x_sampler = Subsampler(x, la).to(device)
    x, la, xidx = x_sampler(10)
    assert (x.device.type == device)
    assert (la.device.type == device)
    assert isinstance(xidx, list)
Example #4
0
def random_sinkhorn(x_sampler=None, y_sampler=None, x=None, la=None, y=None, lb=None, use_finite=True, n_iter=100,
                    epsilon=1, max_calls=None,
                    batch_sizes: Union[List[int], int] = 10, save_trace=False, ref=None):
    trace, ref = check_trace(save_trace, ref=ref, ref_needed=True)

    if n_iter is None:
        assert max_calls is not None
        n_iter = int(1e6)

    if isinstance(batch_sizes, int):
        batch_sizes = [batch_sizes for _ in range(n_iter)]
    else:
        n_iter = len(batch_sizes)
    if use_finite:
        x_sampler = Subsampler(x, la)
        y_sampler = Subsampler(y, lb)
    F, G = None, None
    n_calls = 0
    for i in range(n_iter):
        if max_calls is not None and n_calls > max_calls:
            break
        x, la, _ = x_sampler(batch_sizes[i])
        y, lb, _ = y_sampler(batch_sizes[i])
        eG = 0 if i == 0 else G(y)
        F = FinitePotential(y, eG + lb, epsilon=epsilon)
        eF = 0 if i == 0 else F(x)
        G = FinitePotential(x, eF + la, epsilon=epsilon)
        n_samples = F.n_samples_ + G.n_samples_
        n_calls += F.n_calls_ + G.n_calls_
        if save_trace:
            this_trace = dict(n_iter=i + 1, n_calls=n_calls, n_samples=n_samples, algorithm='random')
            for name, (fr, xr, gr, yr) in ref.items():
                this_trace[f'ref_err_{name}'] = (var_norm(F(xr, free=True) - fr) + var_norm(G(yr, free=True) - gr)).item()
            print(' '.join(f'{k}:{v}' for k, v in this_trace.items()))
            trace.append(this_trace)
    if save_trace:
        return F, G, trace
    else:
        return F, G
Example #5
0
def make_data(data_source, n_samples):
    if 'dragon' in data_source:
        _, size = data_source.split('_')
        y, lb = make_dragon(size=int(size))
        x, la = make_sphere(len(y))
        x_sampler = Subsampler(x, la)
        y_sampler = Subsampler(y, lb)
    else:
        if data_source == 'gmm_1d':
            x_sampler, y_sampler = make_gmm_1d()
        elif data_source == 'gmm_2d':
            x_sampler, y_sampler = make_gmm_2d()
        elif data_source == 'gmm_10d':
            x_sampler, y_sampler = make_gmm(10, 5)
        elif data_source == 'gaussian_2d':
            x_sampler, y_sampler = make_gaussian(2)
        elif data_source == 'gaussian_10d':
            x_sampler, y_sampler = make_gaussian(10)
        else:
            raise ValueError
        x, la, _ = x_sampler(n_samples)
        y, lb, _ = y_sampler(n_samples)
    return x, la, y, lb, x_sampler, y_sampler
Example #6
0
def make_data(data_source, n_samples):
    if data_source == 'dragon':
        x, la = make_sphere()
        y, lb = make_dragon()
        x *= 2
        y *= 2
        x_sampler = Subsampler(x, la)
        y_sampler = Subsampler(y, lb)
    else:
        if data_source == 'gmm_1d':
            x_sampler, y_sampler = make_gmm_1d()
        elif data_source == 'gmm_2d':
            x_sampler, y_sampler = make_gmm_2d()
        elif data_source == 'gmm_10d':
            x_sampler, y_sampler = make_gmm(10, 5)
        elif data_source == 'gaussian_2d':
            x_sampler, y_sampler = make_gaussian(2)
        elif data_source == 'gaussian_10d':
            x_sampler, y_sampler = make_gaussian(10)
        else:
            raise ValueError
        x, la, _ = x_sampler(n_samples)
        y, lb, _ = y_sampler(n_samples)
    return x, la, y, lb, x_sampler, y_sampler
Example #7
0
def random_sinkhorn(x_sampler=None,
                    y_sampler=None,
                    x=None,
                    la=None,
                    y=None,
                    lb=None,
                    use_finite=True,
                    n_iter=100,
                    epsilon=1,
                    max_calls=None,
                    start_time=0,
                    batch_sizes: Union[List[int], int] = 10,
                    save_trace=False,
                    ref=None,
                    verbose=True,
                    trace_every=1):
    eval_time = 0
    t0 = time.perf_counter()
    trace, ref = check_trace(save_trace, ref=ref, ref_needed=True)

    if n_iter is None:
        assert max_calls is not None
        n_iter = int(1e6)

    if isinstance(batch_sizes, int):
        batch_sizes = [batch_sizes for _ in range(n_iter)]
    else:
        n_iter = len(batch_sizes)
    if use_finite:
        x_sampler = Subsampler(x, la)
        y_sampler = Subsampler(y, lb)
    F, G = None, None
    n_calls = 0
    call_trace = 0
    for i in range(n_iter):
        if max_calls is not None and n_calls > max_calls:
            break
        x, la, _ = x_sampler(batch_sizes[i])
        y, lb, _ = y_sampler(batch_sizes[i])
        eG = 0 if i == 0 else G(y)
        F = FinitePotential(y, eG + lb, epsilon=epsilon)
        eF = 0 if i == 0 else F(x)
        G = FinitePotential(x, eF + la, epsilon=epsilon)
        n_samples = F.n_samples_ + G.n_samples_
        n_calls += F.n_calls_ + G.n_calls_
        if save_trace and n_calls >= call_trace:
            this_trace = dict(n_iter=i + 1,
                              n_calls=n_calls,
                              n_samples=n_samples,
                              algorithm='random')
            eval_t0 = time.perf_counter()
            fixed_err, ref_err = evaluate(F, G, epsilon, ref)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            eval_time += time.perf_counter() - eval_t0
            for name, err in fixed_err.items():
                this_trace[f'fixed_err_{name}'] = err
            for name, err in ref_err.items():
                this_trace[f'ref_err_{name}'] = err
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            this_trace['time'] = time.perf_counter(
            ) - t0 - eval_time + start_time
            this_trace['eval_time'] = eval_time
            trace.append(this_trace)
            call_trace = n_calls + trace_every
            if verbose:
                print(' '.join(
                    f'{k}:{v:.2e}' if type(v) in [int, float] else f'{k}:{v}'
                    for k, v in this_trace.items()))
    if save_trace:
        return F, G, trace
    else:
        return F, G
Example #8
0
def online_sinkhorn(x_sampler=None,
                    y_sampler=None,
                    x=None,
                    la=None,
                    y=None,
                    lb=None,
                    use_finite=True,
                    epsilon=1.,
                    max_length=100000,
                    trim_every=None,
                    refit=False,
                    precompute_C=False,
                    n_iter=100,
                    force_full=False,
                    batch_sizes: Optional[Union[List[int], int]] = 10,
                    max_calls=None,
                    verbose=True,
                    start_time=0,
                    trace_every=1,
                    lrs: Union[List[float], float] = .1,
                    save_trace=False,
                    ref=None):
    eval_time = 0
    t0 = time.perf_counter()

    if n_iter is None:
        assert max_calls is not None
        n_iter = int(1e6)

    if force_full:
        assert use_finite

    if not isinstance(batch_sizes, int) or not isinstance(batch_sizes, int):
        if not isinstance(batch_sizes, int):
            n_iter = len(batch_sizes)
        else:
            n_iter = len(lrs)
    if isinstance(batch_sizes, int):
        batch_sizes = [batch_sizes for _ in range(n_iter)]
    if isinstance(lrs, (float, int)):
        lrs = [lrs for _ in range(n_iter)]
    assert (n_iter == len(lrs) == len(batch_sizes))

    if use_finite:
        x_sampler = Subsampler(x, la)
        y_sampler = Subsampler(y, lb)
        F = FinitePotential(y, epsilon=epsilon).to(x_sampler.device)
        G = FinitePotential(x, epsilon=epsilon).to(y_sampler.device)
    else:
        if x_sampler is None:
            x_sampler = Subsampler(x, la)
        if y_sampler is None:
            y_sampler = Subsampler(y, lb)
        F = InfinitePotential(max_length=max_length,
                              dimension=y_sampler.dimension,
                              epsilon=epsilon).to(x_sampler.device)
        G = InfinitePotential(max_length=max_length,
                              dimension=x_sampler.dimension,
                              epsilon=epsilon).to(y_sampler.device)

    if force_full:  # save for later
        xf, laf, yf, lbf = x, la, y, lb
        if precompute_C:
            C = torch.empty((xf.shape[0], yf.shape[0], 1), device=xf.device)
        else:
            C = None
    else:
        xf, laf, yf, lbf = None, None, None, None
        C = None
    # Init
    x, la, xidx = x_sampler(batch_sizes[0])
    y, lb, yidx = y_sampler(batch_sizes[0])
    if force_full and precompute_C:
        this_C = compute_distance(x, y, lazy=False)
        scatter(C[..., 0], xidx, yidx, this_C[..., 0].transpose(0, 1))

    F.push(yidx if use_finite else y, la)
    G.push(xidx if use_finite else x, lb)

    trace, ref = check_trace(save_trace, ref=ref, ref_needed=True)

    call_trace = 0
    for i in range(n_iter):
        n_calls = F.n_calls_ + G.n_calls_
        n_samples = F.n_samples_ + G.n_samples_
        if max_calls is not None and n_calls > max_calls:
            break
        if save_trace and n_calls >= call_trace:
            eval_t0 = time.perf_counter()
            this_trace = dict(n_iter=i,
                              n_calls=n_calls,
                              n_samples=n_samples,
                              algorithm='online')
            fixed_err, ref_err = evaluate(F, G, epsilon, ref)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            eval_time += time.perf_counter() - eval_t0
            for name, err in fixed_err.items():
                this_trace[f'fixed_err_{name}'] = err
            for name, err in ref_err.items():
                this_trace[f'ref_err_{name}'] = err
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            this_trace['time'] = time.perf_counter(
            ) - t0 - eval_time + start_time
            this_trace['eval_time'] = eval_time
            trace.append(this_trace)
            call_trace = n_calls + trace_every
            if verbose:
                print(' '.join(
                    f'{k}:{v:.2e}' if type(v) in [int, float] else f'{k}:{v}'
                    for k, v in this_trace.items()))
        y, lb, yidx = y_sampler(batch_sizes[i])
        if refit:
            F.push(yidx if use_finite else y, -float('inf'))
            F.refit(G)
        else:
            F.add_weight(safe_log(1 - lrs[i]))
            if force_full and precompute_C:
                eG, this_C = G(y, return_C=True)
                xidx = check_idx(len(G.positions), G.seen)
                scatter(C[..., 0], xidx, yidx, this_C[..., 0].transpose(0, 1))
            else:
                eG = G(y)
            F.push(yidx if use_finite else y, np.log(lrs[i]) + eG + lb)
        x, la, xidx = x_sampler(batch_sizes[i])
        if refit:
            G.push(xidx if use_finite else x, -float('inf'))
            G.refit(F)
        else:
            G.add_weight(safe_log(1 - lrs[i]))
            if force_full and precompute_C:
                eF, this_C = F(x, return_C=True)
                yidx = check_idx(len(F.positions), F.seen)
                scatter(C[..., 0], xidx, yidx, this_C[..., 0])
            else:
                eF = F(x)
            G.push(xidx if use_finite else x, np.log(lrs[i]) + eF + la)
        if not use_finite and trim_every is not None and i % trim_every == 0:
            G.trim()
            F.trim()

        if force_full and G.full and F.full:
            start_time = time.perf_counter() - t0 - eval_time + start_time
            res = sinkhorn(xf,
                           laf,
                           yf,
                           lbf,
                           F=F,
                           G=G,
                           save_trace=save_trace,
                           trace=trace,
                           start_iter=i,
                           n_iter=n_iter - 1,
                           ref=ref,
                           precompute_C=C if precompute_C else False,
                           max_calls=max_calls,
                           trace_every=trace_every,
                           start_time=start_time,
                           epsilon=epsilon)
            if save_trace:
                F, G, trace = res
            else:
                F, G = res
            break

    anchor = F(torch.zeros_like(x[[0]]))
    F.add_weight(anchor)
    G.add_weight(-anchor)
    if save_trace:
        return F, G, trace
    else:
        return F, G
Example #9
0
def online_sinkhorn(x_sampler=None, y_sampler=None,
                    x=None, la=None, y=None, lb=None, use_finite=True,
                    epsilon=1., max_length=100000, trim_every=None,
                    refit=False,
                    n_iter=100, force_full=True,
                    batch_sizes: Optional[Union[List[int], int]] = 10, max_calls=None,
                    lrs: Union[List[float], float] = .1, save_trace=False, ref=None):
    if n_iter is None:
        assert max_calls is not None
        n_iter = int(1e6)

    if not isinstance(batch_sizes, int) or not isinstance(batch_sizes, int):
        if not isinstance(batch_sizes, int):
            n_iter = len(batch_sizes)
        else:
            n_iter = len(lrs)
    if isinstance(batch_sizes, int):
        batch_sizes = [batch_sizes for _ in range(n_iter)]
    if isinstance(lrs, (float, int)):
        lrs = [lrs for _ in range(n_iter)]
    assert (n_iter == len(lrs) == len(batch_sizes))

    if use_finite:
        x_sampler = Subsampler(x, la)
        y_sampler = Subsampler(y, lb)
        F = FinitePotential(y, epsilon=epsilon)
        G = FinitePotential(x, epsilon=epsilon)
    else:
        if x_sampler is None:
            x_sampler = Subsampler(x, la)
        if y_sampler is None:
            y_sampler = Subsampler(y, lb)
        F = InfinitePotential(max_length=max_length, dimension=y_sampler.dimension, epsilon=epsilon).to(x_sampler.device)
        G = InfinitePotential(max_length=max_length, dimension=x_sampler.dimension, epsilon=epsilon).to(y_sampler.device)

    if force_full:  # save for later
        xf, laf, yf, lbf = x, la, y, lb
    else:
        xf, laf, yf, lbf = None, None, None, None
    # Init
    x, la, xidx = x_sampler(batch_sizes[0])
    y, lb, yidx = y_sampler(batch_sizes[0])
    F.push(yidx if use_finite else y, la)
    G.push(xidx if use_finite else x, lb)

    trace, ref = check_trace(save_trace, ref=ref, ref_needed=True)

    for i in range(n_iter):
        n_calls = F.n_calls_ + G.n_calls_
        n_samples = F.n_samples_ + G.n_samples_
        if max_calls is not None and n_calls > max_calls:
            break
        if save_trace:
            this_trace = dict(n_iter=i, n_calls=n_calls, n_samples=n_samples, algorithm='online')
            for name, (fr, xr, gr, yr) in ref.items():
                f = F(xr, free=True)
                g = G(yr, free=True)
                this_trace[f'ref_err_{name}'] = (var_norm(f - fr) + var_norm(g - gr)).item()

                gg = FinitePotential(xr, fr - np.log(len(fr)))(yr)
                ff = FinitePotential(yr, gr - np.log(len(gr)))(xr)
                this_trace[f'fixed_err_{name}'] = (var_norm(f - ff) + var_norm(g - gg)).item()
                
            trace.append(this_trace)
            print(' '.join(f'{k}:{v}' for k, v in this_trace.items()))
        if use_finite and force_full:
            if G.full and F.full:  # Force full iterations once every point has been observed
                res = sinkhorn(xf, laf, yf, lbf, F=F, G=G, save_trace=save_trace, trace=trace, start_iter=i,
                               n_iter=n_iter - 1, ref=ref, precompute_C=True, precompute_for_free=True,
                               # A better implementation would not require to recompute C
                               epsilon=epsilon)
                if save_trace:
                    F, G, trace = res
                else:
                    F, G = res
                break

        y, lb, yidx = y_sampler(batch_sizes[i])
        if refit:
            F.push(yidx if use_finite else y, -float('inf'))
            F.refit(G)
        else:
            F.add_weight(safe_log(1 - lrs[i]))
            F.push(yidx if use_finite else y, np.log(lrs[i]) + G(y) + lb)
        x, la, xidx = x_sampler(batch_sizes[i])
        if refit:
            G.push(xidx if use_finite else x, -float('inf'))
            G.refit(F)
        else:
            G.add_weight(safe_log(1 - lrs[i]))
            G.push(xidx if use_finite else x, np.log(lrs[i]) + F(x) + la)
        if not use_finite and trim_every is not None and i % trim_every == 0:
            G.trim()
            F.trim()
    anchor = F(torch.zeros_like(x[[0]]))
    F.add_weight(anchor)
    G.add_weight(-anchor)
    if save_trace:
        return F.cpu(), G.cpu(), trace
    else:
        return F.cpu(), G.cpu()