def update(params, updates, init_key=None): del params if init_key is None: raise ValueError('`init_key` cannot be `None`.') mu_key, nu_key = random.split(init_key) mu = state.variable( tree_map(lambda g: np.zeros(g.shape), updates), key=mu_key, name='mu') nu = state.variable( tree_map(lambda g: np.zeros(g.shape), updates), key=nu_key, name='nu') mu = state.assign(_update_moment(updates, mu, decay, 1), name='mu') nu = state.assign(_update_moment(updates, nu, decay, 2), name='nu') updates = tree_map(lambda g, m, n: g / np.sqrt(n - m**2 + eps), updates, mu, nu) return updates
def update(params, updates, init_key=None): del params count_key, seed_key = random.split(init_key) count = state.variable(0., key=count_key, name='count') rng_key = state.variable(random.PRNGKey(seed), key=seed_key, name='rng_key') num_vars = len(tree_leaves(updates)) treedef = tree_structure(updates) variance = eta / (1 + count)**gamma all_keys = random.split(rng_key, num_vars + 1) noise = tree_map(lambda g, k: random.normal(k, shape=g.shape), updates, tree_unflatten(treedef, all_keys[1:])) updates = tree_map(lambda g, n: g + variance * n, updates, noise) updates, count, rng_key = primitive.tie_all( updates, state.assign(count + 1., name='count', key=updates), state.assign(all_keys[0], name='rng_key', key=updates)) return updates
def update(params, updates, init_key=None): del params if init_key is None: raise ValueError('`init_key` cannot be `None`.') count_key, mu_key, nu_key = random.split(init_key, 3) count = state.variable(0., key=count_key, name='count') mu = state.variable( tree_map(lambda g: np.zeros(g.shape), updates), key=mu_key, name='mu') nu = state.variable( tree_map(lambda g: np.zeros(g.shape), updates), key=nu_key, name='nu') mu = state.assign(_update_moment(updates, mu, b1, 1), name='mu') nu = state.assign(_update_moment(updates, nu, b2, 2), name='nu') count = state.assign(count + 1., name='count', key=updates) mu_hat = tree_map(lambda t: t / (1 - b1**count), mu) nu_hat = tree_map(lambda t: t / (1 - b2**count), nu) updates = tree_map(lambda m, v: m / (np.sqrt(v) + eps), mu_hat, nu_hat) return updates
def update(params, updates, init_key=None): del params if init_key is None: raise ValueError('`init_key` cannot be `None`.') count = state.variable(0., key=init_key, name='count') updates = tree_map(lambda g: step_size_fn(count) * g, updates) updates, count = primitive.tie_all( updates, state.assign(count + 1., name='count', key=updates)) return updates
def init(self, init_key, *args, name=None, **kwargs): """Initializes a Template into a Layer.""" specs = jax.tree_map(state.make_array_spec, args) kwargs = dict( cls=self.cls, specs=specs, init_args=self.init_args, init_kwargs=self.init_kwargs, ) layer = primitive.call_bind(template_init_p)( _template_build)(init_key, name=name, **kwargs) if name is not None: layer = state.variable(layer, name=name) else: layer_params = {k: state.variable(v, name=k) for k, v in layer.variables().items()} layer = layer.replace(**layer_params) return layer
def update(params, updates, init_key=None): del params if init_key is None: raise ValueError('`init_key` cannot be `None`.') nu = state.variable( tree_map(lambda g: np.zeros(g.shape), updates), key=init_key, name='nu') nu = state.assign(_update_moment(updates, nu, decay, 2), name='nu') updates = tree_map(lambda g, n: g / (np.sqrt(n + eps)), updates, nu) return updates
def update(params, updates, init_key=None): del params count_key, grad_acc_key = random.split(init_key) count = state.variable(0., key=count_key, name='count') grad_acc = state.variable(tree_map(lambda g: np.zeros(g.shape), updates), key=grad_acc_key, name='grad_acc') c = count % k acc = c != 0 grad_acc = state.assign(tree_map(lambda g, ga: acc * ga + g, updates, grad_acc), name='grad_acc') emit = c == (k - 1) updates = tree_map(lambda ga: emit * ga, grad_acc) updates, count = primitive.tie_all( updates, state.assign(count + 1., name='count', key=updates)) return updates
def update(params, updates, init_key=None): del params if init_key is None: raise ValueError('`init_key` cannot be `None`.') tr = state.variable(tree_map(lambda g: np.zeros(g.shape), updates), key=init_key, name='trace') f = lambda g, t: g + decay * t update_trace = state.assign(tree_map(f, updates, tr), name='trace') updates = tree_map(f, updates, update_trace) if nesterov else update_trace return updates
def random_normal(rng, name=None): return state.variable(random_normal_p.bind(rng, name=name), name=name)