Exemple #1
0
def gather(output_dirs):
    traces = []

    for output_dir in output_dirs:
        for exp_dir in os.listdir(output_dir):
            try:
                conf = json.load(
                    open(join(output_dir, exp_dir, 'config.json'), 'r'))
                run = json.load(
                    open(join(output_dir, exp_dir, 'run.json'), 'r'))
                status = run['status']
            except:
                print(f'No trace for {exp_dir}')
                continue
            try:
                trace = torch.load(join(output_dir, exp_dir, 'artifacts',
                                        'results.pkl'),
                                   map_location=torch.device('cpu'))['trace']
            except:
                print(f'No trace for {exp_dir}, {status}')
                continue
            print(output_dir, exp_dir, conf, len(trace))
            trace = pd.DataFrame(trace)
            for k, v in conf.items():
                if k not in ['n_iter', 'n_samples']:
                    trace[k] = v
                else:
                    if k == 'n_iter':
                        trace['total_n_iter'] = v
                trace['exp_dir'] = exp_dir
                trace['status'] = status
            traces.append(trace)
    traces = pd.concat(traces)
    traces.to_pickle(join(get_output_dir(), 'all.pkl'))
    return traces
Exemple #2
0
def create_one(config, index=0):
    job_folder = join(get_output_dir(), 'online', 'jobs')
    project_root = os.path.abspath(os.getcwd())
    if not os.path.exists(job_folder):
        os.makedirs(job_folder)
    filename = join(job_folder, f'run_{index}.slurm')
    config_str = ' '.join(f'{key}={value}' for key, value in config.items())
    with open(filename, 'w+') as f:
        f.write(
            SLURM_TEMPLATE.format(project_root=project_root,
                                  filename=filename,
                                  config_str=config_str))
    return filename
Exemple #3
0
import os
from os.path import join

import numpy as np
import torch
from sacred import Experiment
from sacred.observers import FileStorageObserver

from onlikhorn.algorithm import sinkhorn, online_sinkhorn, random_sinkhorn, subsampled_sinkhorn, schedule
from onlikhorn.cache import torch_cached
from onlikhorn.dataset import get_output_dir, make_data
from onlikhorn.gaussian import sinkhorn_gaussian

exp_name = 'online_grid_quiver_2'
exp = Experiment(exp_name)
exp_dir = join(get_output_dir(), exp_name)
exp.observers = [FileStorageObserver(exp_dir)]


@exp.config
def config():
    data_source = 'gmm_1d'
    n_samples = 10000
    max_length = 20000
    device = 'cuda'

    # Overrided
    batch_size = 100
    seed = 0
    epsilon = 1e-2
Exemple #4
0
def plot_gaussian(df, ytype='test', epsilon=1e-2, refit=False):
    df = df.query(
        '(method in ["online", "sinkhorn", "subsampled"]) & data_source != "dragon" &'
        f'(method != "online" | (refit ==  {refit} & lr_exp == "auto")) &'
        f'epsilon == {epsilon}')

    if refit:  # Remove saturated memory
        df.loc[(df['method'] == 'online') & (df['n_samples'] == 40000),
               ['ref_err_test', 'fixed_err']] = np.nan

    pk = [
        'data_source', 'epsilon', 'method', 'refit', 'batch_exp', 'lr_exp',
        'batch_size', 'n_iter'
    ]
    df = df.groupby(by=pk).agg(['mean', 'std']).reset_index('n_iter')

    NAMES = {'gaussian_2d': '2D Gaussian', 'gaussian_10d': '10D Gaussian'}

    df1 = df
    fig, axes = plt.subplots(1, 2, figsize=(width, width * 0.2))
    order = {'gaussian_2d': 0, 'gaussian_10d': 1}
    for (data_source, epsilon), df2 in df1.groupby(['data_source', 'epsilon']):
        iter_at_prec = {}
        for index, df3 in df2.groupby(
            ['method', 'refit', 'batch_exp', 'lr_exp', 'batch_size']):
            n_calls = df3['n_calls']
            if ytype == 'train':
                y = df3['ref_err_train']
            elif ytype == 'test':
                y = df3['ref_err_test']
            elif ytype == 'err':
                y = df3['fixed_err']
            else:
                raise ValueError
            if index[0] == 'sinkhorn':
                label = 'Sinkhorn $n = 10^4$'
            elif index[0] == 'subsampled':
                label = f'Sinkhorn $n = {index[-1]}$'
            else:
                if index[2] == 0:
                    label = f'O-S $n(t) = 100$'
                else:
                    label = f'O-S $n(t) \propto t^{{{index[2]}}}$'
            axes[order[data_source]].plot(n_calls['mean'],
                                          y['mean'],
                                          label=label,
                                          linewidth=2,
                                          alpha=0.8)

        axes[order[data_source]].annotate(NAMES[data_source],
                                          xy=(.5, .83),
                                          xycoords="axes fraction",
                                          ha='center',
                                          va='bottom')
    axes[0].annotate('Computations',
                     xy=(-.2, -.25),
                     xycoords="axes fraction",
                     ha='center',
                     va='bottom')
    axes[1].legend(loc='center left',
                   frameon=False,
                   bbox_to_anchor=(1.03, 0.5),
                   ncol=1)
    for ax in axes:
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.tick_params(axis='both', which='major', labelsize=5)
        ax.tick_params(axis='both', which='minor', labelsize=5)
        ax.minorticks_on()
    if ytype == 'err':
        axes[0].set_ylabel(r'$\Vert T(\hat f){-}\hat g\Vert_{\textrm{var}}$',
                           fontsize=5)
    elif ytype == 'train':
        axes[0].set_ylabel(r'$\Vert \hat f {-} f^\star\Vert_{\textrm{var}}$',
                           fontsize=5)
    else:
        axes[0].set_ylabel(r'$\Vert \hat f {-} f^\star\Vert_{\textrm{var}}$',
                           fontsize=5)
    sns.despine(fig)
    fig.subplots_adjust(right=0.75, bottom=0.21)
    fig.savefig(
        join(get_output_dir(),
             f'online_{epsilon}_{refit}_{ytype}_gaussian.pdf'))
