Example #1
0
    def test_dict_properties_of_sample_batches(self):
        base_dict = {
            "a": np.array([1, 2, 3]),
            "b": np.array([[0.1, 0.2], [0.3, 0.4]]),
            "c": True,
        }
        batch = SampleBatch(base_dict)
        try:
            SampleBatch(base_dict)
        except AssertionError:
            pass  # expected
        keys_ = list(base_dict.keys())
        values_ = list(base_dict.values())
        items_ = list(base_dict.items())
        assert list(batch.keys()) == keys_
        assert list(batch.values()) == values_
        assert list(batch.items()) == items_

        # Add an item and check, whether it's in the "added" list.
        batch["d"] = np.array(1)
        assert batch.added_keys == {"d"}, batch.added_keys
        # Access two keys and check, whether they are in the
        # "accessed" list.
        print(batch["a"], batch["b"])
        assert batch.accessed_keys == {"a", "b"}, batch.accessed_keys
        # Delete a key and check, whether it's in the "deleted" list.
        del batch["c"]
        assert batch.deleted_keys == {"c"}, batch.deleted_keys
Example #2
0
def _log_action_prob_pytorch(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
    """
    Log the mean of the probability of each actions, over the training batch.
    Also log the probabilities of the last step.
    Works only with the torch framework
    """
    # TODO make it work for other space than Discrete
    # TODO make is work for nested spaces
    # TODO add entropy
    to_log = {}
    if isinstance(policy.action_space, gym.spaces.Discrete):
        # print("train_batch", train_batch)
        # DO not support nested discrete spaces
        assert train_batch["action_dist_inputs"].dim() == 2

        action_dist_inputs_avg = train_batch["action_dist_inputs"].mean(axis=0)
        action_dist_inputs_single = train_batch["action_dist_inputs"][-1, :]

        for action_i in range(policy.action_space.n):
            to_log[f"act_dist_inputs_avg_{action_i}"] = action_dist_inputs_avg[action_i]
            to_log[f"act_dist_inputs_single_{action_i}"] = action_dist_inputs_single[action_i]

        assert train_batch["action_prob"].dim() == 1
        to_log[f"action_prob_avg"] = train_batch["action_prob"].mean(axis=0)
        to_log[f"action_prob_single"] = train_batch["action_prob"][-1]

        if "q_values" in train_batch.keys():
            assert train_batch["q_values"].dim() == 2
            q_values_avg = train_batch["q_values"].mean(axis=0)
            q_values_single = train_batch["q_values"][-1, :]

            for action_i in range(policy.action_space.n):
                to_log[f"q_values_avg_{action_i}"] = q_values_avg[action_i]
                to_log[f"q_values_single_{action_i}"] = q_values_single[action_i]


    else:
        raise NotImplementedError()
    return to_log
Example #3
0
    def train_q(self, batch: SampleBatch) -> TensorType:
        """Trains self.q_model using Q-Reg loss on given batch.

        Args:
            batch: A SampleBatch of episodes to train on

        Returns:
            A list of losses for each training iteration
        """
        losses = []
        obs = torch.tensor(batch[SampleBatch.OBS], device=self.device)
        actions = torch.tensor(batch[SampleBatch.ACTIONS], device=self.device)
        ps = torch.zeros([batch.count], device=self.device)
        returns = torch.zeros([batch.count], device=self.device)
        discounts = torch.zeros([batch.count], device=self.device)

        # Neccessary if policy uses recurrent/attention model
        num_state_inputs = 0
        for k in batch.keys():
            if k.startswith("state_in_"):
                num_state_inputs += 1
        state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]

        # get rewards, old_prob, new_prob
        rewards = batch[SampleBatch.REWARDS]
        old_log_prob = torch.tensor(batch[SampleBatch.ACTION_LOGP])
        new_log_prob = (self.policy.compute_log_likelihoods(
            actions=batch[SampleBatch.ACTIONS],
            obs_batch=batch[SampleBatch.OBS],
            state_batches=[batch[k] for k in state_keys],
            prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
            actions_normalized=False,
        ).detach().cpu())
        prob_ratio = torch.exp(new_log_prob - old_log_prob)

        eps_begin = 0
        for episode in batch.split_by_episode():
            eps_end = eps_begin + episode.count

            # calculate importance ratios and returns

            for t in range(episode.count):
                discounts[eps_begin + t] = self.gamma**t
                if t == 0:
                    pt_prev = 1.0
                else:
                    pt_prev = ps[eps_begin + t - 1]
                ps[eps_begin + t] = pt_prev * prob_ratio[eps_begin + t]

                # O(n^3)
                # ret = 0
                # for t_prime in range(t, episode.count):
                #     gamma = self.gamma ** (t_prime - t)
                #     rho_t_1_t_prime = 1.0
                #     for k in range(t + 1, min(t_prime + 1, episode.count)):
                #         rho_t_1_t_prime = rho_t_1_t_prime * prob_ratio[eps_begin + k]
                #     r = rewards[eps_begin + t_prime]
                #     ret += gamma * rho_t_1_t_prime * r

                # O(n^2)
                ret = 0
                rho = 1
                for t_ in reversed(range(t, episode.count)):
                    ret = rewards[eps_begin + t_] + self.gamma * rho * ret
                    rho = prob_ratio[eps_begin + t_]

                returns[eps_begin + t] = ret

            # Update before next episode
            eps_begin = eps_end

        indices = np.arange(batch.count)
        for _ in range(self.n_iters):
            minibatch_losses = []
            np.random.shuffle(indices)
            for idx in range(0, batch.count, self.batch_size):
                idxs = indices[idx:idx + self.batch_size]
                q_values, _ = self.q_model({"obs": obs[idxs]}, [], None)
                q_acts = torch.gather(q_values, -1,
                                      actions[idxs].unsqueeze(-1)).squeeze(-1)
                loss = discounts[idxs] * ps[idxs] * (returns[idxs] - q_acts)**2
                loss = torch.mean(loss)
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad.clip_grad_norm_(self.q_model.variables(),
                                                   self.clip_grad_norm)
                self.optimizer.step()
                minibatch_losses.append(loss.item())
            iter_loss = sum(minibatch_losses) / len(minibatch_losses)
            losses.append(iter_loss)
            if iter_loss < self.delta:
                break
        return losses
