def __init__(self, net_dim, state_dim, action_dim, learning_rate=1e-4): super().__init__(net_dim, state_dim, action_dim, learning_rate) self.explore_rate = 0.25 # epsilon-greedy, the rate of choosing random action self.softmax = torch.nn.Softmax(dim=1) self.action_dim = action_dim self.act = QNetTwin(net_dim, state_dim, action_dim).to(self.device) self.act_target = deepcopy(self.act) self.criterion = torch.nn.SmoothL1Loss() self.optimizer = torch.optim.Adam(self.act.parameters(), lr=learning_rate)
class AgentDoubleDQN(AgentDQN): def __init__(self, net_dim, state_dim, action_dim, learning_rate=1e-4): super().__init__(net_dim, state_dim, action_dim, learning_rate) self.explore_rate = 0.25 # epsilon-greedy, the rate of choosing random action self.softmax = torch.nn.Softmax(dim=1) self.action_dim = action_dim self.act = QNetTwin(net_dim, state_dim, action_dim).to(self.device) self.act_target = deepcopy(self.act) self.criterion = torch.nn.SmoothL1Loss() self.optimizer = torch.optim.Adam(self.act.parameters(), lr=learning_rate) def select_actions(self, states): # for discrete action space states = torch.as_tensor(states, dtype=torch.float32, device=self.device) actions = self.act(states) if rd.rand() < self.explore_rate: # epsilon-greedy a_prob_l = self.softmax(actions).detach().cpu().numpy( ) # choose action according to Q value a_int = [ rd.choice(self.action_dim, p=a_prob) for a_prob in a_prob_l ] else: a_int = actions.argmax(dim=1).detach().cpu().numpy() return a_int def update_policy(self, buffer, max_step, batch_size, repeat_times): """Contribution of DDQN (Double DQN) 1. Twin Q-Network. Use min(q1, q2) to reduce over-estimation. """ buffer.update__now_len__before_sample() next_q = obj_critic = None for _ in range(int(max_step * repeat_times)): with torch.no_grad(): reward, mask, action, state, next_s = buffer.random_sample( batch_size) next_q = self.act_target(next_s).max(dim=1, keepdim=True)[0] q_label = reward + mask * next_q action = action.type(torch.long) q_eval1, q_eval2 = [ qs.gather(1, action) for qs in self.act.get__q1_q2(state) ] obj_critic = self.criterion(q_eval1, q_label) + self.criterion( q_eval2, q_label) self.optimizer.zero_grad() obj_critic.backward() self.optimizer.step() soft_target_update(self.act_target, self.act) return next_q.mean().item(), obj_critic.item() / 2