Exemple #5
0
def plot_random(df, ytype='test', epsilon=1e-2, name=''):
    df = df.query(
        '(method in ["random", "sinkhorn", "subsampled"]) & data_source != "dragon" &'
        f'epsilon == {epsilon}')

    pk = [
        'data_source', 'epsilon', 'method', 'refit', 'batch_exp', 'lr_exp',
        'batch_size', 'n_iter'
    ]
    df = df.groupby(by=pk).agg(['mean', 'std']).reset_index('n_iter')

    NAMES = {'gmm_1d': '1D GMM', 'gmm_10d': '10D GMM', 'gmm_2d': '2D GMM'}

    df1 = df
    fig, axes = plt.subplots(1, 3, figsize=(width, width * 0.2))
    order = {'gmm_1d': 0, 'gmm_10d': 2, 'gmm_2d': 1}
    for (data_source, epsilon), df2 in df1.groupby(['data_source', 'epsilon']):
        iter_at_prec = {}
        for index, df3 in df2.groupby(
            ['method', 'refit', 'batch_exp', 'lr_exp', 'batch_size']):
            n_calls = df3['n_calls']
            if ytype == 'train':
                y = df3['ref_err_train']
            elif ytype == 'test':
                y = df3['ref_err_test']
            elif ytype == 'err':
                y = df3['fixed_err']
            else:
                raise ValueError
            if index[0] == 'sinkhorn':
                label = 'Sinkhorn $n = 10^4$'
            elif index[0] == 'subsampled':
                label = f'Sinkhorn $n = {index[-1]}$'
            else:
                if index[2] == 0:
                    label = f'R-S $n = {index[-1]}$'
                else:
                    label = f'R-S $n = {index[-1]}$'
            axes[order[data_source]].plot(n_calls['mean'],
                                          y['mean'],
                                          label=label,
                                          linewidth=2,
                                          alpha=0.8)

        axes[order[data_source]].annotate(NAMES[data_source],
                                          xy=(.5, .83),
                                          xycoords="axes fraction",
                                          ha='center',
                                          va='bottom')
    axes[0].annotate('Computations',
                     xy=(-.3, -.25),
                     xycoords="axes fraction",
                     ha='center',
                     va='bottom')
    axes[2].legend(loc='center left',
                   frameon=False,
                   bbox_to_anchor=(1.03, 0.5),
                   ncol=1)
    for ax in axes:
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.tick_params(axis='both', which='major', labelsize=5)
        ax.tick_params(axis='both', which='minor', labelsize=5)
        ax.minorticks_on()
    if ytype == 'err':
        axes[0].set_ylabel(r'$\Vert T(\hat f){-}\hat g\Vert_{\textrm{var}}$',
                           fontsize=5)
    elif ytype == 'train':
        axes[0].set_ylabel(r'$\Vert \hat f {-} f^\star\Vert_{\textrm{var}}$',
                           fontsize=5)
    else:
        axes[0].set_ylabel(r'$\Vert \hat f {-} f_0^\star\Vert_{\textrm{var}}$',
                           fontsize=5)
    sns.despine(fig)
    fig.subplots_adjust(right=0.75, bottom=0.21)
    fig.savefig(join(get_output_dir(), f'random_{epsilon}_{ytype}_{name}.pdf'))