Example #4
0
def build_q_losses_wt_additional_logs(
    policy: Policy, model, _, train_batch: SampleBatch
) -> TensorType:
    """
    Copy of build_q_losses with additional values saved into the policy
    Made only 2 changes, see in comments.
    """
    config = policy.config
    # Q-network evaluation.
    q_t, q_logits_t, q_probs_t = compute_q_values(
        policy,
        policy.q_model,
        train_batch[SampleBatch.CUR_OBS],
        explore=False,
        is_training=True,
    )

    # Addition 1 out of 2
    policy.last_q_t = q_t.clone()

    # Target Q-network evaluation.
    q_tp1, q_logits_tp1, q_probs_tp1 = compute_q_values(
        policy,
        policy.target_q_model,
        train_batch[SampleBatch.NEXT_OBS],
        explore=False,
        is_training=True,
    )

    # Addition 2 out of 2
    policy.last_target_q_t = q_tp1.clone()

    # Q scores for actions which we know were selected in the given state.
    one_hot_selection = F.one_hot(
        train_batch[SampleBatch.ACTIONS], policy.action_space.n
    )
    q_t_selected = torch.sum(
        torch.where(
            q_t > FLOAT_MIN, q_t, torch.tensor(0.0, device=policy.device)
        )
        * one_hot_selection,
        1,
    )
    q_logits_t_selected = torch.sum(
        q_logits_t * torch.unsqueeze(one_hot_selection, -1), 1
    )

    # compute estimate of best possible value starting from state at t + 1
    if config["double_q"]:
        (
            q_tp1_using_online_net,
            q_logits_tp1_using_online_net,
            q_dist_tp1_using_online_net,
        ) = compute_q_values(
            policy,
            policy.q_model,
            train_batch[SampleBatch.NEXT_OBS],
            explore=False,
            is_training=True,
        )
        q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1)
        q_tp1_best_one_hot_selection = F.one_hot(
            q_tp1_best_using_online_net, policy.action_space.n
        )
        q_tp1_best = torch.sum(
            torch.where(
                q_tp1 > FLOAT_MIN,
                q_tp1,
                torch.tensor(0.0, device=policy.device),
            )
            * q_tp1_best_one_hot_selection,
            1,
        )
        q_probs_tp1_best = torch.sum(
            q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1
        )
    else:
        q_tp1_best_one_hot_selection = F.one_hot(
            torch.argmax(q_tp1, 1), policy.action_space.n
        )
        q_tp1_best = torch.sum(
            torch.where(
                q_tp1 > FLOAT_MIN,
                q_tp1,
                torch.tensor(0.0, device=policy.device),
            )
            * q_tp1_best_one_hot_selection,
            1,
        )
        q_probs_tp1_best = torch.sum(
            q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1
        )

    if PRIO_WEIGHTS not in train_batch.keys():
        assert config["prioritized_replay"] is False
        prio_weights = torch.tensor(
            [1.0] * len(train_batch[SampleBatch.REWARDS])
        ).to(policy.device)
    else:
        prio_weights = train_batch[PRIO_WEIGHTS]

    policy.q_loss = QLoss(
        q_t_selected,
        q_logits_t_selected,
        q_tp1_best,
        q_probs_tp1_best,
        prio_weights,
        train_batch[SampleBatch.REWARDS],
        train_batch[SampleBatch.DONES].float(),
        config["gamma"],
        config["n_step"],
        config["num_atoms"],
        config["v_min"],
        config["v_max"],
    )

    return policy.q_loss.loss
