def _add_log_prob_and_value_to_episodes_recurrent( episodes, model, phi, batch_states, obs_normalizer, device, ): # Sort desc by lengths so that pack_sequence does not change the order episodes = sorted(episodes, key=len, reverse=True) # Prepare data for a recurrent model seqs_states = [] seqs_next_states = [] for ep in episodes: states = batch_states([transition["state"] for transition in ep], device, phi) next_states = batch_states( [transition["next_state"] for transition in ep], device, phi) if obs_normalizer: states = obs_normalizer(states, update=False) next_states = obs_normalizer(next_states, update=False) seqs_states.append(states) seqs_next_states.append(next_states) flat_transitions = flatten_sequences_time_first(episodes) # Predict values using a recurrent model with torch.no_grad(), pfrl.utils.evaluating(model): rs = concatenate_recurrent_states( [ep[0]["recurrent_state"] for ep in episodes]) next_rs = concatenate_recurrent_states( [ep[0]["next_recurrent_state"] for ep in episodes]) assert (rs is None) or (next_rs is None) or (len(rs) == len(next_rs)) (flat_distribs, flat_vs), _ = pack_and_forward(model, seqs_states, rs) (_, flat_next_vs), _ = pack_and_forward(model, seqs_next_states, next_rs) flat_actions = torch.tensor([b["action"] for b in flat_transitions], device=device) flat_log_probs = flat_distribs.log_prob(flat_actions).cpu().numpy() flat_vs = flat_vs.cpu().numpy() flat_next_vs = flat_next_vs.cpu().numpy() # Add predicted values to transitions for transition, log_prob, v, next_v in zip(flat_transitions, flat_log_probs, flat_vs, flat_next_vs): transition["log_prob"] = float(log_prob) transition["v_pred"] = float(v) transition["next_v_pred"] = float(next_v)
def _update_vf_once_recurrent(self, episodes): # Sort episodes desc by length for pack_sequence episodes = sorted(episodes, key=len, reverse=True) flat_transitions = flatten_sequences_time_first(episodes) # Prepare data for a recurrent model seqs_states = [] for ep in episodes: states = self.batch_states( [transition["state"] for transition in ep], self.device, self.phi) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) seqs_states.append(states) flat_vs_teacher = torch.as_tensor( [[transition["v_teacher"]] for transition in flat_transitions], device=self.device, dtype=torch.float, ) with torch.no_grad(): vf_rs = concatenate_recurrent_states( _collect_first_recurrent_states_of_vf(episodes)) flat_vs_pred, _ = pack_and_forward(self.vf, seqs_states, vf_rs) vf_loss = F.mse_loss(flat_vs_pred, flat_vs_teacher) self.vf.zero_grad() vf_loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.vf.parameters(), self.max_grad_norm) self.vf_optimizer.step()
def _test_non_lstm(self, gpu, name): in_size = 2 out_size = 3 device = "cuda:{}".format(gpu) if gpu >= 0 else "cpu" seqs_x = [ torch.rand(4, in_size, device=device), torch.rand(1, in_size, device=device), torch.rand(3, in_size, device=device), ] seqs_x = torch.nn.utils.rnn.pack_sequence(seqs_x, enforce_sorted=False) self.assertTrue(name in ("GRU", "RNN")) cls = getattr(nn, name) link = cls(num_layers=1, input_size=in_size, hidden_size=out_size) link.to(device) # Forward twice: with None and non-None random states y0, h0 = link(seqs_x, None) y1, h1 = link(seqs_x, h0) y0, _ = torch.nn.utils.rnn.pad_packed_sequence(y0, batch_first=True) y1, _ = torch.nn.utils.rnn.pad_packed_sequence(y1, batch_first=True) self.assertEqual(h0.shape, (1, 3, out_size)) self.assertEqual(h1.shape, (1, 3, out_size)) self.assertEqual(y0.shape, (3, 4, out_size)) self.assertEqual(y1.shape, (3, 4, out_size)) # Masked at 0 rs0_mask0 = mask_recurrent_state_at(h0, 0) y1m0, _ = link(seqs_x, rs0_mask0) y1m0, _ = torch.nn.utils.rnn.pad_packed_sequence(y1m0, batch_first=True) torch_assert_allclose(y1m0[0], y0[0]) torch_assert_allclose(y1m0[1], y1[1]) torch_assert_allclose(y1m0[2], y1[2]) # Masked at (1, 2) rs0_mask12 = mask_recurrent_state_at(h0, (1, 2)) y1m12, _ = link(seqs_x, rs0_mask12) y1m12, _ = torch.nn.utils.rnn.pad_packed_sequence(y1m12, batch_first=True) torch_assert_allclose(y1m12[0], y1[0]) torch_assert_allclose(y1m12[1], y0[1]) torch_assert_allclose(y1m12[2], y0[2]) # Get at 1 and concat with None rs0_get1 = get_recurrent_state_at(h0, 1, detach=False) assert rs0_get1.requires_grad torch_assert_allclose(rs0_get1, h0[:, 1]) concat_rs_get1 = concatenate_recurrent_states([None, rs0_get1, None]) y1g1, _ = link(seqs_x, concat_rs_get1) y1g1, _ = torch.nn.utils.rnn.pad_packed_sequence(y1g1, batch_first=True) torch_assert_allclose(y1g1[0], y0[0]) torch_assert_allclose(y1g1[1], y1[1]) torch_assert_allclose(y1g1[2], y0[2]) # Get at 1 with detach=True rs0_get1_detach = get_recurrent_state_at(h0, 1, detach=True) assert not rs0_get1_detach.requires_grad torch_assert_allclose(rs0_get1_detach, h0[:, 1])
def _line_search(self, full_step, dataset, advs, action_distrib_old, gain): """Do line search for a safe step size.""" policy_params = list(self.policy.parameters()) policy_params_sizes = [param.numel() for param in policy_params] policy_params_shapes = [param.shape for param in policy_params] step_size = 1.0 flat_params = _flatten_and_concat_variables(policy_params).detach() if self.recurrent: seqs_states = [] for ep in dataset: states = self.batch_states( [transition["state"] for transition in ep], self.device, self.phi ) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) seqs_states.append(states) with torch.no_grad(), pfrl.utils.evaluating(self.model): policy_rs = concatenate_recurrent_states( _collect_first_recurrent_states_of_policy(dataset) ) def evaluate_current_policy(): distrib, _ = pack_and_forward(self.policy, seqs_states, policy_rs) return distrib else: states = self.batch_states( [transition["state"] for transition in dataset], self.device, self.phi ) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) def evaluate_current_policy(): return self.policy(states) flat_transitions = ( flatten_sequences_time_first(dataset) if self.recurrent else dataset ) actions = torch.tensor( [transition["action"] for transition in flat_transitions], device=self.device, ) log_prob_old = torch.tensor( [transition["log_prob"] for transition in flat_transitions], device=self.device, dtype=torch.float, ) for i in range(self.line_search_max_backtrack + 1): self.logger.info("Line search iteration: %s step size: %s", i, step_size) new_flat_params = flat_params + step_size * full_step new_params = _split_and_reshape_to_ndarrays( new_flat_params, sizes=policy_params_sizes, shapes=policy_params_shapes, ) _replace_params_data(policy_params, new_params) with torch.no_grad(), pfrl.utils.evaluating(self.policy): new_action_distrib = evaluate_current_policy() new_gain = self._compute_gain( log_prob=new_action_distrib.log_prob(actions), log_prob_old=log_prob_old, entropy=new_action_distrib.entropy(), advs=advs, ) new_kl = torch.mean( torch.distributions.kl_divergence( action_distrib_old, new_action_distrib ) ) improve = float(new_gain) - float(gain) self.logger.info("Surrogate objective improve: %s", improve) self.logger.info("KL divergence: %s", float(new_kl)) if not torch.isfinite(new_gain): self.logger.info("Surrogate objective is not finite. Bakctracking...") elif not torch.isfinite(new_kl): self.logger.info("KL divergence is not finite. Bakctracking...") elif improve < 0: self.logger.info("Surrogate objective didn't improve. Bakctracking...") elif float(new_kl) > self.max_kl: self.logger.info("KL divergence exceeds max_kl. Bakctracking...") else: self.kl_record.append(float(new_kl)) self.policy_step_size_record.append(step_size) break step_size *= 0.5 else: self.logger.info( "Line search coundn't find a good step size. The policy was not" " updated." ) self.policy_step_size_record.append(0.0) _replace_params_data( policy_params, _split_and_reshape_to_ndarrays( flat_params, sizes=policy_params_sizes, shapes=policy_params_shapes ), )
def _update_policy_recurrent(self, dataset): """Update the policy using a given dataset. The policy is updated via CG and line search. """ # Sort episodes desc by length for pack_sequence dataset = sorted(dataset, key=len, reverse=True) flat_transitions = flatten_sequences_time_first(dataset) # Prepare data for a recurrent model seqs_states = [] for ep in dataset: states = self.batch_states( [transition["state"] for transition in ep], self.device, self.phi, ) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) seqs_states.append(states) flat_actions = torch.as_tensor( [transition["action"] for transition in flat_transitions], device=self.device, ) flat_advs = torch.as_tensor( [transition["adv"] for transition in flat_transitions], device=self.device, dtype=torch.float, ) if self.standardize_advantages: std_advs, mean_advs = torch.std_mean(flat_advs, unbiased=False) flat_advs = (flat_advs - mean_advs) / (std_advs + 1e-8) with torch.no_grad(): policy_rs = concatenate_recurrent_states( _collect_first_recurrent_states_of_policy(dataset) ) flat_distribs, _ = pack_and_forward(self.policy, seqs_states, policy_rs) log_prob_old = torch.tensor( [transition["log_prob"] for transition in flat_transitions], device=self.device, dtype=torch.float, ) gain = self._compute_gain( log_prob=flat_distribs.log_prob(flat_actions), log_prob_old=log_prob_old, entropy=flat_distribs.entropy(), advs=flat_advs, ) # Distribution to compute KL div against with torch.no_grad(): # torch.distributions.Distribution cannot be deepcopied action_distrib_old, _ = pack_and_forward( self.policy, seqs_states, policy_rs ) full_step = self._compute_kl_constrained_step( action_distrib=flat_distribs, action_distrib_old=action_distrib_old, gain=gain, ) self._line_search( full_step=full_step, dataset=dataset, advs=flat_advs, action_distrib_old=action_distrib_old, gain=gain, )
def batch_recurrent_experiences(experiences, device, phi, gamma, batch_states=batch_states): """Batch experiences for recurrent model updates. Args: experiences: list of episodes. Each episode is a list containing between 1 and n dicts, each containing: - state (object): State - action (object): Action - reward (float): Reward - is_state_terminal (bool): True iff next state is terminal - next_state (object): Next state The list must be sorted desc by lengths to be packed by `torch.nn.rnn.pack_sequence` later. device : GPU or CPU the tensor should be placed on phi : Preprocessing function gamma: discount factor batch_states: function that converts a list to a batch Returns: dict of batched transitions """ assert _is_sorted_desc_by_lengths(experiences) flat_transitions = flatten_sequences_time_first(experiences) batch_exp = { "state": [ batch_states([transition["state"] for transition in ep], device, phi) for ep in experiences ], "action": torch.as_tensor( [transition["action"] for transition in flat_transitions], device=device), "reward": torch.as_tensor( [transition["reward"] for transition in flat_transitions], dtype=torch.float, device=device, ), "next_state": [ batch_states([transition["next_state"] for transition in ep], device, phi) for ep in experiences ], "is_state_terminal": torch.as_tensor( [ transition["is_state_terminal"] for transition in flat_transitions ], dtype=torch.float, device=device, ), "discount": torch.full((len(flat_transitions), ), gamma, dtype=torch.float, device=device), "recurrent_state": recurrent_state_from_numpy( concatenate_recurrent_states( [ep[0]["recurrent_state"] for ep in experiences]), device, ), "next_recurrent_state": recurrent_state_from_numpy( concatenate_recurrent_states( [ep[0]["next_recurrent_state"] for ep in experiences]), device, ), } # Batch next actions only when all the transitions have them if all(transition["next_action"] is not None for transition in flat_transitions): batch_exp["next_action"] = torch.as_tensor( [transition["next_action"] for transition in flat_transitions], device=device, ) return batch_exp
def _update_once_recurrent(self, episodes, mean_advs, std_advs): assert std_advs is None or std_advs > 0 device = self.device # Sort desc by lengths so that pack_sequence does not change the order episodes = sorted(episodes, key=len, reverse=True) flat_transitions = flatten_sequences_time_first(episodes) # Prepare data for a recurrent model seqs_states = [] for ep in episodes: states = self.batch_states( [transition["state"] for transition in ep], self.device, self.phi, ) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) seqs_states.append(states) flat_actions = torch.tensor( [transition["action"] for transition in flat_transitions], device=device, ) flat_advs = torch.tensor( [transition["adv"] for transition in flat_transitions], dtype=torch.float, device=device, ) if self.standardize_advantages: flat_advs = (flat_advs - mean_advs) / (std_advs + 1e-8) flat_log_probs_old = torch.tensor( [transition["log_prob"] for transition in flat_transitions], dtype=torch.float, device=device, ) flat_vs_pred_old = torch.tensor( [[transition["v_pred"]] for transition in flat_transitions], dtype=torch.float, device=device, ) flat_vs_teacher = torch.tensor( [[transition["v_teacher"]] for transition in flat_transitions], dtype=torch.float, device=device, ) with torch.no_grad(), pfrl.utils.evaluating(self.model): rs = concatenate_recurrent_states( [ep[0]["recurrent_state"] for ep in episodes]) (flat_distribs, flat_vs_pred), _ = pack_and_forward(self.model, seqs_states, rs) flat_log_probs = flat_distribs.log_prob(flat_actions) flat_entropy = flat_distribs.entropy() self.model.zero_grad() loss = self._lossfun( entropy=flat_entropy, vs_pred=flat_vs_pred, log_probs=flat_log_probs, vs_pred_old=flat_vs_pred_old, log_probs_old=flat_log_probs_old, advs=flat_advs, vs_teacher=flat_vs_teacher, ) loss.backward() if self.max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.n_updates += 1
def get_and_concat_rs_forward(): _, rs = one_step_forward(par, x_t0, None) rs0 = get_recurrent_state_at(rs, 0, detach=True) rs1 = get_recurrent_state_at(rs, 1, detach=True) concat_rs = concatenate_recurrent_states([rs0, rs1]) return one_step_forward(par, x_t1, concat_rs)