Exemple #6
0
def plot_warmup(df, ytype='err', epsilon=1e-2):
    df = df.query(
        '(method in ["sinkhorn_precompute", "online_as_warmup"]) & '
        '(data_source in ["gmm_2d", "gmm_10d", "dragon"]) & '
        'batch_size == 100 & '
        f'epsilon == {epsilon} & '
        '((refit == False & batch_exp == .5 & lr_exp == 0) | method != "online_as_warmup")'
    )

    # df.loc[df['fixed_err'].isna(), 'fixed_err'] = df.loc[df['fixed_err'].isna(), 'var_err_train'].values

    pk = [
        'data_source', 'epsilon', 'method', 'refit', 'batch_exp', 'lr_exp',
        'batch_size', 'n_iter'
    ]
    df = df.groupby(by=pk).agg(['mean', 'std']).reset_index('n_iter')

    NAMES = {'dragon': 'Stanford 3D', 'gmm_10d': '10D GMM', 'gmm_2d': '2D GMM'}
    fig, axes = plt.subplots(1, 3, figsize=(0.7 * width, width * 0.2 * 0.7))
    order = {'gmm_2d': 0, 'gmm_10d': 1, 'dragon': 2}
    speedups = []
    for (data_source, epsilon), df2 in df.groupby(['data_source', 'epsilon']):
        iter_at_prec = {}
        for index, df3 in df2.groupby(
            ['method', 'refit', 'batch_exp', 'lr_exp', 'batch_size']):
            n_calls = df3['n_calls']
            if ytype == 'train':
                y = df3['ref_err_train']
            elif ytype == 'test':
                y = df3['ref_err_test']
            elif ytype == 'err':
                y = df3['fixed_err']
            else:
                raise ValueError
            if index[0] == 'sinkhorn_precompute':
                label = 'Standard\nSinkhorn'
            else:
                label = 'Online\nSinkhorn\nwarmup'
            axes[order[data_source]].plot(
                n_calls['mean'],
                y['mean'],
                label=label,
                linewidth=2,
                alpha=0.8,
                color='C3' if index[0] == 'sinkhorn_precompute' else 'C0')
            # if index[0] != 'sinkhorn_precompute':
            #     axes[i].fill_between(n_calls['mean'], y['mean'] - y['std'], y['mean'] + y['std'],
            #                          alpha=0.2)
            try:
                iter_at_prec[index[0]] = n_calls['mean'].iloc[np.where(
                    y['mean'] < 1e-3)[0][0]]
            except IndexError:
                iter_at_prec[index[0]] = np.float('inf')
        axes[order[data_source]].annotate(NAMES[data_source],
                                          xy=(.5, .8),
                                          xycoords="axes fraction",
                                          ha='center',
                                          va='bottom')
        speedups.append(
            dict(data_source=data_source,
                 epsilon=epsilon,
                 speedup=iter_at_prec['sinkhorn_precompute'] /
                 iter_at_prec['online_as_warmup'],
                 ytype=ytype))
    axes[0].annotate('Comput.',
                     xy=(-.27, -.32),
                     xycoords="axes fraction",
                     ha='center',
                     va='bottom')
    if ytype == 'err' and epsilon == 1e-3:
        axes[0].set_ylim([1e-5, 0.5])
        axes[0].set_ylim([1e-6, 0.1])
        axes[0].set_ylim([1e-4, 0.1])
    axes[2].legend(loc='center left',
                   frameon=False,
                   bbox_to_anchor=(.95, 0.5),
                   ncol=1)
    for ax in axes:
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.tick_params(axis='both', which='major', labelsize=5)
        ax.tick_params(axis='both', which='minor', labelsize=5)
        ax.minorticks_on()
        # ax.set_xlim([0.8e9, 1.2e11])
        # ax.set_xticks([1e9, 1e10, 1e11])
    if ytype == 'err':
        axes[0].set_ylabel(r'$\Vert T(\hat f){-} \hat g\Vert_{\textrm{var}}$',
                           fontsize=5)
    elif ytype == 'train':
        axes[0].set_ylabel(r'$\Vert \hat f {-} f^\star\Vert_{\textrm{var}}$',
                           fontsize=5)
    else:
        axes[0].set_ylabel('$\Vert f {-} f_0^\star\Vert_{\textrm{var}}$',
                           fontsize=5)
    sns.despine(fig)
    fig.subplots_adjust(right=0.8, bottom=0.23)
    fig.savefig(join(get_output_dir(), f'online+full_{epsilon}_{ytype}.pdf'))
    return speedups
