def gather(output_dirs): traces = [] for output_dir in output_dirs: for exp_dir in os.listdir(output_dir): try: conf = json.load( open(join(output_dir, exp_dir, 'config.json'), 'r')) run = json.load( open(join(output_dir, exp_dir, 'run.json'), 'r')) status = run['status'] except: print(f'No trace for {exp_dir}') continue try: trace = torch.load(join(output_dir, exp_dir, 'artifacts', 'results.pkl'), map_location=torch.device('cpu'))['trace'] except: print(f'No trace for {exp_dir}, {status}') continue print(output_dir, exp_dir, conf, len(trace)) trace = pd.DataFrame(trace) for k, v in conf.items(): if k not in ['n_iter', 'n_samples']: trace[k] = v else: if k == 'n_iter': trace['total_n_iter'] = v trace['exp_dir'] = exp_dir trace['status'] = status traces.append(trace) traces = pd.concat(traces) traces.to_pickle(join(get_output_dir(), 'all.pkl')) return traces
def create_one(config, index=0): job_folder = join(get_output_dir(), 'online', 'jobs') project_root = os.path.abspath(os.getcwd()) if not os.path.exists(job_folder): os.makedirs(job_folder) filename = join(job_folder, f'run_{index}.slurm') config_str = ' '.join(f'{key}={value}' for key, value in config.items()) with open(filename, 'w+') as f: f.write( SLURM_TEMPLATE.format(project_root=project_root, filename=filename, config_str=config_str)) return filename
import os from os.path import join import numpy as np import torch from sacred import Experiment from sacred.observers import FileStorageObserver from onlikhorn.algorithm import sinkhorn, online_sinkhorn, random_sinkhorn, subsampled_sinkhorn, schedule from onlikhorn.cache import torch_cached from onlikhorn.dataset import get_output_dir, make_data from onlikhorn.gaussian import sinkhorn_gaussian exp_name = 'online_grid_quiver_2' exp = Experiment(exp_name) exp_dir = join(get_output_dir(), exp_name) exp.observers = [FileStorageObserver(exp_dir)] @exp.config def config(): data_source = 'gmm_1d' n_samples = 10000 max_length = 20000 device = 'cuda' # Overrided batch_size = 100 seed = 0 epsilon = 1e-2
def plot_gaussian(df, ytype='test', epsilon=1e-2, refit=False): df = df.query( '(method in ["online", "sinkhorn", "subsampled"]) & data_source != "dragon" &' f'(method != "online" | (refit == {refit} & lr_exp == "auto")) &' f'epsilon == {epsilon}') if refit: # Remove saturated memory df.loc[(df['method'] == 'online') & (df['n_samples'] == 40000), ['ref_err_test', 'fixed_err']] = np.nan pk = [ 'data_source', 'epsilon', 'method', 'refit', 'batch_exp', 'lr_exp', 'batch_size', 'n_iter' ] df = df.groupby(by=pk).agg(['mean', 'std']).reset_index('n_iter') NAMES = {'gaussian_2d': '2D Gaussian', 'gaussian_10d': '10D Gaussian'} df1 = df fig, axes = plt.subplots(1, 2, figsize=(width, width * 0.2)) order = {'gaussian_2d': 0, 'gaussian_10d': 1} for (data_source, epsilon), df2 in df1.groupby(['data_source', 'epsilon']): iter_at_prec = {} for index, df3 in df2.groupby( ['method', 'refit', 'batch_exp', 'lr_exp', 'batch_size']): n_calls = df3['n_calls'] if ytype == 'train': y = df3['ref_err_train'] elif ytype == 'test': y = df3['ref_err_test'] elif ytype == 'err': y = df3['fixed_err'] else: raise ValueError if index[0] == 'sinkhorn': label = 'Sinkhorn $n = 10^4$' elif index[0] == 'subsampled': label = f'Sinkhorn $n = {index[-1]}$' else: if index[2] == 0: label = f'O-S $n(t) = 100$' else: label = f'O-S $n(t) \propto t^{{{index[2]}}}$' axes[order[data_source]].plot(n_calls['mean'], y['mean'], label=label, linewidth=2, alpha=0.8) axes[order[data_source]].annotate(NAMES[data_source], xy=(.5, .83), xycoords="axes fraction", ha='center', va='bottom') axes[0].annotate('Computations', xy=(-.2, -.25), xycoords="axes fraction", ha='center', va='bottom') axes[1].legend(loc='center left', frameon=False, bbox_to_anchor=(1.03, 0.5), ncol=1) for ax in axes: ax.set_xscale('log') ax.set_yscale('log') ax.tick_params(axis='both', which='major', labelsize=5) ax.tick_params(axis='both', which='minor', labelsize=5) ax.minorticks_on() if ytype == 'err': axes[0].set_ylabel(r'$\Vert T(\hat f){-}\hat g\Vert_{\textrm{var}}$', fontsize=5) elif ytype == 'train': axes[0].set_ylabel(r'$\Vert \hat f {-} f^\star\Vert_{\textrm{var}}$', fontsize=5) else: axes[0].set_ylabel(r'$\Vert \hat f {-} f^\star\Vert_{\textrm{var}}$', fontsize=5) sns.despine(fig) fig.subplots_adjust(right=0.75, bottom=0.21) fig.savefig( join(get_output_dir(), f'online_{epsilon}_{refit}_{ytype}_gaussian.pdf'))
def plot_random(df, ytype='test', epsilon=1e-2, name=''): df = df.query( '(method in ["random", "sinkhorn", "subsampled"]) & data_source != "dragon" &' f'epsilon == {epsilon}') pk = [ 'data_source', 'epsilon', 'method', 'refit', 'batch_exp', 'lr_exp', 'batch_size', 'n_iter' ] df = df.groupby(by=pk).agg(['mean', 'std']).reset_index('n_iter') NAMES = {'gmm_1d': '1D GMM', 'gmm_10d': '10D GMM', 'gmm_2d': '2D GMM'} df1 = df fig, axes = plt.subplots(1, 3, figsize=(width, width * 0.2)) order = {'gmm_1d': 0, 'gmm_10d': 2, 'gmm_2d': 1} for (data_source, epsilon), df2 in df1.groupby(['data_source', 'epsilon']): iter_at_prec = {} for index, df3 in df2.groupby( ['method', 'refit', 'batch_exp', 'lr_exp', 'batch_size']): n_calls = df3['n_calls'] if ytype == 'train': y = df3['ref_err_train'] elif ytype == 'test': y = df3['ref_err_test'] elif ytype == 'err': y = df3['fixed_err'] else: raise ValueError if index[0] == 'sinkhorn': label = 'Sinkhorn $n = 10^4$' elif index[0] == 'subsampled': label = f'Sinkhorn $n = {index[-1]}$' else: if index[2] == 0: label = f'R-S $n = {index[-1]}$' else: label = f'R-S $n = {index[-1]}$' axes[order[data_source]].plot(n_calls['mean'], y['mean'], label=label, linewidth=2, alpha=0.8) axes[order[data_source]].annotate(NAMES[data_source], xy=(.5, .83), xycoords="axes fraction", ha='center', va='bottom') axes[0].annotate('Computations', xy=(-.3, -.25), xycoords="axes fraction", ha='center', va='bottom') axes[2].legend(loc='center left', frameon=False, bbox_to_anchor=(1.03, 0.5), ncol=1) for ax in axes: ax.set_xscale('log') ax.set_yscale('log') ax.tick_params(axis='both', which='major', labelsize=5) ax.tick_params(axis='both', which='minor', labelsize=5) ax.minorticks_on() if ytype == 'err': axes[0].set_ylabel(r'$\Vert T(\hat f){-}\hat g\Vert_{\textrm{var}}$', fontsize=5) elif ytype == 'train': axes[0].set_ylabel(r'$\Vert \hat f {-} f^\star\Vert_{\textrm{var}}$', fontsize=5) else: axes[0].set_ylabel(r'$\Vert \hat f {-} f_0^\star\Vert_{\textrm{var}}$', fontsize=5) sns.despine(fig) fig.subplots_adjust(right=0.75, bottom=0.21) fig.savefig(join(get_output_dir(), f'random_{epsilon}_{ytype}_{name}.pdf'))
def plot_warmup(df, ytype='err', epsilon=1e-2): df = df.query( '(method in ["sinkhorn_precompute", "online_as_warmup"]) & ' '(data_source in ["gmm_2d", "gmm_10d", "dragon"]) & ' 'batch_size == 100 & ' f'epsilon == {epsilon} & ' '((refit == False & batch_exp == .5 & lr_exp == 0) | method != "online_as_warmup")' ) # df.loc[df['fixed_err'].isna(), 'fixed_err'] = df.loc[df['fixed_err'].isna(), 'var_err_train'].values pk = [ 'data_source', 'epsilon', 'method', 'refit', 'batch_exp', 'lr_exp', 'batch_size', 'n_iter' ] df = df.groupby(by=pk).agg(['mean', 'std']).reset_index('n_iter') NAMES = {'dragon': 'Stanford 3D', 'gmm_10d': '10D GMM', 'gmm_2d': '2D GMM'} fig, axes = plt.subplots(1, 3, figsize=(0.7 * width, width * 0.2 * 0.7)) order = {'gmm_2d': 0, 'gmm_10d': 1, 'dragon': 2} speedups = [] for (data_source, epsilon), df2 in df.groupby(['data_source', 'epsilon']): iter_at_prec = {} for index, df3 in df2.groupby( ['method', 'refit', 'batch_exp', 'lr_exp', 'batch_size']): n_calls = df3['n_calls'] if ytype == 'train': y = df3['ref_err_train'] elif ytype == 'test': y = df3['ref_err_test'] elif ytype == 'err': y = df3['fixed_err'] else: raise ValueError if index[0] == 'sinkhorn_precompute': label = 'Standard\nSinkhorn' else: label = 'Online\nSinkhorn\nwarmup' axes[order[data_source]].plot( n_calls['mean'], y['mean'], label=label, linewidth=2, alpha=0.8, color='C3' if index[0] == 'sinkhorn_precompute' else 'C0') # if index[0] != 'sinkhorn_precompute': # axes[i].fill_between(n_calls['mean'], y['mean'] - y['std'], y['mean'] + y['std'], # alpha=0.2) try: iter_at_prec[index[0]] = n_calls['mean'].iloc[np.where( y['mean'] < 1e-3)[0][0]] except IndexError: iter_at_prec[index[0]] = np.float('inf') axes[order[data_source]].annotate(NAMES[data_source], xy=(.5, .8), xycoords="axes fraction", ha='center', va='bottom') speedups.append( dict(data_source=data_source, epsilon=epsilon, speedup=iter_at_prec['sinkhorn_precompute'] / iter_at_prec['online_as_warmup'], ytype=ytype)) axes[0].annotate('Comput.', xy=(-.27, -.32), xycoords="axes fraction", ha='center', va='bottom') if ytype == 'err' and epsilon == 1e-3: axes[0].set_ylim([1e-5, 0.5]) axes[0].set_ylim([1e-6, 0.1]) axes[0].set_ylim([1e-4, 0.1]) axes[2].legend(loc='center left', frameon=False, bbox_to_anchor=(.95, 0.5), ncol=1) for ax in axes: ax.set_xscale('log') ax.set_yscale('log') ax.tick_params(axis='both', which='major', labelsize=5) ax.tick_params(axis='both', which='minor', labelsize=5) ax.minorticks_on() # ax.set_xlim([0.8e9, 1.2e11]) # ax.set_xticks([1e9, 1e10, 1e11]) if ytype == 'err': axes[0].set_ylabel(r'$\Vert T(\hat f){-} \hat g\Vert_{\textrm{var}}$', fontsize=5) elif ytype == 'train': axes[0].set_ylabel(r'$\Vert \hat f {-} f^\star\Vert_{\textrm{var}}$', fontsize=5) else: axes[0].set_ylabel('$\Vert f {-} f_0^\star\Vert_{\textrm{var}}$', fontsize=5) sns.despine(fig) fig.subplots_adjust(right=0.8, bottom=0.23) fig.savefig(join(get_output_dir(), f'online+full_{epsilon}_{ytype}.pdf')) return speedups
axes[0].set_ylabel(r'$\Vert \hat f {-} f^\star\Vert_{\textrm{var}}$', fontsize=5) else: axes[0].set_ylabel(r'$\Vert \hat f {-} f^\star\Vert_{\textrm{var}}$', fontsize=5) sns.despine(fig) fig.subplots_adjust(right=0.75, bottom=0.21) fig.savefig( join(get_output_dir(), f'online_{epsilon}_{refit}_{ytype}_gaussian.pdf')) pipeline = ['random'] if 'gather' in pipeline: output_dirs = [join(get_output_dir(), 'online_grid12')] df = gather(output_dirs) df.to_pickle(join(get_output_dir(), 'all_warmup.pkl')) # Figure 1 if 'figure_1' in pipeline: output_dirs = [ join(get_output_dir(), 'online_grid10'), join(get_output_dir(), 'online_grid11') ] df = pd.read_pickle(join(get_output_dir(), 'all.pkl')) for ytype in ['test']: for epsilon in np.logspace(-4, -1, 4): for refit in [False, True]: plot_online(df, refit=refit, epsilon=epsilon, ytype=ytype) del df
def one_dimensional_exp(): eps = 1e-1 grid = torch.linspace(-4, 12, 500)[:, None] C = compute_distance(grid, grid) x_sampler = Sampler(mean=torch.tensor([[1.], [2], [3]]), cov=torch.tensor([[[.1]], [[.1]], [[.1]]]), p=torch.ones(3) / 3) y_sampler = Sampler(mean=torch.tensor([[0.], [3], [5]]), cov=torch.tensor([[[.1]], [[.1]], [[.4]]]), p=torch.ones(3) / 3) torch.manual_seed(100) np.random.seed(100) lpx = x_sampler.log_prob(grid) lpy = y_sampler.log_prob(grid) lpx -= torch.logsumexp(lpx, dim=0) lpy -= torch.logsumexp(lpy, dim=0) px = torch.exp(lpx) py = torch.exp(lpy) fevals = [] gevals = [] labels = [] plans = [] mem = Memory(location=expanduser('~/cache')) n_samples = 5000 x, loga = x_sampler(n_samples) y, logb = y_sampler(n_samples) f, g = mem.cache(sinkhorn)(x.numpy(), y.numpy(), eps=eps, n_iter=100) f = torch.from_numpy(f).float() g = torch.from_numpy(g).float() distance = compute_distance(grid, y) feval = - eps * torch.logsumexp((- distance + g[None, :]) / eps + logb[None, :], dim=1) distance = compute_distance(grid, x) geval = - eps * torch.logsumexp((- distance + f[None, :]) / eps + loga[None, :], dim=1) plan = (lpx[:, None] + feval[:, None] / eps + lpy[None, :] + geval[None, :] / eps - C / eps) plans.append((plan, grid, grid)) fevals.append(feval) gevals.append(geval) labels.append(f'True potential') m = 50 hatf, posx, hatg, posy, sto_fevals, sto_gevals = sampling_sinkhorn(x_sampler, y_sampler, m=m, eps=eps, n_iter=100, grid=grid) feval = evaluate_potential(hatg, posy, grid, eps) geval = evaluate_potential(hatf, posx, grid, eps) plan = (lpx[:, None] + feval[:, None] / eps + lpy[None, :] + geval[None, :] / eps - C / eps) plans.append((plan, grid, grid)) fevals.append(feval) gevals.append(geval) labels.append(f'Online Sinkhorn') alpha, posx, posy, _, _ = rkhs_sinkhorn(x_sampler, y_sampler, m=m, eps=eps, n_iter=100, sigma=1, grid=grid) feval = evaluate_kernel(alpha, posx, grid, sigma=1) geval = evaluate_kernel(alpha, posy, grid, sigma=1) plan = (lpx[:, None] + feval[:, None] / eps + lpy[None, :] + geval[None, :] / eps - C / eps) plans.append((plan, grid, grid)) fevals.append(feval) gevals.append(geval) labels.append(f'RKHS') fevals = torch.cat([feval[None, :] for feval in fevals], dim=0) gevals = torch.cat([geval[None, :] for geval in gevals], dim=0) fig = plt.figure(figsize=(width, .21 * width)) gs = gridspec.GridSpec(ncols=6, nrows=1, width_ratios=[1, 1.5, 1.5, .8, .8, .8], figure=fig) plt.subplots_adjust(right=0.97, left=0.05, bottom=0.27, top=0.85) ax0 = fig.add_subplot(gs[0]) ax0.plot(grid, px, label=r'$\alpha$') ax0.plot(grid, py, label=r'$\beta$') ax0.legend(frameon=False) # ax0.axis('off') ax1 = fig.add_subplot(gs[1]) ax2 = fig.add_subplot(gs[2]) ax3 = fig.add_subplot(gs[3]) ax4 = fig.add_subplot(gs[4]) ax5 = fig.add_subplot(gs[5]) for i, (label, feval, geval, (plan, x, y)) in enumerate(zip(labels, fevals, gevals, plans)): if label == 'True potential': ax1.plot(grid, feval, label=label, zorder=100 if label == 'True Potential' else 1, linewidth=3 if label == 'True potential' else 1, color='C3') ax2.plot(grid, geval, label=None, zorder=100 if label == 'True Potential' else 1, linewidth=3 if label == 'True potential' else 1, color='C3') plan = plan.numpy() ax3.contour(y[:, 0], x[:, 0], plan, levels=30) elif label == 'RKHS': ax1.plot(grid, feval, label=label, linewidth=2, color='C4') ax2.plot(grid, geval, label=None, linewidth=2, color='C4') plan = plan.numpy() ax4.contour(y[:, 0], x[:, 0], plan, levels=30) else: plan = plan.numpy() ax5.contour(y[:, 0], x[:, 0], plan, levels=30) ax0.set_title('Distributions') ax1.set_title('Estimated $f$') ax2.set_title('Estimated $g$') ax3.set_title('True OT plan') ax4.set_title('RKHS') ax5.set_title('O-S') # ax4.set_title('Estimated OT plan') colors = plt.cm.get_cmap('Blues')(np.linspace(0.2, 1, len(sto_fevals[::2]))) for i, eval in enumerate(sto_fevals[::2]): ax1.plot(grid, eval, color=colors[i], linewidth=2, label=f'O-S $n_t={i * 10 * 2 * 50}$' if i % 2 == 0 else None, zorder=1) for i, eval in enumerate(sto_gevals[::2]): ax2.plot(grid, eval, color=colors[i], linewidth=2, label=None, zorder=1) for ax in (ax1, ax2, ax3): ax.tick_params(axis='both', which='major', labelsize=5) ax.tick_params(axis='both', which='minor', labelsize=5) ax.minorticks_on() ax1.legend(frameon=False, bbox_to_anchor=(-1, -0.53), ncol=5, loc='lower left') # ax2.legend(frameon=False, bbox_to_anchor=(0., 1), loc='upper left') sns.despine(fig) for ax in [ax3, ax4, ax5]: ax.axis('off') ax0.axes.get_yaxis().set_visible(False) plt.savefig(join(get_output_dir(), 'continuous.pdf')) plt.show()
'batch_exp': [0, .5, 1], 'lr_exp': ['auto'] }) subsampled = ParameterGrid({ 'data_source': data_sources, 'n_samples': [1000], 'batch_size': [10], 'n_iter': [10000], 'seed': seeds, 'epsilon': epsilons, 'max_calls': [1e8], 'method': ['sinkhorn'], }) grids = [reference, compete, subsampled] job_folder = join(get_output_dir(), 'online', 'jobs') project_root = os.path.abspath(os.getcwd()) if not os.path.exists(job_folder): os.makedirs(job_folder) config_str = '' nb_jobs = 0 for grid in grids: for index, config in enumerate(grid): config_str += ' '.join(f'{key}={value}' for key, value in config.items()) + '\n' nb_jobs += 1 print(nb_jobs) config_file = join(job_folder, f'config.txt') with open(config_file, 'w+') as f: