Example #1
0
def get_distribution_inputs_and_class(policy,
                                      model,
                                      obs_batch,
                                      *,
                                      explore=True,
                                      is_training=False,
                                      **kwargs):
    q_vals = compute_q_values(policy, model, obs_batch, explore, is_training)
    q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals

    policy.q_values = q_vals
    return policy.q_values, TorchCategorical, []  # state-out
Example #2
0
def r2d2_loss(policy: Policy, model, _,
              train_batch: SampleBatch) -> TensorType:
    """Constructs the loss for R2D2TorchPolicy.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        train_batch (SampleBatch): The training data.

    Returns:
        TensorType: A single loss tensor.
    """
    config = policy.config

    # Construct internal state inputs.
    i = 0
    state_batches = []
    while "state_in_{}".format(i) in train_batch:
        state_batches.append(train_batch["state_in_{}".format(i)])
        i += 1
    assert state_batches

    # Q-network evaluation (at t).
    q, _, _, _ = compute_q_values(policy,
                                  model,
                                  train_batch,
                                  state_batches=state_batches,
                                  seq_lens=train_batch.get("seq_lens"),
                                  explore=False,
                                  is_training=True)

    # Target Q-network evaluation (at t+1).
    q_target, _, _, _ = compute_q_values(policy,
                                         policy.target_q_model,
                                         train_batch,
                                         state_batches=state_batches,
                                         seq_lens=train_batch.get("seq_lens"),
                                         explore=False,
                                         is_training=True)

    actions = train_batch[SampleBatch.ACTIONS].long()
    dones = train_batch[SampleBatch.DONES].float()
    rewards = train_batch[SampleBatch.REWARDS]
    weights = train_batch[PRIO_WEIGHTS]

    B = state_batches[0].shape[0]
    T = q.shape[0] // B

    # Q scores for actions which we know were selected in the given state.
    one_hot_selection = F.one_hot(actions, policy.action_space.n)
    q_selected = torch.sum(
        torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=policy.device))
        * one_hot_selection, 1)

    if config["double_q"]:
        best_actions = torch.argmax(q, dim=1)
    else:
        best_actions = torch.argmax(q_target, dim=1)

    best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n)
    q_target_best = torch.sum(
        torch.where(q_target > FLOAT_MIN, q_target,
                    torch.tensor(0.0, device=policy.device)) *
        best_actions_one_hot,
        dim=1)

    if config["num_atoms"] > 1:
        raise ValueError("Distributional R2D2 not supported yet!")
    else:
        q_target_best_masked_tp1 = (1.0 - dones) * torch.cat(
            [q_target_best[1:],
             torch.tensor([0.0], device=policy.device)])

        if config["use_h_function"]:
            h_inv = h_inverse(q_target_best_masked_tp1,
                              config["h_function_epsilon"])
            target = h_function(
                rewards + config["gamma"]**config["n_step"] * h_inv,
                config["h_function_epsilon"])
        else:
            target = rewards + \
                config["gamma"] ** config["n_step"] * q_target_best_masked_tp1

        # Seq-mask all loss-related terms.
        seq_mask = sequence_mask(train_batch["seq_lens"], T)[:, :-1]
        # Mask away also the burn-in sequence at the beginning.
        burn_in = policy.config["burn_in"]
        if burn_in > 0 and burn_in < T:
            seq_mask[:, :burn_in] = False

        num_valid = torch.sum(seq_mask)

        def reduce_mean_valid(t):
            return torch.sum(t[seq_mask]) / num_valid

        # Make sure use the correct time indices:
        # Q(t) - [gamma * r + Q^(t+1)]
        q_selected = q_selected.reshape([B, T])[:, :-1]
        td_error = q_selected - target.reshape([B, T])[:, :-1].detach()
        td_error = td_error * seq_mask
        weights = weights.reshape([B, T])[:, :-1]
        policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error))
        policy._td_error = td_error.reshape([-1])
        policy._loss_stats = {
            "mean_q": reduce_mean_valid(q_selected),
            "min_q": torch.min(q_selected),
            "max_q": torch.max(q_selected),
            "mean_td_error": reduce_mean_valid(td_error),
        }

    return policy._total_loss
