def compute_rainbow_q_values(policy, model, obs, explore, is_training=False):
    """ supports normal DQN and distributional DQN now 
    reference: https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/dqn_tf_policy.py
    """
    # # NOTE: LAZY DEV
    # if policy.config["num_atoms"] > 1:
    #     raise ValueError("torch DQN does not support distributional DQN yet!")
    config = policy.config

    # model_out, state = model({
    #     SampleBatch.CUR_OBS: obs,
    #     "is_training": is_training,
    # }, [], None)
    # NOTE: set noise with training flag 
    if is_training:
        model.train()
    else:
        model.eval()
    model_out, state = model.get_embeddings({
        SampleBatch.CUR_OBS: obs,
        "is_training": is_training,
    }, [], None)

    out = model.get_advantages_or_q_values(model_out)
    if config["num_atoms"] > 1:
        (advantages_or_q_values, z, support_logits_per_action, logits, dist) = out
    else:
        (advantages_or_q_values, logits, dist) = out

    if policy.config["dueling"]:
        state_value = model.get_state_value(model_out)
        
        if config["num_atoms"] > 1:
            support_logits_per_action_mean = torch.mean(support_logits_per_action, 1)

            support_logits_per_action_centered = support_logits_per_action - torch.unsqueeze(
                support_logits_per_action_mean, 1)

            support_logits_per_action = torch.unsqueeze(
                state_value, 1) + support_logits_per_action_centered

            support_prob_per_action = F.softmax(support_logits_per_action)

            q_values = torch.sum(z * support_prob_per_action, dim=-1)
            logits = support_logits_per_action
            dist = support_prob_per_action
        else:
            advantages_mean = reduce_mean_ignore_inf(advantages_or_q_values, 1)
            advantages_centered = advantages_or_q_values - torch.unsqueeze(
                advantages_mean, 1)
            q_values = state_value + advantages_centered
    else:
        q_values = advantages_or_q_values

    # return q_values
    return q_values, logits, dist 
Exemple #2
0
def compute_q_values(policy: Policy,
                     model: ModelV2,
                     obs: TensorType,
                     explore,
                     is_training: bool = False):
    config = policy.config

    model_out, state = model(
        {
            SampleBatch.CUR_OBS: obs,
            "is_training": is_training,
        }, [], None)

    if config["num_atoms"] > 1:
        (action_scores, z, support_logits_per_action, logits,
         probs_or_logits) = model.get_q_value_distributions(model_out)
    else:
        (action_scores, logits,
         probs_or_logits) = model.get_q_value_distributions(model_out)

    if config["dueling"]:
        state_score = model.get_state_value(model_out)
        if policy.config["num_atoms"] > 1:
            support_logits_per_action_mean = torch.mean(
                support_logits_per_action, dim=1)
            support_logits_per_action_centered = (
                support_logits_per_action -
                torch.unsqueeze(support_logits_per_action_mean, dim=1))
            support_logits_per_action = torch.unsqueeze(
                state_score, dim=1) + support_logits_per_action_centered
            support_prob_per_action = nn.functional.softmax(
                support_logits_per_action)
            value = torch.sum(z * support_prob_per_action, dim=-1)
            logits = support_logits_per_action
            probs_or_logits = support_prob_per_action
        else:
            advantages_mean = reduce_mean_ignore_inf(action_scores, 1)
            advantages_centered = action_scores - torch.unsqueeze(
                advantages_mean, 1)
            value = state_score + advantages_centered
    else:
        value = action_scores

    return value, logits, probs_or_logits
Exemple #3
0
def compute_q_values(policy, model, obs, explore, is_training=False):
    if policy.config["num_atoms"] > 1:
        raise ValueError("torch DQN does not support distributional DQN yet!")

    model_out, state = model({
        SampleBatch.CUR_OBS: obs,
        "is_training": is_training,
    }, [], None)

    advantages_or_q_values = model.get_advantages_or_q_values(model_out)

    if policy.config["dueling"]:
        state_value = model.get_state_value(model_out)
        advantages_mean = reduce_mean_ignore_inf(advantages_or_q_values, 1)
        advantages_centered = advantages_or_q_values - torch.unsqueeze(
            advantages_mean, 1)
        q_values = state_value + advantages_centered
    else:
        q_values = advantages_or_q_values

    return q_values