def test_update_discrete_type2(self):
        env = self.env_discrete
        func_q = self.func_q_type2
        transition_batch = self.transition_discrete

        q1 = Q(func_q, env)
        q2 = Q(func_q, env)
        q_targ1 = q1.copy()
        q_targ2 = q2.copy()
        updater1 = ClippedDoubleQLearning(q1,
                                          q_targ_list=[q_targ1, q_targ2],
                                          optimizer=sgd(1.0))
        updater2 = ClippedDoubleQLearning(q2,
                                          q_targ_list=[q_targ1, q_targ2],
                                          optimizer=sgd(1.0))

        params1 = deepcopy(q1.params)
        params2 = deepcopy(q2.params)
        function_state1 = deepcopy(q1.function_state)
        function_state2 = deepcopy(q2.function_state)

        updater1.update(transition_batch)
        updater2.update(transition_batch)

        self.assertPytreeNotEqual(params1, q1.params)
        self.assertPytreeNotEqual(params2, q2.params)
        self.assertPytreeNotEqual(function_state1, q1.function_state)
        self.assertPytreeNotEqual(function_state2, q2.function_state)
Ejemplo n.º 2
0
    def test_value_transform(self):
        env = self.env_discrete
        func_v = self.func_v
        transition_batch = self.transition_discrete

        v = V(func_v, env, random_seed=11)

        params_init = deepcopy(v.params)
        function_state_init = deepcopy(v.function_state)

        # first update without value transform
        updater = SimpleTD(v, optimizer=sgd(1.0))
        updater.update(transition_batch)
        params_without_reg = deepcopy(v.params)
        function_state_without_reg = deepcopy(v.function_state)
        self.assertPytreeNotEqual(params_without_reg, params_init)
        self.assertPytreeNotEqual(function_state_without_reg,
                                  function_state_init)

        # reset weights
        v = V(func_v, env, value_transform=LogTransform(), random_seed=11)
        self.assertPytreeAlmostEqual(params_init, v.params)
        self.assertPytreeAlmostEqual(function_state_init, v.function_state)

        # then update with value transform
        updater = SimpleTD(v, optimizer=sgd(1.0))
        updater.update(transition_batch)
        params_with_reg = deepcopy(v.params)
        function_state_with_reg = deepcopy(v.function_state)
        self.assertPytreeNotEqual(params_with_reg, params_init)
        self.assertPytreeNotEqual(function_state_with_reg, function_state_init)
        self.assertPytreeNotEqual(params_with_reg, params_without_reg)
        self.assertPytreeAlmostEqual(function_state_with_reg,
                                     function_state_without_reg)  # same!
Ejemplo n.º 3
0
    def test_policyreg(self):
        env = self.env_discrete
        func_p = self.func_p_type1
        transition_batch = self.transition_discrete

        p = StochasticTransitionModel(func_p, env, random_seed=11)

        params_init = deepcopy(p.params)
        function_state_init = deepcopy(p.function_state)

        # first update without policy regularizer
        updater = ModelUpdater(p, optimizer=sgd(1.0))
        updater.update(transition_batch)
        params_without_reg = deepcopy(p.params)
        function_state_without_reg = deepcopy(p.function_state)
        self.assertPytreeNotEqual(params_without_reg, params_init)
        self.assertPytreeNotEqual(function_state_without_reg, function_state_init)

        # reset weights
        p = StochasticTransitionModel(func_p, env, random_seed=11)
        self.assertPytreeAlmostEqual(params_init, p.params)
        self.assertPytreeAlmostEqual(function_state_init, p.function_state)

        # then update with policy regularizer
        reg = EntropyRegularizer(p, beta=1.0)
        updater = ModelUpdater(p, optimizer=sgd(1.0), regularizer=reg)
        updater.update(transition_batch)
        params_with_reg = deepcopy(p.params)
        function_state_with_reg = deepcopy(p.function_state)
        self.assertPytreeNotEqual(params_with_reg, params_init)
        self.assertPytreeNotEqual(params_with_reg, params_without_reg)  # <---- important
        self.assertPytreeNotEqual(function_state_with_reg, function_state_init)
        self.assertPytreeAlmostEqual(function_state_with_reg, function_state_without_reg)  # same!
    def test_update_boxspace(self):
        env = self.env_boxspace
        func_q = self.func_q_type1
        func_pi = self.func_pi_boxspace
        transition_batch = self.transition_boxspace

        q1 = Q(func_q, env)
        q2 = Q(func_q, env)
        pi1 = Policy(func_pi, env)
        pi2 = Policy(func_pi, env)
        q_targ1 = q1.copy()
        q_targ2 = q2.copy()
        updater1 = ClippedDoubleQLearning(q1,
                                          pi_targ_list=[pi1, pi2],
                                          q_targ_list=[q_targ1, q_targ2],
                                          optimizer=sgd(1.0))
        updater2 = ClippedDoubleQLearning(q2,
                                          pi_targ_list=[pi1, pi2],
                                          q_targ_list=[q_targ1, q_targ2],
                                          optimizer=sgd(1.0))

        params1 = deepcopy(q1.params)
        params2 = deepcopy(q2.params)
        function_state1 = deepcopy(q1.function_state)
        function_state2 = deepcopy(q2.function_state)

        updater1.update(transition_batch)
        updater2.update(transition_batch)

        self.assertPytreeNotEqual(params1, q1.params)
        self.assertPytreeNotEqual(params2, q2.params)
        self.assertPytreeNotEqual(function_state1, q1.function_state)
        self.assertPytreeNotEqual(function_state2, q2.function_state)
