class DQNAgent(object): def __init__(self, env, args, work_dir): self.env = env self.args = args self.work_dir = work_dir self.n_action = self.env.action_space.n self.arr_actions = np.arange(self.n_action) self.memory = ReplayMemory(self.args.buffer_size, self.args.device) self.qNetwork = ValueNetwork(self.n_action, self.env).to(self.args.device) self.targetNetwork = ValueNetwork(self.n_action, self.env).to(self.args.device) self.qNetwork.train() self.targetNetwork.eval() self.optimizer = optim.RMSprop(self.qNetwork.parameters(), lr=0.00025, eps=0.001, alpha=0.95) self.crit = nn.MSELoss() self.eps = max(self.args.eps, self.args.eps_min) self.eps_delta = ( self.eps - self.args.eps_min) / self.args.exploration_decay_speed def reset(self): return torch.cat([preprocess_state(self.env.reset(), self.env)] * 4, 1) def select_action(self, state): action_prob = np.zeros(self.n_action, np.float32) action_prob.fill(self.eps / self.n_action) max_q, max_q_index = self.qNetwork(Variable(state.to( self.args.device))).data.cpu().max(1) action_prob[max_q_index[0]] += 1 - self.eps action = np.random.choice(self.arr_actions, p=action_prob) next_state, reward, done, _ = self.env.step(action) next_state = torch.cat( [state.narrow(1, 1, 3), preprocess_state(next_state, self.env)], 1) self.memory.push( (state, torch.LongTensor([int(action)]), torch.Tensor([reward]), next_state, torch.Tensor([done]))) return next_state, reward, done, max_q[0] def run(self): state = self.reset() # init buffer for _ in range(self.args.buffer_init_size): next_state, _, done, _ = self.select_action(state) state = self.reset() if done else next_state total_frame = 0 reward_list = np.zeros(self.args.log_size, np.float32) qval_list = np.zeros(self.args.log_size, np.float32) start_time = time.time() for epi in count(): reward_list[epi % self.args.log_size] = 0 qval_list[epi % self.args.log_size] = -1e9 state = self.reset() done = False ep_len = 0 if epi % self.args.save_freq == 0: model_file = os.path.join(self.work_dir, 'model.th') with open(model_file, 'wb') as f: torch.save(self.qNetwork, f) while not done: if total_frame % self.args.sync_period == 0: self.targetNetwork.load_state_dict( self.qNetwork.state_dict()) self.eps = max(self.args.eps_min, self.eps - self.eps_delta) next_state, reward, done, qval = self.select_action(state) reward_list[epi % self.args.log_size] += reward qval_list[epi % self.args.log_size] = max( qval_list[epi % self.args.log_size], qval) state = next_state total_frame += 1 ep_len += 1 if ep_len % self.args.learn_freq == 0: batch_state, batch_action, batch_reward, batch_next_state, batch_done = self.memory.sample( self.args.batch_size) batch_q = self.qNetwork(batch_state).gather( 1, batch_action.unsqueeze(1)).squeeze(1) batch_next_q = self.targetNetwork(batch_next_state).detach( ).max(1)[0] * self.args.gamma * (1 - batch_done) loss = self.crit(batch_q, batch_reward + batch_next_q) self.optimizer.zero_grad() loss.backward() self.optimizer.step() output_str = 'episode %d frame %d time %.2fs cur_rew %.3f mean_rew %.3f cur_maxq %.3f mean_maxq %.3f' % ( epi, total_frame, time.time() - start_time, reward_list[epi % self.args.log_size], np.mean(reward_list), qval_list[epi % self.args.log_size], np.mean(qval_list)) print(output_str) logging.info(output_str)
class Agent(): def __init__(self, state_size, action_size, num_agents): state_dim = state_size #agent_input_state_dim = state_size*2 # Previos state is passed in with with the current state. action_dim = action_size self.num_agents = num_agents max_size = 100000 ### self.replay = Replay(max_size) hidden_dim = 128 self.critic_net = ValueNetwork(state_dim, action_dim, hidden_dim).to(device) self.target_critic_net = ValueNetwork(state_dim, action_dim, hidden_dim).to(device) self.actor_net = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device) self.target_actor_net = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device) for target_param, param in zip(self.target_critic_net.parameters(), self.critic_net.parameters()): target_param.data.copy_(param.data) for target_param, param in zip(self.target_actor_net.parameters(), self.actor_net.parameters()): target_param.data.copy_(param.data) self.critic_optimizer = optim.Adam(self.critic_net.parameters(), lr=CRITIC_LEARNING_RATE) self.actor_optimizer = optim.Adam(self.actor_net.parameters(), lr=ACTOR_LEARNING_RATE) def get_action(self, state): return self.actor_net.get_action(state)[0] def add_replay(self, state, action, reward, next_state, done): for i in range(self.num_agents): self.replay.add(state[i], action[i], reward[i], next_state[i], done[i]) def learning_step(self): #Check if relay buffer contains enough samples for 1 batch if (self.replay.cursize < BATCH_SIZE): return #Get Samples state, action, reward, next_state, done = self.replay.get(BATCH_SIZE) #calculate loss actor_loss = self.critic_net(state, self.actor_net(state)) actor_loss = -actor_loss.mean() next_action = self.target_actor_net(next_state) target_value = self.target_critic_net(next_state, next_action.detach()) expected_value = reward + (1.0 - done) * DISCOUNT_RATE * target_value value = self.critic_net(state, action) critic_loss = F.mse_loss(value, expected_value.detach()) #backprop self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() #soft update self.soft_update(self.critic_net, self.target_critic_net, TAU) self.soft_update(self.actor_net, self.target_actor_net, TAU) def save(self, name): torch.save(self.critic_net.state_dict(), name + "_critic") torch.save(self.actor_net.state_dict(), name + "_actor") def load(self, name): self.critic_net.load_state_dict(torch.load(name + "_critic")) self.critic_net.eval() self.actor_net.load_state_dict(torch.load(name + "_actor")) self.actor_net.eval() for target_param, param in zip(self.target_critic_net.parameters(), self.critic_net.parameters()): target_param.data.copy_(param.data) for target_param, param in zip(self.target_actor_net.parameters(), self.actor_net.parameters()): target_param.data.copy_(param.data) def soft_update(self, local_model, target_model, tau): """Soft update model parameters. θ_target = τ*θ_local + (1 - τ)*θ_target Params ====== local_model (PyTorch model): weights will be copied from target_model (PyTorch model): weights will be copied to tau (float): interpolation parameter """ for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
class SAC: def __init__(self, env_name, n_states, n_actions, memory_size, batch_size, gamma, alpha, lr, action_bounds, reward_scale): self.env_name = env_name self.n_states = n_states self.n_actions = n_actions self.memory_size = memory_size self.batch_size = batch_size self.gamma = gamma self.alpha = alpha self.lr = lr self.action_bounds = action_bounds self.reward_scale = reward_scale self.memory = Memory(memory_size=self.memory_size) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.policy_network = PolicyNetwork( n_states=self.n_states, n_actions=self.n_actions, action_bounds=self.action_bounds).to(self.device) self.q_value_network1 = QvalueNetwork(n_states=self.n_states, n_actions=self.n_actions).to( self.device) self.q_value_network2 = QvalueNetwork(n_states=self.n_states, n_actions=self.n_actions).to( self.device) self.value_network = ValueNetwork(n_states=self.n_states).to( self.device) self.value_target_network = ValueNetwork(n_states=self.n_states).to( self.device) self.value_target_network.load_state_dict( self.value_network.state_dict()) self.value_target_network.eval() self.value_loss = torch.nn.MSELoss() self.q_value_loss = torch.nn.MSELoss() self.value_opt = Adam(self.value_network.parameters(), lr=self.lr) self.q_value1_opt = Adam(self.q_value_network1.parameters(), lr=self.lr) self.q_value2_opt = Adam(self.q_value_network2.parameters(), lr=self.lr) self.policy_opt = Adam(self.policy_network.parameters(), lr=self.lr) def store(self, state, reward, done, action, next_state): state = from_numpy(state).float().to("cpu") reward = torch.Tensor([reward]).to("cpu") done = torch.Tensor([done]).to("cpu") action = torch.Tensor([action]).to("cpu") next_state = from_numpy(next_state).float().to("cpu") self.memory.add(state, reward, done, action, next_state) def unpack(self, batch): batch = Transition(*zip(*batch)) states = torch.cat(batch.state).view(self.batch_size, self.n_states).to(self.device) rewards = torch.cat(batch.reward).view(self.batch_size, 1).to(self.device) dones = torch.cat(batch.done).view(self.batch_size, 1).to(self.device) actions = torch.cat(batch.action).view(-1, self.n_actions).to(self.device) next_states = torch.cat(batch.next_state).view( self.batch_size, self.n_states).to(self.device) return states, rewards, dones, actions, next_states def train(self): if len(self.memory) < self.batch_size: return 0, 0, 0 else: batch = self.memory.sample(self.batch_size) states, rewards, dones, actions, next_states = self.unpack(batch) # Calculating the value target reparam_actions, log_probs = self.policy_network.sample_or_likelihood( states) q1 = self.q_value_network1(states, reparam_actions) q2 = self.q_value_network2(states, reparam_actions) q = torch.min(q1, q2) target_value = q.detach() - self.alpha * log_probs.detach() value = self.value_network(states) value_loss = self.value_loss(value, target_value) # Calculating the Q-Value target with torch.no_grad(): target_q = self.reward_scale * rewards + \ self.gamma * self.value_target_network(next_states) * (1 - dones) q1 = self.q_value_network1(states, actions) q2 = self.q_value_network2(states, actions) q1_loss = self.q_value_loss(q1, target_q) q2_loss = self.q_value_loss(q2, target_q) policy_loss = (self.alpha * log_probs - q).mean() self.policy_opt.zero_grad() policy_loss.backward() self.policy_opt.step() self.value_opt.zero_grad() value_loss.backward() self.value_opt.step() self.q_value1_opt.zero_grad() q1_loss.backward() self.q_value1_opt.step() self.q_value2_opt.zero_grad() q2_loss.backward() self.q_value2_opt.step() self.soft_update_target_network(self.value_network, self.value_target_network) return value_loss.item(), 0.5 * ( q1_loss + q2_loss).item(), policy_loss.item() def choose_action(self, states): states = np.expand_dims(states, axis=0) states = from_numpy(states).float().to(self.device) action, _ = self.policy_network.sample_or_likelihood(states) return action.detach().cpu().numpy()[0] @staticmethod def soft_update_target_network(local_network, target_network, tau=0.005): for target_param, local_param in zip(target_network.parameters(), local_network.parameters()): target_param.data.copy_(tau * local_param.data + (1 - tau) * target_param.data) def save_weights(self): torch.save(self.policy_network.state_dict(), self.env_name + "_weights.pth") def load_weights(self): self.policy_network.load_state_dict( torch.load(self.env_name + "_weights.pth")) def set_to_eval_mode(self): self.policy_network.eval()