Exemple #1
0
    def sinusoid():

        net_init, net_fn = mlp(n_output=1,
                               n_hidden_layer=2,
                               bias_coef=1.0,
                               n_hidden_unit=40,
                               activation='relu',
                               norm='batch_norm')

        rng = random.PRNGKey(42)
        in_shape = (-1, 1)
        out_shape, net_params = net_init(rng, in_shape)

        def loss(params, batch):
            inputs, targets = batch
            predictions = net_fn(params, inputs)
            return np.mean((predictions - targets)**2)

        opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-2,
                                                               mass=0.9)
        opt_update = jit(opt_update)

        @jit
        def step(i, opt_state, batch):
            params = get_params(opt_state)
            g = grad(loss)(params, batch)
            return opt_update(i, g, opt_state)

        task = sinusoid_task(n_support=1000, n_query=100)

        opt_state = opt_init(net_params)
        for i, (x, y) in enumerate(
                minibatch(task['x_train'],
                          task['y_train'],
                          batch_size=256,
                          train_epochs=1000)):
            opt_state = step(i, opt_state, batch=(x, y))
            if i == 0 or (i + 1) % 100 == 0:
                print(
                    f"train loss: {loss(get_params(opt_state), (task['x_train'], task['y_train']))},"
                    f"\ttest loss: {loss(get_params(opt_state), (task['x_test'], task['y_test']))}"
                )
else:
    raise ValueError

grad_loss = jit(grad(lambda p, x, y: loss(f(p, x), y)))
param_loss = jit(lambda p, x, y: loss(f(p, x), y))

# optimizers    #TODO: separate optimizers for nonlinear and linear?
outer_opt_init, outer_opt_update, outer_get_params = select_opt(
    args.outer_opt_alg, args.outer_step_size)()
inner_opt_init, inner_opt_update, inner_get_params = select_opt(
    args.inner_opt_alg, args.inner_step_size)()

# consistent task for plotting eval
if args.dataset == 'sinusoid':
    task_eval = sinusoid_task(n_support=args.n_support,
                              n_query=args.n_query,
                              noise_std=args.noise_std)
elif args.dataset == 'omniglot':
    omniglot_splits = load_omniglot(n_support=args.n_support,
                                    n_query=args.n_query)
    task_eval = omniglot_task(omniglot_splits['val'],
                              n_way=args.n_way,
                              n_support=args.n_support,
                              n_query=args.n_query)
elif args.dataset == 'circle':
    task_eval = circle_task(n_way=args.n_way,
                            n_support=args.n_support,
                            n_query=args.n_query)
else:
    raise ValueError
for run in tqdm(range(args.n_repeat)):
    # build network
    net_init, f = mlp(n_output=1,
                      n_hidden_layer=args.n_hidden_layer,
                      n_hidden_unit=args.n_hidden_unit,
                      bias_coef=args.bias_coef,
                      activation=args.activation,
                      norm=args.norm)

    # initialize network
    key = random.PRNGKey(run)
    _, params = net_init(key, (-1, 1))

    # data
    task = sinusoid_task(n_support=args.n_support)
    x_train, y_train, x_test, y_test = task['x_train'], task['y_train'], task[
        'x_test'], task['y_test']

    # linearized network
    f_lin = tangents.linearize(f, params)

    # Create an MSE predictor to solve the NTK equation in function space.
    theta = tangents.ntk(f, batch_size=32)
    g_dd = theta(params, x_train)
    g_td = theta(params, x_test, x_train)
    predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td)
    import ipdb
    ipdb.set_trace()

    # Get initial values of the network in function space.
from data import sinusoid_task
import ipdb
import numpy as onp
from sklearn import gaussian_process
import matplotlib.pyplot as plt

task = sinusoid_task(n_support=100, n_query=10, noise_std=0.05)

periodic_kernel = gaussian_process.kernels.ExpSineSquared()
gpr = gaussian_process.GaussianProcessRegressor(kernel=periodic_kernel,
                                                n_restarts_optimizer=100,
                                                alpha=0.05)


def assess(gpr, task):
    y_pred = gpr.predict(X=task['x_test'])
    print(f"test MSE: {onp.mean(onp.square(y_pred - task['y_test']))}")


def get_evals(kernel, task):
    k_train = kernel(task['x_train'])
    k_test = kernel(task['x_test'])
    evals_train, _ = onp.linalg.eigh(k_train)
    evals_test, _ = onp.linalg.eigh(k_test)

    return evals_train, evals_test


assess(gpr, task)
evals_train_pre, evals_test_pre = get_evals(periodic_kernel, task)
gpr.fit(X=task['x_train'], y=task['y_train'])
    loss = lambda fx, targets: -np.sum(logsoftmax(fx) * targets) / targets.shape[0]
    acc = lambda fx, targets: np.mean(np.argmax(logsoftmax(fx), axis=-1) == np.argmax(targets, axis=-1))
    param_acc = jit(lambda p, x, y: acc(f(p, x), y))
else:
    raise ValueError

grad_loss = jit(grad(lambda p, x, y: loss(f(p, x), y)))
param_loss = jit(lambda p, x, y: loss(f(p, x), y))

# optimizers    #TODO: separate optimizers for nonlinear and linear?
outer_opt_init, outer_opt_update, outer_get_params = select_opt(args.outer_opt_alg, args.outer_step_size)()
inner_opt_init, inner_opt_update, inner_get_params = select_opt(args.inner_opt_alg, args.inner_step_size)()

# consistent task for plotting eval
if args.dataset == 'sinusoid':
    task_eval = sinusoid_task(n_support=args.n_support, n_query=args.n_query)
elif args.dataset == 'omniglot':
    omniglot_splits = load_omniglot(n_support=args.n_support, n_query=args.n_query)
    task_eval = omniglot_task(omniglot_splits['val'], n_way=args.n_way, n_support=args.n_support, n_query=args.n_query)
elif args.dataset == 'circle':
    task_eval = circle_task(n_way=args.n_way, n_support=args.n_support, n_query=args.n_query)
else:
    raise ValueError

outer_state = outer_opt_init(params)
outer_state_lin = outer_opt_init(params)

plotter = VisdomPlotter(viz)

if args.dataset == 'sinusoid':
    task_fn = partial(sinusoid_task, n_support=args.n_support, n_query=args.n_query)