Ejemplo n.º 1
0
def optimizer_fn(net_params, step_size=1e-3):
    opt = trax_opt.Adam(step_size=step_size, b1=0.9, b2=0.999, eps=1e-08)
    opt_init = lambda x: (x, opt.tree_init(x))
    opt_update = lambda i, g, s: opt.tree_update(i, g, s[0], s[1])
    get_params = lambda x: x[0]
    opt_state = opt_init(net_params)
    return opt_state, opt_update, get_params
Ejemplo n.º 2
0
def optimizer_fn(net_params, learning_rate=1e-3):
  """Exposes a convenient interface for the optimizer.

  Args:
    net_params: A nested structure of network parameters.
    learning_rate: Learning rate.

  Returns:
    A tuple (opt_state, opt_update, get_params), where:
      opt_state: Pair (net_params, opt_slots) - initial optimization state.
      opt_update: Function (step, grads, opt_state) -> opt_state doing one
        optimization step.
      get_params: Function opt_state -> net_params for extracting the network
        parameters from the optimization state.
  """
  opt = trax_opt.Adam(learning_rate=learning_rate, b1=0.9, b2=0.999, eps=1e-08)
  (init_slots, init_nontrainable_slots) = opt.tree_init(net_params)
  init_state = (net_params, init_slots)

  def opt_update(step, grads, opt_state):
    (params, slots) = opt_state
    # Pass the initial nontrainable_slots as we don't tune them during training.
    # (yet!)
    return opt.tree_update(step, grads, params, slots, init_nontrainable_slots)

  def get_params(opt_state):
    (params, _) = opt_state
    return params

  return init_state, opt_update, get_params