Example #1
0
 def update(params, updates, init_key=None):
   del params
   if init_key is None:
     raise ValueError('`init_key` cannot be `None`.')
   mu_key, nu_key = random.split(init_key)
   mu = state.variable(
       tree_map(lambda g: np.zeros(g.shape), updates), key=mu_key, name='mu')
   nu = state.variable(
       tree_map(lambda g: np.zeros(g.shape), updates), key=nu_key, name='nu')
   mu = state.assign(_update_moment(updates, mu, decay, 1), name='mu')
   nu = state.assign(_update_moment(updates, nu, decay, 2), name='nu')
   updates = tree_map(lambda g, m, n: g / np.sqrt(n - m**2 + eps), updates, mu,
                      nu)
   return updates
Example #2
0
 def update(params, updates, init_key=None):
   del params
   count_key, seed_key = random.split(init_key)
   count = state.variable(0., key=count_key, name='count')
   rng_key = state.variable(random.PRNGKey(seed), key=seed_key, name='rng_key')
   num_vars = len(tree_leaves(updates))
   treedef = tree_structure(updates)
   variance = eta / (1 + count)**gamma
   all_keys = random.split(rng_key, num_vars + 1)
   noise = tree_map(lambda g, k: random.normal(k, shape=g.shape), updates,
                    tree_unflatten(treedef, all_keys[1:]))
   updates = tree_map(lambda g, n: g + variance * n, updates, noise)
   updates, count, rng_key = primitive.tie_all(
       updates, state.assign(count + 1., name='count', key=updates),
       state.assign(all_keys[0], name='rng_key', key=updates))
   return updates
Example #3
0
 def update(params, updates, init_key=None):
   del params
   if init_key is None:
     raise ValueError('`init_key` cannot be `None`.')
   count_key, mu_key, nu_key = random.split(init_key, 3)
   count = state.variable(0., key=count_key, name='count')
   mu = state.variable(
       tree_map(lambda g: np.zeros(g.shape), updates), key=mu_key, name='mu')
   nu = state.variable(
       tree_map(lambda g: np.zeros(g.shape), updates), key=nu_key, name='nu')
   mu = state.assign(_update_moment(updates, mu, b1, 1), name='mu')
   nu = state.assign(_update_moment(updates, nu, b2, 2), name='nu')
   count = state.assign(count + 1., name='count', key=updates)
   mu_hat = tree_map(lambda t: t / (1 - b1**count), mu)
   nu_hat = tree_map(lambda t: t / (1 - b2**count), nu)
   updates = tree_map(lambda m, v: m / (np.sqrt(v) + eps), mu_hat, nu_hat)
   return updates
Example #4
0
 def update(params, updates, init_key=None):
     del params
     if init_key is None:
         raise ValueError('`init_key` cannot be `None`.')
     count = state.variable(0., key=init_key, name='count')
     updates = tree_map(lambda g: step_size_fn(count) * g, updates)
     updates, count = primitive.tie_all(
         updates, state.assign(count + 1., name='count', key=updates))
     return updates
Example #5
0
 def init(self, init_key, *args, name=None, **kwargs):
   """Initializes a Template into a Layer."""
   specs = jax.tree_map(state.make_array_spec, args)
   kwargs = dict(
       cls=self.cls,
       specs=specs,
       init_args=self.init_args,
       init_kwargs=self.init_kwargs,
   )
   layer = primitive.call_bind(template_init_p)(
       _template_build)(init_key, name=name, **kwargs)
   if name is not None:
     layer = state.variable(layer, name=name)
   else:
     layer_params = {k: state.variable(v, name=k)
                     for k, v in layer.variables().items()}
     layer = layer.replace(**layer_params)
   return layer
Example #6
0
 def update(params, updates, init_key=None):
   del params
   if init_key is None:
     raise ValueError('`init_key` cannot be `None`.')
   nu = state.variable(
       tree_map(lambda g: np.zeros(g.shape), updates), key=init_key, name='nu')
   nu = state.assign(_update_moment(updates, nu, decay, 2), name='nu')
   updates = tree_map(lambda g, n: g / (np.sqrt(n + eps)), updates, nu)
   return updates
Example #7
0
    def update(params, updates, init_key=None):
        del params
        count_key, grad_acc_key = random.split(init_key)
        count = state.variable(0., key=count_key, name='count')
        grad_acc = state.variable(tree_map(lambda g: np.zeros(g.shape),
                                           updates),
                                  key=grad_acc_key,
                                  name='grad_acc')

        c = count % k
        acc = c != 0
        grad_acc = state.assign(tree_map(lambda g, ga: acc * ga + g, updates,
                                         grad_acc),
                                name='grad_acc')
        emit = c == (k - 1)
        updates = tree_map(lambda ga: emit * ga, grad_acc)
        updates, count = primitive.tie_all(
            updates, state.assign(count + 1., name='count', key=updates))
        return updates
Example #8
0
 def update(params, updates, init_key=None):
     del params
     if init_key is None:
         raise ValueError('`init_key` cannot be `None`.')
     tr = state.variable(tree_map(lambda g: np.zeros(g.shape), updates),
                         key=init_key,
                         name='trace')
     f = lambda g, t: g + decay * t
     update_trace = state.assign(tree_map(f, updates, tr), name='trace')
     updates = tree_map(f, updates,
                        update_trace) if nesterov else update_trace
     return updates
def random_normal(rng, name=None):
    return state.variable(random_normal_p.bind(rng, name=name), name=name)