Ejemplo n.º 5
0
def create_optax_optim(name,
                       learning_rate=None,
                       momentum=0.9,
                       weight_decay=0,
                       **kwargs):
    """ Optimizer Factory

    Args:
        learning_rate (float): specify learning rate or leave up to scheduler / optim if None
        weight_decay (float): weight decay to apply to all params, not applied if 0
        **kwargs: optional / optimizer specific params that override defaults

    With regards to the kwargs, I've tried to keep the param naming incoming via kwargs from
    config file more consistent so there is less variation. Names of common args such as eps,
    beta1, beta2 etc will be remapped where possible (even if optimizer impl uses a diff name)
    and removed when not needed. A list of some common params to use in config files as named:
        eps (float): default stability / regularization epsilon value
        beta1 (float): moving average / momentum coefficient for gradient
        beta2 (float): moving average / momentum coefficient for gradient magnitude (squared grad)
    """
    name = name.lower()
    opt_args = dict(learning_rate=learning_rate, **kwargs)
    _rename(opt_args, ('beta1', 'beta2'), ('b1', 'b2'))
    if name == 'sgd' or name == 'momentum' or name == 'nesterov':
        _erase(opt_args, ('eps', ))
        if name == 'momentum':
            optimizer = optax.sgd(momentum=momentum, **opt_args)
        elif name == 'nesterov':
            optimizer = optax.sgd(momentum=momentum, nesterov=True)
        else:
            assert name == 'sgd'
            optimizer = optax.sgd(momentum=0, **opt_args)
    elif name == 'adabelief':
        optimizer = optax.adabelief(**opt_args)
    elif name == 'adam' or name == 'adamw':
        if name == 'adamw':
            optimizer = optax.adamw(weight_decay=weight_decay, **opt_args)
        else:
            optimizer = optax.adam(**opt_args)
    elif name == 'lamb':
        optimizer = optax.lamb(weight_decay=weight_decay, **opt_args)
    elif name == 'lars':
        optimizer = lars(weight_decay=weight_decay, **opt_args)
    elif name == 'rmsprop':
        optimizer = optax.rmsprop(momentum=momentum, **opt_args)
    elif name == 'rmsproptf':
        optimizer = optax.rmsprop(momentum=momentum,
                                  initial_scale=1.0,
                                  **opt_args)
    else:
        assert False, f"Invalid optimizer name specified ({name})"

    return optimizer
