def optimizer_fn(net_params, step_size=1e-3): opt = trax_opt.Adam(step_size=step_size, b1=0.9, b2=0.999, eps=1e-08) opt_init = lambda x: (x, opt.tree_init(x)) opt_update = lambda i, g, s: opt.tree_update(i, g, s[0], s[1]) get_params = lambda x: x[0] opt_state = opt_init(net_params) return opt_state, opt_update, get_params
def optimizer_fn(net_params, learning_rate=1e-3): """Exposes a convenient interface for the optimizer. Args: net_params: A nested structure of network parameters. learning_rate: Learning rate. Returns: A tuple (opt_state, opt_update, get_params), where: opt_state: Pair (net_params, opt_slots) - initial optimization state. opt_update: Function (step, grads, opt_state) -> opt_state doing one optimization step. get_params: Function opt_state -> net_params for extracting the network parameters from the optimization state. """ opt = trax_opt.Adam(learning_rate=learning_rate, b1=0.9, b2=0.999, eps=1e-08) (init_slots, init_nontrainable_slots) = opt.tree_init(net_params) init_state = (net_params, init_slots) def opt_update(step, grads, opt_state): (params, slots) = opt_state # Pass the initial nontrainable_slots as we don't tune them during training. # (yet!) return opt.tree_update(step, grads, params, slots, init_nontrainable_slots) def get_params(opt_state): (params, _) = opt_state return params return init_state, opt_update, get_params