taskbatch(task_fn=task_fn, batch_size=args.task_batch_size, n_task=args.n_train_task)), total=args.n_train_task // args.task_batch_size): aux = dict() # ntk if i == 0 or (i + 1) % (args.n_train_task // args.task_batch_size // ntk_frequency) == 0: ntk = tangents.ntk(f, batch_size=100)(outer_get_params(outer_state), task_eval['x_train']) aux['ntk_train_rank_eval'] = onp.linalg.matrix_rank(ntk) f_lin = tangents.linearize(f, outer_get_params(outer_state_lin)) ntk_lin = tangents.ntk(f_lin, batch_size=100)( outer_get_params(outer_state_lin), task_eval['x_train']) aux['ntk_train_rank_eval_lin'] = onp.linalg.matrix_rank(ntk_lin) log.append([(key, aux[key]) for key in win_rank_eval_keys]) # spectrum evals, evecs = onp.linalg.eigh(ntk) # eigenvectors are columns for j in range(len(evals)): aux[f'ntk_spectrum_{j}_eval'] = evals[j] log.append([(key, aux[key]) for key in win_spectrum_eval_keys]) evals = evals.clip(min=1e-10) ind = onp.arange(len(evals)) + 1 # +1 because we are taking log ind = ind[::-1] X = onp.stack([ind, evals], axis=1) logX = onp.log10( X ) # don't ignore the clipped eigenvalues when doing linear regression slope, intercept, r_value, p_value, std_err = stats.linregress(logX)
plotter = VisdomPlotter(viz) plot_update_period = 100 eval_period = 1 total_iters = args.n_train_task // args.task_batch_size for i, task_batch in tqdm(enumerate(taskbatch(task_fn=task_fn, batch_size=args.task_batch_size, n_task=args.n_train_task)), total=total_iters): # optimization step state, aux = step(i, state, task_batch=( task_batch['x_train'], task_batch['y_train'], task_batch['x_test'], task_batch['y_test'] )) # append iteration and training aux info to log log.append([('update', i)]) log.append([(key, aux[key]) for key in training_keys]) if i == 0 or (i + 1) % eval_period == 0: aux_eval = eval(i, state) log.append([(key, aux_eval[key]) for key in eval_keys]) if (i + 1) % plot_update_period == 0: plotter.log_to_line( win_name='loss', log=log, plot_keys=win_loss_keys, title='losses, training tasks', xlabel='update', ylabel='loss', X=log['update'],
raise ValueError for i, task_batch in tqdm(enumerate(taskbatch(task_fn=task_fn, batch_size=args.task_batch_size, n_task=args.n_train_task)), total=args.n_train_task // args.task_batch_size): outer_state, aux = outer_step( i=i, state=outer_state, task_batch=( task_batch['x_train'], task_batch['y_train'], task_batch['x_test'], task_batch['y_test'])) log.append([('update', i)]) log.append([(key, aux[key]) for key in all_plot_keys]) if (i + 1) % (args.n_train_task // args.task_batch_size // 100) == 0: if args.dataset == 'sinusoid': title = 'maml sinusoid regression' ylabel = 'half l2 loss' elif args.dataset == 'omniglot': title = f'maml omniglot {args.n_way}-way {args.n_support}-shot classification' ylabel = 'cross-entropy loss' else: raise ValueError plotter.log_to_line( win_name='loss', log=log,
return opt_apply(i, grad_loss_lin(params, x_train, y_train), state), l_train, l_test win_loss, win_rmse = None, None for i in tqdm(range(args.n_inner_step)): rmse_train = rmse(state, state_lin, x_train) rmse_test = rmse(state, state_lin, x_test) state, l_train, l_test = step(i, state) state_lin, l_train_lin, l_test_lin = step_lin(i, state_lin) fx_train, fx_test = predictor(fx_train_ana_init, fx_test_ana_init, args.inner_step_size * (i + 1)) l_train_lin_ana = loss(fx_train, y_train) l_test_lin_ana = loss(fx_test, y_test) log.append([('iteration', i)]) log.append([('loss_train', l_train), ('loss_test', l_test), ('loss_train_lin', l_train_lin), ('loss_test_lin', l_test_lin)]) log.append([('loss_train_lin_ana', l_train_lin_ana), ('loss_test_lin_ana', l_test_lin_ana)]) log.append([('rmse_train', rmse_train), ('rmse_test', rmse_test)]) if (i + 1) % (args.n_inner_step // 100) == 0: win_loss_keys = [ 'loss_train', 'loss_test', 'loss_train_lin', 'loss_test_lin', 'loss_train_lin_ana', 'loss_test_lin_ana' ] win_rmse_keys = ['rmse_train', 'rmse_test'] if win_loss is None: win_loss = viz.line(
log_eval = Log(keys=['update'] + aux_keys) win_loss, win_loss_eval = None, None _, params = net_init(rng=random.PRNGKey(42), input_shape=(-1, 1)) state = outer_opt_init(params) for i, task_batch in tqdm(enumerate(taskbatch(task_fn=sinusoid_task, batch_size=args.task_batch_size, n_task=args.n_train_task, n_support=args.n_support, n_query=args.n_query))): state, aux = outer_step(i=i, state=state, task=(task_batch['x_train'], task_batch['y_train'], task_batch['x_test'], task_batch['y_test'])) aux, aux_eval = aux log.append([(key, aux[key]) for key in aux_keys] + [('update', i)]) log_eval.append([(key, aux_eval[key]) for key in aux_eval_keys] + [('update', i)]) if (i + 1) % (args.n_train_task // args.task_batch_size // 100) == 0: win_loss_X = log['update'] win_loss_Y = onp.stack([log[key] for key in aux_keys] + [onp.convolve(log['loss_train'], [0.05] * 20, 'same')], axis=1) if win_loss is None: win_loss = viz.line( X=win_loss_X, Y=win_loss_Y, opts=dict(title=f'maml sinusoid regression on meta-training tasks', xlabel='update', ylabel='half l2 loss', legend=['train', 'test', 'train_lin', 'test_lin', 'train_smooth'], dash=onp.array(['dot' for i in range(len(aux_keys) + 1)]))