def body(carry, key): kernel, state = carry state, kernel = kernel.call_and_update(key, state) for cb in callbacks: kernel, state, _ = primitive.tie_all(kernel, state, cb(kernel, state)) return (kernel, state), state
def update(params, updates, init_key=None): del params if init_key is None: raise ValueError('`init_key` cannot be `None`.') count = state.variable(0., key=init_key, name='count') updates = tree_map(lambda g: step_size_fn(count) * g, updates) updates, count = primitive.tie_all( updates, state.assign(count + 1., name='count', key=updates)) return updates
def run(params, init_key=None): opt = state.init(gradient_descent(update, objective), name='opt')(init_key, params) def body(carry, _): opt, params = carry params, opt = opt.call_and_update(params) return (opt, params), () opt, params = lax.scan(body, (opt, params), np.arange(num_iters))[0] opt, params = primitive.tie_all(state.assign(opt, name='opt'), params) return params
def update(params, updates, init_key=None): del params count_key, seed_key = random.split(init_key) count = state.variable(0., key=count_key, name='count') rng_key = state.variable(random.PRNGKey(seed), key=seed_key, name='rng_key') num_vars = len(tree_leaves(updates)) treedef = tree_structure(updates) variance = eta / (1 + count)**gamma all_keys = random.split(rng_key, num_vars + 1) noise = tree_map(lambda g, k: random.normal(k, shape=g.shape), updates, tree_unflatten(treedef, all_keys[1:])) updates = tree_map(lambda g, n: g + variance * n, updates, noise) updates, count, rng_key = primitive.tie_all( updates, state.assign(count + 1., name='count', key=updates), state.assign(all_keys[0], name='rng_key', key=updates)) return updates
def update(params, updates, init_key=None): del params count_key, grad_acc_key = random.split(init_key) count = state.variable(0., key=count_key, name='count') grad_acc = state.variable(tree_map(lambda g: np.zeros(g.shape), updates), key=grad_acc_key, name='grad_acc') c = count % k acc = c != 0 grad_acc = state.assign(tree_map(lambda g, ga: acc * ga + g, updates, grad_acc), name='grad_acc') emit = c == (k - 1) updates = tree_map(lambda ga: emit * ga, grad_acc) updates, count = primitive.tie_all( updates, state.assign(count + 1., name='count', key=updates)) return updates