class DQNAgent(object): def __init__(self, env, args, work_dir): self.env = env self.args = args self.work_dir = work_dir self.n_action = self.env.action_space.n self.arr_actions = np.arange(self.n_action) self.memory = ReplayMemory(self.args.buffer_size, self.args.device) self.qNetwork = ValueNetwork(self.n_action, self.env).to(self.args.device) self.targetNetwork = ValueNetwork(self.n_action, self.env).to(self.args.device) self.qNetwork.train() self.targetNetwork.eval() self.optimizer = optim.RMSprop(self.qNetwork.parameters(), lr=0.00025, eps=0.001, alpha=0.95) self.crit = nn.MSELoss() self.eps = max(self.args.eps, self.args.eps_min) self.eps_delta = ( self.eps - self.args.eps_min) / self.args.exploration_decay_speed def reset(self): return torch.cat([preprocess_state(self.env.reset(), self.env)] * 4, 1) def select_action(self, state): action_prob = np.zeros(self.n_action, np.float32) action_prob.fill(self.eps / self.n_action) max_q, max_q_index = self.qNetwork(Variable(state.to( self.args.device))).data.cpu().max(1) action_prob[max_q_index[0]] += 1 - self.eps action = np.random.choice(self.arr_actions, p=action_prob) next_state, reward, done, _ = self.env.step(action) next_state = torch.cat( [state.narrow(1, 1, 3), preprocess_state(next_state, self.env)], 1) self.memory.push( (state, torch.LongTensor([int(action)]), torch.Tensor([reward]), next_state, torch.Tensor([done]))) return next_state, reward, done, max_q[0] def run(self): state = self.reset() # init buffer for _ in range(self.args.buffer_init_size): next_state, _, done, _ = self.select_action(state) state = self.reset() if done else next_state total_frame = 0 reward_list = np.zeros(self.args.log_size, np.float32) qval_list = np.zeros(self.args.log_size, np.float32) start_time = time.time() for epi in count(): reward_list[epi % self.args.log_size] = 0 qval_list[epi % self.args.log_size] = -1e9 state = self.reset() done = False ep_len = 0 if epi % self.args.save_freq == 0: model_file = os.path.join(self.work_dir, 'model.th') with open(model_file, 'wb') as f: torch.save(self.qNetwork, f) while not done: if total_frame % self.args.sync_period == 0: self.targetNetwork.load_state_dict( self.qNetwork.state_dict()) self.eps = max(self.args.eps_min, self.eps - self.eps_delta) next_state, reward, done, qval = self.select_action(state) reward_list[epi % self.args.log_size] += reward qval_list[epi % self.args.log_size] = max( qval_list[epi % self.args.log_size], qval) state = next_state total_frame += 1 ep_len += 1 if ep_len % self.args.learn_freq == 0: batch_state, batch_action, batch_reward, batch_next_state, batch_done = self.memory.sample( self.args.batch_size) batch_q = self.qNetwork(batch_state).gather( 1, batch_action.unsqueeze(1)).squeeze(1) batch_next_q = self.targetNetwork(batch_next_state).detach( ).max(1)[0] * self.args.gamma * (1 - batch_done) loss = self.crit(batch_q, batch_reward + batch_next_q) self.optimizer.zero_grad() loss.backward() self.optimizer.step() output_str = 'episode %d frame %d time %.2fs cur_rew %.3f mean_rew %.3f cur_maxq %.3f mean_maxq %.3f' % ( epi, total_frame, time.time() - start_time, reward_list[epi % self.args.log_size], np.mean(reward_list), qval_list[epi % self.args.log_size], np.mean(qval_list)) print(output_str) logging.info(output_str)
class SAC: def __init__(self, env, lr=3e-4, gamma=0.99, polyak=5e-3, alpha=0.2, reward_scale=1.0, cuda=True, writer=None): state_size = env.observation_space.shape[0] action_size = env.action_space.shape[0] self.actor = Actor(state_size, action_size) self.critic = Critic(state_size, action_size) self.target_critic = Critic(state_size, action_size).eval() self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr) self.q1_optimizer = optim.Adam(self.critic.q1.parameters(), lr=lr) self.q2_optimizer = optim.Adam(self.critic.q2.parameters(), lr=lr) self.target_critic.load_state_dict(self.critic.state_dict()) for param in self.target_critic.parameters(): param.requires_grad = False self.memory = ReplayMemory() self.gamma = gamma self.alpha = alpha self.polyak = polyak # Always between 0 and 1, usually close to 1 self.reward_scale = reward_scale self.writer = writer self.cuda = cuda if cuda: self.actor = self.actor.to('cuda') self.critic = self.critic.to('cuda') self.target_critic = self.target_critic.to('cuda') def explore(self, state): if self.cuda: state = torch.tensor(state).unsqueeze(0).to('cuda', torch.float) action, _, _ = self.actor.sample(state) # action, _ = self.actor(state) return action.cpu().detach().numpy().reshape(-1) def exploit(self, state): if self.cuda: state = torch.tensor(state).unsqueeze(0).to('cuda', torch.float) _, _, action = self.actor.sample(state) return action.cpu().detach().numpy().reshape(-1) def store_step(self, state, action, next_state, reward, terminal): state = to_tensor_unsqueeze(state) if action.dtype == np.float32: action = torch.from_numpy(action) next_state = to_tensor_unsqueeze(next_state) reward = torch.from_numpy(np.array([reward]).astype(np.float)) terminal = torch.from_numpy(np.array([terminal]).astype(np.uint8)) self.memory.push(state, action, next_state, reward, terminal) def target_update(self, target_net, net): for t, s in zip(target_net.parameters(), net.parameters()): # t.data.copy_(t.data * (1.0 - self.polyak) + s.data * self.polyak) t.data.mul_(1.0 - self.polyak) t.data.add_(self.polyak * s.data) def calc_target_q(self, next_states, rewards, terminals): with torch.no_grad(): next_action, entropy, _ = self.actor.sample( next_states) # penalty term next_q1, next_q2 = self.target_critic(next_states, next_action) next_q = torch.min(next_q1, next_q2) - self.alpha * entropy target_q = rewards * self.reward_scale + ( 1. - terminals) * self.gamma * next_q return target_q def calc_critic_loss(self, states, actions, next_states, rewards, terminals): q1, q2 = self.critic(states, actions) target_q = self.calc_target_q(next_states, rewards, terminals) q1_loss = torch.mean((q1 - target_q).pow(2)) q2_loss = torch.mean((q2 - target_q).pow(2)) return q1_loss, q2_loss def calc_actor_loss(self, states): action, entropy, _ = self.actor.sample(states) q1, q2 = self.critic(states, action) q = torch.min(q1, q2) # actor_loss = torch.mean(-q - self.alpha * entropy) actor_loss = (self.alpha * entropy - q).mean() return actor_loss, entropy def train(self, timestep, batch_size=256): if len(self.memory) < batch_size: return transitions = self.memory.sample(batch_size) transitions = Transition(*zip(*transitions)) if self.cuda: states = torch.cat(transitions.state).to('cuda') actions = torch.stack(transitions.action).to('cuda') next_states = torch.cat(transitions.next_state).to('cuda') rewards = torch.stack(transitions.reward).to('cuda') terminals = torch.stack(transitions.terminal).to('cuda') else: states = torch.cat(transitions.state) actions = torch.stack(transitions.action) next_states = torch.cat(transitions.next_state) rewards = torch.stack(transitions.reward) terminals = torch.stack(transitions.terminal) # Compute target Q func q1_loss, q2_loss = self.calc_critic_loss(states, actions, next_states, rewards, terminals) # Compute actor loss actor_loss, mean = self.calc_actor_loss(states) update_params(self.q1_optimizer, self.critic.q1, q1_loss) update_params(self.q2_optimizer, self.critic.q2, q2_loss) update_params(self.actor_optimizer, self.actor, actor_loss) # target update self.target_update(self.target_critic, self.critic) if timestep % 100 and self.writer: self.writer.add_scalar('Loss/Actor', actor_loss.item(), timestep) self.writer.add_scalar('Loss/Critic', q1_loss.item(), timestep) def save_weights(self, path): self.actor.save(os.path.join(path, 'actor')) self.critic.save(os.path.join(path, 'critic'))