def batch_act(self, batch_obs): with torch.no_grad(), evaluating(self.model): batch_av = self._evaluate_model_and_update_recurrent_states( batch_obs) batch_argmax = batch_av.greedy_actions.cpu().numpy() if self.training: batch_action = [ self.explorer.select_action( self.t, lambda: batch_argmax[i], action_value=batch_av[i:i + 1], ) for i in range(len(batch_obs)) ] self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) else: # stochastic batch_action = [ self.explorer.select_action( self.t, lambda: batch_argmax[i], action_value=batch_av[i:i + 1], ) for i in range(len(batch_obs)) ] # deterministic # batch_action = batch_argmax return batch_action
def batch_act(self, batch_obs: Sequence[Any]) -> Sequence[Any]: with torch.no_grad(), evaluating(self.model): batch_av, self.batch_h = self._evaluate_model_and_update_recurrent_states( batch_obs) batch_argmax = batch_av.greedy_actions.detach().cpu().numpy() if self.training: batch_action = [ self.explorer.select_action( self.t, lambda: batch_argmax[i], action_value=batch_av[i:i + 1], ) for i in range(len(batch_obs)) ] self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) else: batch_action = batch_argmax return batch_action
def _trajectory_centric_planning(self, trajectories): state_shape = tuple( trajectories[0]["state"].shape)[1:] # torch.Size -> tuple # Aligning Shapes for Parallel Processing with GPUs # If Atari, it will be (0, 4, 84, 84) batch_states = torch.empty((0, ) + state_shape, dtype=torch.float32) for trajectory in trajectories: bs = torch.empty((self.n_trj_step, ) + state_shape, dtype=torch.float32) bs[:len(trajectory["state"])] = trajectory["state"] # numpy.vstack batch_states = torch.cat((batch_states, bs), dim=0) batch_states = batch_states.to(self.device) with torch.no_grad(), evaluating(self.model): batch_q, _ = self.model(batch_states) q_theta_arr = batch_q.q_values.cpu() q_theta_arr = q_theta_arr.reshape( (len(trajectories), self.n_trj_step, self.n_actions)) q_np_arr = torch.empty((0, self.n_actions), dtype=torch.float32) for q_np, trajectory in zip(q_theta_arr, trajectories): # batch_state = trajectory['state'] batch_action = trajectory['action'] batch_reward = trajectory['reward'] q_np = q_np[:len(batch_action)] for t in range(len(batch_action) - 2, -1, -1): # t:= T-2, 0 V_np = torch.max(q_np[t + 1]) # V_NP(s_t+1) := max_a Q(s_t+1, a) q_np[t, batch_action[t]] = batch_reward[t] + self.gamma * V_np q_np_arr = torch.cat((q_np_arr, q_np.reshape(-1, self.n_actions)), dim=0) return q_np_arr.to(self.device)
def _batch_select_greedy_actions(self, batch_obs): with torch.no_grad(), evaluating(self.policy): batch_xs = self.batch_states(batch_obs, self.device, self.phi) batch_action = self.policy(batch_xs).sample() return batch_action.cpu().numpy()