Exemple #7
0
        axes[0].set_ylabel(r'$\Vert \hat f {-} f^\star\Vert_{\textrm{var}}$',
                           fontsize=5)
    else:
        axes[0].set_ylabel(r'$\Vert \hat f {-} f^\star\Vert_{\textrm{var}}$',
                           fontsize=5)
    sns.despine(fig)
    fig.subplots_adjust(right=0.75, bottom=0.21)
    fig.savefig(
        join(get_output_dir(),
             f'online_{epsilon}_{refit}_{ytype}_gaussian.pdf'))


pipeline = ['random']

if 'gather' in pipeline:
    output_dirs = [join(get_output_dir(), 'online_grid12')]
    df = gather(output_dirs)
    df.to_pickle(join(get_output_dir(), 'all_warmup.pkl'))

# Figure 1
if 'figure_1' in pipeline:
    output_dirs = [
        join(get_output_dir(), 'online_grid10'),
        join(get_output_dir(), 'online_grid11')
    ]
    df = pd.read_pickle(join(get_output_dir(), 'all.pkl'))
    for ytype in ['test']:
        for epsilon in np.logspace(-4, -1, 4):
            for refit in [False, True]:
                plot_online(df, refit=refit, epsilon=epsilon, ytype=ytype)
    del df
Exemple #8
0
def one_dimensional_exp():
    eps = 1e-1

    grid = torch.linspace(-4, 12, 500)[:, None]
    C = compute_distance(grid, grid)

    x_sampler = Sampler(mean=torch.tensor([[1.], [2], [3]]), cov=torch.tensor([[[.1]], [[.1]], [[.1]]]),
                        p=torch.ones(3) / 3)
    y_sampler = Sampler(mean=torch.tensor([[0.], [3], [5]]), cov=torch.tensor([[[.1]], [[.1]], [[.4]]]),
                        p=torch.ones(3) / 3)
    torch.manual_seed(100)
    np.random.seed(100)
    lpx = x_sampler.log_prob(grid)
    lpy = y_sampler.log_prob(grid)
    lpx -= torch.logsumexp(lpx, dim=0)
    lpy -= torch.logsumexp(lpy, dim=0)
    px = torch.exp(lpx)
    py = torch.exp(lpy)

    fevals = []
    gevals = []
    labels = []
    plans = []

    mem = Memory(location=expanduser('~/cache'))
    n_samples = 5000
    x, loga = x_sampler(n_samples)
    y, logb = y_sampler(n_samples)
    f, g = mem.cache(sinkhorn)(x.numpy(), y.numpy(), eps=eps, n_iter=100)
    f = torch.from_numpy(f).float()
    g = torch.from_numpy(g).float()
    distance = compute_distance(grid, y)
    feval = - eps * torch.logsumexp((- distance + g[None, :]) / eps + logb[None, :], dim=1)
    distance = compute_distance(grid, x)
    geval = - eps * torch.logsumexp((- distance + f[None, :]) / eps + loga[None, :], dim=1)

    plan = (lpx[:, None] + feval[:, None] / eps + lpy[None, :]
            + geval[None, :] / eps - C / eps)

    plans.append((plan, grid, grid))

    fevals.append(feval)
    gevals.append(geval)
    labels.append(f'True potential')

    m = 50
    hatf, posx, hatg, posy, sto_fevals, sto_gevals = sampling_sinkhorn(x_sampler, y_sampler, m=m, eps=eps,
                                                                       n_iter=100,
                                                                       grid=grid)
    feval = evaluate_potential(hatg, posy, grid, eps)
    geval = evaluate_potential(hatf, posx, grid, eps)
    plan = (lpx[:, None] + feval[:, None] / eps + lpy[None, :]
            + geval[None, :] / eps - C / eps)
    plans.append((plan, grid, grid))

    fevals.append(feval)
    gevals.append(geval)
    labels.append(f'Online Sinkhorn')

    alpha, posx, posy, _, _ = rkhs_sinkhorn(x_sampler, y_sampler, m=m, eps=eps,
                                                                 n_iter=100, sigma=1,
                                                                 grid=grid)
    feval = evaluate_kernel(alpha, posx, grid, sigma=1)
    geval = evaluate_kernel(alpha, posy, grid, sigma=1)
    plan = (lpx[:, None] + feval[:, None] / eps + lpy[None, :]
            + geval[None, :] / eps - C / eps)
    plans.append((plan, grid, grid))
    fevals.append(feval)
    gevals.append(geval)
    labels.append(f'RKHS')

    fevals = torch.cat([feval[None, :] for feval in fevals], dim=0)
    gevals = torch.cat([geval[None, :] for geval in gevals], dim=0)

    fig = plt.figure(figsize=(width, .21 * width))
    gs = gridspec.GridSpec(ncols=6, nrows=1, width_ratios=[1, 1.5, 1.5, .8, .8, .8], figure=fig)
    plt.subplots_adjust(right=0.97, left=0.05, bottom=0.27, top=0.85)
    ax0 = fig.add_subplot(gs[0])
    ax0.plot(grid, px, label=r'$\alpha$')
    ax0.plot(grid, py, label=r'$\beta$')
    ax0.legend(frameon=False)
    # ax0.axis('off')
    ax1 = fig.add_subplot(gs[1])
    ax2 = fig.add_subplot(gs[2])
    ax3 = fig.add_subplot(gs[3])
    ax4 = fig.add_subplot(gs[4])
    ax5 = fig.add_subplot(gs[5])

    for i, (label, feval, geval, (plan, x, y)) in enumerate(zip(labels, fevals, gevals, plans)):
        if label == 'True potential':
            ax1.plot(grid, feval, label=label, zorder=100 if label == 'True Potential' else 1,
                     linewidth=3 if label == 'True potential' else 1, color='C3')
            ax2.plot(grid, geval, label=None, zorder=100 if label == 'True Potential' else 1,
                     linewidth=3 if label == 'True potential' else 1, color='C3')
            plan = plan.numpy()
            ax3.contour(y[:, 0], x[:, 0], plan, levels=30)
        elif label == 'RKHS':
            ax1.plot(grid, feval, label=label, linewidth=2, color='C4')
            ax2.plot(grid, geval, label=None, linewidth=2, color='C4')
            plan = plan.numpy()
            ax4.contour(y[:, 0], x[:, 0], plan, levels=30)
        else:
            plan = plan.numpy()
            ax5.contour(y[:, 0], x[:, 0], plan, levels=30)
    ax0.set_title('Distributions')
    ax1.set_title('Estimated $f$')
    ax2.set_title('Estimated $g$')
    ax3.set_title('True OT plan')
    ax4.set_title('RKHS')
    ax5.set_title('O-S')
    # ax4.set_title('Estimated OT plan')
    colors = plt.cm.get_cmap('Blues')(np.linspace(0.2, 1, len(sto_fevals[::2])))
    for i, eval in enumerate(sto_fevals[::2]):
        ax1.plot(grid, eval, color=colors[i],
                 linewidth=2, label=f'O-S $n_t={i * 10 * 2 * 50}$' if i % 2 == 0 else None,
                 zorder=1)
    for i, eval in enumerate(sto_gevals[::2]):
        ax2.plot(grid, eval, color=colors[i],
                 linewidth=2, label=None,
                 zorder=1)
    for ax in (ax1, ax2, ax3):
        ax.tick_params(axis='both', which='major', labelsize=5)
        ax.tick_params(axis='both', which='minor', labelsize=5)
        ax.minorticks_on()
    ax1.legend(frameon=False, bbox_to_anchor=(-1, -0.53), ncol=5, loc='lower left')
    # ax2.legend(frameon=False, bbox_to_anchor=(0., 1), loc='upper left')
    sns.despine(fig)
    for ax in [ax3, ax4, ax5]:
        ax.axis('off')
    ax0.axes.get_yaxis().set_visible(False)
    plt.savefig(join(get_output_dir(), 'continuous.pdf'))
    plt.show()
Exemple #9
0
        'batch_exp': [0, .5, 1],
        'lr_exp': ['auto']
    })
    subsampled = ParameterGrid({
        'data_source': data_sources,
        'n_samples': [1000],
        'batch_size': [10],
        'n_iter': [10000],
        'seed': seeds,
        'epsilon': epsilons,
        'max_calls': [1e8],
        'method': ['sinkhorn'],
    })
    grids = [reference, compete, subsampled]

job_folder = join(get_output_dir(), 'online', 'jobs')
project_root = os.path.abspath(os.getcwd())
if not os.path.exists(job_folder):
    os.makedirs(job_folder)

config_str = ''
nb_jobs = 0
for grid in grids:
    for index, config in enumerate(grid):
        config_str += ' '.join(f'{key}={value}'
                               for key, value in config.items()) + '\n'
        nb_jobs += 1

print(nb_jobs)
config_file = join(job_folder, f'config.txt')
with open(config_file, 'w+') as f: