Exemplo n.º 1
0
def rnd_update_step(
    state: RNDTrainingState, transitions: types.Transition, loss_fn: RNDLoss,
    optimizer: optax.GradientTransformation
) -> Tuple[RNDTrainingState, Dict[str, jnp.ndarray]]:
    """Run an update steps on the given transitions.

  Args:
    state: The learner state.
    transitions: Transitions to update on.
    loss_fn: The loss function.
    optimizer: The optimizer of the predictor network.

  Returns:
    A new state and metrics.
  """
    loss, grads = jax.value_and_grad(loss_fn)(state.params,
                                              state.target_params,
                                              transitions=transitions)

    update, optimizer_state = optimizer.update(grads, state.optimizer_state)
    params = optax.apply_updates(state.params, update)

    new_state = RNDTrainingState(
        optimizer_state=optimizer_state,
        params=params,
        target_params=state.target_params,
        steps=state.steps + 1,
    )
    return new_state, {'rnd_loss': loss}
Exemplo n.º 2
0
def ail_update_step(
        state: DiscriminatorTrainingState, data: Tuple[types.Transition,
                                                       types.Transition],
        optimizer: optax.GradientTransformation,
        ail_network: ail_networks.AILNetworks, loss_fn: losses.Loss
) -> Tuple[DiscriminatorTrainingState, losses.Metrics]:
    """Run an update steps on the given transitions.

  Args:
    state: The learner state.
    data: Demo and rb transitions.
    optimizer: Discriminator optimizer.
    ail_network: AIL networks.
    loss_fn: Discriminator loss to minimize.

  Returns:
    A new state and metrics.
  """
    demo_transitions, rb_transitions = data
    key, discriminator_key, loss_key = jax.random.split(state.key, 3)

    def compute_loss(
            discriminator_params: networks_lib.Params) -> losses.LossOutput:
        discriminator_fn = functools.partial(
            ail_network.discriminator_network.apply,
            discriminator_params,
            state.policy_params,
            is_training=True,
            rng=discriminator_key)
        return loss_fn(discriminator_fn, state.discriminator_state,
                       demo_transitions, rb_transitions, loss_key)

    loss_grad = jax.grad(compute_loss, has_aux=True)

    grads, (loss,
            new_discriminator_state) = loss_grad(state.discriminator_params)

    update, optimizer_state = optimizer.update(
        grads, state.optimizer_state, params=state.discriminator_params)
    discriminator_params = optax.apply_updates(state.discriminator_params,
                                               update)

    new_state = DiscriminatorTrainingState(
        optimizer_state=optimizer_state,
        discriminator_params=discriminator_params,
        discriminator_state=new_discriminator_state,
        policy_params=state.policy_params,  # Not modified.
        key=key,
        steps=state.steps + 1,
    )
    return new_state, loss
