コード例 #1
0
if args.dataset == 'sinusoid':
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
elif args.dataset in ['omniglot', 'circle']:
    loss = lambda fx, targets: -np.sum(logsoftmax(fx) * targets
                                       ) / targets.shape[0]
    acc = lambda fx, targets: np.mean(
        np.argmax(logsoftmax(fx), axis=-1) == np.argmax(targets, axis=-1))
    param_acc = jit(lambda p, x, y: acc(f(p, x), y))
else:
    raise ValueError

grad_loss = jit(grad(lambda p, x, y: loss(f(p, x), y)))
param_loss = jit(lambda p, x, y: loss(f(p, x), y))

# optimizers    #TODO: separate optimizers for nonlinear and linear?
outer_opt_init, outer_opt_update, outer_get_params = select_opt(
    args.outer_opt_alg, args.outer_step_size)()
inner_opt_init, inner_opt_update, inner_get_params = select_opt(
    args.inner_opt_alg, args.inner_step_size)()

# consistent task for plotting eval
if args.dataset == 'sinusoid':
    task_eval = sinusoid_task(n_support=args.n_support,
                              n_query=args.n_query,
                              noise_std=args.noise_std)
elif args.dataset == 'omniglot':
    omniglot_splits = load_omniglot(n_support=args.n_support,
                                    n_query=args.n_query)
    task_eval = omniglot_task(omniglot_splits['val'],
                              n_way=args.n_way,
                              n_support=args.n_support,
                              n_query=args.n_query)
コード例 #2
0
# data
if args.dataset == 'sinusoid':
    task_fn = partial(sinusoid_task, n_support=args.n_support, n_query=args.n_query, noise_std=args.noise_std)

# build network
if args.dataset == 'sinusoid':
    net_init, f = mlp(n_output=1,
                      n_hidden_layer=args.n_hidden_layer,
                      n_hidden_unit=args.n_hidden_unit,
                      bias_coef=args.bias_coef,
                      activation=args.activation,
                      norm=args.norm)
    _, params = net_init(rng=random.PRNGKey(42), input_shape=(-1, 1))

# optimizer
opt_init, opt_update, get_params = select_opt(args.opt_alg, args.step_size)()

# loss function
if args.dataset == 'sinusoid':
    def loss(fx, y):
        return 0.5 * np.mean((fx - y) ** 2)

grad_loss = grad(lambda p, x, y: loss(f(p, x), y))
param_loss = lambda p, x, y: loss(f(p, x), y)

def pytree_to_array(pytree):
    return np.concatenate([x.flatten() for x in tree_flatten(pytree)[0]])


# if args.stop_gradient:
#     def loss_alignment(g_support, g_query):