Example #3
0
def build_q_losses_wt_additional_logs(
    policy: Policy, model, _, train_batch: SampleBatch
) -> TensorType:
    """
    Copy of build_q_losses with additional values saved into the policy
    Made only 2 changes, see in comments.
    """
    config = policy.config
    # Q-network evaluation.
    q_t, q_logits_t, q_probs_t = compute_q_values(
        policy,
        policy.q_model,
        train_batch[SampleBatch.CUR_OBS],
        explore=False,
        is_training=True,
    )

    # Addition 1 out of 2
    policy.last_q_t = q_t.clone()

    # Target Q-network evaluation.
    q_tp1, q_logits_tp1, q_probs_tp1 = compute_q_values(
        policy,
        policy.target_q_model,
        train_batch[SampleBatch.NEXT_OBS],
        explore=False,
        is_training=True,
    )

    # Addition 2 out of 2
    policy.last_target_q_t = q_tp1.clone()

    # Q scores for actions which we know were selected in the given state.
    one_hot_selection = F.one_hot(
        train_batch[SampleBatch.ACTIONS], policy.action_space.n
    )
    q_t_selected = torch.sum(
        torch.where(
            q_t > FLOAT_MIN, q_t, torch.tensor(0.0, device=policy.device)
        )
        * one_hot_selection,
        1,
    )
    q_logits_t_selected = torch.sum(
        q_logits_t * torch.unsqueeze(one_hot_selection, -1), 1
    )

    # compute estimate of best possible value starting from state at t + 1
    if config["double_q"]:
        (
            q_tp1_using_online_net,
            q_logits_tp1_using_online_net,
            q_dist_tp1_using_online_net,
        ) = compute_q_values(
            policy,
            policy.q_model,
            train_batch[SampleBatch.NEXT_OBS],
            explore=False,
            is_training=True,
        )
        q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1)
        q_tp1_best_one_hot_selection = F.one_hot(
            q_tp1_best_using_online_net, policy.action_space.n
        )
        q_tp1_best = torch.sum(
            torch.where(
                q_tp1 > FLOAT_MIN,
                q_tp1,
                torch.tensor(0.0, device=policy.device),
            )
            * q_tp1_best_one_hot_selection,
            1,
        )
        q_probs_tp1_best = torch.sum(
            q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1
        )
    else:
        q_tp1_best_one_hot_selection = F.one_hot(
            torch.argmax(q_tp1, 1), policy.action_space.n
        )
        q_tp1_best = torch.sum(
            torch.where(
                q_tp1 > FLOAT_MIN,
                q_tp1,
                torch.tensor(0.0, device=policy.device),
            )
            * q_tp1_best_one_hot_selection,
            1,
        )
        q_probs_tp1_best = torch.sum(
            q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1
        )

    if PRIO_WEIGHTS not in train_batch.keys():
        assert config["prioritized_replay"] is False
        prio_weights = torch.tensor(
            [1.0] * len(train_batch[SampleBatch.REWARDS])
        ).to(policy.device)
    else:
        prio_weights = train_batch[PRIO_WEIGHTS]

    policy.q_loss = QLoss(
        q_t_selected,
        q_logits_t_selected,
        q_tp1_best,
        q_probs_tp1_best,
        prio_weights,
        train_batch[SampleBatch.REWARDS],
        train_batch[SampleBatch.DONES].float(),
        config["gamma"],
        config["n_step"],
        config["num_atoms"],
        config["v_min"],
        config["v_max"],
    )

    return policy.q_loss.loss
Example #4
0
def build_drq_q_losses(policy, model, _, train_batch):
    """ use input augmentation on Q target and Q updates
    """
    config = policy.config
    aug_num = config["aug_num"]

    # target q network evalution
    q_tp1_best_avg = 0
    orig_nxt_obs = train_batch[SampleBatch.NEXT_OBS].clone()
    for _ in range(aug_num):
        # augment obs
        aug_nxt_obs = model.trans(orig_nxt_obs.permute(0, 3, 1,
                                                       2).float()).permute(
                                                           0, 2, 3, 1)

        q_tp1 = compute_q_values(policy,
                                 policy.target_q_model,
                                 aug_nxt_obs,
                                 explore=False,
                                 is_training=True)

        # compute estimate of best possible value starting from state at t + 1
        if config["double_q"]:
            q_tp1_using_online_net = compute_q_values(policy,
                                                      policy.q_model,
                                                      aug_nxt_obs,
                                                      explore=False,
                                                      is_training=True)
            q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net,
                                                       1)
            q_tp1_best_one_hot_selection = F.one_hot(
                q_tp1_best_using_online_net, policy.action_space.n)
            q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
        else:
            q_tp1_best_one_hot_selection = F.one_hot(torch.argmax(q_tp1, 1),
                                                     policy.action_space.n)
            q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)

        # accumulate target Q with augmented next obs
        q_tp1_best_avg += q_tp1_best
    q_tp1_best_avg /= aug_num

    # q network evaluation
    aug_loss = 0
    orig_cur_obs = train_batch[SampleBatch.CUR_OBS].clone()
    for _ in range(aug_num):
        # augment obs
        aug_cur_obs = model.trans(orig_cur_obs.permute(0, 3, 1,
                                                       2).float()).permute(
                                                           0, 2, 3, 1)

        q_t = compute_q_values(policy,
                               policy.q_model,
                               aug_cur_obs,
                               explore=False,
                               is_training=True)

        # q scores for actions which we know were selected in the given state.
        one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS],
                                      policy.action_space.n)
        q_t_selected = torch.sum(q_t * one_hot_selection, 1)

        # Bellman error
        policy.q_loss = QLoss(q_t_selected, q_tp1_best_avg,
                              train_batch[PRIO_WEIGHTS],
                              train_batch[SampleBatch.REWARDS],
                              train_batch[SampleBatch.DONES].float(),
                              config["gamma"], config["n_step"],
                              config["num_atoms"], config["v_min"],
                              config["v_max"])
        # accumulate loss with augmented obs
        aug_loss += policy.q_loss.loss
    return aug_loss / aug_num