def sgd_step( state: TrainingState, samples: reverb.ReplaySample ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]: """Performs an update step, averaging over pmap replicas.""" # Compute loss and gradients. grad_fn = jax.value_and_grad(loss, has_aux=True) key, key_grad = jax.random.split(state.random_key) (loss_value, priorities), gradients = grad_fn(state.params, state.target_params, key_grad, samples) # Average gradients over pmap replicas before optimizer update. gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) # Apply optimizer updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = optax.periodic_update(new_params, state.target_params, steps, self._target_update_period) new_state = TrainingState( params=new_params, target_params=target_params, opt_state=new_opt_state, steps=steps, random_key=key) return new_state, priorities, {'loss': loss_value}
def learner_step(self, params, data, learner_state, unused_key): target_params = optax.periodic_update(params.online, params.target, learner_state.count, self._target_period) dloss_dtheta = jax.grad(self._loss)(params.online, target_params, *data) updates, opt_state = self._optimizer.update(dloss_dtheta, learner_state.opt_state) online_params = optax.apply_updates(params.online, updates) return (Params(online_params, target_params), LearnerState(learner_state.count + 1, opt_state))
def sgd_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: key, key_policy, key_critic = jax.random.split(state.key, 3) # Compute losses and their gradients. policy_loss_value, policy_gradients = policy_loss_and_grad( state.policy_params, state.critic_params, transitions, key_policy) critic_loss_value, critic_gradients = critic_loss_and_grad( state.critic_params, state.target_policy_params, state.target_critic_params, transitions, key_critic) # Get optimizer updates and state. policy_updates, policy_opt_state = policy_optimizer.update( policy_gradients, state.policy_opt_state) critic_updates, critic_opt_state = critic_optimizer.update( critic_gradients, state.critic_opt_state) # Apply optimizer updates to parameters. policy_params = optax.apply_updates(state.policy_params, policy_updates) critic_params = optax.apply_updates(state.critic_params, critic_updates) steps = state.steps + 1 # Periodically update target networks. target_policy_params, target_critic_params = optax.periodic_update( (policy_params, critic_params), (state.target_policy_params, state.target_critic_params), steps, target_update_period) new_state = TrainingState( policy_params=policy_params, target_policy_params=target_policy_params, critic_params=critic_params, target_critic_params=target_critic_params, policy_opt_state=policy_opt_state, critic_opt_state=critic_opt_state, steps=steps, key=key, ) metrics = { 'policy_loss': policy_loss_value, 'critic_loss': critic_loss_value, } return new_state, metrics
def sgd_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: # TODO(jaslanides): Use a shared forward pass for efficiency. policy_loss_and_grad = jax.value_and_grad(policy_loss) critic_loss_and_grad = jax.value_and_grad(critic_loss) # Compute losses and their gradients. policy_loss_value, policy_gradients = policy_loss_and_grad( state.policy_params, state.critic_params, transitions.next_observation) critic_loss_value, critic_gradients = critic_loss_and_grad( state.critic_params, state, transitions) # Get optimizer updates and state. policy_updates, policy_opt_state = policy_optimizer.update( # pytype: disable=attribute-error policy_gradients, state.policy_opt_state) critic_updates, critic_opt_state = critic_optimizer.update( # pytype: disable=attribute-error critic_gradients, state.critic_opt_state) # Apply optimizer updates to parameters. policy_params = optax.apply_updates(state.policy_params, policy_updates) critic_params = optax.apply_updates(state.critic_params, critic_updates) steps = state.steps + 1 # Periodically update target networks. target_policy_params, target_critic_params = optax.periodic_update( (policy_params, critic_params), (state.target_policy_params, state.target_critic_params), steps, self._target_update_period) new_state = TrainingState( policy_params=policy_params, critic_params=critic_params, target_policy_params=target_policy_params, target_critic_params=target_critic_params, policy_opt_state=policy_opt_state, critic_opt_state=critic_opt_state, steps=steps, ) metrics = { 'policy_loss': policy_loss_value, 'critic_loss': critic_loss_value, } return new_state, metrics
def sgd_step( state: TrainingState, batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]: next_rng_key, rng_key = jax.random.split(state.rng_key) # Implements one SGD step of the loss and updates training state (loss, extra), grads = jax.value_and_grad( self._loss, has_aux=True)(state.params, state.target_params, batch, rng_key) extra.metrics.update({'total_loss': loss}) # Apply the optimizer updates updates, new_opt_state = optimizer.update(grads, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = optax.periodic_update(new_params, state.target_params, steps, target_update_period) new_training_state = TrainingState(new_params, target_params, new_opt_state, steps, next_rng_key) return new_training_state, extra