コード例 #1
0
ファイル: SAC.py プロジェクト: henry-prior/jax-rl
def get_td_target(
    rng: PRNGSequence,
    state: jnp.ndarray,
    action: jnp.ndarray,
    next_state: jnp.ndarray,
    reward: jnp.ndarray,
    not_done: jnp.ndarray,
    discount: float,
    max_action: float,
    action_dim: int,
    actor_params: FrozenDict,
    critic_target_params: FrozenDict,
    log_alpha_params: FrozenDict,
) -> jnp.ndarray:
    next_action, next_log_p = apply_gaussian_policy_model(
        actor_params, action_dim, max_action, next_state, rng, True, False)

    target_Q1, target_Q2 = apply_double_critic_model(critic_target_params,
                                                     next_state, next_action,
                                                     False)
    target_Q = (jnp.minimum(target_Q1, target_Q2) -
                jnp.exp(apply_constant_model(log_alpha_params, -3.5, False)) *
                next_log_p)
    target_Q = reward + not_done * discount * target_Q

    return target_Q
コード例 #2
0
def get_td_target(
    rng: PRNGSequence,
    state: jnp.ndarray,
    action: jnp.ndarray,
    next_state: jnp.ndarray,
    reward: jnp.ndarray,
    not_done: jnp.ndarray,
    discount: float,
    max_action: float,
    action_dim: int,
    actor_target_params: FrozenDict,
    critic_target_params: FrozenDict,
) -> jnp.ndarray:
    mu, log_sig = apply_gaussian_policy_model(
        actor_target_params, action_dim, max_action, next_state, None, False, True
    )
    next_action = mu + jnp.exp(log_sig) * random.normal(rng, mu.shape)
    next_action = max_action * nn.tanh(next_action)

    target_Q1, target_Q2 = apply_double_critic_model(
        critic_target_params, next_state, next_action, False
    )
    target_Q = jnp.minimum(target_Q1, target_Q2)
    target_Q = reward + not_done * discount * target_Q

    return target_Q
コード例 #3
0
ファイル: SAC.py プロジェクト: henry-prior/jax-rl
 def loss_fn(actor_params):
     actor_action, log_p = apply_gaussian_policy_model(
         actor_params, action_dim, max_action, state, rng, True, False)
     q1, q2 = apply_double_critic_model(critic_params, state, actor_action,
                                        False)
     min_q = jnp.minimum(q1, q2)
     partial_loss_fn = jax.vmap(
         partial(
             actor_loss_fn,
             jax.lax.stop_gradient(
                 apply_constant_model(log_alpha_params, -3.5, False)),
         ), )
     actor_loss = partial_loss_fn(log_p, min_q)
     return jnp.mean(actor_loss), log_p
コード例 #4
0
def sample_actions_and_evaluate(
    rng: PRNGSequence,
    actor_target_params: FrozenDict,
    critic_target_params: FrozenDict,
    max_action: float,
    action_dim: int,
    state: jnp.ndarray,
    batch_size: int,
    action_sample_size: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    To build our nonparametric policy, q(s, a), we sample `action_sample_size`
    actions from each policy in the batch and evaluate their Q-values.
    """
    # get the policy distribution for each state and sample `action_sample_size`
    # actions from each
    mu, log_sig = apply_gaussian_policy_model(
        actor_target_params, action_dim, max_action, state, None, False, True
    )
    mu = jnp.expand_dims(mu, axis=1)
    sig = jnp.expand_dims(jnp.exp(log_sig), axis=1)
    sampled_actions = (
        mu + random.normal(rng, (batch_size, action_sample_size, action_dim)) * sig
    )
    sampled_actions = sampled_actions.reshape(
        (batch_size * action_sample_size, action_dim)
    )

    sampled_actions = jax.lax.stop_gradient(sampled_actions)

    states_repeated = jnp.repeat(state, action_sample_size, axis=0)

    # evaluate each of the sampled actions at their corresponding state
    # we keep the `sampled_actions` array unnquashed because we need to calcuate
    # the log probabilities using it, but we pass the squashed actions to the critic
    Q1 = apply_double_critic_model(
        critic_target_params,
        states_repeated,
        max_action * nn.tanh(sampled_actions),
        True,
    )
    Q1 = Q1.reshape((batch_size, action_sample_size))

    Q1 = jax.lax.stop_gradient(Q1)

    return Q1, sampled_actions
コード例 #5
0
ファイル: MPO.py プロジェクト: d3sm0/jax-rl
def sample_actions_and_evaluate(
    rng: PRNGSequence,
    actor_target_params: FrozenDict,
    critic_target_params: FrozenDict,
    max_action: float,
    action_dim: int,
    state: jnp.ndarray,
    batch_size: int,
    action_sample_size: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    To build our nonparametric policy, q(s, a), we sample `action_sample_size`
    actions from each policy in the batch and evaluate their Q-values.
    """
    state_dim = state.shape[-1]
    # get the policy distribution for each state and sample `action_sample_size`
    # actions from each
    mu, log_sig = apply_gaussian_policy_model(
        actor_target_params, state_dim, max_action, state, None, False, True
    )
    sig = jnp.exp(log_sig)
    sampled_actions = mu + random.normal(rng, (batch_size, action_sample_size)) * sig
    sampled_actions = max_action * nn.tanh(sampled_actions)
    sampled_actions = sampled_actions.reshape(
        (batch_size * action_sample_size, action_dim)
    )

    sampled_actions = jax.lax.stop_gradient(sampled_actions)

    states_repeated = jnp.repeat(state, action_sample_size, axis=0)

    # evaluate each of the sampled actions at their corresponding state
    Q1 = apply_double_critic_model(
        critic_target_params, states_repeated, sampled_actions, True
    )
    Q1 = Q1.reshape((batch_size, action_sample_size))

    Q1 = jax.lax.stop_gradient(Q1)

    return Q1, sampled_actions
コード例 #6
0
ファイル: SAC.py プロジェクト: henry-prior/jax-rl
 def loss_fn(critic_params):
     current_Q1, current_Q2 = apply_double_critic_model(
         critic_params, state, action, False)
     critic_loss = double_mse(current_Q1, current_Q2, target_Q)
     return jnp.mean(critic_loss)