def template(x, init_key=None): layer = state.init(ScalarMul(2 * jnp.ones(1)), name='scalar_mul')( init_key, x) x, layer = layer.call_and_update(x) x, layer = layer.call_and_update(x) state.assign(layer, name='scalar_mul') return x[0]
def update(params, updates, init_key=None): del params if init_key is None: raise ValueError('`init_key` cannot be `None`.') mu_key, nu_key = random.split(init_key) mu = state.variable( tree_map(lambda g: np.zeros(g.shape), updates), key=mu_key, name='mu') nu = state.variable( tree_map(lambda g: np.zeros(g.shape), updates), key=nu_key, name='nu') mu = state.assign(_update_moment(updates, mu, decay, 1), name='mu') nu = state.assign(_update_moment(updates, nu, decay, 2), name='nu') updates = tree_map(lambda g, m, n: g / np.sqrt(n - m**2 + eps), updates, mu, nu) return updates
def update(params, updates, init_key=None): del params count_key, seed_key = random.split(init_key) count = state.variable(0., key=count_key, name='count') rng_key = state.variable(random.PRNGKey(seed), key=seed_key, name='rng_key') num_vars = len(tree_leaves(updates)) treedef = tree_structure(updates) variance = eta / (1 + count)**gamma all_keys = random.split(rng_key, num_vars + 1) noise = tree_map(lambda g, k: random.normal(k, shape=g.shape), updates, tree_unflatten(treedef, all_keys[1:])) updates = tree_map(lambda g, n: g + variance * n, updates, noise) updates, count, rng_key = primitive.tie_all( updates, state.assign(count + 1., name='count', key=updates), state.assign(all_keys[0], name='rng_key', key=updates)) return updates
def update(params, updates, init_key=None): del params if init_key is None: raise ValueError('`init_key` cannot be `None`.') count_key, mu_key, nu_key = random.split(init_key, 3) count = state.variable(0., key=count_key, name='count') mu = state.variable( tree_map(lambda g: np.zeros(g.shape), updates), key=mu_key, name='mu') nu = state.variable( tree_map(lambda g: np.zeros(g.shape), updates), key=nu_key, name='nu') mu = state.assign(_update_moment(updates, mu, b1, 1), name='mu') nu = state.assign(_update_moment(updates, nu, b2, 2), name='nu') count = state.assign(count + 1., name='count', key=updates) mu_hat = tree_map(lambda t: t / (1 - b1**count), mu) nu_hat = tree_map(lambda t: t / (1 - b2**count), nu) updates = tree_map(lambda m, v: m / (np.sqrt(v) + eps), mu_hat, nu_hat) return updates
def update(params, updates, init_key=None): del params if init_key is None: raise ValueError('`init_key` cannot be `None`.') count = state.variable(0., key=init_key, name='count') updates = tree_map(lambda g: step_size_fn(count) * g, updates) updates, count = primitive.tie_all( updates, state.assign(count + 1., name='count', key=updates)) return updates
def update(params, updates, init_key=None): del params if init_key is None: raise ValueError('`init_key` cannot be `None`.') nu = state.variable( tree_map(lambda g: np.zeros(g.shape), updates), key=init_key, name='nu') nu = state.assign(_update_moment(updates, nu, decay, 2), name='nu') updates = tree_map(lambda g, n: g / (np.sqrt(n + eps)), updates, nu) return updates
def update(params, updates, init_key=None): del params count_key, grad_acc_key = random.split(init_key) count = state.variable(0., key=count_key, name='count') grad_acc = state.variable(tree_map(lambda g: np.zeros(g.shape), updates), key=grad_acc_key, name='grad_acc') c = count % k acc = c != 0 grad_acc = state.assign(tree_map(lambda g, ga: acc * ga + g, updates, grad_acc), name='grad_acc') emit = c == (k - 1) updates = tree_map(lambda ga: emit * ga, grad_acc) updates, count = primitive.tie_all( updates, state.assign(count + 1., name='count', key=updates)) return updates
def update(params, updates, init_key=None): del params if init_key is None: raise ValueError('`init_key` cannot be `None`.') tr = state.variable(tree_map(lambda g: np.zeros(g.shape), updates), key=init_key, name='trace') f = lambda g, t: g + decay * t update_trace = state.assign(tree_map(f, updates, tr), name='trace') updates = tree_map(f, updates, update_trace) if nesterov else update_trace return updates
def run(params, init_key=None): opt = state.init(gradient_descent(update, objective), name='opt')(init_key, params) def body(carry, _): opt, params = carry params, opt = opt.call_and_update(params) return (opt, params), () opt, params = lax.scan(body, (opt, params), np.arange(num_iters))[0] opt, params = primitive.tie_all(state.assign(opt, name='opt'), params) return params
def step(key, state, init_key=None): kernel = st.init(kernel_fn, name='kernel')(init_key, key, state) def body(carry, key): kernel, state = carry state, kernel = kernel.call_and_update(key, state) for cb in callbacks: kernel, state, _ = primitive.tie_all(kernel, state, cb(kernel, state)) return (kernel, state), state (kernel, _), states = lax.scan(body, (kernel, state), random.split(key, num_steps)) return primitive.tie_in(st.assign(kernel, name='kernel'), states)