Exemplo n.º 1
0
        def loss(trajectory: buffer.Trajectory, rnn_unroll_state: RNNState):
            """"Computes a linear combination of the policy gradient loss and value loss
      and regularizes it with an entropy term."""
            inputs = pack(trajectory)

            # Dyanmically unroll the network. This Haiku utility function unpacks the
            # list of input tensors such that the i^{th} row from each input tensor
            # is presented to the i^{th} unrolled RNN module.
            (logits, values, _, _,
             state_embeddings), new_rnn_unroll_state = hk.dynamic_unroll(
                 network, inputs, rnn_unroll_state)
            trajectory_len = trajectory.actions.shape[0]

            # Compute the combined loss given the output of the model.
            td_errors = rlax.td_lambda(v_tm1=values[:-1, 0],
                                       r_t=jnp.squeeze(trajectory.rewards, -1),
                                       discount_t=trajectory.discounts *
                                       discount,
                                       v_t=values[1:, 0],
                                       lambda_=jnp.array(td_lambda))
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1, 0],
                a_t=jnp.squeeze(trajectory.actions, 1),
                adv_t=td_errors,
                w_t=jnp.ones(trajectory_len))
            entropy_loss = jnp.mean(
                rlax.entropy_loss(logits[:-1, 0], jnp.ones(trajectory_len)))

            combined_loss = (actor_loss + critic_cost * critic_loss +
                             entropy_cost * entropy_loss)

            return combined_loss, new_rnn_unroll_state
Exemplo n.º 2
0
        def loss(trajectory: sequence.Trajectory) -> jnp.ndarray:
            """"Actor-critic loss."""
            logits, values = network(trajectory.observations)
            td_errors = rlax.td_lambda(
                v_tm1=values[:-1],
                r_t=trajectory.rewards,
                discount_t=trajectory.discounts * discount,
                v_t=values[1:],
                lambda_=jnp.array(td_lambda),
            )
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=trajectory.actions,
                adv_t=td_errors,
                w_t=jnp.ones_like(td_errors))

            return actor_loss + critic_loss
Exemplo n.º 3
0
        def loss(trajectory: buffer.Trajectory) -> jnp.ndarray:
            """"Actor-critic loss."""
            observations, rewards, actions = pack(trajectory)
            logits, values, _, _, _ = network(observations, rewards, actions)

            td_errors = rlax.td_lambda(v_tm1=values[:-1],
                                       r_t=jnp.squeeze(trajectory.rewards, -1),
                                       discount_t=trajectory.discounts *
                                       discount,
                                       v_t=values[1:],
                                       lambda_=jnp.array(td_lambda))
            critic_loss = jnp.mean(td_errors**2)
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits[:-1],
                a_t=jnp.squeeze(trajectory.actions, 1),
                adv_t=td_errors,
                w_t=jnp.ones_like(td_errors))

            entropy_loss = jnp.mean(
                rlax.entropy_loss(logits[:-1], jnp.ones_like(td_errors)))

            return actor_loss + critic_cost * critic_loss + entropy_cost * entropy_loss
Exemplo n.º 4
0
    def loss(
        weights,
        observations,
        actions,
        rewards,
        td_lambda=0.2,
        discount=0.99,
        policy_cost=0.25,
        entropy_cost=1e-3,
    ):
        """Actor-critic loss."""
        logits, values = network(weights, observations)
        values = jnp.append(values, jnp.sum(rewards))

        # replace -inf values by tiny finite value
        logits = jnp.maximum(logits, MINIMUM_LOGIT)

        td_errors = rlax.td_lambda(
            v_tm1=values[:-1],
            r_t=rewards,
            discount_t=jnp.full_like(rewards, discount),
            v_t=values[1:],
            lambda_=td_lambda,
        )
        critic_loss = jnp.mean(td_errors ** 2)

        if type_ == "a2c":
            actor_loss = rlax.policy_gradient_loss(
                logits_t=logits,
                a_t=actions,
                adv_t=td_errors,
                w_t=jnp.ones(td_errors.shape[0]),
            )
        elif type_ == "supervised":
            actor_loss = jnp.mean(cross_entropy(logits, actions))

        entropy_loss = -jnp.mean(entropy(logits))

        return policy_cost * actor_loss, critic_loss, entropy_cost * entropy_loss
Exemplo n.º 5
0
    def loss(trajectory: sequence.Trajectory, rnn_unroll_state: LSTMState):
      """"Actor-critic loss."""
      (logits, values), new_rnn_unroll_state = hk.dynamic_unroll(
          network, trajectory.observations[:, None, ...], rnn_unroll_state)
      seq_len = trajectory.actions.shape[0]
      td_errors = rlax.td_lambda(
          v_tm1=values[:-1, 0],
          r_t=trajectory.rewards,
          discount_t=trajectory.discounts * discount,
          v_t=values[1:, 0],
          lambda_=jnp.array(td_lambda),
      )
      critic_loss = jnp.mean(td_errors**2)
      actor_loss = rlax.policy_gradient_loss(
          logits_t=logits[:-1, 0],
          a_t=trajectory.actions,
          adv_t=td_errors,
          w_t=jnp.ones(seq_len))
      entropy_loss = jnp.mean(
          rlax.entropy_loss(logits[:-1, 0], jnp.ones(seq_len)))

      combined_loss = actor_loss + critic_loss + entropy_cost * entropy_loss

      return combined_loss, new_rnn_unroll_state