Esempio n. 1
0
def get_td_target(
    rng: PRNGSequence,
    state: jnp.ndarray,
    action: jnp.ndarray,
    next_state: jnp.ndarray,
    reward: jnp.ndarray,
    not_done: jnp.ndarray,
    discount: float,
    max_action: float,
    action_dim: int,
    actor_params: FrozenDict,
    critic_target_params: FrozenDict,
    log_alpha_params: FrozenDict,
) -> jnp.ndarray:
    next_action, next_log_p = apply_gaussian_policy_model(
        actor_params, action_dim, max_action, next_state, rng, True, False)

    target_Q1, target_Q2 = apply_double_critic_model(critic_target_params,
                                                     next_state, next_action,
                                                     False)
    target_Q = (jnp.minimum(target_Q1, target_Q2) -
                jnp.exp(apply_constant_model(log_alpha_params, -3.5, False)) *
                next_log_p)
    target_Q = reward + not_done * discount * target_Q

    return target_Q
Esempio n. 2
0
def get_td_target(
    rng: PRNGSequence,
    state: jnp.ndarray,
    action: jnp.ndarray,
    next_state: jnp.ndarray,
    reward: jnp.ndarray,
    not_done: jnp.ndarray,
    discount: float,
    max_action: float,
    action_dim: int,
    actor_target_params: FrozenDict,
    critic_target_params: FrozenDict,
) -> jnp.ndarray:
    mu, log_sig = apply_gaussian_policy_model(
        actor_target_params, action_dim, max_action, next_state, None, False, True
    )
    next_action = mu + jnp.exp(log_sig) * random.normal(rng, mu.shape)
    next_action = max_action * nn.tanh(next_action)

    target_Q1, target_Q2 = apply_double_critic_model(
        critic_target_params, next_state, next_action, False
    )
    target_Q = jnp.minimum(target_Q1, target_Q2)
    target_Q = reward + not_done * discount * target_Q

    return target_Q
Esempio n. 3
0
    def loss_fn(mlo, slo, actor_params):
        # get the distribution of the actor network (current policy)
        mu, log_sig = apply_gaussian_policy_model(
            actor_params, action_dim, max_action, state, None, False, True
        )
        sig = jnp.exp(log_sig)
        # get the distribution of the target network (old policy)
        target_mu, target_log_sig = apply_gaussian_policy_model(
            actor_target_params, action_dim, max_action, state, None, False, True
        )
        target_mu = jax.lax.stop_gradient(target_mu)
        target_log_sig = jax.lax.stop_gradient(target_log_sig)
        target_sig = jnp.exp(target_log_sig)

        # get the log likelihooods of the sampled actions according to the
        # decoupled distributions. described in section 4.2.1 of
        # Relative Entropy Regularized Policy Iteration
        # this ensures that the nonparametric policy won't collapse to give
        # a probability of 1 to the best action, which is a risk when we use
        # the on-policy distribution to calculate the likelihood.
        actor_log_prob = gaussian_likelihood(sampled_actions, target_mu, log_sig)
        actor_log_prob += gaussian_likelihood(sampled_actions, mu, target_log_sig)
        actor_log_prob = actor_log_prob.transpose((0, 1))

        mu_kl = kl_mvg_diag(target_mu, target_sig, mu, target_sig).mean()
        sig_kl = kl_mvg_diag(target_mu, target_sig, target_mu, sig).mean()

        mlo = mu_lagrange_step(mlo, eps_mu - jax.lax.stop_gradient(mu_kl))
        slo = sig_lagrange_step(slo, eps_sig - jax.lax.stop_gradient(sig_kl))

        # maximize the log likelihood, regularized by the divergence between
        # the target policy and the current policy. the goal here is to fit
        # the parametric policy to have the minimum divergence with the nonparametric
        # distribution based on the sampled actions.
        actor_loss = -(actor_log_prob * weights).sum(axis=1).mean()
        actor_loss -= jax.lax.stop_gradient(
            apply_constant_model(mlo.target, 1.0, True)
        ) * (eps_mu - mu_kl)
        actor_loss -= jax.lax.stop_gradient(
            apply_constant_model(slo.target, 100.0, True)
        ) * (eps_sig - sig_kl)
        return actor_loss.mean(), (mlo, slo)
Esempio n. 4
0
 def select_action(self, state: jnp.ndarray) -> jnp.ndarray:
     mu, _ = apply_gaussian_policy_model(
         self.actor_optimizer.target,
         self.action_dim,
         self.max_action,
         state,
         None,
         False,
         False,
     )
     return mu.flatten()
