def _add_log_prob_and_value_to_episodes( episodes, model, phi, batch_states, obs_normalizer, device, ): dataset = list(itertools.chain.from_iterable(episodes)) # Compute v_pred and next_v_pred states = batch_states([b["state"] for b in dataset], device, phi) next_states = batch_states([b["next_state"] for b in dataset], device, phi) if obs_normalizer: states = obs_normalizer(states, update=False) next_states = obs_normalizer(next_states, update=False) with torch.no_grad(), pfrl.utils.evaluating(model): distribs, vs_pred = model(states) _, next_vs_pred = model(next_states) actions = torch.tensor([b["action"] for b in dataset], device=device) log_probs = distribs.log_prob(actions).cpu().numpy() vs_pred = vs_pred.cpu().numpy().ravel() next_vs_pred = next_vs_pred.cpu().numpy().ravel() for transition, log_prob, v_pred, next_v_pred in zip( dataset, log_probs, vs_pred, next_vs_pred): transition["log_prob"] = log_prob transition["v_pred"] = v_pred transition["next_v_pred"] = next_v_pred
def batch_experiences(experiences, device, phi, gamma, batch_states=batch_states): """Takes a batch of k experiences each of which contains j consecutive transitions and vectorizes them, where j is between 1 and n. Args: experiences: list of experiences. Each experience is a list containing between 1 and n dicts containing - state (object): State - action (object): Action - reward (float): Reward - is_state_terminal (bool): True iff next state is terminal - next_state (object): Next state device : GPU or CPU the tensor should be placed on phi : Preprocessing function gamma: discount factor batch_states: function that converts a list to a batch Returns: dict of batched transitions """ batch_exp = { "state": batch_states([elem[0]["state"] for elem in experiences], device, phi), "action": torch.as_tensor( [elem[0]["action"] for elem in experiences], device=device ), "reward": torch.as_tensor( [ sum((gamma ** i) * exp[i]["reward"] for i in range(len(exp))) for exp in experiences ], dtype=torch.float32, device=device, ), "next_state": batch_states( [elem[-1]["next_state"] for elem in experiences], device, phi ), "is_state_terminal": torch.as_tensor( [ any(transition["is_state_terminal"] for transition in exp) for exp in experiences ], dtype=torch.float32, device=device, ), "discount": torch.as_tensor( [(gamma ** len(elem)) for elem in experiences], dtype=torch.float32, device=device, ), } if all(elem[-1]["next_action"] is not None for elem in experiences): batch_exp["next_action"] = torch.as_tensor( [elem[-1]["next_action"] for elem in experiences], device=device ) return batch_exp
def _add_log_prob_and_value_to_episodes_recurrent( episodes, model, phi, batch_states, obs_normalizer, device, ): # Sort desc by lengths so that pack_sequence does not change the order episodes = sorted(episodes, key=len, reverse=True) # Prepare data for a recurrent model seqs_states = [] seqs_next_states = [] for ep in episodes: states = batch_states([transition["state"] for transition in ep], device, phi) next_states = batch_states( [transition["next_state"] for transition in ep], device, phi) if obs_normalizer: states = obs_normalizer(states, update=False) next_states = obs_normalizer(next_states, update=False) seqs_states.append(states) seqs_next_states.append(next_states) flat_transitions = flatten_sequences_time_first(episodes) # Predict values using a recurrent model with torch.no_grad(), pfrl.utils.evaluating(model): rs = concatenate_recurrent_states( [ep[0]["recurrent_state"] for ep in episodes]) next_rs = concatenate_recurrent_states( [ep[0]["next_recurrent_state"] for ep in episodes]) assert (rs is None) or (next_rs is None) or (len(rs) == len(next_rs)) (flat_distribs, flat_vs), _ = pack_and_forward(model, seqs_states, rs) (_, flat_next_vs), _ = pack_and_forward(model, seqs_next_states, next_rs) flat_actions = torch.tensor([b["action"] for b in flat_transitions], device=device) flat_log_probs = flat_distribs.log_prob(flat_actions).cpu().numpy() flat_vs = flat_vs.cpu().numpy() flat_next_vs = flat_next_vs.cpu().numpy() # Add predicted values to transitions for transition, log_prob, v, next_v in zip(flat_transitions, flat_log_probs, flat_vs, flat_next_vs): transition["log_prob"] = float(log_prob) transition["v_pred"] = float(v) transition["next_v_pred"] = float(next_v)
def _update_vf(self, dataset): """Update the value function using a given dataset. The value function is updated via SGD to minimize TD(lambda) errors. """ assert "state" in dataset[0] assert "v_teacher" in dataset[0] for batch in _yield_minibatches( dataset, minibatch_size=self.vf_batch_size, num_epochs=self.vf_epochs ): states = batch_states([b["state"] for b in batch], self.device, self.phi) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) vs_teacher = torch.as_tensor( [b["v_teacher"] for b in batch], device=self.device, dtype=torch.float, ) vs_pred = self.vf(states) vf_loss = F.mse_loss(vs_pred, vs_teacher[..., None]) self.vf.zero_grad() vf_loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.vf.parameters(), self.max_grad_norm) self.vf_optimizer.step()
def _update_policy(self, dataset): """Update the policy using a given dataset. The policy is updated via CG and line search. """ assert "state" in dataset[0] assert "action" in dataset[0] assert "adv" in dataset[0] # Use full-batch states = batch_states([b["state"] for b in dataset], self.device, self.phi) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) actions = torch.as_tensor([b["action"] for b in dataset], device=self.device) advs = torch.as_tensor([b["adv"] for b in dataset], device=self.device, dtype=torch.float) if self.standardize_advantages: std_advs, mean_advs = torch.std_mean(advs, unbiased=False) advs = (advs - mean_advs) / (std_advs + 1e-8) # Recompute action distributions for batch backprop action_distrib = self.policy(states) log_prob_old = torch.as_tensor( [transition["log_prob"] for transition in dataset], device=self.device, dtype=torch.float, ) gain = self._compute_gain( log_prob=action_distrib.log_prob(actions), log_prob_old=log_prob_old, entropy=action_distrib.entropy(), advs=advs, ) # Distribution to compute KL div against with torch.no_grad(): # torch.distributions.Distribution cannot be deepcopied action_distrib_old = self.policy(states) full_step = self._compute_kl_constrained_step( action_distrib=action_distrib, action_distrib_old=action_distrib_old, gain=gain, ) self._line_search( full_step=full_step, dataset=dataset, advs=advs, action_distrib_old=action_distrib_old, gain=gain, )
def _act_eval(self, obs): # Use the process-local model for acting with torch.no_grad(): statevar = batch_states([obs], self.device, self.phi) if self.recurrent: (action_distrib, _, _), self.test_recurrent_states = one_step_forward( self.model, statevar, self.test_recurrent_states) else: action_distrib, _, _ = self.model(statevar) if self.act_deterministically: return mode_of_distribution(action_distrib).numpy()[0] else: return action_distrib.sample().numpy()[0]
def batch_trajectory(trajectory, device, phi, batch_states=batch_states): batch_tr = { 'state': batch_states([elem['state'] for elem in trajectory], device, phi), 'action': np.asarray([elem['action'] for elem in trajectory], dtype=np.int32), 'reward': np.asarray([elem['reward'] for elem in trajectory], dtype=np.float32), 'is_state_terminal': np.asarray([elem['is_state_terminal'] for elem in trajectory], dtype=np.float32), # 'feature': [elem['feature'].cpu().detach().clone().numpy() for elem in trajectory] 'feature': np.asarray([elem['feature'] for elem in trajectory], dtype=np.float32), } return batch_tr
def _act_train(self, obs): statevar = batch_states([obs], self.device, self.phi) if self.recurrent: ( (action_distrib, action_value, v), self.train_recurrent_states, ) = one_step_forward(self.model, statevar, self.train_recurrent_states) else: action_distrib, action_value, v = self.model(statevar) self.past_action_values[self.t] = action_value action = action_distrib.sample()[0] # Save values for a later update self.past_values[self.t] = v self.past_action_distrib[self.t] = action_distrib with torch.no_grad(): if self.recurrent: ( (avg_action_distrib, _, _), self.shared_recurrent_states, ) = one_step_forward( self.shared_average_model, statevar, self.shared_recurrent_states, ) else: avg_action_distrib, _, _ = self.shared_average_model(statevar) self.past_avg_action_distrib[self.t] = avg_action_distrib self.past_actions[self.t] = action # Update stats self.average_value += (1 - self.average_value_decay) * ( float(v) - self.average_value ) self.average_entropy += (1 - self.average_entropy_decay) * ( float(action_distrib.entropy()) - self.average_entropy ) self.last_state = obs self.last_action = action.numpy() self.last_action_distrib = deepcopy_distribution(action_distrib) return self.last_action
def _observe_train(self, state, reward, done, reset): assert self.last_state is not None assert self.last_action is not None # Add a transition to the replay buffer if self.replay_buffer is not None: self.replay_buffer.append( state=self.last_state, action=self.last_action, reward=reward, next_state=state, is_state_terminal=done, mu=self.last_action_distrib, ) if done or reset: self.replay_buffer.stop_current_episode() self.t += 1 self.past_rewards[self.t - 1] = reward if self.process_idx == 0: self.logger.debug( "t:%s r:%s a:%s", self.t, reward, self.last_action, ) if self.t - self.t_start == self.t_max or done or reset: if done: statevar = None else: statevar = batch_states([state], self.device, self.phi) self.update_on_policy(statevar) for _ in range(self.n_times_replay): self.update_from_replay() if done or reset: self.train_recurrent_states = None self.shared_recurrent_states = None self.last_state = None self.last_action = None self.last_action_distrib = None
def _update_obs_normalizer(self, dataset): assert self.obs_normalizer states = batch_states([b["state"] for b in dataset], self.device, self.phi) self.obs_normalizer.experience(states)
def update_from_replay(self): if self.replay_buffer is None: return if len(self.replay_buffer) < self.replay_start_size: return episode = self.replay_buffer.sample_episodes(1, self.t_max)[0] model_recurrent_state = None shared_recurrent_state = None rewards = {} actions = {} action_distribs = {} action_distribs_mu = {} avg_action_distribs = {} action_values = {} values = {} for t, transition in enumerate(episode): bs = batch_states([transition["state"]], self.device, self.phi) if self.recurrent: ( (action_distrib, action_value, v), model_recurrent_state, ) = one_step_forward(self.model, bs, model_recurrent_state) else: action_distrib, action_value, v = self.model(bs) with torch.no_grad(): if self.recurrent: ( (avg_action_distrib, _, _), shared_recurrent_state, ) = one_step_forward( self.shared_average_model, bs, shared_recurrent_state, ) else: avg_action_distrib, _, _ = self.shared_average_model(bs) actions[t] = transition["action"] values[t] = v action_distribs[t] = action_distrib avg_action_distribs[t] = avg_action_distrib rewards[t] = transition["reward"] action_distribs_mu[t] = transition["mu"] action_values[t] = action_value last_transition = episode[-1] if last_transition["is_state_terminal"]: R = 0 else: with torch.no_grad(): last_s = batch_states([last_transition["next_state"]], self.device, self.phi) if self.recurrent: (_, _, last_v), _ = one_step_forward(self.model, last_s, model_recurrent_state) else: _, _, last_v = self.model(last_s) R = float(last_v) return self.update( R=R, t_start=0, t_stop=len(episode), rewards=rewards, actions=actions, values=values, action_distribs=action_distribs, action_distribs_mu=action_distribs_mu, avg_action_distribs=avg_action_distribs, action_values=action_values, )
def batch_recurrent_experiences(experiences, device, phi, gamma, batch_states=batch_states): """Batch experiences for recurrent model updates. Args: experiences: list of episodes. Each episode is a list containing between 1 and n dicts, each containing: - state (object): State - action (object): Action - reward (float): Reward - is_state_terminal (bool): True iff next state is terminal - next_state (object): Next state The list must be sorted desc by lengths to be packed by `torch.nn.rnn.pack_sequence` later. device : GPU or CPU the tensor should be placed on phi : Preprocessing function gamma: discount factor batch_states: function that converts a list to a batch Returns: dict of batched transitions """ assert _is_sorted_desc_by_lengths(experiences) flat_transitions = flatten_sequences_time_first(experiences) batch_exp = { "state": [ batch_states([transition["state"] for transition in ep], device, phi) for ep in experiences ], "action": torch.as_tensor( [transition["action"] for transition in flat_transitions], device=device), "reward": torch.as_tensor( [transition["reward"] for transition in flat_transitions], dtype=torch.float, device=device, ), "next_state": [ batch_states([transition["next_state"] for transition in ep], device, phi) for ep in experiences ], "is_state_terminal": torch.as_tensor( [ transition["is_state_terminal"] for transition in flat_transitions ], dtype=torch.float, device=device, ), "discount": torch.full((len(flat_transitions), ), gamma, dtype=torch.float, device=device), "recurrent_state": recurrent_state_from_numpy( concatenate_recurrent_states( [ep[0]["recurrent_state"] for ep in experiences]), device, ), "next_recurrent_state": recurrent_state_from_numpy( concatenate_recurrent_states( [ep[0]["next_recurrent_state"] for ep in experiences]), device, ), } # Batch next actions only when all the transitions have them if all(transition["next_action"] is not None for transition in flat_transitions): batch_exp["next_action"] = torch.as_tensor( [transition["next_action"] for transition in flat_transitions], device=device, ) return batch_exp