Ejemplo n.º 6
0
def get_optimizer(optimizer_name: OptimizerName,
                  learning_rate: float,
                  momentum: float = 0.0,
                  adam_beta1: float = 0.9,
                  adam_beta2: float = 0.999,
                  adam_epsilon: float = 1e-8,
                  rmsprop_decay: float = 0.9,
                  rmsprop_epsilon: float = 1e-8,
                  adagrad_init_accumulator: float = 0.1,
                  adagrad_epsilon: float = 1e-6) -> Optimizer:
  """Given parameters, returns the corresponding optimizer.

  Args:
    optimizer_name: One of SGD, MOMENTUM, ADAM, RMSPROP.
    learning_rate: Learning rate for all optimizers.
    momentum: Momentum parameter for MOMENTUM.
    adam_beta1: beta1 parameter for ADAM.
    adam_beta2: beta2 parameter for ADAM.
    adam_epsilon: epsilon parameter for ADAM.
    rmsprop_decay: decay parameter for RMSPROP.
    rmsprop_epsilon: epsilon parameter for RMSPROP.
    adagrad_init_accumulator: initial accumulator for ADAGRAD.
    adagrad_epsilon: epsilon parameter for ADAGRAD.

  Returns:
    Returns the Optimizer with the specified properties.

  Raises:
    ValueError: iff the optimizer names is not one of SGD, MOMENTUM, ADAM,
  RMSPROP, or Adagrad, raises errors.
  """
  if optimizer_name == OptimizerName.SGD:
    return Optimizer(*optax.sgd(learning_rate))
  elif optimizer_name == OptimizerName.MOMENTUM:
    return Optimizer(*optax.sgd(learning_rate, momentum))
  elif optimizer_name == OptimizerName.ADAM:
    return Optimizer(*optax.adam(
        learning_rate, b1=adam_beta1, b2=adam_beta2, eps=adam_epsilon))
  elif optimizer_name == OptimizerName.RMSPROP:
    return Optimizer(
        *optax.rmsprop(learning_rate, decay=rmsprop_decay, eps=rmsprop_epsilon))
  elif optimizer_name == OptimizerName.ADAGRAD:
    return Optimizer(*optax.adagrad(
        learning_rate,
        initial_accumulator_value=adagrad_init_accumulator,
        eps=adagrad_epsilon))
  else:
    raise ValueError(f'Unsupported optimizer_name {optimizer_name}.')
Ejemplo n.º 7
0
 def test_integration(self):
     env = catch.Catch()
     action_spec = env.action_spec()
     num_actions = action_spec.num_values
     obs_spec = env.observation_spec()
     agent = agent_lib.Agent(
         num_actions=num_actions,
         obs_spec=obs_spec,
         net_factory=haiku_nets.CatchNet,
     )
     unroll_length = 20
     learner = learner_lib.Learner(
         agent=agent,
         rng_key=jax.random.PRNGKey(42),
         opt=optax.sgd(1e-2),
         batch_size=1,
         discount_factor=0.99,
         frames_per_iter=unroll_length,
     )
     actor = actor_lib.Actor(
         agent=agent,
         env=env,
         learner=learner,
         unroll_length=unroll_length,
     )
     frame_count, params = actor.pull_params()
     actor.unroll_and_push(frame_count=frame_count, params=params)
     learner.run(max_iterations=1)
def run_svgd(key, lr, full_data=False, progress_bar=False):
    key, subkey = random.split(key)
    init_particles = ravel(*sample_from_prior(subkey, n_particles))
    svgd_opt = optax.sgd(lr)

    svgd_grad = models.KernelGradient(
        get_target_logp=lambda batch: get_minibatch_logp(*batch), scaled=True)
    particles = models.Particles(key,
                                 svgd_grad.gradient,
                                 init_particles,
                                 custom_optimizer=svgd_opt)

    test_batches = get_batches(x_test, y_test, 2 *
                               NUM_VALS) if full_data else get_batches(
                                   x_val, y_val, 2 * NUM_VALS)
    train_batches = get_batches(xx, yy, NUM_STEPS +
                                1) if full_data else get_batches(
                                    x_train, y_train, NUM_STEPS + 1)
    for i, batch in tqdm(enumerate(train_batches),
                         total=NUM_STEPS,
                         disable=not progress_bar):
        particles.step(batch)
        if i % (NUM_STEPS // NUM_VALS) == 0:
            test_logp = get_minibatch_logp(*next(test_batches))
            stepdata = {
                "accuracy":
                compute_test_accuracy(unravel(particles.particles)[0]),
                "test_logp": test_logp(particles.particles),
            }
            metrics.append_to_log(particles.rundata, stepdata)

    particles.done()
    return particles
Ejemplo n.º 9
0
def Sgd(learning_rate: float):
    r"""Stochastic Gradient Descent Optimizer.
    The `Stochastic Gradient Descent <https://en.wikipedia.org/wiki/Stochastic_gradient_descent>`_
    is one of the most popular optimizers in machine learning applications.
    Given a stochastic estimate of the gradient of the cost function (:math:`G(\mathbf{p})`),
    it performs the update:

    .. math:: p^\prime_k = p_k -\eta G_k(\mathbf{p}),

    where :math:`\eta` is the so-called learning rate.
    NetKet also implements two extensions to the simple SGD,
    the first one is :math:`L_2` regularization,
    and the second one is the possibility to set a decay
    factor :math:`\gamma \leq 1` for the learning rate, such that
    at iteration :math:`n` the learning rate is :math:`\eta \gamma^n`.

    Args:
       learning_rate: The learning rate :math:`\eta`.

    Examples:
       Simple SGD optimizer.

       >>> from netket.optimizer import Sgd
       >>> op = Sgd(learning_rate=0.05)
    """
    from optax import sgd

    return sgd(learning_rate)
Ejemplo n.º 10
0
def train(  # pylint: disable=invalid-name
    Phi,
    Psi,
    num_epochs,
    learning_rate,
    key,
    estimator,
    alpha,
    optimizer,
    use_l2_reg,
    reg_coeff,
    use_penalty,
    j,
    num_rows,
    skipsize=1):
  """Training function."""
  Phis = [Phi]  # pylint: disable=invalid-name
  grads = []
  if optimizer == 'sgd':
    optim = optax.sgd(learning_rate)
  elif optimizer == 'adam':
    optim = optax.adam(learning_rate)
  opt_state = optim.init(Phi)
  for i in tqdm(range(num_epochs)):
    key, subkey = jax.random.split(key)
    Phi, opt_state, grad = estimates.nabla_phi_analytical(
        Phi, Psi, subkey, optim, opt_state, estimator, alpha, use_l2_reg,
        reg_coeff, use_penalty, j, num_rows)
    Phis.append(Phi)
    grads.append(grad)
    if i % skipsize == 0:
      Phis.append(Phi)
      grads.append(grad)
  return jnp.stack(Phis), jnp.stack(grads)
Ejemplo n.º 11
0
def Momentum(learning_rate: float, beta: float = 0.9, nesterov: bool = False):
    r"""Momentum-based Optimizer.
        The momentum update incorporates an exponentially weighted moving average
        over previous gradients to speed up descent
        `Qian, N. (1999) <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.57.5612&rep=rep1&type=pdf>`_.
        The momentum vector :math:`\mathbf{m}` is initialized to zero.
        Given a stochastic estimate of the gradient of the cost function
        :math:`G(\mathbf{p})`, the updates for the parameter :math:`p_k` and
        corresponding component of the momentum :math:`m_k` are

        .. math:: m^\prime_k &= \beta m_k + (1-\beta)G_k(\mathbf{p})\\
        p^\prime_k &= \eta m^\prime_k


        Args:
           learning_rate: The learning rate :math:`\eta`
           beta: Momentum exponential decay rate, should be in [0,1].
           nesterov: Flag to use nesterov momentum correction

        Examples:
           Momentum optimizer.

           >>> from netket.optimizer import Momentum
           >>> op = Momentum(learning_rate=0.01)
    """
    from optax import sgd

    return sgd(learning_rate, momentum=beta, nesterov=nesterov)
Ejemplo n.º 12
0
def create_train_state(rng, config):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(config.learning_rate, config.momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)
Ejemplo n.º 13
0
    def test_nn(self):
        seed = 33
        key = jax.random.PRNGKey(seed)

        class SimpleNN(nn.Module):
            @nn.compact
            def __call__(self, x):
                x = nn.Dense(features=100)(x)
                x = jax.nn.relu(x)
                x = nn.Dense(features=100)(x)
                x = jax.nn.relu(x)
                x = nn.Dense(features=1)(x)
                return x

        x = jax.random.normal(key, [20, 2])
        y = jax.random.normal(key, [20, 1])
        lr = 0.001
        model = SimpleNN()
        params = model.init(key, x)

        m = optax.sgd(learning_rate=-1., momentum=0.9)
        d = optax.sgd(learning_rate=-2., momentum=0.9)
        g = kitchen_sink(chains=[m, d],
                         combinator='grafting',
                         learning_rate=lr)
        s_m = m.init(params)
        s_d = d.init(params)
        s_g = g.init(params)

        def loss_fn(params):
            yhat = model.apply(params, x)
            loss = jnp.sum((y - yhat)**2)
            return loss

        for _ in range(10):
            grad_fn = jax.value_and_grad(loss_fn)
            _, grad = grad_fn(params)
            u_m, s_m = m.update(grad, s_m)
            u_d, s_d = d.update(grad, s_d)
            u_g, s_g = g.update(grad, s_g)

            u_m_n = jax.tree_map(jnp.linalg.norm, u_m)
            u_d_n = jax.tree_map(jnp.linalg.norm, u_d)
            u_g2 = jax.tree_multimap(
                lambda m, d, dn: -lr * d / (dn + 1e-6) * m, u_m_n, u_d, u_d_n)

            chex.assert_trees_all_close(u_g, u_g2)
Ejemplo n.º 14
0
    def test_policyreg_boxspace(self):
        env = self.env_boxspace
        func_q = self.func_q_type1
        func_pi = self.func_pi_boxspace
        transition_batch = self.transition_boxspace

        q = Q(func_q, env, random_seed=11)
        pi = Policy(func_pi, env, random_seed=17)
        q_targ = q.copy()

        params_init = deepcopy(q.params)
        function_state_init = deepcopy(q.function_state)

        # first update without policy regularizer
        policy_reg = EntropyRegularizer(pi, beta=1.0)
        updater = Sarsa(q, q_targ, optimizer=sgd(1.0))
        updater.update(transition_batch)
        params_without_reg = deepcopy(q.params)
        function_state_without_reg = deepcopy(q.function_state)
        self.assertPytreeNotEqual(params_without_reg, params_init)
        self.assertPytreeNotEqual(function_state_without_reg,
                                  function_state_init)

        # reset weights
        q = Q(func_q, env, random_seed=11)
        pi = Policy(func_pi, env, random_seed=17)
        q_targ = q.copy()
        self.assertPytreeAlmostEqual(params_init, q.params, decimal=10)
        self.assertPytreeAlmostEqual(function_state_init,
                                     q.function_state,
                                     decimal=10)

        # then update with policy regularizer
        policy_reg = EntropyRegularizer(pi, beta=1.0)
        updater = Sarsa(q,
                        q_targ,
                        optimizer=sgd(1.0),
                        policy_regularizer=policy_reg)
        updater.update(transition_batch)
        params_with_reg = deepcopy(q.params)
        function_state_with_reg = deepcopy(q.function_state)
        self.assertPytreeNotEqual(params_with_reg, params_init)
        self.assertPytreeNotEqual(params_with_reg,
                                  params_without_reg)  # <--- important
        self.assertPytreeNotEqual(function_state_with_reg, function_state_init)
        self.assertPytreeAlmostEqual(function_state_with_reg,
                                     function_state_without_reg)  # same!
Ejemplo n.º 15
0
    def test_policyreg_discrete(self):
        env = self.env_discrete
        func_q = self.func_q_type1
        func_pi = self.func_pi_discrete
        transition_batch = self.transition_discrete

        q = Q(func_q, env, random_seed=11)
        pi = Policy(func_pi, env, random_seed=17)
        q_targ = q.copy()

        params_init = deepcopy(q.params)
        function_state_init = deepcopy(q.function_state)

        # first update without policy regularizer
        policy_reg = EntropyRegularizer(pi, beta=1.0)
        updater = Sarsa(q, q_targ, optimizer=sgd(1.0))
        updater.update(transition_batch)
        params_without_reg = deepcopy(q.params)
        function_state_without_reg = deepcopy(q.function_state)
        self.assertPytreeNotEqual(params_without_reg, params_init)
        self.assertPytreeNotEqual(function_state_without_reg,
                                  function_state_init)

        # reset weights
        q = Q(func_q, env, random_seed=11)
        pi = Policy(func_pi, env, random_seed=17)
        q_targ = q.copy()
        self.assertPytreeAlmostEqual(params_init, q.params)
        self.assertPytreeAlmostEqual(function_state_init, q.function_state)

        # then update with policy regularizer
        policy_reg = EntropyRegularizer(pi, beta=1.0)
        updater = Sarsa(q,
                        q_targ,
                        optimizer=sgd(1.0),
                        policy_regularizer=policy_reg)
        print('updater.target_params:', updater.target_params)
        print('updater.target_function_state:', updater.target_function_state)
        updater.update(transition_batch)
        params_with_reg = deepcopy(q.params)
        function_state_with_reg = deepcopy(q.function_state)
        self.assertPytreeNotEqual(params_with_reg, params_init)
        self.assertPytreeNotEqual(function_state_with_reg, function_state_init)
        self.assertPytreeNotEqual(params_with_reg, params_without_reg)
        self.assertPytreeAlmostEqual(function_state_with_reg,
                                     function_state_without_reg)  # same!
Ejemplo n.º 16
0
 def test_1d(self):
     """Test grafting."""
     x = 10.
     lr = 0.01
     m = optax.sgd(learning_rate=-1.)
     d = optax.sgd(learning_rate=-2.)
     g = kitchen_sink(chains=[m, d],
                      combinator='grafting',
                      learning_rate=lr)
     state = g.init(x)
     for _ in range(10):
         grad_fn = jax.value_and_grad(lambda x: x**2)
         _, grad = grad_fn(x)
         updates, state = g.update(grad, state)
         dx = 2 * x
         x -= lr * dx
         self.assertAlmostEqual(updates, -lr * dx, places=4)
Ejemplo n.º 17
0
def optimizer(hyperparameters: spec.Hyperparamters, num_train_examples: int):
  steps_per_epoch = num_train_examples // get_batch_size('imagenet')
  learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch)
  opt_init_fn, opt_update_fn = optax.sgd(
      nesterov=True,
      momentum=hyperparameters.momentum,
      learning_rate=learning_rate_fn)
  return opt_init_fn, opt_update_fn
Ejemplo n.º 18
0
def create_train_state(rng, config: ml_collections.ConfigDict, model):
    """Create initial training state."""
    params = get_initial_params(rng, model)
    tx = optax.chain(
        optax.sgd(learning_rate=config.learning_rate,
                  momentum=config.momentum),
        optax.additive_weight_decay(weight_decay=config.weight_decay))
    state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
    return state
Ejemplo n.º 19
0
    def test_policyreg_boxspace(self):
        env = self.env_boxspace
        func_v = self.func_v
        func_pi = self.func_pi_boxspace
        transition_batch = self.transition_boxspace

        v = V(func_v, env, random_seed=11)
        pi = Policy(func_pi, env, random_seed=17)
        v_targ = v.copy()

        params_init = deepcopy(v.params)
        function_state_init = deepcopy(v.function_state)

        # first update without policy regularizer
        policy_reg = EntropyRegularizer(pi, beta=1.0)
        updater = SimpleTD(v, v_targ, optimizer=sgd(1.0))
        updater.update(transition_batch)
        params_without_reg = deepcopy(v.params)
        function_state_without_reg = deepcopy(v.function_state)
        self.assertPytreeNotEqual(params_without_reg, params_init)
        self.assertPytreeNotEqual(function_state_without_reg,
                                  function_state_init)

        # reset weights
        v = V(func_v, env, random_seed=11)
        pi = Policy(func_pi, env, random_seed=17)
        v_targ = v.copy()
        self.assertPytreeAlmostEqual(params_init, v.params)
        self.assertPytreeAlmostEqual(function_state_init, v.function_state)

        # then update with policy regularizer
        policy_reg = EntropyRegularizer(pi, beta=1.0)
        updater = SimpleTD(v,
                           v_targ,
                           optimizer=sgd(1.0),
                           policy_regularizer=policy_reg)
        updater.update(transition_batch)
        params_with_reg = deepcopy(v.params)
        function_state_with_reg = deepcopy(v.function_state)
        self.assertPytreeNotEqual(params_with_reg, params_init)
        self.assertPytreeNotEqual(function_state_with_reg, function_state_init)
        self.assertPytreeNotEqual(params_with_reg, params_without_reg)
        self.assertPytreeAlmostEqual(function_state_with_reg,
                                     function_state_without_reg)  # same!
Ejemplo n.º 20
0
    def test_update_discrete_nogrid(self):
        env = self.env_discrete
        func_q = self.func_q_type1

        q = Q(func_q, env)
        q_targ = q.copy()

        msg = r"len\(q_targ_list\) must be at least 2"
        with self.assertRaisesRegex(ValueError, msg):
            ClippedDoubleQLearning(q, q_targ_list=[q_targ], optimizer=sgd(1.0))
Ejemplo n.º 21
0
    def test_update_boxspace(self):
        env = self.env_boxspace
        func_q = self.func_q_type1

        q = Q(func_q, env, random_seed=11)
        q_targ = q.copy()

        msg = r"SoftQLearning class is only implemented for discrete actions spaces"
        with self.assertRaisesRegex(NotImplementedError, msg):
            SoftQLearning(q, q_targ, optimizer=sgd(1.0))
Ejemplo n.º 22
0
def main(_):
    train_dataset = datasets.load_image_dataset('mnist', FLAGS.batch_size)
    train_dataset = list(train_dataset.as_numpy_iterator())
    test_dataset = datasets.load_image_dataset('mnist', NUM_EXAMPLES,
                                               datasets.Split.TEST)
    full_test_batch = next(test_dataset.as_numpy_iterator())

    if FLAGS.dpsgd:
        tx = optax.dpsgd(learning_rate=FLAGS.learning_rate,
                         l2_norm_clip=FLAGS.l2_norm_clip,
                         noise_multiplier=FLAGS.noise_multiplier,
                         seed=FLAGS.seed)
    else:
        tx = optax.sgd(learning_rate=FLAGS.learning_rate)

    @jax.jit
    def train_step(params, opt_state, batch):
        grad_fn = jax.grad(loss_fn, has_aux=True)
        if FLAGS.dpsgd:
            # Insert dummy dimension in axis 1 to use jax.vmap over the batch
            batch = jax.tree_map(lambda x: x[:, None], batch)
            # Use jax.vmap across the batch to extract per-example gradients
            grad_fn = jax.vmap(grad_fn, in_axes=(None, 0))

        grads, _ = grad_fn(params, batch)
        updates, new_opt_state = tx.update(grads, opt_state, params)
        new_params = optax.apply_updates(params, updates)
        return new_params, new_opt_state

    key = jax.random.PRNGKey(FLAGS.seed)
    _, params = init_random_params(key, (-1, 28, 28, 1))
    opt_state = tx.init(params)

    print('\nStarting training...')
    for epoch in range(1, FLAGS.epochs + 1):
        start_time = time.time()
        for batch in train_dataset:
            params, opt_state = train_step(params, opt_state, batch)
        epoch_time = time.time() - start_time
        print(f'Epoch {epoch} in {epoch_time:0.2f} seconds.')

        # Evaluate test accuracy
        test_loss, test_acc = test_step(params, full_test_batch)
        print(
            f'Test Loss: {test_loss:.2f}  Test Accuracy (%): {test_acc:.2f}).')

        # Determine privacy loss so far
        if FLAGS.dpsgd:
            steps = epoch * NUM_EXAMPLES // FLAGS.batch_size
            eps = compute_epsilon(steps, FLAGS.delta)
            print(
                f'For delta={FLAGS.delta:.0e}, the current epsilon is: {eps:.2f}.'
            )
        else:
            print('Trained with vanilla non-private SGD optimizer.')
Ejemplo n.º 23
0
def setup_models():
    models = GAN(
        hk.transform(lambda latents: nx.SkipGenerator(
            32, max_hidden_feature_size=128)(latents)),
        hk.without_apply_rng(
            hk.transform(lambda images: nx.ResidualDiscriminator(
                32, max_hidden_feature_size=128)(images))),
        hk.without_apply_rng(
            hk.transform(lambda latents: nx.style_embedding_network(
                final_embedding_size=128, intermediate_latent_size=128)
                         (latents))),
    )

    optimizers = GAN(
        optax.sgd(0.01, momentum=0.9),
        optax.sgd(0.01, momentum=0.9),
        optax.sgd(0.01, momentum=0.9),
    )

    return Trainer(models, optimizers)
Ejemplo n.º 24
0
    def test_bad_q_targ_list(self):
        env = self.env_boxspace
        func = self.func_pi_boxspace
        transitions = self.transitions_boxspace
        print(transitions)

        pi = Policy(func, env)
        q1_targ = Q(self.func_q_type1, env)

        msg = f"q_targ_list must be a list or a tuple, got: {type(q1_targ)}"
        with self.assertRaisesRegex(TypeError, msg):
            SoftPG(pi, q1_targ, optimizer=sgd(1.0))
Ejemplo n.º 25
0
def krylov_constraint_solve_upto_r(C,r,tol=1e-5,lr=1e-2):#,W0=None):
    """ Iterative routine to compute the solution basis to the constraint CQ=0 and QᵀQ=I
        up to the rank r, with given tolerance. Uses gradient descent (+ momentum) on the
        objective |CQ|^2, which provably converges at an exponential rate."""
    W = np.random.randn(C.shape[-1],r)/np.sqrt(C.shape[-1])# if W0 is None else W0
    W = device_put(W)
    opt_init,opt_update = optax.sgd(lr,.9)
    opt_state = opt_init(W)  # init stats

    def loss(W):
        return (jnp.absolute(C@W)**2).sum()/2 # added absolute for complex support

    loss_and_grad = jit(jax.value_and_grad(loss))
    # setup progress bar
    pbar = tqdm(total=100,desc=f'Krylov Solving for Equivariant Subspace r<={r}',
    bar_format="{l_bar}{bar}| {n:.3g}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")
    prog_val = 0
    lstart, _ = loss_and_grad(W)
    
    for i in range(20000):
        
        lossval, grad = loss_and_grad(W)
        updates, opt_state = opt_update(grad, opt_state, W)
        W = optax.apply_updates(W, updates)
        # update progress bar
        progress = max(100*np.log(lossval/lstart)/np.log(tol**2/lstart)-prog_val,0)
        progress = min(100-prog_val,progress)
        if progress>0:
            prog_val += progress
            pbar.update(progress)

        if jnp.sqrt(lossval) <tol: # check convergence condition
            pbar.close()
            break # has converged
        if lossval>2e3 and i>100: # Solve diverged due to too high learning rate
            logging.warning(f"Constraint solving diverged, trying lower learning rate {lr/3:.2e}")
            if lr < 1e-4: raise ConvergenceError(f"Failed to converge even with smaller learning rate {lr:.2e}")
            return krylov_constraint_solve_upto_r(C,r,tol,lr=lr/3)
    else: raise ConvergenceError("Failed to converge.")
    # Orthogonalize solution at the end
    U,S,VT = np.linalg.svd(np.array(W),full_matrices=False) 
    # Would like to do economy SVD here (to not have the unecessary O(n^2) memory cost) 
    # but this is not supported in numpy (or Jax) unfortunately.
    rank = (S>10*tol).sum()
    Q = device_put(U[:,:rank])
    # final_L
    final_L = loss_and_grad(Q)[0]
    assert final_L <tol, f"Normalized basis has too high error {final_L:.2e} for tol {tol:.2e}"
    scutoff = (S[rank] if r>rank else 0)
    assert rank==0 or scutoff < S[rank-1]/100, f"Singular value gap too small: {S[rank-1]:.2e} \
        above cutoff {scutoff:.2e} below cutoff. Final L {final_L:.2e}, earlier {S[rank-5:rank]}"
    #logging.debug(f"found Rank {r}, above cutoff {S[rank-1]:.3e} after {S[rank] if r>rank else np.inf:.3e}. Loss {final_L:.1e}")
    return Q
Ejemplo n.º 26
0
def test_basic():

    w = 2.0
    grads = 1.5
    lr = 1.0
    rng = elegy.RNGSeq(42)

    go = generalize_optimizer(optax.sgd(lr))

    states = go.init(rng, w)
    w, states = go.apply(w, grads, states, rng)

    assert w == 0.5
Ejemplo n.º 27
0
    def test_boxspace_without_pi(self):
        env = self.env_boxspace
        func_q = self.func_q_type1

        q1 = Q(func_q, env)
        q2 = Q(func_q, env)
        q_targ1 = q1.copy()
        q_targ2 = q2.copy()

        msg = r"pi_targ_list must be provided if action space is not discrete"
        with self.assertRaisesRegex(TypeError, msg):
            ClippedDoubleQLearning(q1,
                                   q_targ_list=[q_targ1, q_targ2],
                                   optimizer=sgd(1.0))
Ejemplo n.º 28
0
    def test_update_type2(self):
        env = self.env_discrete
        func_p = self.func_p_type2
        transition_batch = self.transition_discrete

        p = StochasticTransitionModel(func_p, env, random_seed=11)
        updater = ModelUpdater(p, optimizer=sgd(1.0))

        params = deepcopy(p.params)
        function_state = deepcopy(p.function_state)

        updater.update(transition_batch)

        self.assertPytreeNotEqual(params, p.params)
        self.assertPytreeNotEqual(function_state, p.function_state)
Ejemplo n.º 29
0
    def test_update_discrete_type2(self):
        env = self.env_discrete
        func_q = self.func_q_type2

        q = Q(func_q, env)
        q_targ = q.copy()
        updater = DoubleQLearning(q, q_targ=q_targ, optimizer=sgd(1.0))

        params = deepcopy(q.params)
        function_state = deepcopy(q.function_state)

        updater.update(self.transition_discrete)

        self.assertPytreeNotEqual(params, q.params)
        self.assertPytreeNotEqual(function_state, q.function_state)
Ejemplo n.º 30
0
    def test_update_boxspace_nogrid(self):
        env = self.env_boxspace
        func_q = self.func_q_type1
        func_pi = self.func_pi_boxspace

        q = Q(func_q, env)
        pi = Policy(func_pi, env)
        q_targ = q.copy()

        msg = r"len\(q_targ_list\) \* len\(pi_targ_list\) must be at least 2"
        with self.assertRaisesRegex(ValueError, msg):
            ClippedDoubleQLearning(q,
                                   pi_targ_list=[pi],
                                   q_targ_list=[q_targ],
                                   optimizer=sgd(1.0))