Example #5
0
    def train_q(self, batch: SampleBatch) -> TensorType:
        """Trains self.q_model using FQE loss on given batch.

        Args:
            batch: A SampleBatch of episodes to train on

        Returns:
            A list of losses for each training iteration
        """
        losses = []
        for _ in range(self.n_iters):
            minibatch_losses = []
            batch.shuffle()
            for idx in range(0, batch.count, self.batch_size):
                minibatch = batch[idx : idx + self.batch_size]
                obs = torch.tensor(minibatch[SampleBatch.OBS], device=self.device)
                actions = torch.tensor(
                    minibatch[SampleBatch.ACTIONS], device=self.device
                )
                rewards = torch.tensor(
                    minibatch[SampleBatch.REWARDS], device=self.device
                )
                next_obs = torch.tensor(
                    minibatch[SampleBatch.NEXT_OBS], device=self.device
                )
                dones = torch.tensor(minibatch[SampleBatch.DONES], device=self.device)

                # Neccessary if policy uses recurrent/attention model
                num_state_inputs = 0
                for k in batch.keys():
                    if k.startswith("state_in_"):
                        num_state_inputs += 1
                state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]

                # Compute action_probs for next_obs as in FQE
                all_actions = torch.zeros([minibatch.count, self.policy.action_space.n])
                all_actions[:] = torch.arange(self.policy.action_space.n)
                next_action_prob = self.policy.compute_log_likelihoods(
                    actions=all_actions.T,
                    obs_batch=next_obs,
                    state_batches=[minibatch[k] for k in state_keys],
                    prev_action_batch=minibatch[SampleBatch.ACTIONS],
                    prev_reward_batch=minibatch[SampleBatch.REWARDS],
                    actions_normalized=False,
                )
                next_action_prob = (
                    torch.exp(next_action_prob.T).to(self.device).detach()
                )

                q_values, _ = self.q_model({"obs": obs}, [], None)
                q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
                with torch.no_grad():
                    next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None)
                next_v = torch.sum(next_q_values * next_action_prob, axis=-1)
                targets = rewards + ~dones * self.gamma * next_v
                loss = (targets - q_acts) ** 2
                loss = torch.mean(loss)
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad.clip_grad_norm_(
                    self.q_model.variables(), self.clip_grad_norm
                )
                self.optimizer.step()
                minibatch_losses.append(loss.item())
            iter_loss = sum(minibatch_losses) / len(minibatch_losses)
            losses.append(iter_loss)
            if iter_loss < self.delta:
                break
            self.update_target()
        return losses
