Exemplo n.º 1
0
    return softmax_xent, (logits, updated_state)


@partial(jax.jit, static_argnames=('is_training', ))
def loss(params, state, inputs, targets, theta, is_training):
    softmax_xent, (logits,
                   updated_state) = loss_with_logits(params, state, inputs,
                                                     targets, theta,
                                                     is_training)
    return softmax_xent, updated_state


loss_and_grad = jax.jit(jax.value_and_grad(loss, has_aux=True),
                        static_argnames=('is_training', ))

opt_funcs = inner_optim.init_optimizer(args.inner_optimizer)
reset_opt_params = opt_funcs['reset_opt_params']
opt_step = opt_funcs['opt_step']

default_values = {
    'lr': args.lr,
    'b1': args.b1,
    'b2': args.b2,
    'eps': args.eps,
    'mom': args.mom,
    'wd': args.wd,
    'mask': 0.5,
}


class InnerState(NamedTuple):
Exemplo n.º 2
0
                                                 jnp.array(list(range(K))))

  result = jnp.sum(loss_values)
  return result, state_updated


@partial(jax.jit, static_argnums=3)
def unroll_and_L_v(w, theta, inner_optim_params, K):
  w_unrolled, _ = unroll(w, theta, inner_optim_params, K)
  return L_v(w_unrolled)


grad_unroll_and_L_v = jax.jit(
    jax.grad(unroll_and_L_v, argnums=1), static_argnums=3)

opt_funcs = inner_optim.init_optimizer('sgd')
reset_opt_params = opt_funcs['reset_opt_params']
opt_step = opt_funcs['opt_step']

init_opt_params = {'lr': 0.001, 'wd': 0.0}


class InnerState(NamedTuple):
  inner_state: jnp.ndarray
  inner_opt_state: Any
  t: jnp.ndarray
  pert_accums: Optional[jnp.ndarray] = None


def init_state_fn(rng):
  """Initialize the inner parameters."""