Esempio n. 1
0
 def template(x, init_key=None):
   layer = state.init(ScalarMul(2 * jnp.ones(1)), name='scalar_mul')(
       init_key, x)
   x, layer = layer.call_and_update(x)
   x, layer = layer.call_and_update(x)
   state.assign(layer, name='scalar_mul')
   return x[0]
Esempio n. 2
0
 def update(params, updates, init_key=None):
   del params
   if init_key is None:
     raise ValueError('`init_key` cannot be `None`.')
   mu_key, nu_key = random.split(init_key)
   mu = state.variable(
       tree_map(lambda g: np.zeros(g.shape), updates), key=mu_key, name='mu')
   nu = state.variable(
       tree_map(lambda g: np.zeros(g.shape), updates), key=nu_key, name='nu')
   mu = state.assign(_update_moment(updates, mu, decay, 1), name='mu')
   nu = state.assign(_update_moment(updates, nu, decay, 2), name='nu')
   updates = tree_map(lambda g, m, n: g / np.sqrt(n - m**2 + eps), updates, mu,
                      nu)
   return updates
Esempio n. 3
0
 def update(params, updates, init_key=None):
   del params
   count_key, seed_key = random.split(init_key)
   count = state.variable(0., key=count_key, name='count')
   rng_key = state.variable(random.PRNGKey(seed), key=seed_key, name='rng_key')
   num_vars = len(tree_leaves(updates))
   treedef = tree_structure(updates)
   variance = eta / (1 + count)**gamma
   all_keys = random.split(rng_key, num_vars + 1)
   noise = tree_map(lambda g, k: random.normal(k, shape=g.shape), updates,
                    tree_unflatten(treedef, all_keys[1:]))
   updates = tree_map(lambda g, n: g + variance * n, updates, noise)
   updates, count, rng_key = primitive.tie_all(
       updates, state.assign(count + 1., name='count', key=updates),
       state.assign(all_keys[0], name='rng_key', key=updates))
   return updates
Esempio n. 4
0
 def update(params, updates, init_key=None):
   del params
   if init_key is None:
     raise ValueError('`init_key` cannot be `None`.')
   count_key, mu_key, nu_key = random.split(init_key, 3)
   count = state.variable(0., key=count_key, name='count')
   mu = state.variable(
       tree_map(lambda g: np.zeros(g.shape), updates), key=mu_key, name='mu')
   nu = state.variable(
       tree_map(lambda g: np.zeros(g.shape), updates), key=nu_key, name='nu')
   mu = state.assign(_update_moment(updates, mu, b1, 1), name='mu')
   nu = state.assign(_update_moment(updates, nu, b2, 2), name='nu')
   count = state.assign(count + 1., name='count', key=updates)
   mu_hat = tree_map(lambda t: t / (1 - b1**count), mu)
   nu_hat = tree_map(lambda t: t / (1 - b2**count), nu)
   updates = tree_map(lambda m, v: m / (np.sqrt(v) + eps), mu_hat, nu_hat)
   return updates
Esempio n. 5
0
 def update(params, updates, init_key=None):
     del params
     if init_key is None:
         raise ValueError('`init_key` cannot be `None`.')
     count = state.variable(0., key=init_key, name='count')
     updates = tree_map(lambda g: step_size_fn(count) * g, updates)
     updates, count = primitive.tie_all(
         updates, state.assign(count + 1., name='count', key=updates))
     return updates
Esempio n. 6
0
 def update(params, updates, init_key=None):
   del params
   if init_key is None:
     raise ValueError('`init_key` cannot be `None`.')
   nu = state.variable(
       tree_map(lambda g: np.zeros(g.shape), updates), key=init_key, name='nu')
   nu = state.assign(_update_moment(updates, nu, decay, 2), name='nu')
   updates = tree_map(lambda g, n: g / (np.sqrt(n + eps)), updates, nu)
   return updates
Esempio n. 7
0
    def update(params, updates, init_key=None):
        del params
        count_key, grad_acc_key = random.split(init_key)
        count = state.variable(0., key=count_key, name='count')
        grad_acc = state.variable(tree_map(lambda g: np.zeros(g.shape),
                                           updates),
                                  key=grad_acc_key,
                                  name='grad_acc')

        c = count % k
        acc = c != 0
        grad_acc = state.assign(tree_map(lambda g, ga: acc * ga + g, updates,
                                         grad_acc),
                                name='grad_acc')
        emit = c == (k - 1)
        updates = tree_map(lambda ga: emit * ga, grad_acc)
        updates, count = primitive.tie_all(
            updates, state.assign(count + 1., name='count', key=updates))
        return updates
Esempio n. 8
0
 def update(params, updates, init_key=None):
     del params
     if init_key is None:
         raise ValueError('`init_key` cannot be `None`.')
     tr = state.variable(tree_map(lambda g: np.zeros(g.shape), updates),
                         key=init_key,
                         name='trace')
     f = lambda g, t: g + decay * t
     update_trace = state.assign(tree_map(f, updates, tr), name='trace')
     updates = tree_map(f, updates,
                        update_trace) if nesterov else update_trace
     return updates
Esempio n. 9
0
    def run(params, init_key=None):
        opt = state.init(gradient_descent(update, objective),
                         name='opt')(init_key, params)

        def body(carry, _):
            opt, params = carry
            params, opt = opt.call_and_update(params)
            return (opt, params), ()

        opt, params = lax.scan(body, (opt, params), np.arange(num_iters))[0]
        opt, params = primitive.tie_all(state.assign(opt, name='opt'), params)
        return params
Esempio n. 10
0
    def step(key, state, init_key=None):
        kernel = st.init(kernel_fn, name='kernel')(init_key, key, state)

        def body(carry, key):
            kernel, state = carry
            state, kernel = kernel.call_and_update(key, state)
            for cb in callbacks:
                kernel, state, _ = primitive.tie_all(kernel, state,
                                                     cb(kernel, state))
            return (kernel, state), state

        (kernel, _), states = lax.scan(body, (kernel, state),
                                       random.split(key, num_steps))
        return primitive.tie_in(st.assign(kernel, name='kernel'), states)