Пример #1
0
 def _policy_fun_all_timesteps(self, observations, lengths, state, rng):
   return policy_based_utils.run_policy_all_timesteps(
       self._policy_and_value_net_apply,
       observations,
       self._policy_and_value_net_weights,
       state,
       rng,
       self.train_env.action_space,
   )
Пример #2
0
def combined_loss(new_weights,
                  observations,
                  actions,
                  target_values,
                  advantage_weights,
                  policy_and_value_net_apply,
                  action_space,
                  state=None,
                  rng=None):
    """Returns the loss components."""

    # TODO(afrozm): This is where we need to eventually feed the earlier
    #  observations than this observation, currently the replay buffer just gives
    #  the observation as is, for transformer like policies, we should also get
    #  all the earlier observations as well, and then the extra dimension will
    #  just be time. For now we reshape as (batch, 1, *obs_shape).
    observations = jnp.expand_dims(observations, axis=1)

    (log_probab_actions_new, value_predictions_new, state_new,
     unused_rng_new) = (policy_based_utils.run_policy_all_timesteps(
         policy_and_value_net_apply,
         observations,
         new_weights,
         state,
         rng,
         action_space,
     ))

    critic_loss_val, intermediate_state = critic_loss(observations,
                                                      target_values,
                                                      value_predictions_new,
                                                      state=state_new)
    actor_loss_val, final_state = actor_loss(actions,
                                             advantage_weights,
                                             log_probab_actions_new,
                                             state=intermediate_state)
    entropy_val = entropy(log_probab_actions_new)

    return critic_loss_val, actor_loss_val, entropy_val, final_state