n_way=args.n_way, n_support=args.n_support, n_query=args.n_query) elif args.dataset == 'circle': task_fn = partial(circle_task, n_way=args.n_way, n_support=args.n_support, n_query=args.n_query) else: raise ValueError ntk_frequency = 50 plot_update_frequency = 100 for i, task_batch in tqdm(enumerate( taskbatch(task_fn=task_fn, batch_size=args.task_batch_size, n_task=args.n_train_task)), total=args.n_train_task // args.task_batch_size): aux = dict() # ntk if i == 0 or (i + 1) % (args.n_train_task // args.task_batch_size // ntk_frequency) == 0: ntk = tangents.ntk(f, batch_size=100)(outer_get_params(outer_state), task_eval['x_train']) aux['ntk_train_rank_eval'] = onp.linalg.matrix_rank(ntk) f_lin = tangents.linearize(f, outer_get_params(outer_state_lin)) ntk_lin = tangents.ntk(f_lin, batch_size=100)( outer_get_params(outer_state_lin), task_eval['x_train']) aux['ntk_train_rank_eval_lin'] = onp.linalg.matrix_rank(ntk_lin) log.append([(key, aux[key]) for key in win_rank_eval_keys])
# aggregate batch aux for key in list(aux.keys()): aux[f'{key}_batch'] = aux[key] aux[key] = np.mean(aux[key]) return np.mean(loss_b), aux @jit def step(i, state, task_batch): params = get_params(state) g, aux = grad(loss_batch, has_aux=True)(params, task_batch) return opt_update(i, g, state), aux task_batch_eval = list(taskbatch(task_fn=task_fn, batch_size=64, n_task=64))[0] def eval_inner(params, task): x_support, y_support, x_query, y_query = task grads = grad_loss(params, x_support, y_support) # step_size = 1 / np.linalg.norm(pytree_to_array(grads)) step_size = args.alignment_coefficient inner_sgd_fn = lambda g, p: p - step_size * g params_updated = tree_multimap(inner_sgd_fn, grads, params) loss = param_loss(params_updated, x_query, y_query) return loss @jit def eval(i, state):
all_plot_keys = all_plot_keys + win_acc_keys + win_acc_eval_keys log = Log(keys=['update'] + all_plot_keys) outer_state = outer_opt_init(params) plotter = VisdomPlotter(viz) if args.dataset == 'sinusoid': task_fn = partial(sinusoid_task, n_support=args.n_support, n_query=args.n_query) elif args.dataset == 'omniglot': task_fn = partial(omniglot_task, split_dict=omniglot_splits['train'], n_way=args.n_way, n_support=args.n_support, n_query=args.n_query) else: raise ValueError for i, task_batch in tqdm(enumerate(taskbatch(task_fn=task_fn, batch_size=args.task_batch_size, n_task=args.n_train_task)), total=args.n_train_task // args.task_batch_size): outer_state, aux = outer_step( i=i, state=outer_state, task_batch=( task_batch['x_train'], task_batch['y_train'], task_batch['x_test'], task_batch['y_test'])) log.append([('update', i)]) log.append([(key, aux[key]) for key in all_plot_keys]) if (i + 1) % (args.n_train_task // args.task_batch_size // 100) == 0:
_, net_params = net_init(rng, (-1, 1)) opt_state = opt_init(net_params) loss_np = onp.array( []) # faster to use onp here, maybe because jax.np uses GPU iter_np = onp.array([]) win_loss = viz.line(np.array([-1]), np.array([-1]), opts=dict( title='maml loss', xlabel=f'updates ({args.task_batch_size} tasks per)', ylabel='mse')) for i, task_batch in tqdm.tqdm( enumerate( taskbatch(sinusoid_task, batch_size=args.task_batch_size, n_task=args.n_train_task, n_support=args.n_support))): opt_state, l = step(i, opt_state, (task_batch['x_train'], task_batch['y_train'], task_batch['x_test'], task_batch['y_test'])) # wandb.log({'iteration': onp.array(i), 'loss': onp.array(l)}) loss_np = onp.append(loss_np, onp.array(l)) iter_np = onp.append(iter_np, onp.array(i)) if (i + 1) % 1000 == 0: viz.line(loss_np, iter_np, win=win_loss, update='replace') print(f"iteration {i}:\tmaml loss: {l}") net_params = get_params(opt_state) xrange_inputs = np.linspace(-5, 5, 100).reshape(-1, 1) targets = np.sin(xrange_inputs) predictions = net_fn(net_params, xrange_inputs)
g, aux = grad(maml_loss, has_aux=True)(params, task) return outer_opt_update(i, g, state), aux # logging, plotting aux_keys = ['loss_train', 'loss_test', 'loss_train_lin', 'loss_test_lin'] aux_eval_keys = aux_keys log = Log(keys=['update'] + aux_keys) log_eval = Log(keys=['update'] + aux_keys) win_loss, win_loss_eval = None, None _, params = net_init(rng=random.PRNGKey(42), input_shape=(-1, 1)) state = outer_opt_init(params) for i, task_batch in tqdm(enumerate(taskbatch(task_fn=sinusoid_task, batch_size=args.task_batch_size, n_task=args.n_train_task, n_support=args.n_support, n_query=args.n_query))): state, aux = outer_step(i=i, state=state, task=(task_batch['x_train'], task_batch['y_train'], task_batch['x_test'], task_batch['y_test'])) aux, aux_eval = aux log.append([(key, aux[key]) for key in aux_keys] + [('update', i)]) log_eval.append([(key, aux_eval[key]) for key in aux_eval_keys] + [('update', i)]) if (i + 1) % (args.n_train_task // args.task_batch_size // 100) == 0: win_loss_X = log['update'] win_loss_Y = onp.stack([log[key] for key in aux_keys] + [onp.convolve(log['loss_train'], [0.05] * 20, 'same')], axis=1) if win_loss is None: