def test_update_discrete_type2(self): env = self.env_discrete func_q = self.func_q_type2 transition_batch = self.transition_discrete q1 = Q(func_q, env) q2 = Q(func_q, env) q_targ1 = q1.copy() q_targ2 = q2.copy() updater1 = ClippedDoubleQLearning(q1, q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0)) updater2 = ClippedDoubleQLearning(q2, q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0)) params1 = deepcopy(q1.params) params2 = deepcopy(q2.params) function_state1 = deepcopy(q1.function_state) function_state2 = deepcopy(q2.function_state) updater1.update(transition_batch) updater2.update(transition_batch) self.assertPytreeNotEqual(params1, q1.params) self.assertPytreeNotEqual(params2, q2.params) self.assertPytreeNotEqual(function_state1, q1.function_state) self.assertPytreeNotEqual(function_state2, q2.function_state)
def test_value_transform(self): env = self.env_discrete func_v = self.func_v transition_batch = self.transition_discrete v = V(func_v, env, random_seed=11) params_init = deepcopy(v.params) function_state_init = deepcopy(v.function_state) # first update without value transform updater = SimpleTD(v, optimizer=sgd(1.0)) updater.update(transition_batch) params_without_reg = deepcopy(v.params) function_state_without_reg = deepcopy(v.function_state) self.assertPytreeNotEqual(params_without_reg, params_init) self.assertPytreeNotEqual(function_state_without_reg, function_state_init) # reset weights v = V(func_v, env, value_transform=LogTransform(), random_seed=11) self.assertPytreeAlmostEqual(params_init, v.params) self.assertPytreeAlmostEqual(function_state_init, v.function_state) # then update with value transform updater = SimpleTD(v, optimizer=sgd(1.0)) updater.update(transition_batch) params_with_reg = deepcopy(v.params) function_state_with_reg = deepcopy(v.function_state) self.assertPytreeNotEqual(params_with_reg, params_init) self.assertPytreeNotEqual(function_state_with_reg, function_state_init) self.assertPytreeNotEqual(params_with_reg, params_without_reg) self.assertPytreeAlmostEqual(function_state_with_reg, function_state_without_reg) # same!
def test_policyreg(self): env = self.env_discrete func_p = self.func_p_type1 transition_batch = self.transition_discrete p = StochasticTransitionModel(func_p, env, random_seed=11) params_init = deepcopy(p.params) function_state_init = deepcopy(p.function_state) # first update without policy regularizer updater = ModelUpdater(p, optimizer=sgd(1.0)) updater.update(transition_batch) params_without_reg = deepcopy(p.params) function_state_without_reg = deepcopy(p.function_state) self.assertPytreeNotEqual(params_without_reg, params_init) self.assertPytreeNotEqual(function_state_without_reg, function_state_init) # reset weights p = StochasticTransitionModel(func_p, env, random_seed=11) self.assertPytreeAlmostEqual(params_init, p.params) self.assertPytreeAlmostEqual(function_state_init, p.function_state) # then update with policy regularizer reg = EntropyRegularizer(p, beta=1.0) updater = ModelUpdater(p, optimizer=sgd(1.0), regularizer=reg) updater.update(transition_batch) params_with_reg = deepcopy(p.params) function_state_with_reg = deepcopy(p.function_state) self.assertPytreeNotEqual(params_with_reg, params_init) self.assertPytreeNotEqual(params_with_reg, params_without_reg) # <---- important self.assertPytreeNotEqual(function_state_with_reg, function_state_init) self.assertPytreeAlmostEqual(function_state_with_reg, function_state_without_reg) # same!
def test_update_boxspace(self): env = self.env_boxspace func_q = self.func_q_type1 func_pi = self.func_pi_boxspace transition_batch = self.transition_boxspace q1 = Q(func_q, env) q2 = Q(func_q, env) pi1 = Policy(func_pi, env) pi2 = Policy(func_pi, env) q_targ1 = q1.copy() q_targ2 = q2.copy() updater1 = ClippedDoubleQLearning(q1, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0)) updater2 = ClippedDoubleQLearning(q2, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0)) params1 = deepcopy(q1.params) params2 = deepcopy(q2.params) function_state1 = deepcopy(q1.function_state) function_state2 = deepcopy(q2.function_state) updater1.update(transition_batch) updater2.update(transition_batch) self.assertPytreeNotEqual(params1, q1.params) self.assertPytreeNotEqual(params2, q2.params) self.assertPytreeNotEqual(function_state1, q1.function_state) self.assertPytreeNotEqual(function_state2, q2.function_state)
def create_optax_optim(name, learning_rate=None, momentum=0.9, weight_decay=0, **kwargs): """ Optimizer Factory Args: learning_rate (float): specify learning rate or leave up to scheduler / optim if None weight_decay (float): weight decay to apply to all params, not applied if 0 **kwargs: optional / optimizer specific params that override defaults With regards to the kwargs, I've tried to keep the param naming incoming via kwargs from config file more consistent so there is less variation. Names of common args such as eps, beta1, beta2 etc will be remapped where possible (even if optimizer impl uses a diff name) and removed when not needed. A list of some common params to use in config files as named: eps (float): default stability / regularization epsilon value beta1 (float): moving average / momentum coefficient for gradient beta2 (float): moving average / momentum coefficient for gradient magnitude (squared grad) """ name = name.lower() opt_args = dict(learning_rate=learning_rate, **kwargs) _rename(opt_args, ('beta1', 'beta2'), ('b1', 'b2')) if name == 'sgd' or name == 'momentum' or name == 'nesterov': _erase(opt_args, ('eps', )) if name == 'momentum': optimizer = optax.sgd(momentum=momentum, **opt_args) elif name == 'nesterov': optimizer = optax.sgd(momentum=momentum, nesterov=True) else: assert name == 'sgd' optimizer = optax.sgd(momentum=0, **opt_args) elif name == 'adabelief': optimizer = optax.adabelief(**opt_args) elif name == 'adam' or name == 'adamw': if name == 'adamw': optimizer = optax.adamw(weight_decay=weight_decay, **opt_args) else: optimizer = optax.adam(**opt_args) elif name == 'lamb': optimizer = optax.lamb(weight_decay=weight_decay, **opt_args) elif name == 'lars': optimizer = lars(weight_decay=weight_decay, **opt_args) elif name == 'rmsprop': optimizer = optax.rmsprop(momentum=momentum, **opt_args) elif name == 'rmsproptf': optimizer = optax.rmsprop(momentum=momentum, initial_scale=1.0, **opt_args) else: assert False, f"Invalid optimizer name specified ({name})" return optimizer
def get_optimizer(optimizer_name: OptimizerName, learning_rate: float, momentum: float = 0.0, adam_beta1: float = 0.9, adam_beta2: float = 0.999, adam_epsilon: float = 1e-8, rmsprop_decay: float = 0.9, rmsprop_epsilon: float = 1e-8, adagrad_init_accumulator: float = 0.1, adagrad_epsilon: float = 1e-6) -> Optimizer: """Given parameters, returns the corresponding optimizer. Args: optimizer_name: One of SGD, MOMENTUM, ADAM, RMSPROP. learning_rate: Learning rate for all optimizers. momentum: Momentum parameter for MOMENTUM. adam_beta1: beta1 parameter for ADAM. adam_beta2: beta2 parameter for ADAM. adam_epsilon: epsilon parameter for ADAM. rmsprop_decay: decay parameter for RMSPROP. rmsprop_epsilon: epsilon parameter for RMSPROP. adagrad_init_accumulator: initial accumulator for ADAGRAD. adagrad_epsilon: epsilon parameter for ADAGRAD. Returns: Returns the Optimizer with the specified properties. Raises: ValueError: iff the optimizer names is not one of SGD, MOMENTUM, ADAM, RMSPROP, or Adagrad, raises errors. """ if optimizer_name == OptimizerName.SGD: return Optimizer(*optax.sgd(learning_rate)) elif optimizer_name == OptimizerName.MOMENTUM: return Optimizer(*optax.sgd(learning_rate, momentum)) elif optimizer_name == OptimizerName.ADAM: return Optimizer(*optax.adam( learning_rate, b1=adam_beta1, b2=adam_beta2, eps=adam_epsilon)) elif optimizer_name == OptimizerName.RMSPROP: return Optimizer( *optax.rmsprop(learning_rate, decay=rmsprop_decay, eps=rmsprop_epsilon)) elif optimizer_name == OptimizerName.ADAGRAD: return Optimizer(*optax.adagrad( learning_rate, initial_accumulator_value=adagrad_init_accumulator, eps=adagrad_epsilon)) else: raise ValueError(f'Unsupported optimizer_name {optimizer_name}.')
def test_integration(self): env = catch.Catch() action_spec = env.action_spec() num_actions = action_spec.num_values obs_spec = env.observation_spec() agent = agent_lib.Agent( num_actions=num_actions, obs_spec=obs_spec, net_factory=haiku_nets.CatchNet, ) unroll_length = 20 learner = learner_lib.Learner( agent=agent, rng_key=jax.random.PRNGKey(42), opt=optax.sgd(1e-2), batch_size=1, discount_factor=0.99, frames_per_iter=unroll_length, ) actor = actor_lib.Actor( agent=agent, env=env, learner=learner, unroll_length=unroll_length, ) frame_count, params = actor.pull_params() actor.unroll_and_push(frame_count=frame_count, params=params) learner.run(max_iterations=1)
def run_svgd(key, lr, full_data=False, progress_bar=False): key, subkey = random.split(key) init_particles = ravel(*sample_from_prior(subkey, n_particles)) svgd_opt = optax.sgd(lr) svgd_grad = models.KernelGradient( get_target_logp=lambda batch: get_minibatch_logp(*batch), scaled=True) particles = models.Particles(key, svgd_grad.gradient, init_particles, custom_optimizer=svgd_opt) test_batches = get_batches(x_test, y_test, 2 * NUM_VALS) if full_data else get_batches( x_val, y_val, 2 * NUM_VALS) train_batches = get_batches(xx, yy, NUM_STEPS + 1) if full_data else get_batches( x_train, y_train, NUM_STEPS + 1) for i, batch in tqdm(enumerate(train_batches), total=NUM_STEPS, disable=not progress_bar): particles.step(batch) if i % (NUM_STEPS // NUM_VALS) == 0: test_logp = get_minibatch_logp(*next(test_batches)) stepdata = { "accuracy": compute_test_accuracy(unravel(particles.particles)[0]), "test_logp": test_logp(particles.particles), } metrics.append_to_log(particles.rundata, stepdata) particles.done() return particles
def Sgd(learning_rate: float): r"""Stochastic Gradient Descent Optimizer. The `Stochastic Gradient Descent <https://en.wikipedia.org/wiki/Stochastic_gradient_descent>`_ is one of the most popular optimizers in machine learning applications. Given a stochastic estimate of the gradient of the cost function (:math:`G(\mathbf{p})`), it performs the update: .. math:: p^\prime_k = p_k -\eta G_k(\mathbf{p}), where :math:`\eta` is the so-called learning rate. NetKet also implements two extensions to the simple SGD, the first one is :math:`L_2` regularization, and the second one is the possibility to set a decay factor :math:`\gamma \leq 1` for the learning rate, such that at iteration :math:`n` the learning rate is :math:`\eta \gamma^n`. Args: learning_rate: The learning rate :math:`\eta`. Examples: Simple SGD optimizer. >>> from netket.optimizer import Sgd >>> op = Sgd(learning_rate=0.05) """ from optax import sgd return sgd(learning_rate)
def train( # pylint: disable=invalid-name Phi, Psi, num_epochs, learning_rate, key, estimator, alpha, optimizer, use_l2_reg, reg_coeff, use_penalty, j, num_rows, skipsize=1): """Training function.""" Phis = [Phi] # pylint: disable=invalid-name grads = [] if optimizer == 'sgd': optim = optax.sgd(learning_rate) elif optimizer == 'adam': optim = optax.adam(learning_rate) opt_state = optim.init(Phi) for i in tqdm(range(num_epochs)): key, subkey = jax.random.split(key) Phi, opt_state, grad = estimates.nabla_phi_analytical( Phi, Psi, subkey, optim, opt_state, estimator, alpha, use_l2_reg, reg_coeff, use_penalty, j, num_rows) Phis.append(Phi) grads.append(grad) if i % skipsize == 0: Phis.append(Phi) grads.append(grad) return jnp.stack(Phis), jnp.stack(grads)
def Momentum(learning_rate: float, beta: float = 0.9, nesterov: bool = False): r"""Momentum-based Optimizer. The momentum update incorporates an exponentially weighted moving average over previous gradients to speed up descent `Qian, N. (1999) <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.57.5612&rep=rep1&type=pdf>`_. The momentum vector :math:`\mathbf{m}` is initialized to zero. Given a stochastic estimate of the gradient of the cost function :math:`G(\mathbf{p})`, the updates for the parameter :math:`p_k` and corresponding component of the momentum :math:`m_k` are .. math:: m^\prime_k &= \beta m_k + (1-\beta)G_k(\mathbf{p})\\ p^\prime_k &= \eta m^\prime_k Args: learning_rate: The learning rate :math:`\eta` beta: Momentum exponential decay rate, should be in [0,1]. nesterov: Flag to use nesterov momentum correction Examples: Momentum optimizer. >>> from netket.optimizer import Momentum >>> op = Momentum(learning_rate=0.01) """ from optax import sgd return sgd(learning_rate, momentum=beta, nesterov=nesterov)
def create_train_state(rng, config): """Creates initial `TrainState`.""" cnn = CNN() params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] tx = optax.sgd(config.learning_rate, config.momentum) return train_state.TrainState.create( apply_fn=cnn.apply, params=params, tx=tx)
def test_nn(self): seed = 33 key = jax.random.PRNGKey(seed) class SimpleNN(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(features=100)(x) x = jax.nn.relu(x) x = nn.Dense(features=100)(x) x = jax.nn.relu(x) x = nn.Dense(features=1)(x) return x x = jax.random.normal(key, [20, 2]) y = jax.random.normal(key, [20, 1]) lr = 0.001 model = SimpleNN() params = model.init(key, x) m = optax.sgd(learning_rate=-1., momentum=0.9) d = optax.sgd(learning_rate=-2., momentum=0.9) g = kitchen_sink(chains=[m, d], combinator='grafting', learning_rate=lr) s_m = m.init(params) s_d = d.init(params) s_g = g.init(params) def loss_fn(params): yhat = model.apply(params, x) loss = jnp.sum((y - yhat)**2) return loss for _ in range(10): grad_fn = jax.value_and_grad(loss_fn) _, grad = grad_fn(params) u_m, s_m = m.update(grad, s_m) u_d, s_d = d.update(grad, s_d) u_g, s_g = g.update(grad, s_g) u_m_n = jax.tree_map(jnp.linalg.norm, u_m) u_d_n = jax.tree_map(jnp.linalg.norm, u_d) u_g2 = jax.tree_multimap( lambda m, d, dn: -lr * d / (dn + 1e-6) * m, u_m_n, u_d, u_d_n) chex.assert_trees_all_close(u_g, u_g2)
def test_policyreg_boxspace(self): env = self.env_boxspace func_q = self.func_q_type1 func_pi = self.func_pi_boxspace transition_batch = self.transition_boxspace q = Q(func_q, env, random_seed=11) pi = Policy(func_pi, env, random_seed=17) q_targ = q.copy() params_init = deepcopy(q.params) function_state_init = deepcopy(q.function_state) # first update without policy regularizer policy_reg = EntropyRegularizer(pi, beta=1.0) updater = Sarsa(q, q_targ, optimizer=sgd(1.0)) updater.update(transition_batch) params_without_reg = deepcopy(q.params) function_state_without_reg = deepcopy(q.function_state) self.assertPytreeNotEqual(params_without_reg, params_init) self.assertPytreeNotEqual(function_state_without_reg, function_state_init) # reset weights q = Q(func_q, env, random_seed=11) pi = Policy(func_pi, env, random_seed=17) q_targ = q.copy() self.assertPytreeAlmostEqual(params_init, q.params, decimal=10) self.assertPytreeAlmostEqual(function_state_init, q.function_state, decimal=10) # then update with policy regularizer policy_reg = EntropyRegularizer(pi, beta=1.0) updater = Sarsa(q, q_targ, optimizer=sgd(1.0), policy_regularizer=policy_reg) updater.update(transition_batch) params_with_reg = deepcopy(q.params) function_state_with_reg = deepcopy(q.function_state) self.assertPytreeNotEqual(params_with_reg, params_init) self.assertPytreeNotEqual(params_with_reg, params_without_reg) # <--- important self.assertPytreeNotEqual(function_state_with_reg, function_state_init) self.assertPytreeAlmostEqual(function_state_with_reg, function_state_without_reg) # same!
def test_policyreg_discrete(self): env = self.env_discrete func_q = self.func_q_type1 func_pi = self.func_pi_discrete transition_batch = self.transition_discrete q = Q(func_q, env, random_seed=11) pi = Policy(func_pi, env, random_seed=17) q_targ = q.copy() params_init = deepcopy(q.params) function_state_init = deepcopy(q.function_state) # first update without policy regularizer policy_reg = EntropyRegularizer(pi, beta=1.0) updater = Sarsa(q, q_targ, optimizer=sgd(1.0)) updater.update(transition_batch) params_without_reg = deepcopy(q.params) function_state_without_reg = deepcopy(q.function_state) self.assertPytreeNotEqual(params_without_reg, params_init) self.assertPytreeNotEqual(function_state_without_reg, function_state_init) # reset weights q = Q(func_q, env, random_seed=11) pi = Policy(func_pi, env, random_seed=17) q_targ = q.copy() self.assertPytreeAlmostEqual(params_init, q.params) self.assertPytreeAlmostEqual(function_state_init, q.function_state) # then update with policy regularizer policy_reg = EntropyRegularizer(pi, beta=1.0) updater = Sarsa(q, q_targ, optimizer=sgd(1.0), policy_regularizer=policy_reg) print('updater.target_params:', updater.target_params) print('updater.target_function_state:', updater.target_function_state) updater.update(transition_batch) params_with_reg = deepcopy(q.params) function_state_with_reg = deepcopy(q.function_state) self.assertPytreeNotEqual(params_with_reg, params_init) self.assertPytreeNotEqual(function_state_with_reg, function_state_init) self.assertPytreeNotEqual(params_with_reg, params_without_reg) self.assertPytreeAlmostEqual(function_state_with_reg, function_state_without_reg) # same!
def test_1d(self): """Test grafting.""" x = 10. lr = 0.01 m = optax.sgd(learning_rate=-1.) d = optax.sgd(learning_rate=-2.) g = kitchen_sink(chains=[m, d], combinator='grafting', learning_rate=lr) state = g.init(x) for _ in range(10): grad_fn = jax.value_and_grad(lambda x: x**2) _, grad = grad_fn(x) updates, state = g.update(grad, state) dx = 2 * x x -= lr * dx self.assertAlmostEqual(updates, -lr * dx, places=4)
def optimizer(hyperparameters: spec.Hyperparamters, num_train_examples: int): steps_per_epoch = num_train_examples // get_batch_size('imagenet') learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) opt_init_fn, opt_update_fn = optax.sgd( nesterov=True, momentum=hyperparameters.momentum, learning_rate=learning_rate_fn) return opt_init_fn, opt_update_fn
def create_train_state(rng, config: ml_collections.ConfigDict, model): """Create initial training state.""" params = get_initial_params(rng, model) tx = optax.chain( optax.sgd(learning_rate=config.learning_rate, momentum=config.momentum), optax.additive_weight_decay(weight_decay=config.weight_decay)) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) return state
def test_policyreg_boxspace(self): env = self.env_boxspace func_v = self.func_v func_pi = self.func_pi_boxspace transition_batch = self.transition_boxspace v = V(func_v, env, random_seed=11) pi = Policy(func_pi, env, random_seed=17) v_targ = v.copy() params_init = deepcopy(v.params) function_state_init = deepcopy(v.function_state) # first update without policy regularizer policy_reg = EntropyRegularizer(pi, beta=1.0) updater = SimpleTD(v, v_targ, optimizer=sgd(1.0)) updater.update(transition_batch) params_without_reg = deepcopy(v.params) function_state_without_reg = deepcopy(v.function_state) self.assertPytreeNotEqual(params_without_reg, params_init) self.assertPytreeNotEqual(function_state_without_reg, function_state_init) # reset weights v = V(func_v, env, random_seed=11) pi = Policy(func_pi, env, random_seed=17) v_targ = v.copy() self.assertPytreeAlmostEqual(params_init, v.params) self.assertPytreeAlmostEqual(function_state_init, v.function_state) # then update with policy regularizer policy_reg = EntropyRegularizer(pi, beta=1.0) updater = SimpleTD(v, v_targ, optimizer=sgd(1.0), policy_regularizer=policy_reg) updater.update(transition_batch) params_with_reg = deepcopy(v.params) function_state_with_reg = deepcopy(v.function_state) self.assertPytreeNotEqual(params_with_reg, params_init) self.assertPytreeNotEqual(function_state_with_reg, function_state_init) self.assertPytreeNotEqual(params_with_reg, params_without_reg) self.assertPytreeAlmostEqual(function_state_with_reg, function_state_without_reg) # same!
def test_update_discrete_nogrid(self): env = self.env_discrete func_q = self.func_q_type1 q = Q(func_q, env) q_targ = q.copy() msg = r"len\(q_targ_list\) must be at least 2" with self.assertRaisesRegex(ValueError, msg): ClippedDoubleQLearning(q, q_targ_list=[q_targ], optimizer=sgd(1.0))
def test_update_boxspace(self): env = self.env_boxspace func_q = self.func_q_type1 q = Q(func_q, env, random_seed=11) q_targ = q.copy() msg = r"SoftQLearning class is only implemented for discrete actions spaces" with self.assertRaisesRegex(NotImplementedError, msg): SoftQLearning(q, q_targ, optimizer=sgd(1.0))
def main(_): train_dataset = datasets.load_image_dataset('mnist', FLAGS.batch_size) train_dataset = list(train_dataset.as_numpy_iterator()) test_dataset = datasets.load_image_dataset('mnist', NUM_EXAMPLES, datasets.Split.TEST) full_test_batch = next(test_dataset.as_numpy_iterator()) if FLAGS.dpsgd: tx = optax.dpsgd(learning_rate=FLAGS.learning_rate, l2_norm_clip=FLAGS.l2_norm_clip, noise_multiplier=FLAGS.noise_multiplier, seed=FLAGS.seed) else: tx = optax.sgd(learning_rate=FLAGS.learning_rate) @jax.jit def train_step(params, opt_state, batch): grad_fn = jax.grad(loss_fn, has_aux=True) if FLAGS.dpsgd: # Insert dummy dimension in axis 1 to use jax.vmap over the batch batch = jax.tree_map(lambda x: x[:, None], batch) # Use jax.vmap across the batch to extract per-example gradients grad_fn = jax.vmap(grad_fn, in_axes=(None, 0)) grads, _ = grad_fn(params, batch) updates, new_opt_state = tx.update(grads, opt_state, params) new_params = optax.apply_updates(params, updates) return new_params, new_opt_state key = jax.random.PRNGKey(FLAGS.seed) _, params = init_random_params(key, (-1, 28, 28, 1)) opt_state = tx.init(params) print('\nStarting training...') for epoch in range(1, FLAGS.epochs + 1): start_time = time.time() for batch in train_dataset: params, opt_state = train_step(params, opt_state, batch) epoch_time = time.time() - start_time print(f'Epoch {epoch} in {epoch_time:0.2f} seconds.') # Evaluate test accuracy test_loss, test_acc = test_step(params, full_test_batch) print( f'Test Loss: {test_loss:.2f} Test Accuracy (%): {test_acc:.2f}).') # Determine privacy loss so far if FLAGS.dpsgd: steps = epoch * NUM_EXAMPLES // FLAGS.batch_size eps = compute_epsilon(steps, FLAGS.delta) print( f'For delta={FLAGS.delta:.0e}, the current epsilon is: {eps:.2f}.' ) else: print('Trained with vanilla non-private SGD optimizer.')
def setup_models(): models = GAN( hk.transform(lambda latents: nx.SkipGenerator( 32, max_hidden_feature_size=128)(latents)), hk.without_apply_rng( hk.transform(lambda images: nx.ResidualDiscriminator( 32, max_hidden_feature_size=128)(images))), hk.without_apply_rng( hk.transform(lambda latents: nx.style_embedding_network( final_embedding_size=128, intermediate_latent_size=128) (latents))), ) optimizers = GAN( optax.sgd(0.01, momentum=0.9), optax.sgd(0.01, momentum=0.9), optax.sgd(0.01, momentum=0.9), ) return Trainer(models, optimizers)
def test_bad_q_targ_list(self): env = self.env_boxspace func = self.func_pi_boxspace transitions = self.transitions_boxspace print(transitions) pi = Policy(func, env) q1_targ = Q(self.func_q_type1, env) msg = f"q_targ_list must be a list or a tuple, got: {type(q1_targ)}" with self.assertRaisesRegex(TypeError, msg): SoftPG(pi, q1_targ, optimizer=sgd(1.0))
def krylov_constraint_solve_upto_r(C,r,tol=1e-5,lr=1e-2):#,W0=None): """ Iterative routine to compute the solution basis to the constraint CQ=0 and QᵀQ=I up to the rank r, with given tolerance. Uses gradient descent (+ momentum) on the objective |CQ|^2, which provably converges at an exponential rate.""" W = np.random.randn(C.shape[-1],r)/np.sqrt(C.shape[-1])# if W0 is None else W0 W = device_put(W) opt_init,opt_update = optax.sgd(lr,.9) opt_state = opt_init(W) # init stats def loss(W): return (jnp.absolute(C@W)**2).sum()/2 # added absolute for complex support loss_and_grad = jit(jax.value_and_grad(loss)) # setup progress bar pbar = tqdm(total=100,desc=f'Krylov Solving for Equivariant Subspace r<={r}', bar_format="{l_bar}{bar}| {n:.3g}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]") prog_val = 0 lstart, _ = loss_and_grad(W) for i in range(20000): lossval, grad = loss_and_grad(W) updates, opt_state = opt_update(grad, opt_state, W) W = optax.apply_updates(W, updates) # update progress bar progress = max(100*np.log(lossval/lstart)/np.log(tol**2/lstart)-prog_val,0) progress = min(100-prog_val,progress) if progress>0: prog_val += progress pbar.update(progress) if jnp.sqrt(lossval) <tol: # check convergence condition pbar.close() break # has converged if lossval>2e3 and i>100: # Solve diverged due to too high learning rate logging.warning(f"Constraint solving diverged, trying lower learning rate {lr/3:.2e}") if lr < 1e-4: raise ConvergenceError(f"Failed to converge even with smaller learning rate {lr:.2e}") return krylov_constraint_solve_upto_r(C,r,tol,lr=lr/3) else: raise ConvergenceError("Failed to converge.") # Orthogonalize solution at the end U,S,VT = np.linalg.svd(np.array(W),full_matrices=False) # Would like to do economy SVD here (to not have the unecessary O(n^2) memory cost) # but this is not supported in numpy (or Jax) unfortunately. rank = (S>10*tol).sum() Q = device_put(U[:,:rank]) # final_L final_L = loss_and_grad(Q)[0] assert final_L <tol, f"Normalized basis has too high error {final_L:.2e} for tol {tol:.2e}" scutoff = (S[rank] if r>rank else 0) assert rank==0 or scutoff < S[rank-1]/100, f"Singular value gap too small: {S[rank-1]:.2e} \ above cutoff {scutoff:.2e} below cutoff. Final L {final_L:.2e}, earlier {S[rank-5:rank]}" #logging.debug(f"found Rank {r}, above cutoff {S[rank-1]:.3e} after {S[rank] if r>rank else np.inf:.3e}. Loss {final_L:.1e}") return Q
def test_basic(): w = 2.0 grads = 1.5 lr = 1.0 rng = elegy.RNGSeq(42) go = generalize_optimizer(optax.sgd(lr)) states = go.init(rng, w) w, states = go.apply(w, grads, states, rng) assert w == 0.5
def test_boxspace_without_pi(self): env = self.env_boxspace func_q = self.func_q_type1 q1 = Q(func_q, env) q2 = Q(func_q, env) q_targ1 = q1.copy() q_targ2 = q2.copy() msg = r"pi_targ_list must be provided if action space is not discrete" with self.assertRaisesRegex(TypeError, msg): ClippedDoubleQLearning(q1, q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0))
def test_update_type2(self): env = self.env_discrete func_p = self.func_p_type2 transition_batch = self.transition_discrete p = StochasticTransitionModel(func_p, env, random_seed=11) updater = ModelUpdater(p, optimizer=sgd(1.0)) params = deepcopy(p.params) function_state = deepcopy(p.function_state) updater.update(transition_batch) self.assertPytreeNotEqual(params, p.params) self.assertPytreeNotEqual(function_state, p.function_state)
def test_update_discrete_type2(self): env = self.env_discrete func_q = self.func_q_type2 q = Q(func_q, env) q_targ = q.copy() updater = DoubleQLearning(q, q_targ=q_targ, optimizer=sgd(1.0)) params = deepcopy(q.params) function_state = deepcopy(q.function_state) updater.update(self.transition_discrete) self.assertPytreeNotEqual(params, q.params) self.assertPytreeNotEqual(function_state, q.function_state)
def test_update_boxspace_nogrid(self): env = self.env_boxspace func_q = self.func_q_type1 func_pi = self.func_pi_boxspace q = Q(func_q, env) pi = Policy(func_pi, env) q_targ = q.copy() msg = r"len\(q_targ_list\) \* len\(pi_targ_list\) must be at least 2" with self.assertRaisesRegex(ValueError, msg): ClippedDoubleQLearning(q, pi_targ_list=[pi], q_targ_list=[q_targ], optimizer=sgd(1.0))