Esempio n. 5
0
File: MPO.py Progetto: d3sm0/jax-rl
 def select_action(self, state: jnp.ndarray) -> jnp.ndarray:
     mu, _ = apply_gaussian_policy_model(
         self.actor_optimizer.target,
         self.state_dim,
         self.max_action,
         state.reshape(1, -1),
         None,
         False,
         True,
     )
     return mu
Esempio n. 6
0
 def sample_action(self, rng: PRNGSequence,
                   state: jnp.ndarray) -> jnp.ndarray:
     mu, log_sig = apply_gaussian_policy_model(
         self.actor_optimizer.target,
         self.action_dim,
         self.max_action,
         state,
         None,
         False,
         False,
     )
     return mu + jax.random.normal(rng, mu.shape) * jnp.exp(log_sig)
Esempio n. 7
0
 def loss_fn(actor_params):
     actor_action, log_p = apply_gaussian_policy_model(
         actor_params, action_dim, max_action, state, rng, True, False)
     q1, q2 = apply_double_critic_model(critic_params, state, actor_action,
                                        False)
     min_q = jnp.minimum(q1, q2)
     partial_loss_fn = jax.vmap(
         partial(
             actor_loss_fn,
             jax.lax.stop_gradient(
                 apply_constant_model(log_alpha_params, -3.5, False)),
         ), )
     actor_loss = partial_loss_fn(log_p, min_q)
     return jnp.mean(actor_loss), log_p
Esempio n. 8
0
def sample_actions_and_evaluate(
    rng: PRNGSequence,
    actor_target_params: FrozenDict,
    critic_target_params: FrozenDict,
    max_action: float,
    action_dim: int,
    state: jnp.ndarray,
    batch_size: int,
    action_sample_size: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    To build our nonparametric policy, q(s, a), we sample `action_sample_size`
    actions from each policy in the batch and evaluate their Q-values.
    """
    # get the policy distribution for each state and sample `action_sample_size`
    # actions from each
    mu, log_sig = apply_gaussian_policy_model(
        actor_target_params, action_dim, max_action, state, None, False, True
    )
    mu = jnp.expand_dims(mu, axis=1)
    sig = jnp.expand_dims(jnp.exp(log_sig), axis=1)
    sampled_actions = (
        mu + random.normal(rng, (batch_size, action_sample_size, action_dim)) * sig
    )
    sampled_actions = sampled_actions.reshape(
        (batch_size * action_sample_size, action_dim)
    )

    sampled_actions = jax.lax.stop_gradient(sampled_actions)

    states_repeated = jnp.repeat(state, action_sample_size, axis=0)

    # evaluate each of the sampled actions at their corresponding state
    # we keep the `sampled_actions` array unnquashed because we need to calcuate
    # the log probabilities using it, but we pass the squashed actions to the critic
    Q1 = apply_double_critic_model(
        critic_target_params,
        states_repeated,
        max_action * nn.tanh(sampled_actions),
        True,
    )
    Q1 = Q1.reshape((batch_size, action_sample_size))

    Q1 = jax.lax.stop_gradient(Q1)

    return Q1, sampled_actions
Esempio n. 9
0
File: MPO.py Progetto: d3sm0/jax-rl
def sample_actions_and_evaluate(
    rng: PRNGSequence,
    actor_target_params: FrozenDict,
    critic_target_params: FrozenDict,
    max_action: float,
    action_dim: int,
    state: jnp.ndarray,
    batch_size: int,
    action_sample_size: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    To build our nonparametric policy, q(s, a), we sample `action_sample_size`
    actions from each policy in the batch and evaluate their Q-values.
    """
    state_dim = state.shape[-1]
    # get the policy distribution for each state and sample `action_sample_size`
    # actions from each
    mu, log_sig = apply_gaussian_policy_model(
        actor_target_params, state_dim, max_action, state, None, False, True
    )
    sig = jnp.exp(log_sig)
    sampled_actions = mu + random.normal(rng, (batch_size, action_sample_size)) * sig
    sampled_actions = max_action * nn.tanh(sampled_actions)
    sampled_actions = sampled_actions.reshape(
        (batch_size * action_sample_size, action_dim)
    )

    sampled_actions = jax.lax.stop_gradient(sampled_actions)

    states_repeated = jnp.repeat(state, action_sample_size, axis=0)

    # evaluate each of the sampled actions at their corresponding state
    Q1 = apply_double_critic_model(
        critic_target_params, states_repeated, sampled_actions, True
    )
    Q1 = Q1.reshape((batch_size, action_sample_size))

    Q1 = jax.lax.stop_gradient(Q1)

    return Q1, sampled_actions