def _batch_act_train(self, batch_obs): assert self.training b_state = self.batch_states(batch_obs, self.device, self.phi) if self.obs_normalizer: b_state = self.obs_normalizer(b_state, update=False) num_envs = len(batch_obs) if self.batch_last_episode is None: self._initialize_batch_variables(num_envs) assert len(self.batch_last_episode) == num_envs assert len(self.batch_last_state) == num_envs assert len(self.batch_last_action) == num_envs # action_distrib will be recomputed when computing gradients with torch.no_grad(), pfrl.utils.evaluating(self.model): if self.recurrent: assert self.train_prev_recurrent_states is None self.train_prev_recurrent_states = self.train_recurrent_states ( (action_distrib, batch_value), self.train_recurrent_states, ) = one_step_forward(self.model, b_state, self.train_prev_recurrent_states) else: action_distrib, batch_value = self.model(b_state) batch_action = action_distrib.sample().cpu().numpy() self.entropy_record.extend(action_distrib.entropy().cpu().numpy()) self.value_record.extend(batch_value.cpu().numpy()) self.batch_last_state = list(batch_obs) self.batch_last_action = list(batch_action) return batch_action
def _act_train(self, obs): batch_obs = self.batch_states([obs], self.device, self.phi) if self.recurrent: action_distrib, self.train_recurrent_states = one_step_forward( self.model, batch_obs, self.train_recurrent_states ) else: action_distrib = self.model(batch_obs) batch_action = action_distrib.sample() # Save values used to compute losses self.log_prob_sequences[-1].append(action_distrib.log_prob(batch_action)) self.entropy_sequences[-1].append(action_distrib.entropy()) action = batch_action.cpu().numpy()[0] self.logger.debug("t:%s a:%s", self.t, action) # Update stats self.average_entropy += (1 - self.average_entropy_decay) * ( float(action_distrib.entropy()) - self.average_entropy ) return action
def update_on_policy(self, statevar): assert self.t_start < self.t if not self.disable_online_update: if statevar is None: R = 0 else: with torch.no_grad(): if self.recurrent: (_, _, v), _ = one_step_forward(self.model, statevar, self.train_recurrent_states) else: _, _, v = self.model(statevar) R = float(v) self.update( t_start=self.t_start, t_stop=self.t, R=R, actions=self.past_actions, rewards=self.past_rewards, values=self.past_values, action_values=self.past_action_values, action_distribs=self.past_action_distrib, action_distribs_mu=None, avg_action_distribs=self.past_avg_action_distrib, ) self.init_history_data_for_online_update() self.train_recurrent_states = detach_recurrent_state( self.train_recurrent_states)
def _act_train(self, obs): statevar = batch_states([obs], self.device, self.phi) if self.recurrent: ( (action_distrib, action_value, v), self.train_recurrent_states, ) = one_step_forward(self.model, statevar, self.train_recurrent_states) else: action_distrib, action_value, v = self.model(statevar) self.past_action_values[self.t] = action_value action = action_distrib.sample()[0] # Save values for a later update self.past_values[self.t] = v self.past_action_distrib[self.t] = action_distrib with torch.no_grad(): if self.recurrent: ( (avg_action_distrib, _, _), self.shared_recurrent_states, ) = one_step_forward( self.shared_average_model, statevar, self.shared_recurrent_states, ) else: avg_action_distrib, _, _ = self.shared_average_model(statevar) self.past_avg_action_distrib[self.t] = avg_action_distrib self.past_actions[self.t] = action # Update stats self.average_value += (1 - self.average_value_decay) * ( float(v) - self.average_value ) self.average_entropy += (1 - self.average_entropy_decay) * ( float(action_distrib.entropy()) - self.average_entropy ) self.last_state = obs self.last_action = action.numpy() self.last_action_distrib = deepcopy_distribution(action_distrib) return self.last_action
def _evaluate_model_and_update_recurrent_states( self, batch_obs: Sequence[Any] ) -> ActionValue: batch_xs = self.batch_states(batch_obs, self.device, self.phi) if self.recurrent: if self.training: self.train_prev_recurrent_states = self.train_recurrent_states batch_av, self.train_recurrent_states = one_step_forward( self.model, batch_xs, self.train_recurrent_states ) else: batch_av, self.test_recurrent_states = one_step_forward( self.model, batch_xs, self.test_recurrent_states ) else: batch_av = self.model(batch_xs) return batch_av
def _act_eval(self, obs): # Use the process-local model for acting with torch.no_grad(), pfrl.utils.evaluating(self.model): statevar = self.batch_states([obs], self.device, self.phi) if self.recurrent: (pout, _), self.test_recurrent_states = one_step_forward( self.model, statevar, self.test_recurrent_states) else: pout, _ = self.model(statevar) if self.act_deterministically: return mode_of_distribution(pout).cpu().numpy()[0] else: return pout.sample().cpu().numpy()[0]
def _act_eval(self, obs): with torch.no_grad(): batch_obs = self.batch_states([obs], self.device, self.phi) if self.recurrent: action_distrib, self.test_recurrent_states = one_step_forward( self.model, batch_obs, self.test_recurrent_states ) else: action_distrib = self.model(batch_obs) if self.act_deterministically: return mode_of_distribution(action_distrib).cpu().numpy()[0] else: return action_distrib.sample().cpu().numpy()[0]
def _act_eval(self, obs): # Use the process-local model for acting with torch.no_grad(): statevar = batch_states([obs], self.device, self.phi) if self.recurrent: (action_distrib, _, _), self.test_recurrent_states = one_step_forward( self.model, statevar, self.test_recurrent_states) else: action_distrib, _, _ = self.model(statevar) if self.act_deterministically: return mode_of_distribution(action_distrib).numpy()[0] else: return action_distrib.sample().numpy()[0]
def _batch_act_eval(self, batch_obs): assert not self.training b_state = self.batch_states(batch_obs, self.device, self.phi) if self.obs_normalizer: b_state = self.obs_normalizer(b_state, update=False) with torch.no_grad(), pfrl.utils.evaluating(self.model): if self.recurrent: (action_distrib, _), self.test_recurrent_states = one_step_forward( self.model, b_state, self.test_recurrent_states ) else: action_distrib, _ = self.model(b_state) if self.act_deterministically: action = mode_of_distribution(action_distrib).cpu().numpy() else: action = action_distrib.sample().cpu().numpy() return action
def _act_train(self, obs): self.past_obs[self.t] = obs with torch.no_grad(): statevar = self.batch_states([obs], self.device, self.phi) if self.recurrent: self.past_recurrent_state[self.t] = self.train_recurrent_states (pout, vout), self.train_recurrent_states = one_step_forward( self.model, statevar, self.train_recurrent_states) else: pout, vout = self.model(statevar) # Do not backprop through sampled actions action = pout.sample() self.past_action[self.t] = action[0].detach() action = action.cpu().numpy()[0] # Update stats self.average_value += (1 - self.average_value_decay) * ( float(vout) - self.average_value) self.average_entropy += (1 - self.average_entropy_decay) * ( float(pout.entropy()) - self.average_entropy) return action
def update(self, statevar): assert self.t_start < self.t n = self.t - self.t_start self.assert_shared_memory() if statevar is None: R = 0 else: with torch.no_grad(), pfrl.utils.evaluating(self.model): if self.recurrent: (_, vout), _ = one_step_forward(self.model, statevar, self.train_recurrent_states) else: _, vout = self.model(statevar) R = float(vout) pi_loss_factor = self.pi_loss_coef v_loss_factor = self.v_loss_coef # Normalize the loss of sequences truncated by terminal states if self.keep_loss_scale_same and self.t - self.t_start < self.t_max: factor = self.t_max / (self.t - self.t_start) pi_loss_factor *= factor v_loss_factor *= factor if self.normalize_grad_by_t_max: pi_loss_factor /= self.t - self.t_start v_loss_factor /= self.t - self.t_start # Batch re-compute for efficient backprop batch_obs = self.batch_states( [self.past_obs[i] for i in range(self.t_start, self.t)], self.device, self.phi, ) if self.recurrent: (batch_distrib, batch_v), _ = pack_and_forward( self.model, [batch_obs], self.past_recurrent_state[self.t_start], ) else: batch_distrib, batch_v = self.model(batch_obs) batch_action = torch.stack( [self.past_action[i] for i in range(self.t_start, self.t)]) batch_log_prob = batch_distrib.log_prob(batch_action) batch_entropy = batch_distrib.entropy() rev_returns = [] for i in reversed(range(self.t_start, self.t)): R *= self.gamma R += self.past_rewards[i] rev_returns.append(R) batch_return = torch.as_tensor(list(reversed(rev_returns)), dtype=torch.float) batch_adv = batch_return - batch_v.detach().squeeze(-1) assert batch_log_prob.shape == (n, ) assert batch_adv.shape == (n, ) assert batch_entropy.shape == (n, ) pi_loss = torch.sum(-batch_adv * batch_log_prob - self.beta * batch_entropy, dim=0) assert batch_v.shape == (n, 1) assert batch_return.shape == (n, ) v_loss = F.mse_loss(batch_v, batch_return[..., None], reduction="sum") / 2 if pi_loss_factor != 1.0: pi_loss *= pi_loss_factor if v_loss_factor != 1.0: v_loss *= v_loss_factor if self.process_idx == 0: logger.debug("pi_loss:%s v_loss:%s", pi_loss, v_loss) total_loss = torch.squeeze(pi_loss) + torch.squeeze(v_loss) # Compute gradients using thread-specific model self.model.zero_grad() total_loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) # Copy the gradients to the globally shared model copy_param.copy_grad(target_link=self.shared_model, source_link=self.model) # Update the globally shared model self.optimizer.step() if self.process_idx == 0: logger.debug("update") self.sync_parameters() self.past_obs = {} self.past_action = {} self.past_rewards = {} self.past_recurrent_state = {} self.t_start = self.t
def update_from_replay(self): if self.replay_buffer is None: return if len(self.replay_buffer) < self.replay_start_size: return episode = self.replay_buffer.sample_episodes(1, self.t_max)[0] model_recurrent_state = None shared_recurrent_state = None rewards = {} actions = {} action_distribs = {} action_distribs_mu = {} avg_action_distribs = {} action_values = {} values = {} for t, transition in enumerate(episode): bs = batch_states([transition["state"]], self.device, self.phi) if self.recurrent: ( (action_distrib, action_value, v), model_recurrent_state, ) = one_step_forward(self.model, bs, model_recurrent_state) else: action_distrib, action_value, v = self.model(bs) with torch.no_grad(): if self.recurrent: ( (avg_action_distrib, _, _), shared_recurrent_state, ) = one_step_forward( self.shared_average_model, bs, shared_recurrent_state, ) else: avg_action_distrib, _, _ = self.shared_average_model(bs) actions[t] = transition["action"] values[t] = v action_distribs[t] = action_distrib avg_action_distribs[t] = avg_action_distrib rewards[t] = transition["reward"] action_distribs_mu[t] = transition["mu"] action_values[t] = action_value last_transition = episode[-1] if last_transition["is_state_terminal"]: R = 0 else: with torch.no_grad(): last_s = batch_states([last_transition["next_state"]], self.device, self.phi) if self.recurrent: (_, _, last_v), _ = one_step_forward(self.model, last_s, model_recurrent_state) else: _, _, last_v = self.model(last_s) R = float(last_v) return self.update( R=R, t_start=0, t_stop=len(episode), rewards=rewards, actions=actions, values=values, action_distribs=action_distribs, action_distribs_mu=action_distribs_mu, avg_action_distribs=avg_action_distribs, action_values=action_values, )
def get_and_concat_rs_forward(): _, rs = one_step_forward(par, x_t0, None) rs0 = get_recurrent_state_at(rs, 0, detach=True) rs1 = get_recurrent_state_at(rs, 1, detach=True) concat_rs = concatenate_recurrent_states([rs0, rs1]) return one_step_forward(par, x_t1, concat_rs)
def mask01_forward_twice(): _, rs = one_step_forward(par, x_t0, None) rs = mask_recurrent_state_at(rs, [0, 1]) return one_step_forward(par, x_t1, rs)
def no_mask_forward_twice(): _, rs = one_step_forward(par, x_t0, None) return one_step_forward(par, x_t1, rs)