Example #6
0
def pad_batch_to_sequences_of_same_size(
    batch: SampleBatch,
    max_seq_len: int,
    shuffle: bool = False,
    batch_divisibility_req: int = 1,
    feature_keys: Optional[List[str]] = None,
    _use_trajectory_view_api: bool = False,
):
    """Applies padding to `batch` so it's choppable into same-size sequences.

    Shuffles `batch` (if desired), makes sure divisibility requirement is met,
    then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o
    adding a time dimension (yet).
    Padding depends on episodes found in batch and `max_seq_len`.

    Args:
        batch (SampleBatch): The SampleBatch object. All values in here have
            the shape [B, ...].
        max_seq_len (int): The max. sequence length to use for chopping.
        shuffle (bool): Whether to shuffle batch sequences. Shuffle may
            be done in-place. This only makes sense if you're further
            applying minibatch SGD after getting the outputs.
        batch_divisibility_req (int): The int by which the batch dimension
            must be dividable.
        feature_keys (Optional[List[str]]): An optional list of keys to apply
            sequence-chopping to. If None, use all keys in batch that are not
            "state_in/out_"-type keys.
        _use_trajectory_view_api (bool): Whether we are using the Trajectory
            View API to collect and process samples.
    """
    if _use_trajectory_view_api:
        if batch.time_major is not None:
            batch["seq_lens"] = torch.tensor(batch.seq_lens)
            t = 0 if batch.time_major else 1
            for col in batch.data.keys():
                # Cut time-dim from states.
                if "state_" in col[:6]:
                    batch[col] = batch[col][t]
                # Flatten all other data.
                else:
                    # Cut time-dim at `max_seq_len`.
                    if batch.time_major:
                        batch[col] = batch[col][:batch.max_seq_len]
                    batch[col] = batch[col].reshape((-1, ) +
                                                    batch[col].shape[2:])
        return

    if batch_divisibility_req > 1:
        meets_divisibility_reqs = (
            len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
            # not multiagent
            and max(batch[SampleBatch.AGENT_INDEX]) == 0)
    else:
        meets_divisibility_reqs = True

    # RNN-case.
    if "state_in_0" in batch or "state_out_0" in batch:
        dynamic_max = True
    # Multi-agent case.
    elif not meets_divisibility_reqs:
        max_seq_len = batch_divisibility_req
        dynamic_max = False
    # Simple case: not RNN nor do we need to pad.
    else:
        if shuffle:
            batch.shuffle()
        return

    # RNN or multi-agent case.
    state_keys = []
    feature_keys_ = feature_keys or []
    for k in batch.keys():
        if "state_in_" in k:
            state_keys.append(k)
        elif not feature_keys and "state_out_" not in k and k != "infos":
            feature_keys_.append(k)

    feature_sequences, initial_states, seq_lens = \
        chop_into_sequences(
            batch[SampleBatch.EPS_ID],
            batch[SampleBatch.UNROLL_ID],
            batch[SampleBatch.AGENT_INDEX],
            [batch[k] for k in feature_keys_],
            [batch[k] for k in state_keys],
            max_seq_len,
            dynamic_max=dynamic_max,
            shuffle=shuffle)
    for i, k in enumerate(feature_keys_):
        batch[k] = feature_sequences[i]
    for i, k in enumerate(state_keys):
        batch[k] = initial_states[i]
    batch["seq_lens"] = seq_lens

    if log_once("rnn_ma_feed_dict"):
        logger.info("Padded input for RNN:\n\n{}\n".format(
            summarize({
                "features": feature_sequences,
                "initial_states": initial_states,
                "seq_lens": seq_lens,
                "max_seq_len": max_seq_len,
            })))