def sgd_step( state: TrainingState, samples: reverb.ReplaySample ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]: """Performs an update step, averaging over pmap replicas.""" # Compute loss and gradients. grad_fn = jax.value_and_grad(loss, has_aux=True) key, key_grad = jax.random.split(state.random_key) (loss_value, priorities), gradients = grad_fn(state.params, state.target_params, key_grad, samples) # Average gradients over pmap replicas before optimizer update. gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) # Apply optimizer updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, self._target_update_period) new_state = TrainingState(params=new_params, target_params=target_params, opt_state=new_opt_state, steps=steps, random_key=key) return new_state, priorities, {'loss': loss_value}
def sgd_step( state: TrainingState, samples: reverb.ReplaySample ) -> Tuple[TrainingState, LearnerOutputs]: grad_fn = jax.grad(loss, has_aux=True) gradients, (keys, priorities) = grad_fn(state.params, state.target_params, samples) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) steps = state.steps + 1 # Periodically update target networks. target_params = rlax.periodic_update(new_params, state.target_params, steps, self._target_update_period) new_state = TrainingState(params=new_params, target_params=target_params, opt_state=new_opt_state, steps=steps) outputs = LearnerOutputs(keys=keys, priorities=priorities) return new_state, outputs
def learner_step(self, params, data, learner_state, unused_key): target_params = rlax.periodic_update(params.online, params.target, learner_state.count, self._target_period) dloss_dtheta = jax.grad(self._loss)(params.online, target_params, *data) updates, opt_state = self._optimizer.update(dloss_dtheta, learner_state.opt_state) online_params = optax.apply_updates(params.online, updates) return (Params(online_params, target_params), LearnerState(learner_state.count + 1, opt_state))
def sgd_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: # TODO(jaslanides): Use a shared forward pass for efficiency. policy_loss_and_grad = jax.value_and_grad(policy_loss) critic_loss_and_grad = jax.value_and_grad(critic_loss) # Compute losses and their gradients. policy_loss_value, policy_gradients = policy_loss_and_grad( state.policy_params, state.critic_params, transitions.next_observation) critic_loss_value, critic_gradients = critic_loss_and_grad( state.critic_params, state, transitions) # Get optimizer updates and state. policy_updates, policy_opt_state = policy_optimizer.update( # pytype: disable=attribute-error policy_gradients, state.policy_opt_state) critic_updates, critic_opt_state = critic_optimizer.update( # pytype: disable=attribute-error critic_gradients, state.critic_opt_state) # Apply optimizer updates to parameters. policy_params = optax.apply_updates(state.policy_params, policy_updates) critic_params = optax.apply_updates(state.critic_params, critic_updates) steps = state.steps + 1 # Periodically update target networks. target_policy_params, target_critic_params = rlax.periodic_update( (policy_params, critic_params), (state.target_policy_params, state.target_critic_params), steps, self._target_update_period) new_state = TrainingState( policy_params=policy_params, critic_params=critic_params, target_policy_params=target_policy_params, target_critic_params=target_critic_params, policy_opt_state=policy_opt_state, critic_opt_state=critic_opt_state, steps=steps, ) metrics = { 'policy_loss': policy_loss_value, 'critic_loss': critic_loss_value, } return new_state, metrics
def sgd_step(state: TrainingState, batch: reverb.ReplaySample, key: jnp.DeviceArray) -> Tuple[TrainingState, LossExtra]: # Implements one SGD step of the loss and updates training state (loss, extra), grads = jax.value_and_grad( self._loss, has_aux=True)(state.params, state.target_params, batch, key) extra.metrics.update({'total_loss': loss}) # Apply the optimizer updates updates, new_opt_state = optimizer.update(grads, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, target_update_period) new_training_state = TrainingState(new_params, target_params, new_opt_state, steps) return new_training_state, extra
def _learner_step(self, all_params, all_states, batch): target_params = rlax.periodic_update( all_params.online, all_params.target, all_states.learner_steps, self._target_update_interval, ) grad, info = jax.grad(self._loss, has_aux=True)(all_params, batch) updates, optimizer_state = self._optimizer.update( grad.online, all_states.optimizer ) online_params = optax.apply_updates(all_params.online, updates) return ( AllParams(online=online_params, target=target_params), AllStates( optimizer=optimizer_state, learner_steps=all_states.learner_steps + 1, actor_steps=all_states.actor_steps, ), info, )