def subsampled_sinkhorn(x, la, y, lb, n_iter=100, batch_size: int = 10, epsilon=1, save_trace=False, ref=None, precompute_C=True, max_calls=None, trace_every=1): if batch_size is not None and (batch_size != len(x) or batch_size != len(y)): x_sampler = Subsampler(x, la) y_sampler = Subsampler(y, lb) x, la, xidx = x_sampler(batch_size) y, lb, yidx = y_sampler(batch_size) return sinkhorn(x, la, y, lb, n_iter, epsilon, save_trace=save_trace, ref=ref, precompute_C=precompute_C, max_calls=max_calls, trace_every=trace_every)
def subsampled_sinkhorn(x, la, y, lb, n_iter=100, batch_size: int = 10, epsilon=1, save_trace=False, ref=None, precompute_C=True, count_recompute=False, max_calls=None,): x_sampler = Subsampler(x, la) y_sampler = Subsampler(y, lb) x, la, xidx = x_sampler(batch_size) y, lb, yidx = y_sampler(batch_size) return sinkhorn(x, la, y, lb, n_iter, epsilon, save_trace=save_trace, ref=ref, precompute_C=precompute_C, count_recompute=count_recompute, max_calls=max_calls)
def test_subsampler(device): if device == 'cuda' and not torch.cuda.is_available(): pytest.skip('No cuda') x, la, y, lb, x_sampler, y_sampler = make_data('gmm_1d', 100) x_sampler = Subsampler(x, la).to(device) x, la, xidx = x_sampler(10) assert (x.device.type == device) assert (la.device.type == device) assert isinstance(xidx, list)
def random_sinkhorn(x_sampler=None, y_sampler=None, x=None, la=None, y=None, lb=None, use_finite=True, n_iter=100, epsilon=1, max_calls=None, batch_sizes: Union[List[int], int] = 10, save_trace=False, ref=None): trace, ref = check_trace(save_trace, ref=ref, ref_needed=True) if n_iter is None: assert max_calls is not None n_iter = int(1e6) if isinstance(batch_sizes, int): batch_sizes = [batch_sizes for _ in range(n_iter)] else: n_iter = len(batch_sizes) if use_finite: x_sampler = Subsampler(x, la) y_sampler = Subsampler(y, lb) F, G = None, None n_calls = 0 for i in range(n_iter): if max_calls is not None and n_calls > max_calls: break x, la, _ = x_sampler(batch_sizes[i]) y, lb, _ = y_sampler(batch_sizes[i]) eG = 0 if i == 0 else G(y) F = FinitePotential(y, eG + lb, epsilon=epsilon) eF = 0 if i == 0 else F(x) G = FinitePotential(x, eF + la, epsilon=epsilon) n_samples = F.n_samples_ + G.n_samples_ n_calls += F.n_calls_ + G.n_calls_ if save_trace: this_trace = dict(n_iter=i + 1, n_calls=n_calls, n_samples=n_samples, algorithm='random') for name, (fr, xr, gr, yr) in ref.items(): this_trace[f'ref_err_{name}'] = (var_norm(F(xr, free=True) - fr) + var_norm(G(yr, free=True) - gr)).item() print(' '.join(f'{k}:{v}' for k, v in this_trace.items())) trace.append(this_trace) if save_trace: return F, G, trace else: return F, G
def make_data(data_source, n_samples): if 'dragon' in data_source: _, size = data_source.split('_') y, lb = make_dragon(size=int(size)) x, la = make_sphere(len(y)) x_sampler = Subsampler(x, la) y_sampler = Subsampler(y, lb) else: if data_source == 'gmm_1d': x_sampler, y_sampler = make_gmm_1d() elif data_source == 'gmm_2d': x_sampler, y_sampler = make_gmm_2d() elif data_source == 'gmm_10d': x_sampler, y_sampler = make_gmm(10, 5) elif data_source == 'gaussian_2d': x_sampler, y_sampler = make_gaussian(2) elif data_source == 'gaussian_10d': x_sampler, y_sampler = make_gaussian(10) else: raise ValueError x, la, _ = x_sampler(n_samples) y, lb, _ = y_sampler(n_samples) return x, la, y, lb, x_sampler, y_sampler
def make_data(data_source, n_samples): if data_source == 'dragon': x, la = make_sphere() y, lb = make_dragon() x *= 2 y *= 2 x_sampler = Subsampler(x, la) y_sampler = Subsampler(y, lb) else: if data_source == 'gmm_1d': x_sampler, y_sampler = make_gmm_1d() elif data_source == 'gmm_2d': x_sampler, y_sampler = make_gmm_2d() elif data_source == 'gmm_10d': x_sampler, y_sampler = make_gmm(10, 5) elif data_source == 'gaussian_2d': x_sampler, y_sampler = make_gaussian(2) elif data_source == 'gaussian_10d': x_sampler, y_sampler = make_gaussian(10) else: raise ValueError x, la, _ = x_sampler(n_samples) y, lb, _ = y_sampler(n_samples) return x, la, y, lb, x_sampler, y_sampler
def random_sinkhorn(x_sampler=None, y_sampler=None, x=None, la=None, y=None, lb=None, use_finite=True, n_iter=100, epsilon=1, max_calls=None, start_time=0, batch_sizes: Union[List[int], int] = 10, save_trace=False, ref=None, verbose=True, trace_every=1): eval_time = 0 t0 = time.perf_counter() trace, ref = check_trace(save_trace, ref=ref, ref_needed=True) if n_iter is None: assert max_calls is not None n_iter = int(1e6) if isinstance(batch_sizes, int): batch_sizes = [batch_sizes for _ in range(n_iter)] else: n_iter = len(batch_sizes) if use_finite: x_sampler = Subsampler(x, la) y_sampler = Subsampler(y, lb) F, G = None, None n_calls = 0 call_trace = 0 for i in range(n_iter): if max_calls is not None and n_calls > max_calls: break x, la, _ = x_sampler(batch_sizes[i]) y, lb, _ = y_sampler(batch_sizes[i]) eG = 0 if i == 0 else G(y) F = FinitePotential(y, eG + lb, epsilon=epsilon) eF = 0 if i == 0 else F(x) G = FinitePotential(x, eF + la, epsilon=epsilon) n_samples = F.n_samples_ + G.n_samples_ n_calls += F.n_calls_ + G.n_calls_ if save_trace and n_calls >= call_trace: this_trace = dict(n_iter=i + 1, n_calls=n_calls, n_samples=n_samples, algorithm='random') eval_t0 = time.perf_counter() fixed_err, ref_err = evaluate(F, G, epsilon, ref) if torch.cuda.is_available(): torch.cuda.synchronize() eval_time += time.perf_counter() - eval_t0 for name, err in fixed_err.items(): this_trace[f'fixed_err_{name}'] = err for name, err in ref_err.items(): this_trace[f'ref_err_{name}'] = err if torch.cuda.is_available(): torch.cuda.synchronize() this_trace['time'] = time.perf_counter( ) - t0 - eval_time + start_time this_trace['eval_time'] = eval_time trace.append(this_trace) call_trace = n_calls + trace_every if verbose: print(' '.join( f'{k}:{v:.2e}' if type(v) in [int, float] else f'{k}:{v}' for k, v in this_trace.items())) if save_trace: return F, G, trace else: return F, G
def online_sinkhorn(x_sampler=None, y_sampler=None, x=None, la=None, y=None, lb=None, use_finite=True, epsilon=1., max_length=100000, trim_every=None, refit=False, precompute_C=False, n_iter=100, force_full=False, batch_sizes: Optional[Union[List[int], int]] = 10, max_calls=None, verbose=True, start_time=0, trace_every=1, lrs: Union[List[float], float] = .1, save_trace=False, ref=None): eval_time = 0 t0 = time.perf_counter() if n_iter is None: assert max_calls is not None n_iter = int(1e6) if force_full: assert use_finite if not isinstance(batch_sizes, int) or not isinstance(batch_sizes, int): if not isinstance(batch_sizes, int): n_iter = len(batch_sizes) else: n_iter = len(lrs) if isinstance(batch_sizes, int): batch_sizes = [batch_sizes for _ in range(n_iter)] if isinstance(lrs, (float, int)): lrs = [lrs for _ in range(n_iter)] assert (n_iter == len(lrs) == len(batch_sizes)) if use_finite: x_sampler = Subsampler(x, la) y_sampler = Subsampler(y, lb) F = FinitePotential(y, epsilon=epsilon).to(x_sampler.device) G = FinitePotential(x, epsilon=epsilon).to(y_sampler.device) else: if x_sampler is None: x_sampler = Subsampler(x, la) if y_sampler is None: y_sampler = Subsampler(y, lb) F = InfinitePotential(max_length=max_length, dimension=y_sampler.dimension, epsilon=epsilon).to(x_sampler.device) G = InfinitePotential(max_length=max_length, dimension=x_sampler.dimension, epsilon=epsilon).to(y_sampler.device) if force_full: # save for later xf, laf, yf, lbf = x, la, y, lb if precompute_C: C = torch.empty((xf.shape[0], yf.shape[0], 1), device=xf.device) else: C = None else: xf, laf, yf, lbf = None, None, None, None C = None # Init x, la, xidx = x_sampler(batch_sizes[0]) y, lb, yidx = y_sampler(batch_sizes[0]) if force_full and precompute_C: this_C = compute_distance(x, y, lazy=False) scatter(C[..., 0], xidx, yidx, this_C[..., 0].transpose(0, 1)) F.push(yidx if use_finite else y, la) G.push(xidx if use_finite else x, lb) trace, ref = check_trace(save_trace, ref=ref, ref_needed=True) call_trace = 0 for i in range(n_iter): n_calls = F.n_calls_ + G.n_calls_ n_samples = F.n_samples_ + G.n_samples_ if max_calls is not None and n_calls > max_calls: break if save_trace and n_calls >= call_trace: eval_t0 = time.perf_counter() this_trace = dict(n_iter=i, n_calls=n_calls, n_samples=n_samples, algorithm='online') fixed_err, ref_err = evaluate(F, G, epsilon, ref) if torch.cuda.is_available(): torch.cuda.synchronize() eval_time += time.perf_counter() - eval_t0 for name, err in fixed_err.items(): this_trace[f'fixed_err_{name}'] = err for name, err in ref_err.items(): this_trace[f'ref_err_{name}'] = err if torch.cuda.is_available(): torch.cuda.synchronize() this_trace['time'] = time.perf_counter( ) - t0 - eval_time + start_time this_trace['eval_time'] = eval_time trace.append(this_trace) call_trace = n_calls + trace_every if verbose: print(' '.join( f'{k}:{v:.2e}' if type(v) in [int, float] else f'{k}:{v}' for k, v in this_trace.items())) y, lb, yidx = y_sampler(batch_sizes[i]) if refit: F.push(yidx if use_finite else y, -float('inf')) F.refit(G) else: F.add_weight(safe_log(1 - lrs[i])) if force_full and precompute_C: eG, this_C = G(y, return_C=True) xidx = check_idx(len(G.positions), G.seen) scatter(C[..., 0], xidx, yidx, this_C[..., 0].transpose(0, 1)) else: eG = G(y) F.push(yidx if use_finite else y, np.log(lrs[i]) + eG + lb) x, la, xidx = x_sampler(batch_sizes[i]) if refit: G.push(xidx if use_finite else x, -float('inf')) G.refit(F) else: G.add_weight(safe_log(1 - lrs[i])) if force_full and precompute_C: eF, this_C = F(x, return_C=True) yidx = check_idx(len(F.positions), F.seen) scatter(C[..., 0], xidx, yidx, this_C[..., 0]) else: eF = F(x) G.push(xidx if use_finite else x, np.log(lrs[i]) + eF + la) if not use_finite and trim_every is not None and i % trim_every == 0: G.trim() F.trim() if force_full and G.full and F.full: start_time = time.perf_counter() - t0 - eval_time + start_time res = sinkhorn(xf, laf, yf, lbf, F=F, G=G, save_trace=save_trace, trace=trace, start_iter=i, n_iter=n_iter - 1, ref=ref, precompute_C=C if precompute_C else False, max_calls=max_calls, trace_every=trace_every, start_time=start_time, epsilon=epsilon) if save_trace: F, G, trace = res else: F, G = res break anchor = F(torch.zeros_like(x[[0]])) F.add_weight(anchor) G.add_weight(-anchor) if save_trace: return F, G, trace else: return F, G
def online_sinkhorn(x_sampler=None, y_sampler=None, x=None, la=None, y=None, lb=None, use_finite=True, epsilon=1., max_length=100000, trim_every=None, refit=False, n_iter=100, force_full=True, batch_sizes: Optional[Union[List[int], int]] = 10, max_calls=None, lrs: Union[List[float], float] = .1, save_trace=False, ref=None): if n_iter is None: assert max_calls is not None n_iter = int(1e6) if not isinstance(batch_sizes, int) or not isinstance(batch_sizes, int): if not isinstance(batch_sizes, int): n_iter = len(batch_sizes) else: n_iter = len(lrs) if isinstance(batch_sizes, int): batch_sizes = [batch_sizes for _ in range(n_iter)] if isinstance(lrs, (float, int)): lrs = [lrs for _ in range(n_iter)] assert (n_iter == len(lrs) == len(batch_sizes)) if use_finite: x_sampler = Subsampler(x, la) y_sampler = Subsampler(y, lb) F = FinitePotential(y, epsilon=epsilon) G = FinitePotential(x, epsilon=epsilon) else: if x_sampler is None: x_sampler = Subsampler(x, la) if y_sampler is None: y_sampler = Subsampler(y, lb) F = InfinitePotential(max_length=max_length, dimension=y_sampler.dimension, epsilon=epsilon).to(x_sampler.device) G = InfinitePotential(max_length=max_length, dimension=x_sampler.dimension, epsilon=epsilon).to(y_sampler.device) if force_full: # save for later xf, laf, yf, lbf = x, la, y, lb else: xf, laf, yf, lbf = None, None, None, None # Init x, la, xidx = x_sampler(batch_sizes[0]) y, lb, yidx = y_sampler(batch_sizes[0]) F.push(yidx if use_finite else y, la) G.push(xidx if use_finite else x, lb) trace, ref = check_trace(save_trace, ref=ref, ref_needed=True) for i in range(n_iter): n_calls = F.n_calls_ + G.n_calls_ n_samples = F.n_samples_ + G.n_samples_ if max_calls is not None and n_calls > max_calls: break if save_trace: this_trace = dict(n_iter=i, n_calls=n_calls, n_samples=n_samples, algorithm='online') for name, (fr, xr, gr, yr) in ref.items(): f = F(xr, free=True) g = G(yr, free=True) this_trace[f'ref_err_{name}'] = (var_norm(f - fr) + var_norm(g - gr)).item() gg = FinitePotential(xr, fr - np.log(len(fr)))(yr) ff = FinitePotential(yr, gr - np.log(len(gr)))(xr) this_trace[f'fixed_err_{name}'] = (var_norm(f - ff) + var_norm(g - gg)).item() trace.append(this_trace) print(' '.join(f'{k}:{v}' for k, v in this_trace.items())) if use_finite and force_full: if G.full and F.full: # Force full iterations once every point has been observed res = sinkhorn(xf, laf, yf, lbf, F=F, G=G, save_trace=save_trace, trace=trace, start_iter=i, n_iter=n_iter - 1, ref=ref, precompute_C=True, precompute_for_free=True, # A better implementation would not require to recompute C epsilon=epsilon) if save_trace: F, G, trace = res else: F, G = res break y, lb, yidx = y_sampler(batch_sizes[i]) if refit: F.push(yidx if use_finite else y, -float('inf')) F.refit(G) else: F.add_weight(safe_log(1 - lrs[i])) F.push(yidx if use_finite else y, np.log(lrs[i]) + G(y) + lb) x, la, xidx = x_sampler(batch_sizes[i]) if refit: G.push(xidx if use_finite else x, -float('inf')) G.refit(F) else: G.add_weight(safe_log(1 - lrs[i])) G.push(xidx if use_finite else x, np.log(lrs[i]) + F(x) + la) if not use_finite and trim_every is not None and i % trim_every == 0: G.trim() F.trim() anchor = F(torch.zeros_like(x[[0]])) F.add_weight(anchor) G.add_weight(-anchor) if save_trace: return F.cpu(), G.cpu(), trace else: return F.cpu(), G.cpu()