def test_dict_properties_of_sample_batches(self): base_dict = { "a": np.array([1, 2, 3]), "b": np.array([[0.1, 0.2], [0.3, 0.4]]), "c": True, } batch = SampleBatch(base_dict) try: SampleBatch(base_dict) except AssertionError: pass # expected keys_ = list(base_dict.keys()) values_ = list(base_dict.values()) items_ = list(base_dict.items()) assert list(batch.keys()) == keys_ assert list(batch.values()) == values_ assert list(batch.items()) == items_ # Add an item and check, whether it's in the "added" list. batch["d"] = np.array(1) assert batch.added_keys == {"d"}, batch.added_keys # Access two keys and check, whether they are in the # "accessed" list. print(batch["a"], batch["b"]) assert batch.accessed_keys == {"a", "b"}, batch.accessed_keys # Delete a key and check, whether it's in the "deleted" list. del batch["c"] assert batch.deleted_keys == {"c"}, batch.deleted_keys
def _log_action_prob_pytorch(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: """ Log the mean of the probability of each actions, over the training batch. Also log the probabilities of the last step. Works only with the torch framework """ # TODO make it work for other space than Discrete # TODO make is work for nested spaces # TODO add entropy to_log = {} if isinstance(policy.action_space, gym.spaces.Discrete): # print("train_batch", train_batch) # DO not support nested discrete spaces assert train_batch["action_dist_inputs"].dim() == 2 action_dist_inputs_avg = train_batch["action_dist_inputs"].mean(axis=0) action_dist_inputs_single = train_batch["action_dist_inputs"][-1, :] for action_i in range(policy.action_space.n): to_log[f"act_dist_inputs_avg_{action_i}"] = action_dist_inputs_avg[action_i] to_log[f"act_dist_inputs_single_{action_i}"] = action_dist_inputs_single[action_i] assert train_batch["action_prob"].dim() == 1 to_log[f"action_prob_avg"] = train_batch["action_prob"].mean(axis=0) to_log[f"action_prob_single"] = train_batch["action_prob"][-1] if "q_values" in train_batch.keys(): assert train_batch["q_values"].dim() == 2 q_values_avg = train_batch["q_values"].mean(axis=0) q_values_single = train_batch["q_values"][-1, :] for action_i in range(policy.action_space.n): to_log[f"q_values_avg_{action_i}"] = q_values_avg[action_i] to_log[f"q_values_single_{action_i}"] = q_values_single[action_i] else: raise NotImplementedError() return to_log
def train_q(self, batch: SampleBatch) -> TensorType: """Trains self.q_model using Q-Reg loss on given batch. Args: batch: A SampleBatch of episodes to train on Returns: A list of losses for each training iteration """ losses = [] obs = torch.tensor(batch[SampleBatch.OBS], device=self.device) actions = torch.tensor(batch[SampleBatch.ACTIONS], device=self.device) ps = torch.zeros([batch.count], device=self.device) returns = torch.zeros([batch.count], device=self.device) discounts = torch.zeros([batch.count], device=self.device) # Neccessary if policy uses recurrent/attention model num_state_inputs = 0 for k in batch.keys(): if k.startswith("state_in_"): num_state_inputs += 1 state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] # get rewards, old_prob, new_prob rewards = batch[SampleBatch.REWARDS] old_log_prob = torch.tensor(batch[SampleBatch.ACTION_LOGP]) new_log_prob = (self.policy.compute_log_likelihoods( actions=batch[SampleBatch.ACTIONS], obs_batch=batch[SampleBatch.OBS], state_batches=[batch[k] for k in state_keys], prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS), prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS), actions_normalized=False, ).detach().cpu()) prob_ratio = torch.exp(new_log_prob - old_log_prob) eps_begin = 0 for episode in batch.split_by_episode(): eps_end = eps_begin + episode.count # calculate importance ratios and returns for t in range(episode.count): discounts[eps_begin + t] = self.gamma**t if t == 0: pt_prev = 1.0 else: pt_prev = ps[eps_begin + t - 1] ps[eps_begin + t] = pt_prev * prob_ratio[eps_begin + t] # O(n^3) # ret = 0 # for t_prime in range(t, episode.count): # gamma = self.gamma ** (t_prime - t) # rho_t_1_t_prime = 1.0 # for k in range(t + 1, min(t_prime + 1, episode.count)): # rho_t_1_t_prime = rho_t_1_t_prime * prob_ratio[eps_begin + k] # r = rewards[eps_begin + t_prime] # ret += gamma * rho_t_1_t_prime * r # O(n^2) ret = 0 rho = 1 for t_ in reversed(range(t, episode.count)): ret = rewards[eps_begin + t_] + self.gamma * rho * ret rho = prob_ratio[eps_begin + t_] returns[eps_begin + t] = ret # Update before next episode eps_begin = eps_end indices = np.arange(batch.count) for _ in range(self.n_iters): minibatch_losses = [] np.random.shuffle(indices) for idx in range(0, batch.count, self.batch_size): idxs = indices[idx:idx + self.batch_size] q_values, _ = self.q_model({"obs": obs[idxs]}, [], None) q_acts = torch.gather(q_values, -1, actions[idxs].unsqueeze(-1)).squeeze(-1) loss = discounts[idxs] * ps[idxs] * (returns[idxs] - q_acts)**2 loss = torch.mean(loss) self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad.clip_grad_norm_(self.q_model.variables(), self.clip_grad_norm) self.optimizer.step() minibatch_losses.append(loss.item()) iter_loss = sum(minibatch_losses) / len(minibatch_losses) losses.append(iter_loss) if iter_loss < self.delta: break return losses
def build_q_losses_wt_additional_logs( policy: Policy, model, _, train_batch: SampleBatch ) -> TensorType: """ Copy of build_q_losses with additional values saved into the policy Made only 2 changes, see in comments. """ config = policy.config # Q-network evaluation. q_t, q_logits_t, q_probs_t = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.CUR_OBS], explore=False, is_training=True, ) # Addition 1 out of 2 policy.last_q_t = q_t.clone() # Target Q-network evaluation. q_tp1, q_logits_tp1, q_probs_tp1 = compute_q_values( policy, policy.target_q_model, train_batch[SampleBatch.NEXT_OBS], explore=False, is_training=True, ) # Addition 2 out of 2 policy.last_target_q_t = q_tp1.clone() # Q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot( train_batch[SampleBatch.ACTIONS], policy.action_space.n ) q_t_selected = torch.sum( torch.where( q_t > FLOAT_MIN, q_t, torch.tensor(0.0, device=policy.device) ) * one_hot_selection, 1, ) q_logits_t_selected = torch.sum( q_logits_t * torch.unsqueeze(one_hot_selection, -1), 1 ) # compute estimate of best possible value starting from state at t + 1 if config["double_q"]: ( q_tp1_using_online_net, q_logits_tp1_using_online_net, q_dist_tp1_using_online_net, ) = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.NEXT_OBS], explore=False, is_training=True, ) q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1) q_tp1_best_one_hot_selection = F.one_hot( q_tp1_best_using_online_net, policy.action_space.n ) q_tp1_best = torch.sum( torch.where( q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=policy.device), ) * q_tp1_best_one_hot_selection, 1, ) q_probs_tp1_best = torch.sum( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1 ) else: q_tp1_best_one_hot_selection = F.one_hot( torch.argmax(q_tp1, 1), policy.action_space.n ) q_tp1_best = torch.sum( torch.where( q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=policy.device), ) * q_tp1_best_one_hot_selection, 1, ) q_probs_tp1_best = torch.sum( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1 ) if PRIO_WEIGHTS not in train_batch.keys(): assert config["prioritized_replay"] is False prio_weights = torch.tensor( [1.0] * len(train_batch[SampleBatch.REWARDS]) ).to(policy.device) else: prio_weights = train_batch[PRIO_WEIGHTS] policy.q_loss = QLoss( q_t_selected, q_logits_t_selected, q_tp1_best, q_probs_tp1_best, prio_weights, train_batch[SampleBatch.REWARDS], train_batch[SampleBatch.DONES].float(), config["gamma"], config["n_step"], config["num_atoms"], config["v_min"], config["v_max"], ) return policy.q_loss.loss
def train_q(self, batch: SampleBatch) -> TensorType: """Trains self.q_model using FQE loss on given batch. Args: batch: A SampleBatch of episodes to train on Returns: A list of losses for each training iteration """ losses = [] for _ in range(self.n_iters): minibatch_losses = [] batch.shuffle() for idx in range(0, batch.count, self.batch_size): minibatch = batch[idx : idx + self.batch_size] obs = torch.tensor(minibatch[SampleBatch.OBS], device=self.device) actions = torch.tensor( minibatch[SampleBatch.ACTIONS], device=self.device ) rewards = torch.tensor( minibatch[SampleBatch.REWARDS], device=self.device ) next_obs = torch.tensor( minibatch[SampleBatch.NEXT_OBS], device=self.device ) dones = torch.tensor(minibatch[SampleBatch.DONES], device=self.device) # Neccessary if policy uses recurrent/attention model num_state_inputs = 0 for k in batch.keys(): if k.startswith("state_in_"): num_state_inputs += 1 state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] # Compute action_probs for next_obs as in FQE all_actions = torch.zeros([minibatch.count, self.policy.action_space.n]) all_actions[:] = torch.arange(self.policy.action_space.n) next_action_prob = self.policy.compute_log_likelihoods( actions=all_actions.T, obs_batch=next_obs, state_batches=[minibatch[k] for k in state_keys], prev_action_batch=minibatch[SampleBatch.ACTIONS], prev_reward_batch=minibatch[SampleBatch.REWARDS], actions_normalized=False, ) next_action_prob = ( torch.exp(next_action_prob.T).to(self.device).detach() ) q_values, _ = self.q_model({"obs": obs}, [], None) q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze() with torch.no_grad(): next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None) next_v = torch.sum(next_q_values * next_action_prob, axis=-1) targets = rewards + ~dones * self.gamma * next_v loss = (targets - q_acts) ** 2 loss = torch.mean(loss) self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad.clip_grad_norm_( self.q_model.variables(), self.clip_grad_norm ) self.optimizer.step() minibatch_losses.append(loss.item()) iter_loss = sum(minibatch_losses) / len(minibatch_losses) losses.append(iter_loss) if iter_loss < self.delta: break self.update_target() return losses
def pad_batch_to_sequences_of_same_size( batch: SampleBatch, max_seq_len: int, shuffle: bool = False, batch_divisibility_req: int = 1, feature_keys: Optional[List[str]] = None, _use_trajectory_view_api: bool = False, ): """Applies padding to `batch` so it's choppable into same-size sequences. Shuffles `batch` (if desired), makes sure divisibility requirement is met, then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o adding a time dimension (yet). Padding depends on episodes found in batch and `max_seq_len`. Args: batch (SampleBatch): The SampleBatch object. All values in here have the shape [B, ...]. max_seq_len (int): The max. sequence length to use for chopping. shuffle (bool): Whether to shuffle batch sequences. Shuffle may be done in-place. This only makes sense if you're further applying minibatch SGD after getting the outputs. batch_divisibility_req (int): The int by which the batch dimension must be dividable. feature_keys (Optional[List[str]]): An optional list of keys to apply sequence-chopping to. If None, use all keys in batch that are not "state_in/out_"-type keys. _use_trajectory_view_api (bool): Whether we are using the Trajectory View API to collect and process samples. """ if _use_trajectory_view_api: if batch.time_major is not None: batch["seq_lens"] = torch.tensor(batch.seq_lens) t = 0 if batch.time_major else 1 for col in batch.data.keys(): # Cut time-dim from states. if "state_" in col[:6]: batch[col] = batch[col][t] # Flatten all other data. else: # Cut time-dim at `max_seq_len`. if batch.time_major: batch[col] = batch[col][:batch.max_seq_len] batch[col] = batch[col].reshape((-1, ) + batch[col].shape[2:]) return if batch_divisibility_req > 1: meets_divisibility_reqs = ( len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0 # not multiagent and max(batch[SampleBatch.AGENT_INDEX]) == 0) else: meets_divisibility_reqs = True # RNN-case. if "state_in_0" in batch or "state_out_0" in batch: dynamic_max = True # Multi-agent case. elif not meets_divisibility_reqs: max_seq_len = batch_divisibility_req dynamic_max = False # Simple case: not RNN nor do we need to pad. else: if shuffle: batch.shuffle() return # RNN or multi-agent case. state_keys = [] feature_keys_ = feature_keys or [] for k in batch.keys(): if "state_in_" in k: state_keys.append(k) elif not feature_keys and "state_out_" not in k and k != "infos": feature_keys_.append(k) feature_sequences, initial_states, seq_lens = \ chop_into_sequences( batch[SampleBatch.EPS_ID], batch[SampleBatch.UNROLL_ID], batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys_], [batch[k] for k in state_keys], max_seq_len, dynamic_max=dynamic_max, shuffle=shuffle) for i, k in enumerate(feature_keys_): batch[k] = feature_sequences[i] for i, k in enumerate(state_keys): batch[k] = initial_states[i] batch["seq_lens"] = seq_lens if log_once("rnn_ma_feed_dict"): logger.info("Padded input for RNN:\n\n{}\n".format( summarize({ "features": feature_sequences, "initial_states": initial_states, "seq_lens": seq_lens, "max_seq_len": max_seq_len, })))