Exemplo n.º 3
0
def train(network_def: nn.Module,
          optim: optax.GradientTransformation,
          alpha_optim: optax.GradientTransformation,
          optimizer_state: jnp.ndarray,
          alpha_optimizer_state: jnp.ndarray,
          network_params: flax.core.FrozenDict,
          target_params: flax.core.FrozenDict,
          log_alpha: jnp.ndarray,
          key: jnp.ndarray,
          states: jnp.ndarray,
          actions: jnp.ndarray,
          next_states: jnp.ndarray,
          rewards: jnp.ndarray,
          terminals: jnp.ndarray,
          cumulative_gamma: float,
          target_entropy: float,
          reward_scale_factor: float) -> Mapping[str, Any]:
  """Run the training step.

  Returns a list of updated values and losses.

  Args:
    network_def: The SAC network definition.
    optim: The SAC optimizer (which also wraps the SAC parameters).
    alpha_optim: The optimizer for alpha.
    optimizer_state: The SAC optimizer state.
    alpha_optimizer_state: The alpha optimizer state.
    network_params: Parameters for SAC's online network.
    target_params: The parameters for SAC's target network.
    log_alpha: Parameters for alpha network.
    key: An rng key to use for random action selection.
    states: A batch of states.
    actions: A batch of actions.
    next_states: A batch of next states.
    rewards: A batch of rewards.
    terminals: A batch of terminals.
    cumulative_gamma: The discount factor to use.
    target_entropy: The target entropy for the agent.
    reward_scale_factor: A factor by which to scale rewards.

  Returns:
    A mapping from string keys to values, including updated optimizers and
      training statistics.
  """
  # Get the models from all the optimizers.
  frozen_params = network_params  # For use in loss_fn without apply gradients

  batch_size = states.shape[0]
  actions = jnp.reshape(actions, (batch_size, -1))  # Flatten

  def loss_fn(
      params: flax.core.FrozenDict, log_alpha: flax.core.FrozenDict,
      state: jnp.ndarray, action: jnp.ndarray, reward: jnp.ndarray,
      next_state: jnp.ndarray, terminal: jnp.ndarray,
      rng: jnp.ndarray) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
    """Calculates the loss for one transition.

    Args:
      params: Parameters for the SAC network.
      log_alpha: SAC's log_alpha parameter.
      state: A single state vector.
      action: A single action vector.
      reward: A reward scalar.
      next_state: A next state vector.
      terminal: A terminal scalar.
      rng: An RNG key to use for sampling actions.

    Returns:
      A tuple containing 1) the combined SAC loss and 2) a mapping containing
        statistics from the loss step.
    """
    rng1, rng2 = jax.random.split(rng, 2)

    # J_Q(\theta) from equation (5) in paper.
    q_value_1, q_value_2 = network_def.apply(
        params, state, action, method=network_def.critic)
    q_value_1 = jnp.squeeze(q_value_1)
    q_value_2 = jnp.squeeze(q_value_2)

    target_outputs = network_def.apply(target_params, next_state, rng1, True)
    target_q_value_1, target_q_value_2 = target_outputs.critic
    target_q_value = jnp.squeeze(
        jnp.minimum(target_q_value_1, target_q_value_2))

    alpha_value = jnp.exp(log_alpha)
    log_prob = target_outputs.actor.log_probability
    target = reward_scale_factor * reward + cumulative_gamma * (
        target_q_value - alpha_value * log_prob) * (1. - terminal)
    target = jax.lax.stop_gradient(target)
    critic_loss_1 = losses.mse_loss(q_value_1, target)
    critic_loss_2 = losses.mse_loss(q_value_2, target)
    critic_loss = jnp.mean(critic_loss_1 + critic_loss_2)

    # J_{\pi}(\phi) from equation (9) in paper.
    mean_action, sampled_action, action_log_prob = network_def.apply(
        params, state, rng2, method=network_def.actor)

    # We use frozen_params so that gradients can flow back to the actor without
    # being used to update the critic.
    q_value_no_grad_1, q_value_no_grad_2 = network_def.apply(
        frozen_params, state, sampled_action, method=network_def.critic)
    no_grad_q_value = jnp.squeeze(
        jnp.minimum(q_value_no_grad_1, q_value_no_grad_2))
    alpha_value = jnp.exp(jax.lax.stop_gradient(log_alpha))
    policy_loss = jnp.mean(alpha_value * action_log_prob - no_grad_q_value)

    # J(\alpha) from equation (18) in paper.
    entropy_diff = -action_log_prob - target_entropy
    alpha_loss = jnp.mean(log_alpha * jax.lax.stop_gradient(entropy_diff))

    # Giving a smaller weight to the critic empirically gives better results
    combined_loss = 0.5 * critic_loss + 1.0 * policy_loss + 1.0 * alpha_loss
    return combined_loss, {
        'critic_loss': critic_loss,
        'policy_loss': policy_loss,
        'alpha_loss': alpha_loss,
        'critic_value_1': q_value_1,
        'critic_value_2': q_value_2,
        'target_value_1': target_q_value_1,
        'target_value_2': target_q_value_2,
        'mean_action': mean_action
    }

  grad_fn = jax.vmap(
      jax.value_and_grad(loss_fn, argnums=(0, 1), has_aux=True),
      in_axes=(None, None, 0, 0, 0, 0, 0, 0))

  rng = jnp.stack(jax.random.split(key, num=batch_size))
  (_, aux_vars), gradients = grad_fn(network_params, log_alpha, states, actions,
                                     rewards, next_states, terminals, rng)

  # This calculates the mean gradient/aux_vars using the individual
  # gradients/aux_vars from each item in the batch.
  gradients = jax.tree_map(functools.partial(jnp.mean, axis=0), gradients)
  aux_vars = jax.tree_map(functools.partial(jnp.mean, axis=0), aux_vars)
  network_gradient, alpha_gradient = gradients

  # Apply gradients to all the optimizers.
  updates, optimizer_state = optim.update(network_gradient, optimizer_state,
                                          params=network_params)
  network_params = optax.apply_updates(network_params, updates)
  alpha_updates, alpha_optimizer_state = alpha_optim.update(
      alpha_gradient, alpha_optimizer_state, params=log_alpha)
  log_alpha = optax.apply_updates(log_alpha, alpha_updates)

  # Compile everything in a dict.
  returns = {
      'network_params': network_params,
      'log_alpha': log_alpha,
      'optimizer_state': optimizer_state,
      'alpha_optimizer_state': alpha_optimizer_state,
      'Losses/Critic': aux_vars['critic_loss'],
      'Losses/Actor': aux_vars['policy_loss'],
      'Losses/Alpha': aux_vars['alpha_loss'],
      'Values/CriticValues1': jnp.mean(aux_vars['critic_value_1']),
      'Values/CriticValues2': jnp.mean(aux_vars['critic_value_2']),
      'Values/TargetValues1': jnp.mean(aux_vars['target_value_1']),
      'Values/TargetValues2': jnp.mean(aux_vars['target_value_2']),
      'Values/Alpha': jnp.exp(log_alpha),
  }
  for i, a in enumerate(aux_vars['mean_action']):
    returns.update({f'Values/MeanActions{i}': a})
  return returns