Esempio n. 1
0
        def loss(trajectory: buffer.Trajectory, rnn_unroll_state: RNNState):
            """"Computes a linear combination of the policy gradient loss and value loss
      and regularizes it with an entropy term."""
            inputs = pack(trajectory)

            # Dyanmically unroll the network. This Haiku utility function unpacks the
            # list of input tensors such that the i^{th} row from each input tensor
            # is presented to the i^{th} unrolled RNN module.
            (logits, values, _, _,
             state_embeddings), new_rnn_unroll_state = hk.dynamic_unroll(
                 network, inputs, rnn_unroll_state)
            trajectory_len = trajectory.actions.shape[0]

            # Compute the combined loss given the output of the model.
            td_errors = rlax.td_lambda(v_tm1=values[:-1, 0],
                                       r_t=jnp.squeeze(trajectory.rewards, -1),
                                       discount_t=trajectory.discounts *
                                       discount,
                                       v_t=values[1:, 0],
                                       lambda_=jnp.array(td_lambda))
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1, 0],
                a_t=jnp.squeeze(trajectory.actions, 1),
                adv_t=td_errors,
                w_t=jnp.ones(trajectory_len))
            entropy_loss = jnp.mean(
                rlax.entropy_loss(logits[:-1, 0], jnp.ones(trajectory_len)))

            combined_loss = (actor_loss + critic_cost * critic_loss +
                             entropy_cost * entropy_loss)

            return combined_loss, new_rnn_unroll_state
Esempio n. 2
0
  def loss_fn(params: hk.Params,
              sample: reverb.ReplaySample) -> jnp.DeviceArray:
    """Batched, entropy-regularised actor-critic loss with V-trace."""

    # Extract the data.
    data = sample.data
    observations, actions, rewards, discounts, extra = (data.observation,
                                                        data.action,
                                                        data.reward,
                                                        data.discount,
                                                        data.extras)
    initial_state = tree.map_structure(lambda s: s[0], extra['core_state'])
    behaviour_logits = extra['logits']

    # Apply reward clipping.
    rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)

    # Unroll current policy over observations.
    (logits, values), _ = unroll_fn(params, observations, initial_state)

    # Compute importance sampling weights: current policy / behavior policy.
    rhos = rlax.categorical_importance_sampling_ratios(logits[:-1],
                                                       behaviour_logits[:-1],
                                                       actions[:-1])

    # Critic loss.
    vtrace_returns = rlax.vtrace_td_error_and_advantage(
        v_tm1=values[:-1],
        v_t=values[1:],
        r_t=rewards[:-1],
        discount_t=discounts[:-1] * discount,
        rho_tm1=rhos)
    critic_loss = jnp.square(vtrace_returns.errors)

    # Policy gradient loss.
    policy_gradient_loss = rlax.policy_gradient_loss(
        logits_t=logits[:-1],
        a_t=actions[:-1],
        adv_t=vtrace_returns.pg_advantage,
        w_t=jnp.ones_like(rewards[:-1]))

    # Entropy regulariser.
    entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards[:-1]))

    # Combine weighted sum of actor & critic losses, averaged over the sequence.
    mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss +
                         entropy_cost * entropy_loss)  # []

    metrics = {
        'policy_loss': jnp.mean(policy_gradient_loss),
        'critic_loss': jnp.mean(baseline_cost * critic_loss),
        'entropy_loss': jnp.mean(entropy_cost * entropy_loss),
        'entropy': jnp.mean(entropy_loss),
    }

    return mean_loss, metrics
Esempio n. 3
0
        def loss(params: hk.Params,
                 sample: reverb.ReplaySample) -> jnp.ndarray:
            """Entropy-regularised actor-critic loss."""

            # Extract the data.
            observations, actions, rewards, discounts, extra = sample.data
            initial_state = tree.map_structure(lambda s: s[0],
                                               extra['core_state'])
            behaviour_logits = extra['logits']

            #
            actions = actions[:-1]  # [T-1]
            rewards = rewards[:-1]  # [T-1]
            discounts = discounts[:-1]  # [T-1]
            rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)

            # Unroll current policy over observations.
            net = functools.partial(network.apply, params)
            (logits, values), _ = hk.static_unroll(net, observations,
                                                   initial_state)

            # Compute importance sampling weights: current policy / behavior policy.
            rhos = rlax.categorical_importance_sampling_ratios(
                logits[:-1], behaviour_logits[:-1], actions)

            # Critic loss.
            vtrace_returns = rlax.vtrace_td_error_and_advantage(
                v_tm1=values[:-1],
                v_t=values[1:],
                r_t=rewards,
                discount_t=discounts * discount,
                rho_t=rhos)
            critic_loss = jnp.square(vtrace_returns.errors)

            # Policy gradient loss.
            policy_gradient_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=actions,
                adv_t=vtrace_returns.pg_advantage,
                w_t=jnp.ones_like(rewards))

            # Entropy regulariser.
            entropy_loss = rlax.entropy_loss(logits[:-1],
                                             jnp.ones_like(rewards))

            # Combine weighted sum of actor & critic losses.
            mean_loss = jnp.mean(policy_gradient_loss +
                                 baseline_cost * critic_loss +
                                 entropy_cost * entropy_loss)

            return mean_loss
Esempio n. 4
0
        def loss(trajectory: buffer.Trajectory) -> jnp.ndarray:
            """"Actor-critic loss."""
            observations, rewards, actions = pack(trajectory)
            logits, values, _, _, _ = network(observations, rewards, actions)

            td_errors = rlax.td_lambda(v_tm1=values[:-1],
                                       r_t=jnp.squeeze(trajectory.rewards, -1),
                                       discount_t=trajectory.discounts *
                                       discount,
                                       v_t=values[1:],
                                       lambda_=jnp.array(td_lambda))
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=jnp.squeeze(trajectory.actions, 1),
                adv_t=td_errors,
                w_t=jnp.ones_like(td_errors))

            entropy_loss = jnp.mean(
                rlax.entropy_loss(logits[:-1], jnp.ones_like(td_errors)))

            return actor_loss + critic_cost * critic_loss + entropy_cost * entropy_loss
Esempio n. 5
0
    def loss(trajectory: sequence.Trajectory, rnn_unroll_state: LSTMState):
      """"Actor-critic loss."""
      (logits, values), new_rnn_unroll_state = hk.dynamic_unroll(
          network, trajectory.observations[:, None, ...], rnn_unroll_state)
      seq_len = trajectory.actions.shape[0]
      td_errors = rlax.td_lambda(
          v_tm1=values[:-1, 0],
          r_t=trajectory.rewards,
          discount_t=trajectory.discounts * discount,
          v_t=values[1:, 0],
          lambda_=jnp.array(td_lambda),
      )
      critic_loss = jnp.mean(td_errors**2)
      actor_loss = rlax.policy_gradient_loss(
          logits_t=logits[:-1, 0],
          a_t=trajectory.actions,
          adv_t=td_errors,
          w_t=jnp.ones(seq_len))
      entropy_loss = jnp.mean(
          rlax.entropy_loss(logits[:-1, 0], jnp.ones(seq_len)))

      combined_loss = actor_loss + critic_loss + entropy_cost * entropy_loss

      return combined_loss, new_rnn_unroll_state