def _compute_target_values(self, exp_batch): batch_next_state = exp_batch["next_state"] with evaluating(self.model): if self.recurrent: next_qout, _ = pack_and_forward( self.model, batch_next_state, exp_batch["next_recurrent_state"], ) else: next_qout = self.model(batch_next_state) if self.recurrent: target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_qout = self.target_model(batch_next_state) next_q_max = target_next_qout.evaluate_actions( next_qout.greedy_actions) batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] discount = exp_batch["discount"] return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max
def _compute_y_and_t(self, exp_batch): batch_state = exp_batch["state"] batch_size = len(exp_batch["reward"]) if self.recurrent: qout, _ = pack_and_forward( self.model, batch_state, exp_batch["recurrent_state"], ) else: qout = self.model(batch_state) batch_actions = exp_batch["action"] batch_q = qout.evaluate_actions(batch_actions) # Compute target values batch_next_state = exp_batch["next_state"] with torch.no_grad(): if self.recurrent: target_qout, _ = pack_and_forward( self.target_model, batch_state, exp_batch["recurrent_state"], ) target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_qout = self.target_model(batch_state) target_next_qout = self.target_model(batch_next_state) next_q_max = torch.reshape(target_next_qout.max, (batch_size, )) batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] # T Q: Bellman operator t_q = (batch_rewards + exp_batch["discount"] * (1.0 - batch_terminal) * next_q_max) # T_PAL Q: persistent advantage learning operator cur_advantage = torch.reshape( target_qout.compute_advantage(batch_actions), (batch_size, )) next_advantage = torch.reshape( target_next_qout.compute_advantage(batch_actions), (batch_size, )) tpal_q = t_q + self.alpha * torch.max(cur_advantage, next_advantage) return batch_q, tpal_q
def _add_log_prob_and_value_to_episodes_recurrent( episodes, model, phi, batch_states, obs_normalizer, device, ): # Sort desc by lengths so that pack_sequence does not change the order episodes = sorted(episodes, key=len, reverse=True) # Prepare data for a recurrent model seqs_states = [] seqs_next_states = [] for ep in episodes: states = batch_states([transition["state"] for transition in ep], device, phi) next_states = batch_states( [transition["next_state"] for transition in ep], device, phi) if obs_normalizer: states = obs_normalizer(states, update=False) next_states = obs_normalizer(next_states, update=False) seqs_states.append(states) seqs_next_states.append(next_states) flat_transitions = flatten_sequences_time_first(episodes) # Predict values using a recurrent model with torch.no_grad(), pfrl.utils.evaluating(model): rs = concatenate_recurrent_states( [ep[0]["recurrent_state"] for ep in episodes]) next_rs = concatenate_recurrent_states( [ep[0]["next_recurrent_state"] for ep in episodes]) assert (rs is None) or (next_rs is None) or (len(rs) == len(next_rs)) (flat_distribs, flat_vs), _ = pack_and_forward(model, seqs_states, rs) (_, flat_next_vs), _ = pack_and_forward(model, seqs_next_states, next_rs) flat_actions = torch.tensor([b["action"] for b in flat_transitions], device=device) flat_log_probs = flat_distribs.log_prob(flat_actions).cpu().numpy() flat_vs = flat_vs.cpu().numpy() flat_next_vs = flat_next_vs.cpu().numpy() # Add predicted values to transitions for transition, log_prob, v, next_v in zip(flat_transitions, flat_log_probs, flat_vs, flat_next_vs): transition["log_prob"] = float(log_prob) transition["v_pred"] = float(v) transition["next_v_pred"] = float(next_v)
def _compute_y_and_t(self, exp_batch): """Compute a batch of predicted/target return distributions.""" batch_size = exp_batch["reward"].shape[0] # Compute Q-values for current states batch_state = exp_batch["state"] # (batch_size, n_actions, n_atoms) if self.recurrent: qout, _ = pack_and_forward(self.model, batch_state, exp_batch["recurrent_state"]) else: qout = self.model(batch_state) n_atoms = qout.z_values.size()[0] batch_actions = exp_batch["action"] batch_q = qout.evaluate_actions_as_distribution(batch_actions) assert batch_q.shape == (batch_size, n_atoms) with torch.no_grad(): batch_q_target = self._compute_target_values(exp_batch) assert batch_q_target.shape == (batch_size, n_atoms) batch_q_scalars = qout.evaluate_actions(batch_actions) self.q_record.extend(batch_q_scalars.detach().cpu().numpy().ravel()) return batch_q, batch_q_target
def _update_vf_once_recurrent(self, episodes): # Sort episodes desc by length for pack_sequence episodes = sorted(episodes, key=len, reverse=True) flat_transitions = flatten_sequences_time_first(episodes) # Prepare data for a recurrent model seqs_states = [] for ep in episodes: states = self.batch_states( [transition["state"] for transition in ep], self.device, self.phi) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) seqs_states.append(states) flat_vs_teacher = torch.as_tensor( [[transition["v_teacher"]] for transition in flat_transitions], device=self.device, dtype=torch.float, ) with torch.no_grad(): vf_rs = concatenate_recurrent_states( _collect_first_recurrent_states_of_vf(episodes)) flat_vs_pred, _ = pack_and_forward(self.vf, seqs_states, vf_rs) vf_loss = F.mse_loss(flat_vs_pred, flat_vs_teacher) self.vf.zero_grad() vf_loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.vf.parameters(), self.max_grad_norm) self.vf_optimizer.step()
def _compute_target_values(self, exp_batch): """Compute a batch of target return distributions.""" batch_next_state = exp_batch["next_state"] if self.recurrent: target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_qout = self.target_model(batch_next_state) batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] batch_size = exp_batch["reward"].shape[0] z_values = target_next_qout.z_values n_atoms = z_values.size()[0] # next_q_max: (batch_size, n_atoms) next_q_max = target_next_qout.max_as_distribution.detach() assert next_q_max.shape == (batch_size, n_atoms), next_q_max.shape # Tz: (batch_size, n_atoms) Tz = (batch_rewards[..., None] + (1.0 - batch_terminal[..., None]) * torch.unsqueeze(exp_batch["discount"], 1) * z_values[None]) return _apply_categorical_projection(Tz, next_q_max, z_values)
def _compute_y_and_taus(self, exp_batch): """Compute a batch of predicted return distributions. Returns: torch.Tensor: Predicted return distributions. (batch_size, N). """ batch_size = exp_batch["reward"].shape[0] # Compute Q-values for current states batch_state = exp_batch["state"] # (batch_size, n_actions, n_atoms) if self.recurrent: tau2av, _ = pack_and_forward( self.model, batch_state, exp_batch["recurrent_state"], ) else: tau2av = self.model(batch_state) taus = torch.rand( batch_size, self.quantile_thresholds_N, device=self.device, dtype=torch.float, ) av = tau2av(taus) batch_actions = exp_batch["action"] y = av.evaluate_actions_as_quantiles(batch_actions) self.q_record.extend(av.q_values.detach().cpu().numpy().ravel()) return y, taus
def _compute_y_and_t(self, exp_batch): batch_state = exp_batch["state"] batch_size = len(exp_batch["reward"]) if self.recurrent: qout, _ = pack_and_forward( self.model, batch_state, exp_batch["recurrent_state"], ) else: qout = self.model(batch_state) batch_actions = exp_batch["action"] # Q(s_t,a_t) batch_q = qout.evaluate_actions(batch_actions).reshape((batch_size, 1)) with torch.no_grad(): # Compute target values if self.recurrent: target_qout, _ = pack_and_forward( self.target_model, batch_state, exp_batch["recurrent_state"], ) else: target_qout = self.target_model(batch_state) # Q'(s_t,a_t) target_q = target_qout.evaluate_actions(batch_actions).reshape( (batch_size, 1)) # LQ'(s_t,a) target_q_expect = self._l_operator(target_qout).reshape( (batch_size, 1)) # r + g * LQ'(s_{t+1},a) batch_q_target = self._compute_target_values(exp_batch).reshape( (batch_size, 1)) # Q'(s_t,a_t) + r + g * LQ'(s_{t+1},a) - LQ'(s_t,a) t = target_q + batch_q_target - target_q_expect return batch_q, t
def _compute_target_values(self, exp_batch): """Compute a batch of target return distributions.""" batch_next_state = exp_batch["next_state"] batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] with pfrl.utils.evaluating(self.target_model), pfrl.utils.evaluating( self.model): if self.recurrent: target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) next_qout, _ = pack_and_forward( self.model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_qout = self.target_model(batch_next_state) next_qout = self.model(batch_next_state) batch_size = batch_rewards.shape[0] z_values = target_next_qout.z_values n_atoms = z_values.numel() # next_q_max: (batch_size, n_atoms) next_q_max = target_next_qout.evaluate_actions_as_distribution( next_qout.greedy_actions.detach()) assert next_q_max.shape == (batch_size, n_atoms), next_q_max.shape # Tz: (batch_size, n_atoms) Tz = (batch_rewards[..., None] + (1.0 - batch_terminal[..., None]) * exp_batch["discount"][..., None] * z_values[None]) # Tz = ( # batch_rewards.squeeze(dim=-1) # + (1.0 - batch_terminal.unsqueeze(dim=-1)) # * exp_batch["discount"].unsqueeze(dim=-1) # * z_values.unsqueeze(dim=0) # ) return _apply_categorical_projection(Tz, next_q_max, z_values)
def _compute_target_values(self, exp_batch): """Compute a batch of target return distributions. Returns: torch.Tensor: (batch_size, N_prime). """ batch_next_state = exp_batch["next_state"] batch_size = len(exp_batch["reward"]) taus_tilde = torch.rand( batch_size, self.quantile_thresholds_K, device=self.device, dtype=torch.float, ) if self.recurrent: target_next_tau2av, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_tau2av = self.target_model(batch_next_state) greedy_actions = target_next_tau2av(taus_tilde).greedy_actions taus_prime = torch.rand( batch_size, self.quantile_thresholds_N_prime, device=self.device, dtype=torch.float, ) target_next_maxz = target_next_tau2av( taus_prime).evaluate_actions_as_quantiles(greedy_actions) batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] batch_discount = exp_batch["discount"] assert batch_rewards.shape == (batch_size, ) assert batch_terminal.shape == (batch_size, ) assert batch_discount.shape == (batch_size, ) batch_rewards = batch_rewards.unsqueeze(-1) batch_terminal = batch_terminal.unsqueeze(-1) batch_discount = batch_discount.unsqueeze(-1) return (batch_rewards + batch_discount * (1.0 - batch_terminal) * target_next_maxz)
def _compute_target_values(self, exp_batch): batch_next_state = exp_batch["next_state"] if self.recurrent: target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_qout = self.target_model(batch_next_state) next_q_expect = self._l_operator(target_next_qout) batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] return ( batch_rewards + exp_batch["discount"] * (1 - batch_terminal) * next_q_expect )
def _compute_target_values(self, exp_batch): batch_next_state = exp_batch["next_state"] if self.recurrent: target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_qout = self.target_model(batch_next_state) next_q_max = target_next_qout.max batch_terminal = exp_batch["is_state_terminal"] discount = exp_batch["discount"] batch_rewards = exp_batch["reward"] return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max
def _compute_y_and_t(self, exp_batch): batch_size = exp_batch["reward"].shape[0] # Compute Q-values for current states batch_state = exp_batch["state"] if self.recurrent: qout, _ = pack_and_forward(self.model, batch_state, exp_batch["recurrent_state"]) else: qout = self.model(batch_state) batch_actions = exp_batch["action"] batch_q = torch.reshape(qout.evaluate_actions(batch_actions), (batch_size, 1)) with torch.no_grad(): batch_q_target = torch.reshape( self._compute_target_values(exp_batch), (batch_size, 1)) return batch_q, batch_q_target
def evaluate_current_policy(): distrib, _ = pack_and_forward(self.policy, seqs_states, policy_rs) return distrib
def _update_policy_recurrent(self, dataset): """Update the policy using a given dataset. The policy is updated via CG and line search. """ # Sort episodes desc by length for pack_sequence dataset = sorted(dataset, key=len, reverse=True) flat_transitions = flatten_sequences_time_first(dataset) # Prepare data for a recurrent model seqs_states = [] for ep in dataset: states = self.batch_states( [transition["state"] for transition in ep], self.device, self.phi, ) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) seqs_states.append(states) flat_actions = torch.as_tensor( [transition["action"] for transition in flat_transitions], device=self.device, ) flat_advs = torch.as_tensor( [transition["adv"] for transition in flat_transitions], device=self.device, dtype=torch.float, ) if self.standardize_advantages: std_advs, mean_advs = torch.std_mean(flat_advs, unbiased=False) flat_advs = (flat_advs - mean_advs) / (std_advs + 1e-8) with torch.no_grad(): policy_rs = concatenate_recurrent_states( _collect_first_recurrent_states_of_policy(dataset) ) flat_distribs, _ = pack_and_forward(self.policy, seqs_states, policy_rs) log_prob_old = torch.tensor( [transition["log_prob"] for transition in flat_transitions], device=self.device, dtype=torch.float, ) gain = self._compute_gain( log_prob=flat_distribs.log_prob(flat_actions), log_prob_old=log_prob_old, entropy=flat_distribs.entropy(), advs=flat_advs, ) # Distribution to compute KL div against with torch.no_grad(): # torch.distributions.Distribution cannot be deepcopied action_distrib_old, _ = pack_and_forward( self.policy, seqs_states, policy_rs ) full_step = self._compute_kl_constrained_step( action_distrib=flat_distribs, action_distrib_old=action_distrib_old, gain=gain, ) self._line_search( full_step=full_step, dataset=dataset, advs=flat_advs, action_distrib_old=action_distrib_old, gain=gain, )
def update(self, statevar): assert self.t_start < self.t n = self.t - self.t_start self.assert_shared_memory() if statevar is None: R = 0 else: with torch.no_grad(), pfrl.utils.evaluating(self.model): if self.recurrent: (_, vout), _ = one_step_forward(self.model, statevar, self.train_recurrent_states) else: _, vout = self.model(statevar) R = float(vout) pi_loss_factor = self.pi_loss_coef v_loss_factor = self.v_loss_coef # Normalize the loss of sequences truncated by terminal states if self.keep_loss_scale_same and self.t - self.t_start < self.t_max: factor = self.t_max / (self.t - self.t_start) pi_loss_factor *= factor v_loss_factor *= factor if self.normalize_grad_by_t_max: pi_loss_factor /= self.t - self.t_start v_loss_factor /= self.t - self.t_start # Batch re-compute for efficient backprop batch_obs = self.batch_states( [self.past_obs[i] for i in range(self.t_start, self.t)], self.device, self.phi, ) if self.recurrent: (batch_distrib, batch_v), _ = pack_and_forward( self.model, [batch_obs], self.past_recurrent_state[self.t_start], ) else: batch_distrib, batch_v = self.model(batch_obs) batch_action = torch.stack( [self.past_action[i] for i in range(self.t_start, self.t)]) batch_log_prob = batch_distrib.log_prob(batch_action) batch_entropy = batch_distrib.entropy() rev_returns = [] for i in reversed(range(self.t_start, self.t)): R *= self.gamma R += self.past_rewards[i] rev_returns.append(R) batch_return = torch.as_tensor(list(reversed(rev_returns)), dtype=torch.float) batch_adv = batch_return - batch_v.detach().squeeze(-1) assert batch_log_prob.shape == (n, ) assert batch_adv.shape == (n, ) assert batch_entropy.shape == (n, ) pi_loss = torch.sum(-batch_adv * batch_log_prob - self.beta * batch_entropy, dim=0) assert batch_v.shape == (n, 1) assert batch_return.shape == (n, ) v_loss = F.mse_loss(batch_v, batch_return[..., None], reduction="sum") / 2 if pi_loss_factor != 1.0: pi_loss *= pi_loss_factor if v_loss_factor != 1.0: v_loss *= v_loss_factor if self.process_idx == 0: logger.debug("pi_loss:%s v_loss:%s", pi_loss, v_loss) total_loss = torch.squeeze(pi_loss) + torch.squeeze(v_loss) # Compute gradients using thread-specific model self.model.zero_grad() total_loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) # Copy the gradients to the globally shared model copy_param.copy_grad(target_link=self.shared_model, source_link=self.model) # Update the globally shared model self.optimizer.step() if self.process_idx == 0: logger.debug("update") self.sync_parameters() self.past_obs = {} self.past_action = {} self.past_rewards = {} self.past_recurrent_state = {} self.t_start = self.t
def _update_once_recurrent(self, episodes, mean_advs, std_advs): assert std_advs is None or std_advs > 0 device = self.device # Sort desc by lengths so that pack_sequence does not change the order episodes = sorted(episodes, key=len, reverse=True) flat_transitions = flatten_sequences_time_first(episodes) # Prepare data for a recurrent model seqs_states = [] for ep in episodes: states = self.batch_states( [transition["state"] for transition in ep], self.device, self.phi, ) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) seqs_states.append(states) flat_actions = torch.tensor( [transition["action"] for transition in flat_transitions], device=device, ) flat_advs = torch.tensor( [transition["adv"] for transition in flat_transitions], dtype=torch.float, device=device, ) if self.standardize_advantages: flat_advs = (flat_advs - mean_advs) / (std_advs + 1e-8) flat_log_probs_old = torch.tensor( [transition["log_prob"] for transition in flat_transitions], dtype=torch.float, device=device, ) flat_vs_pred_old = torch.tensor( [[transition["v_pred"]] for transition in flat_transitions], dtype=torch.float, device=device, ) flat_vs_teacher = torch.tensor( [[transition["v_teacher"]] for transition in flat_transitions], dtype=torch.float, device=device, ) with torch.no_grad(), pfrl.utils.evaluating(self.model): rs = concatenate_recurrent_states( [ep[0]["recurrent_state"] for ep in episodes]) (flat_distribs, flat_vs_pred), _ = pack_and_forward(self.model, seqs_states, rs) flat_log_probs = flat_distribs.log_prob(flat_actions) flat_entropy = flat_distribs.entropy() self.model.zero_grad() loss = self._lossfun( entropy=flat_entropy, vs_pred=flat_vs_pred, log_probs=flat_log_probs, vs_pred_old=flat_vs_pred_old, log_probs_old=flat_log_probs_old, advs=flat_advs, vs_teacher=flat_vs_teacher, ) loss.backward() if self.max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.n_updates += 1