Exemplo n.º 1
0
 def body(carry, key):
     kernel, state = carry
     state, kernel = kernel.call_and_update(key, state)
     for cb in callbacks:
         kernel, state, _ = primitive.tie_all(kernel, state,
                                              cb(kernel, state))
     return (kernel, state), state
Exemplo n.º 2
0
 def update(params, updates, init_key=None):
     del params
     if init_key is None:
         raise ValueError('`init_key` cannot be `None`.')
     count = state.variable(0., key=init_key, name='count')
     updates = tree_map(lambda g: step_size_fn(count) * g, updates)
     updates, count = primitive.tie_all(
         updates, state.assign(count + 1., name='count', key=updates))
     return updates
Exemplo n.º 3
0
    def run(params, init_key=None):
        opt = state.init(gradient_descent(update, objective),
                         name='opt')(init_key, params)

        def body(carry, _):
            opt, params = carry
            params, opt = opt.call_and_update(params)
            return (opt, params), ()

        opt, params = lax.scan(body, (opt, params), np.arange(num_iters))[0]
        opt, params = primitive.tie_all(state.assign(opt, name='opt'), params)
        return params
Exemplo n.º 4
0
 def update(params, updates, init_key=None):
   del params
   count_key, seed_key = random.split(init_key)
   count = state.variable(0., key=count_key, name='count')
   rng_key = state.variable(random.PRNGKey(seed), key=seed_key, name='rng_key')
   num_vars = len(tree_leaves(updates))
   treedef = tree_structure(updates)
   variance = eta / (1 + count)**gamma
   all_keys = random.split(rng_key, num_vars + 1)
   noise = tree_map(lambda g, k: random.normal(k, shape=g.shape), updates,
                    tree_unflatten(treedef, all_keys[1:]))
   updates = tree_map(lambda g, n: g + variance * n, updates, noise)
   updates, count, rng_key = primitive.tie_all(
       updates, state.assign(count + 1., name='count', key=updates),
       state.assign(all_keys[0], name='rng_key', key=updates))
   return updates
Exemplo n.º 5
0
    def update(params, updates, init_key=None):
        del params
        count_key, grad_acc_key = random.split(init_key)
        count = state.variable(0., key=count_key, name='count')
        grad_acc = state.variable(tree_map(lambda g: np.zeros(g.shape),
                                           updates),
                                  key=grad_acc_key,
                                  name='grad_acc')

        c = count % k
        acc = c != 0
        grad_acc = state.assign(tree_map(lambda g, ga: acc * ga + g, updates,
                                         grad_acc),
                                name='grad_acc')
        emit = c == (k - 1)
        updates = tree_map(lambda ga: emit * ga, grad_acc)
        updates, count = primitive.tie_all(
            updates, state.assign(count + 1., name='count', key=updates))
        return updates