def rnd_update_step( state: RNDTrainingState, transitions: types.Transition, loss_fn: RNDLoss, optimizer: optax.GradientTransformation ) -> Tuple[RNDTrainingState, Dict[str, jnp.ndarray]]: """Run an update steps on the given transitions. Args: state: The learner state. transitions: Transitions to update on. loss_fn: The loss function. optimizer: The optimizer of the predictor network. Returns: A new state and metrics. """ loss, grads = jax.value_and_grad(loss_fn)(state.params, state.target_params, transitions=transitions) update, optimizer_state = optimizer.update(grads, state.optimizer_state) params = optax.apply_updates(state.params, update) new_state = RNDTrainingState( optimizer_state=optimizer_state, params=params, target_params=state.target_params, steps=state.steps + 1, ) return new_state, {'rnd_loss': loss}
def ail_update_step( state: DiscriminatorTrainingState, data: Tuple[types.Transition, types.Transition], optimizer: optax.GradientTransformation, ail_network: ail_networks.AILNetworks, loss_fn: losses.Loss ) -> Tuple[DiscriminatorTrainingState, losses.Metrics]: """Run an update steps on the given transitions. Args: state: The learner state. data: Demo and rb transitions. optimizer: Discriminator optimizer. ail_network: AIL networks. loss_fn: Discriminator loss to minimize. Returns: A new state and metrics. """ demo_transitions, rb_transitions = data key, discriminator_key, loss_key = jax.random.split(state.key, 3) def compute_loss( discriminator_params: networks_lib.Params) -> losses.LossOutput: discriminator_fn = functools.partial( ail_network.discriminator_network.apply, discriminator_params, state.policy_params, is_training=True, rng=discriminator_key) return loss_fn(discriminator_fn, state.discriminator_state, demo_transitions, rb_transitions, loss_key) loss_grad = jax.grad(compute_loss, has_aux=True) grads, (loss, new_discriminator_state) = loss_grad(state.discriminator_params) update, optimizer_state = optimizer.update( grads, state.optimizer_state, params=state.discriminator_params) discriminator_params = optax.apply_updates(state.discriminator_params, update) new_state = DiscriminatorTrainingState( optimizer_state=optimizer_state, discriminator_params=discriminator_params, discriminator_state=new_discriminator_state, policy_params=state.policy_params, # Not modified. key=key, steps=state.steps + 1, ) return new_state, loss
def train(network_def: nn.Module, optim: optax.GradientTransformation, alpha_optim: optax.GradientTransformation, optimizer_state: jnp.ndarray, alpha_optimizer_state: jnp.ndarray, network_params: flax.core.FrozenDict, target_params: flax.core.FrozenDict, log_alpha: jnp.ndarray, key: jnp.ndarray, states: jnp.ndarray, actions: jnp.ndarray, next_states: jnp.ndarray, rewards: jnp.ndarray, terminals: jnp.ndarray, cumulative_gamma: float, target_entropy: float, reward_scale_factor: float) -> Mapping[str, Any]: """Run the training step. Returns a list of updated values and losses. Args: network_def: The SAC network definition. optim: The SAC optimizer (which also wraps the SAC parameters). alpha_optim: The optimizer for alpha. optimizer_state: The SAC optimizer state. alpha_optimizer_state: The alpha optimizer state. network_params: Parameters for SAC's online network. target_params: The parameters for SAC's target network. log_alpha: Parameters for alpha network. key: An rng key to use for random action selection. states: A batch of states. actions: A batch of actions. next_states: A batch of next states. rewards: A batch of rewards. terminals: A batch of terminals. cumulative_gamma: The discount factor to use. target_entropy: The target entropy for the agent. reward_scale_factor: A factor by which to scale rewards. Returns: A mapping from string keys to values, including updated optimizers and training statistics. """ # Get the models from all the optimizers. frozen_params = network_params # For use in loss_fn without apply gradients batch_size = states.shape[0] actions = jnp.reshape(actions, (batch_size, -1)) # Flatten def loss_fn( params: flax.core.FrozenDict, log_alpha: flax.core.FrozenDict, state: jnp.ndarray, action: jnp.ndarray, reward: jnp.ndarray, next_state: jnp.ndarray, terminal: jnp.ndarray, rng: jnp.ndarray) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: """Calculates the loss for one transition. Args: params: Parameters for the SAC network. log_alpha: SAC's log_alpha parameter. state: A single state vector. action: A single action vector. reward: A reward scalar. next_state: A next state vector. terminal: A terminal scalar. rng: An RNG key to use for sampling actions. Returns: A tuple containing 1) the combined SAC loss and 2) a mapping containing statistics from the loss step. """ rng1, rng2 = jax.random.split(rng, 2) # J_Q(\theta) from equation (5) in paper. q_value_1, q_value_2 = network_def.apply( params, state, action, method=network_def.critic) q_value_1 = jnp.squeeze(q_value_1) q_value_2 = jnp.squeeze(q_value_2) target_outputs = network_def.apply(target_params, next_state, rng1, True) target_q_value_1, target_q_value_2 = target_outputs.critic target_q_value = jnp.squeeze( jnp.minimum(target_q_value_1, target_q_value_2)) alpha_value = jnp.exp(log_alpha) log_prob = target_outputs.actor.log_probability target = reward_scale_factor * reward + cumulative_gamma * ( target_q_value - alpha_value * log_prob) * (1. - terminal) target = jax.lax.stop_gradient(target) critic_loss_1 = losses.mse_loss(q_value_1, target) critic_loss_2 = losses.mse_loss(q_value_2, target) critic_loss = jnp.mean(critic_loss_1 + critic_loss_2) # J_{\pi}(\phi) from equation (9) in paper. mean_action, sampled_action, action_log_prob = network_def.apply( params, state, rng2, method=network_def.actor) # We use frozen_params so that gradients can flow back to the actor without # being used to update the critic. q_value_no_grad_1, q_value_no_grad_2 = network_def.apply( frozen_params, state, sampled_action, method=network_def.critic) no_grad_q_value = jnp.squeeze( jnp.minimum(q_value_no_grad_1, q_value_no_grad_2)) alpha_value = jnp.exp(jax.lax.stop_gradient(log_alpha)) policy_loss = jnp.mean(alpha_value * action_log_prob - no_grad_q_value) # J(\alpha) from equation (18) in paper. entropy_diff = -action_log_prob - target_entropy alpha_loss = jnp.mean(log_alpha * jax.lax.stop_gradient(entropy_diff)) # Giving a smaller weight to the critic empirically gives better results combined_loss = 0.5 * critic_loss + 1.0 * policy_loss + 1.0 * alpha_loss return combined_loss, { 'critic_loss': critic_loss, 'policy_loss': policy_loss, 'alpha_loss': alpha_loss, 'critic_value_1': q_value_1, 'critic_value_2': q_value_2, 'target_value_1': target_q_value_1, 'target_value_2': target_q_value_2, 'mean_action': mean_action } grad_fn = jax.vmap( jax.value_and_grad(loss_fn, argnums=(0, 1), has_aux=True), in_axes=(None, None, 0, 0, 0, 0, 0, 0)) rng = jnp.stack(jax.random.split(key, num=batch_size)) (_, aux_vars), gradients = grad_fn(network_params, log_alpha, states, actions, rewards, next_states, terminals, rng) # This calculates the mean gradient/aux_vars using the individual # gradients/aux_vars from each item in the batch. gradients = jax.tree_map(functools.partial(jnp.mean, axis=0), gradients) aux_vars = jax.tree_map(functools.partial(jnp.mean, axis=0), aux_vars) network_gradient, alpha_gradient = gradients # Apply gradients to all the optimizers. updates, optimizer_state = optim.update(network_gradient, optimizer_state, params=network_params) network_params = optax.apply_updates(network_params, updates) alpha_updates, alpha_optimizer_state = alpha_optim.update( alpha_gradient, alpha_optimizer_state, params=log_alpha) log_alpha = optax.apply_updates(log_alpha, alpha_updates) # Compile everything in a dict. returns = { 'network_params': network_params, 'log_alpha': log_alpha, 'optimizer_state': optimizer_state, 'alpha_optimizer_state': alpha_optimizer_state, 'Losses/Critic': aux_vars['critic_loss'], 'Losses/Actor': aux_vars['policy_loss'], 'Losses/Alpha': aux_vars['alpha_loss'], 'Values/CriticValues1': jnp.mean(aux_vars['critic_value_1']), 'Values/CriticValues2': jnp.mean(aux_vars['critic_value_2']), 'Values/TargetValues1': jnp.mean(aux_vars['target_value_1']), 'Values/TargetValues2': jnp.mean(aux_vars['target_value_2']), 'Values/Alpha': jnp.exp(log_alpha), } for i, a in enumerate(aux_vars['mean_action']): returns.update({f'Values/MeanActions{i}': a}) return returns