コード例 #1
0
ファイル: vtrace_torch.py プロジェクト: wallacetroy/ray
def from_logits(behaviour_policy_logits,
                target_policy_logits,
                actions,
                discounts,
                rewards,
                values,
                bootstrap_value,
                dist_class=TorchCategorical,
                model=None,
                clip_rho_threshold=1.0,
                clip_pg_rho_threshold=1.0):
    """multi_from_logits wrapper used only for tests"""

    res = multi_from_logits(
        [behaviour_policy_logits], [target_policy_logits], [actions],
        discounts,
        rewards,
        values,
        bootstrap_value,
        dist_class,
        model,
        clip_rho_threshold=clip_rho_threshold,
        clip_pg_rho_threshold=clip_pg_rho_threshold)

    assert len(res.behaviour_action_log_probs) == 1
    assert len(res.target_action_log_probs) == 1
    return VTraceFromLogitsReturns(
        vs=res.vs,
        pg_advantages=res.pg_advantages,
        log_rhos=res.log_rhos,
        behaviour_action_log_probs=res.behaviour_action_log_probs[0],
        target_action_log_probs=res.target_action_log_probs[0],
    )
コード例 #2
0
ファイル: vtrace_torch.py プロジェクト: wallacetroy/ray
def multi_from_logits(behaviour_policy_logits,
                      target_policy_logits,
                      actions,
                      discounts,
                      rewards,
                      values,
                      bootstrap_value,
                      dist_class,
                      model,
                      behaviour_action_log_probs=None,
                      clip_rho_threshold=1.0,
                      clip_pg_rho_threshold=1.0):
    """V-trace for softmax policies.

    Calculates V-trace actor critic targets for softmax polices as described in

    "IMPALA: Scalable Distributed Deep-RL with
    Importance Weighted Actor-Learner Architectures"
    by Espeholt, Soyer, Munos et al.

    Target policy refers to the policy we are interested in improving and
    behaviour policy refers to the policy that generated the given
    rewards and actions.

    In the notation used throughout documentation and comments, T refers to the
    time dimension ranging from 0 to T-1. B refers to the batch size and
    ACTION_SPACE refers to the list of numbers each representing a number of
    actions.

    Args:
        behaviour_policy_logits: A list with length of ACTION_SPACE of float32
            tensors of shapes [T, B, ACTION_SPACE[0]], ...,
            [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities
            parameterizing the softmax behavior policy.
        target_policy_logits: A list with length of ACTION_SPACE of float32
            tensors of shapes [T, B, ACTION_SPACE[0]], ...,
            [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities
            parameterizing the softmax target policy.
        actions: A list with length of ACTION_SPACE of tensors of shapes
            [T, B, ...], ..., [T, B, ...]
            with actions sampled from the behavior policy.
        discounts: A float32 tensor of shape [T, B] with the discount
            encountered when following the behavior policy.
        rewards: A float32 tensor of shape [T, B] with the rewards generated by
            following the behavior policy.
        values: A float32 tensor of shape [T, B] with the value function
            estimates wrt. the target policy.
        bootstrap_value: A float32 of shape [B] with the value function
            estimate at time T.
        dist_class: action distribution class for the logits.
        model: backing ModelV2 instance
        behaviour_action_log_probs: Precalculated values of the behavior
            actions.
        clip_rho_threshold: A scalar float32 tensor with the clipping threshold
            for importance weights (rho) when calculating the baseline targets
            (vs). rho^bar in the paper.
        clip_pg_rho_threshold: A scalar float32 tensor with the clipping
            threshold on rho_s in:
            \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).

    Returns:
        A `VTraceFromLogitsReturns` namedtuple with the following fields:
        vs: A float32 tensor of shape [T, B]. Can be used as target to train a
            baseline (V(x_t) - vs_t)^2.
        pg_advantages: A float 32 tensor of shape [T, B]. Can be used as an
            estimate of the advantage in the calculation of policy gradients.
        log_rhos: A float32 tensor of shape [T, B] containing the log
            importance sampling weights (log rhos).
        behaviour_action_log_probs: A float32 tensor of shape [T, B] containing
            behaviour policy action log probabilities (log \mu(a_t)).
        target_action_log_probs: A float32 tensor of shape [T, B] containing
            target policy action probabilities (log \pi(a_t)).
    """

    behaviour_policy_logits = convert_to_torch_tensor(
        behaviour_policy_logits, device="cpu")
    target_policy_logits = convert_to_torch_tensor(
        target_policy_logits, device="cpu")
    actions = convert_to_torch_tensor(actions, device="cpu")

    for i in range(len(behaviour_policy_logits)):
        # Make sure tensor ranks are as expected.
        # The rest will be checked by from_action_log_probs.
        assert len(behaviour_policy_logits[i].size()) == 3
        assert len(target_policy_logits[i].size()) == 3

    target_action_log_probs = multi_log_probs_from_logits_and_actions(
        target_policy_logits, actions, dist_class, model)

    if (len(behaviour_policy_logits) > 1
            or behaviour_action_log_probs is None):
        # can't use precalculated values, recompute them. Note that
        # recomputing won't work well for autoregressive action dists
        # which may have variables not captured by 'logits'
        behaviour_action_log_probs = (multi_log_probs_from_logits_and_actions(
            behaviour_policy_logits, actions, dist_class, model))

    behaviour_action_log_probs = force_list(behaviour_action_log_probs)
    log_rhos = get_log_rhos(target_action_log_probs,
                            behaviour_action_log_probs)

    vtrace_returns = from_importance_weights(
        log_rhos=log_rhos,
        discounts=discounts,
        rewards=rewards,
        values=values,
        bootstrap_value=bootstrap_value,
        clip_rho_threshold=clip_rho_threshold,
        clip_pg_rho_threshold=clip_pg_rho_threshold)

    return VTraceFromLogitsReturns(
        log_rhos=log_rhos,
        behaviour_action_log_probs=behaviour_action_log_probs,
        target_action_log_probs=target_action_log_probs,
        **vtrace_returns._asdict())