def bestfit_via_grad_descent(param): # gradient descent _, np = param[0], param[1:] g = jax.grad(cnll)(np) updates, _ = gradient_descent.update(g, gradient_descent.init(np)) np = optix.apply_updates(np, updates) param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np]) return param
def update(params: hk.Params, opt_state: OptState, batch: Batch) -> Tuple[hk.Params, OptState]: """Learning rule (stochastic gradient descent).""" grads = jax.grad(loss)(params, batch) updates, opt_state = opt.update(grads, opt_state) new_params = optix.apply_updates(params, updates) return new_params, opt_state
def train_step(self, data_batch, algo_func, algo_state): """Update all functions.""" # Update Pi pi_grads = jax.grad(_pi_loss_batched)(algo_state.pi_params, data_batch, algo_func, algo_state) pi_updates, new_pi_opt_state = algo_func.pi_optimizer.update( pi_grads, algo_state.pi_opt_state) new_pi_params = optix.apply_updates(algo_state.pi_params, pi_updates) # Update Q q_grads = jax.grad(_q_loss_batched)(algo_state.q_params, data_batch, algo_func, algo_state) q_updates, new_q_opt_state = algo_func.q_optimizer.update( q_grads, algo_state.q_opt_state) new_q_params = optix.apply_updates(algo_state.q_params, q_updates) return DDPGState(new_pi_params, new_q_params, new_pi_opt_state, new_q_opt_state)
def sgd_step(state: TrainingState, trajectory: sequence.Trajectory) -> TrainingState: """Does a step of SGD over a trajectory.""" gradients = jax.grad(loss_fn)(state.params, trajectory) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return TrainingState(params=new_params, opt_state=new_opt_state)
def sgd_step( state: TrainingState, samples: reverb.ReplaySample) -> Tuple[TrainingState, LearnerOutputs]: grad_fn = jax.grad(loss, has_aux=True) gradients, (keys, priorities) = grad_fn(state.params, state.target_params, samples) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) steps = state.steps + 1 # Periodically update target networks. target_params = utils.update_periodically(steps, self._target_update_period, new_params, state.target_params) new_state = TrainingState( params=new_params, target_params=target_params, opt_state=new_opt_state, steps=steps) outputs = LearnerOutputs(keys=keys, priorities=priorities) return new_state, outputs
def jitted_update(params: hk.Params, opt_state: OptState, x: jnp.ndarray, y: jnp.ndarray) -> Tuple[hk.Params, OptState]: grads = jax.grad(self.loss_function)(params, x, y) updates, opt_state = self.optimizer.update(grads, opt_state) params = optix.apply_updates(params, updates) return params, opt_state
def update(state: TrainingState, batch: dataset.Batch) -> TrainingState: """Does a step of SGD given inputs & targets.""" _, optimizer = optix.adam(FLAGS.learning_rate) _, loss_fn = hk.without_apply_rng(hk.transform(sequence_loss)) gradients = jax.grad(loss_fn)(state.params, batch) updates, new_opt_state = optimizer(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return TrainingState(params=new_params, opt_state=new_opt_state)
def update_actor(self, actor_params: hk.Params, critic_params: hk.Params, actor_opt_state: OptState, state: np.ndarray) -> Tuple[hk.Params, OptState]: """Learning rule (stochastic gradient descent).""" _, gradient = jax.value_and_grad(self.actor_loss)(actor_params, critic_params, state) updates, opt_state = self.actor_opt_update(gradient, actor_opt_state) new_params = optix.apply_updates(actor_params, updates) return new_params, opt_state
def update(params: hk.Params, state: State, opt_state: OptState, batch: Batch) -> Tuple[hk.Params, State, OptState]: """Update the params.""" (_, new_state), grads = jax.value_and_grad(train_loss, has_aux=True)(params, state, batch) updates, new_opt_state = optimizer.update(grads, opt_state) new_params = optix.apply_updates(params, updates) return new_params, new_state, new_opt_state
def update( self, params: hk.Params, opt_state: OptState, trajs: Transition, ) -> Tuple[hk.Params, OptState]: g = jax.grad(self._agent.loss)(params, trajs) updates, new_opt_state = self._opt_update(g, opt_state) return optix.apply_updates(params, updates), new_opt_state
def train_step(carry, batch_indices): batch = jax.tree_map(lambda x: x[batch_indices], train_set) params_, net_state_, opt_state_ = carry loss, grad, net_state_ = _perdevice_log_prob_and_grad( batch, params_, net_state_) grad = jax.lax.psum(grad, axis_name='i') updates, opt_state_ = optimizer.update(grad, opt_state_) params_ = optix.apply_updates(params_, updates) return (params_, net_state_, opt_state_), loss
def update( params: hk.Params, rng_key: PRNGKey, opt_state: OptState, batch: Batch, ) -> Tuple[hk.Params, OptState]: """Single SGD update step.""" grads = jax.grad(loss_fn)(params, rng_key, batch) updates, new_opt_state = optimizer.update(grads, opt_state) new_params = optix.apply_updates(params, updates) return new_params, new_opt_state
def learner_step(self, params, data, learner_state, unused_key): is_update_time = (learner_state.count % self._target_period == 0) target_params = rlax.periodic_update(params.online, params.target, is_update_time) dloss_dtheta = jax.grad(self._loss)(params.online, target_params, *data) updates, opt_state = self._optimizer.update(dloss_dtheta, learner_state.opt_state) online_params = optix.apply_updates(params.online, updates) return (Params(online_params, target_params), LearnerState(learner_state.count + 1, opt_state))
def update( params: hk.Params, opt_state: OptState, batch_obs: jnp.DeviceArray, batch_acts: jnp.DeviceArray, batch_returns: jnp.DeviceArray ) -> Tuple[hk.Params, OptState, jnp.DeviceArray]: batch_loss, g = jax.value_and_grad(compute_loss)(params, batch_obs, batch_acts, batch_returns) updates, opt_state = opt_update(g, opt_state) new_params = optix.apply_updates(params, updates) return new_params, opt_state, batch_loss
def update( params: Params, opt_state: OptState, inputs: np.ndarray, targets: np.ndarray, ) -> Tuple[Params, OptState]: """Learning rule (stochastic gradient descent).""" _, gradient = jax.value_and_grad(loss)(params, inputs, targets) updates, opt_state = opt.update(gradient, opt_state) new_params = optix.apply_updates(params, updates) return new_params, opt_state
def sgd_step(state: AgentState, trajectory: sequence.Trajectory) -> AgentState: """Does a step of SGD over a trajectory.""" gradients, new_rnn_state = jax.grad(loss_fn, has_aux=True)( state.params, trajectory, state.rnn_unroll_state) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return state._replace(params=new_params, opt_state=new_opt_state, rnn_unroll_state=new_rnn_state)
def update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t): """Update network weights wrt Q-learning loss.""" def q_learning_loss(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t): q_tm1 = network.apply(net_params, obs_tm1) td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t) return rlax.l2_loss(td_error) dloss_dtheta = jax.grad(q_learning_loss)(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t) updates, opt_state = optimizer.update(dloss_dtheta, opt_state) net_params = optix.apply_updates(net_params, updates) return net_params, opt_state
def train_step(params, opt_state, batch): """Train for a single step.""" value_and_grad_fn = jax.value_and_grad(loss_fn) loss, grad = value_and_grad_fn(params, batch) # Note this is not the usual optix api as we additionally need parameter # values. # updates, opt_state = opt.update(grad, opt_state) updates, opt_state = opt.update_with_params(grad, params, opt_state) new_params = optix.apply_updates(params, updates) return new_params, opt_state, loss
def sgd_step(state: TrainingState, transitions: Sequence[jnp.ndarray]) -> TrainingState: """Performs an SGD step on a batch of transitions.""" gradients = jax.grad(loss)(state.params, state.target_params, transitions) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1)
def test_apply_every(self): # The frequency of the application of sgd k = 4 zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.])) # experimental/optix.py sgd optix_sgd_params = self.init_params sgd = optix.sgd(LR, 0.0) state_sgd = sgd.init(optix_sgd_params) # experimental/optix.py sgd apply every optix_sgd_apply_every_params = self.init_params sgd_apply_every = optix.chain(optix.apply_every(k=k), optix.trace(decay=0, nesterov=False), optix.scale(-LR)) state_sgd_apply_every = sgd_apply_every.init( optix_sgd_apply_every_params) for i in range(STEPS): # Apply a step of sgd updates_sgd, state_sgd = sgd.update(self.per_step_updates, state_sgd) optix_sgd_params = optix.apply_updates(optix_sgd_params, updates_sgd) # Apply a step of sgd_apply_every updates_sgd_apply_every, state_sgd_apply_every = sgd_apply_every.update( self.per_step_updates, state_sgd_apply_every) optix_sgd_apply_every_params = optix.apply_updates( optix_sgd_apply_every_params, updates_sgd_apply_every) if i % k == k - 1: # Check equivalence. for x, y in zip(tree_leaves(optix_sgd_apply_every_params), tree_leaves(optix_sgd_params)): np.testing.assert_allclose(x, y, atol=1e-6, rtol=100) else: # Check updaue is zero. for x, y in zip(tree_leaves(updates_sgd_apply_every), tree_leaves(zero_update)): np.testing.assert_allclose(x, y, atol=1e-10, rtol=1e-5)
def update_critic(self, critic_params: hk.Params, target_critic_params: hk.Params, target_actor_params: hk.Params, critic_opt_state: OptState, state: np.ndarray, action: np.ndarray, next_state: np.ndarray, reward: np.ndarray, not_done: np.ndarray, rng: jnp.ndarray) -> Tuple[hk.Params, OptState]: """Learning rule (stochastic gradient descent).""" _, gradient = jax.value_and_grad(self.critic_loss)( critic_params, target_critic_params, target_actor_params, state, action, next_state, reward, not_done, rng) updates, opt_state = self.critic_opt_update(gradient, critic_opt_state) new_params = optix.apply_updates(critic_params, updates) return new_params, opt_state
def sgd_step(state: TrainingState, transitions: Sequence[jnp.ndarray]) -> TrainingState: """Does a step of SGD for the whole ensemble over `transitions`.""" gradients = jax.grad(loss)(state.params, state.target_params, transitions) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1)
def update(self, params, opt_state, batch: util.Transition): """The actual update function.""" (_, logs), grads = jax.value_and_grad( self._loss, has_aux=True)(params, batch) grad_norm_unclipped = optimizers.l2_norm(grads) updates, updated_opt_state = self._opt.update(grads, opt_state) params = optix.apply_updates(params, updates) weight_norm = optimizers.l2_norm(params) logs.update({ 'grad_norm_unclipped': grad_norm_unclipped, 'weight_norm': weight_norm, }) return params, updated_opt_state, logs
def update(net_params, target_params, opt_state, batch): """Update network weights wrt Q-learning loss.""" def dqn_learning_loss(net_params, target_params, batch): obs_tm1, obs_t, a_tm1, r_t, discount_t = batch q_tm1 = network.apply(net_params, obs_tm1) q_t_value = network.apply(target_params, obs_t) q_t_selector = network.apply(net_params, obs_t) td_error = batched_loss(q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector) return jnp.mean(rlax.l2_loss(td_error)) loss, dloss_dtheta = jax.value_and_grad(dqn_learning_loss)( net_params, target_params, batch) updates, opt_state = optimizer.update(dloss_dtheta, opt_state) net_params = optix.apply_updates(net_params, updates) return net_params, opt_state, loss
def update(self, state, data): """Updates the state using some data and returns metrics.""" rng, new_rng = jax.random.split(state['rng']) params = state['params'] loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data) updates, opt_state = self._opt.update(g, state['opt_state']) params = optix.apply_updates(params, updates) new_state = dict( step=state['step'] + 1, rng=new_rng, opt_state=opt_state, params=params, ) metrics = dict(step=state['step'], loss=loss) return new_state, metrics
def sgd_step( state: TrainingState, samples: reverb.ReplaySample ) -> Tuple[TrainingState, LearnerOutputs]: grad_fn = jax.grad(loss, has_aux=True) gradients, (keys, priorities) = grad_fn(state.params, state.target_params, samples) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) new_state = TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1) outputs = LearnerOutputs(keys=keys, priorities=priorities) return new_state, outputs
def optimize( fn_loss: Any, opt: Any, opt_state: Any, params_to_update: hk.Params, max_grad_norm: float or None, *args, **kwargs, ) -> Tuple[Any, hk.Params, jnp.ndarray, Any]: (loss, aux), grad = jax.value_and_grad(fn_loss, has_aux=True)( params_to_update, *args, **kwargs, ) if max_grad_norm is not None: grad = clip_gradient_norm(grad, max_grad_norm) update, opt_state = opt(grad, opt_state) params_to_update = optix.apply_updates(params_to_update, update) return opt_state, params_to_update, loss, aux
def sgd_step(state: TrainingState, sample: reverb.ReplaySample): # Compute gradients and optionally apply clipping. batch_loss = jax.vmap(loss, in_axes=(None, 0)) mean_loss = lambda p, s: jnp.mean(batch_loss(p, s)) grad_fn = jax.value_and_grad(mean_loss) loss_value, gradients = grad_fn(state.params, sample) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) metrics = { 'loss': loss_value, } new_state = TrainingState(params=new_params, opt_state=new_opt_state) return new_state, metrics
def test_sgd(self): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = optimizers.sgd(LR) state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # experimental/optix.py optix_params = self.init_params sgd = optix.sgd(LR, 0.0) state = sgd.init(optix_params) for _ in range(STEPS): updates, state = sgd.update(self.per_step_updates, state) optix_params = optix.apply_updates(optix_params, updates) # Check equivalence. for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)): np.testing.assert_allclose(x, y, rtol=1e-5)
def inner_loop(rln_params, pln_params, x_spt, y_spt, opt_state): for i, (_x, _y) in enumerate(zip(x_spt, y_spt)): (loss, acc), grads = value_and_grad(loss_acc_fn, 1, has_aux=True)(rln_params, pln_params, _x, _y) if i == 0: initial_loss = loss initial_acc = acc updates, opt_state = opt_update_fn(grads, opt_state, pln_params) pln_params = optix.apply_updates(pln_params, updates) return ( pln_params, { "initial_loss": initial_loss, "initial_acc": initial_acc, "final_loss": loss, "final_acc": acc, "opt_state": opt_state, }, )