n_way=args.n_way,
                      n_support=args.n_support,
                      n_query=args.n_query)
elif args.dataset == 'circle':
    task_fn = partial(circle_task,
                      n_way=args.n_way,
                      n_support=args.n_support,
                      n_query=args.n_query)
else:
    raise ValueError

ntk_frequency = 50
plot_update_frequency = 100
for i, task_batch in tqdm(enumerate(
        taskbatch(task_fn=task_fn,
                  batch_size=args.task_batch_size,
                  n_task=args.n_train_task)),
                          total=args.n_train_task // args.task_batch_size):
    aux = dict()
    # ntk
    if i == 0 or (i + 1) % (args.n_train_task // args.task_batch_size //
                            ntk_frequency) == 0:
        ntk = tangents.ntk(f, batch_size=100)(outer_get_params(outer_state),
                                              task_eval['x_train'])
        aux['ntk_train_rank_eval'] = onp.linalg.matrix_rank(ntk)
        f_lin = tangents.linearize(f, outer_get_params(outer_state_lin))
        ntk_lin = tangents.ntk(f_lin, batch_size=100)(
            outer_get_params(outer_state_lin), task_eval['x_train'])
        aux['ntk_train_rank_eval_lin'] = onp.linalg.matrix_rank(ntk_lin)
        log.append([(key, aux[key]) for key in win_rank_eval_keys])
    # aggregate batch aux
    for key in list(aux.keys()):
        aux[f'{key}_batch'] = aux[key]
        aux[key] = np.mean(aux[key])

    return np.mean(loss_b), aux


@jit
def step(i, state, task_batch):
    params = get_params(state)
    g, aux = grad(loss_batch, has_aux=True)(params, task_batch)
    return opt_update(i, g, state), aux


task_batch_eval = list(taskbatch(task_fn=task_fn, batch_size=64, n_task=64))[0]


def eval_inner(params, task):
    x_support, y_support, x_query, y_query = task
    grads = grad_loss(params, x_support, y_support)
    # step_size = 1 / np.linalg.norm(pytree_to_array(grads))
    step_size = args.alignment_coefficient
    inner_sgd_fn = lambda g, p: p - step_size * g
    params_updated = tree_multimap(inner_sgd_fn, grads, params)
    loss = param_loss(params_updated, x_query, y_query)

    return loss

@jit
def eval(i, state):
Exemple #3
0
    all_plot_keys = all_plot_keys + win_acc_keys + win_acc_eval_keys
log = Log(keys=['update'] + all_plot_keys)

outer_state = outer_opt_init(params)

plotter = VisdomPlotter(viz)

if args.dataset == 'sinusoid':
    task_fn = partial(sinusoid_task, n_support=args.n_support, n_query=args.n_query)
elif args.dataset == 'omniglot':
    task_fn = partial(omniglot_task, split_dict=omniglot_splits['train'], n_way=args.n_way, n_support=args.n_support, n_query=args.n_query)
else:
    raise ValueError

for i, task_batch in tqdm(enumerate(taskbatch(task_fn=task_fn,
                                              batch_size=args.task_batch_size,
                                              n_task=args.n_train_task)),
                          total=args.n_train_task // args.task_batch_size):
    outer_state, aux = outer_step(
        i=i,
        state=outer_state,
        task_batch=(
            task_batch['x_train'],
            task_batch['y_train'],
            task_batch['x_test'],
            task_batch['y_test']))

    log.append([('update', i)])
    log.append([(key, aux[key]) for key in all_plot_keys])

    if (i + 1) % (args.n_train_task // args.task_batch_size // 100) == 0:
Exemple #4
0
_, net_params = net_init(rng, (-1, 1))
opt_state = opt_init(net_params)

loss_np = onp.array(
    [])  # faster to use onp here, maybe because jax.np uses GPU
iter_np = onp.array([])
win_loss = viz.line(np.array([-1]),
                    np.array([-1]),
                    opts=dict(
                        title='maml loss',
                        xlabel=f'updates ({args.task_batch_size} tasks per)',
                        ylabel='mse'))
for i, task_batch in tqdm.tqdm(
        enumerate(
            taskbatch(sinusoid_task,
                      batch_size=args.task_batch_size,
                      n_task=args.n_train_task,
                      n_support=args.n_support))):
    opt_state, l = step(i, opt_state,
                        (task_batch['x_train'], task_batch['y_train'],
                         task_batch['x_test'], task_batch['y_test']))
    # wandb.log({'iteration': onp.array(i), 'loss': onp.array(l)})
    loss_np = onp.append(loss_np, onp.array(l))
    iter_np = onp.append(iter_np, onp.array(i))
    if (i + 1) % 1000 == 0:
        viz.line(loss_np, iter_np, win=win_loss, update='replace')
        print(f"iteration {i}:\tmaml loss: {l}")

net_params = get_params(opt_state)
xrange_inputs = np.linspace(-5, 5, 100).reshape(-1, 1)
targets = np.sin(xrange_inputs)
predictions = net_fn(net_params, xrange_inputs)
    g, aux = grad(maml_loss, has_aux=True)(params, task)
    return outer_opt_update(i, g, state), aux


# logging, plotting
aux_keys = ['loss_train', 'loss_test', 'loss_train_lin', 'loss_test_lin']
aux_eval_keys = aux_keys
log = Log(keys=['update'] + aux_keys)
log_eval = Log(keys=['update'] + aux_keys)
win_loss, win_loss_eval = None, None

_, params = net_init(rng=random.PRNGKey(42), input_shape=(-1, 1))
state = outer_opt_init(params)
for i, task_batch in tqdm(enumerate(taskbatch(task_fn=sinusoid_task,
                                              batch_size=args.task_batch_size,
                                              n_task=args.n_train_task,
                                              n_support=args.n_support,
                                              n_query=args.n_query))):
    state, aux = outer_step(i=i, state=state, task=(task_batch['x_train'],
                                                    task_batch['y_train'],
                                                    task_batch['x_test'],
                                                    task_batch['y_test']))
    aux, aux_eval = aux
    log.append([(key, aux[key]) for key in aux_keys] + [('update', i)])
    log_eval.append([(key, aux_eval[key]) for key in aux_eval_keys] + [('update', i)])

    if (i + 1) % (args.n_train_task // args.task_batch_size // 100) == 0:
        win_loss_X = log['update']
        win_loss_Y = onp.stack([log[key] for key in aux_keys] +
                               [onp.convolve(log['loss_train'], [0.05] * 20, 'same')], axis=1)
        if win_loss is None: