class SAC_Agent: def __init__(self, env, batch_size=256, gamma=0.99, tau=0.005, actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4): #Environment self.env = env state_dim = env.observation_space.shape[0] action_dim = env.action_space.shape[0] #Hyperparameters self.batch_size = batch_size self.gamma = gamma self.tau = tau #Entropy self.alpha = 1 self.target_entropy = -np.prod(env.action_space.shape).item() # heuristic value self.log_alpha = torch.zeros(1, requires_grad=True, device="cuda") self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr) #Networks self.Q1 = SoftQNetwork(state_dim, action_dim).cuda() self.Q1_target = SoftQNetwork(state_dim, action_dim).cuda() self.Q1_target.load_state_dict(self.Q1.state_dict()) self.Q1_optimizer = optim.Adam(self.Q1.parameters(), lr=critic_lr) self.Q2 = SoftQNetwork(state_dim, action_dim).cuda() self.Q2_target = SoftQNetwork(state_dim, action_dim).cuda() self.Q2_target.load_state_dict(self.Q2.state_dict()) self.Q2_optimizer = optim.Adam(self.Q2.parameters(), lr=critic_lr) self.actor = PolicyNetwork(state_dim, action_dim).cuda() self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr) self.loss_function = torch.nn.MSELoss() self.replay_buffer = ReplayBuffer() def act(self, state, deterministic=True): state = torch.tensor(state, dtype=torch.float, device="cuda") mean, log_std = self.actor(state) if(deterministic): action = torch.tanh(mean) else: std = log_std.exp() normal = Normal(mean, std) z = normal.sample() action = torch.tanh(z) action = action.detach().cpu().numpy() return action def update(self, state, action, next_state, reward, done): self.replay_buffer.add_transition(state, action, next_state, reward, done) # Sample next batch and perform batch update: batch_states, batch_actions, batch_next_states, batch_rewards, batch_dones = \ self.replay_buffer.next_batch(self.batch_size) #Map to tensor batch_states = torch.tensor(batch_states, dtype=torch.float, device="cuda") #B,S_D batch_actions = torch.tensor(batch_actions, dtype=torch.float, device="cuda") #B,A_D batch_next_states = torch.tensor(batch_next_states, dtype=torch.float, device="cuda", requires_grad=False) #B,S_D batch_rewards = torch.tensor(batch_rewards, dtype=torch.float, device="cuda", requires_grad=False).unsqueeze(-1) #B,1 batch_dones = torch.tensor(batch_dones, dtype=torch.uint8, device="cuda", requires_grad=False).unsqueeze(-1) #B,1 #Policy evaluation with torch.no_grad(): policy_actions, log_pi = self.actor.sample(batch_next_states) Q1_next_target = self.Q1_target(batch_next_states, policy_actions) Q2_next_target = self.Q2_target(batch_next_states, policy_actions) Q_next_target = torch.min(Q1_next_target, Q2_next_target) td_target = batch_rewards + (1 - batch_dones) * self.gamma * (Q_next_target - self.alpha * log_pi) Q1_value = self.Q1(batch_states, batch_actions) self.Q1_optimizer.zero_grad() loss = self.loss_function(Q1_value, td_target) loss.backward() #torch.nn.utils.clip_grad_norm_(self.Q1.parameters(), 1) self.Q1_optimizer.step() Q2_value = self.Q2(batch_states, batch_actions) self.Q2_optimizer.zero_grad() loss = self.loss_function(Q2_value, td_target) loss.backward() #torch.nn.utils.clip_grad_norm_(self.Q2.parameters(), 1) self.Q2_optimizer.step() #Policy improvement policy_actions, log_pi = self.actor.sample(batch_states) Q1_value = self.Q1(batch_states, policy_actions) Q2_value = self.Q2(batch_states, policy_actions) Q_value = torch.min(Q1_value, Q2_value) self.actor_optimizer.zero_grad() loss = (self.alpha * log_pi - Q_value).mean() loss.backward() #torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1) self.actor_optimizer.step() #Update entropy parameter alpha_loss = (self.log_alpha * (-log_pi - self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.alpha = self.log_alpha.exp() #Update target networks soft_update(self.Q1_target, self.Q1, self.tau) soft_update(self.Q2_target, self.Q2, self.tau) def save(self, file_name): torch.save({'actor_dict': self.actor.state_dict(), 'Q1_dict' : self.Q1.state_dict(), 'Q2_dict' : self.Q2.state_dict(), }, file_name) def load(self, file_name): if os.path.isfile(file_name): print("=> loading checkpoint... ") checkpoint = torch.load(file_name) self.actor.load_state_dict(checkpoint['actor_dict']) self.Q1.load_state_dict(checkpoint['Q1_dict']) self.Q2.load_state_dict(checkpoint['Q2_dict']) print("done !") else: print("no checkpoint found...")
def calc_po_best_response_PER(poacher, target_poacher, po_copy_op, po_good_copy_op, patrollers, pa_s, pa_type, iteration, sess, env, args, final_utility, starting_e, train_episode_num = None): ''' Given a list of patrollers, and their types (DQN, PARAM, RS) Train a DQN poacher as the approximating best response Args: poacher: DQN poacher target_poacher: target DQN poacher po_copy_op: tensorflow copy opertaions, copy the weights from DQN to the target DQN po_good_copy_op: tensorflow copy operations, save the trained ever-best poacher DQN patrollers: a list of patrollers pa_s: the patroller mixed startegy among the list of patrollers pa_type: a list specifying the type of each patroller, {'DQN', 'PARAM', 'RS'} iteration: the current DO iterations sess: tensorflow sess env: the game environment args: some args final_utility: record the best response utility starting_e: the starting of the training epoch Return: Nothing explictly returned due to multithreading. The best response utility is returned in $final_utility$ The best response DQN is copied through the $po_good_copy_op$ ''' #print('FIND_poacher_best_response iteration: ' + str(iteration)) if train_episode_num is None: train_episode_num = args.po_episode_num decrease_time = 1.0 / args.epsilon_decrease epsilon_decrease_every = train_episode_num // decrease_time if not args.PER: replay_buffer = ReplayBuffer(args, args.po_replay_buffer_size) else: replay_buffer = PERMemory(args) pa_strategy = pa_s best_utility = -10000.0 test_utility = [] if starting_e == 0: log = open(args.save_path + 'po_log_train_iter_' + str(iteration) + '.dat', 'w') test_log = open(args.save_path + 'po_log_test_iter_' + str(iteration) + '.dat', 'w') else: log = open(args.save_path + 'po_log_train_iter_' + str(iteration) + '.dat', 'a') test_log = open(args.save_path + 'po_log_test_iter_' + str(iteration) + '.dat', 'a') epsilon = 1.0 learning_rate = args.po_initial_lr global_step = 0 action_id = { ('still', 1): 0, ('up', 0): 1, ('down', 0): 2, ('left', 0): 3, ('right', 0): 4 } sess.run(po_copy_op) for e in range(starting_e, starting_e + train_episode_num): if e > 0 and e % epsilon_decrease_every == 0: epsilon = max(0.1, epsilon - args.epsilon_decrease) if e % args.mix_every_episode == 0 or e == starting_e: pa_chosen_strat = np.argmax(np.random.multinomial(1, pa_strategy)) patroller = patrollers[pa_chosen_strat] type = pa_type[pa_chosen_strat] # if args.gui == 1 and e > 0 and e % args.gui_every_episode == 0: # test_gui(poacher, patroller, sess, args, pah = heurestic_flag, poh = False) ### reset the environment poacher.reset_snare_num() pa_state, po_state = env.reset_game() episode_reward = 0.0 pa_action = 'still' for t in range(args.max_time): global_step += 1 transition = [] ### transition adds current state transition.append(po_state) ### poacher chooses an action, if it has not been caught/returned home if not env.catch_flag and not env.home_flag: po_state = np.array([po_state]) snare_flag, po_action = poacher.infer_action(sess=sess, states=po_state, policy="epsilon_greedy", epsilon=epsilon, po_loc=env.po_loc, animal_density=env.animal_density) else: snare_flag = True po_action = 'still' transition.append(action_id[(po_action, snare_flag)]) ### patroller chooses an action ### Note that heuristic and DQN agent has different APIs if type == 'DQN': pa_state = np.array([pa_state]) # Make it 2-D, i.e., [batch_size(1), state_size] pa_action = patroller.infer_action(sess=sess, states=pa_state, policy="greedy", pa_loc=env.pa_loc, animal_density=env.animal_density) elif type == 'PARAM': pa_loc = env.pa_loc pa_action = patroller.infer_action(pa_loc, env.get_local_po_trace(pa_loc), 1.5, -2.0, 8.0) elif type == 'RS': pa_loc = env.pa_loc footprints = [] actions = ['up', 'down', 'left', 'right'] for i in range(4,8): if env.po_trace[pa_loc[0], pa_loc[1]][i] == 1: footprints.append(actions[i - 4]) pa_action = patroller.infer_action(pa_loc, pa_action, footprints) pa_state, _, po_state, po_reward, end_game = \ env.step(pa_action, po_action, snare_flag) ### transition adds reward, and the new state transition.append(po_reward) transition.append(po_state) episode_reward += po_reward ### Add transition to replay buffer replay_buffer.add_transition(transition) ### Start training ### Sample a minibatch if replay_buffer.size >= args.batch_size: if not args.PER: train_state, train_action, train_reward, train_new_state = \ replay_buffer.sample_batch(args.batch_size) else: train_state, train_action, train_reward,train_new_state, \ idx_batch, weight_batch = replay_buffer.sample_batch(args.batch_size) ### Double DQN get target max_index = poacher.get_max_q_index(sess=sess, states=train_new_state) max_q = target_poacher.get_q_by_index(sess=sess, states=train_new_state, index=max_index) q_target = train_reward + args.reward_gamma * max_q if args.PER: q_pred = sess.run(poacher.output, {poacher.input_state: train_state}) q_pred = q_pred[np.arange(args.batch_size), train_action] TD_error_batch = np.abs(q_target - q_pred) replay_buffer.update(idx_batch, TD_error_batch) if not args.PER: weight = np.ones(args.batch_size) else: weight = weight_batch ### Update parameter feed = { poacher.input_state: train_state, poacher.actions: train_action, poacher.q_target: q_target, poacher.learning_rate: learning_rate, poacher.loss_weight: weight } sess.run(poacher.train_op, feed_dict=feed) ### Update target network if global_step > 0 and global_step % args.target_update_every == 0: sess.run(po_copy_op) ### game ends: 1) the patroller catches the poacher and removes all the snares; ### 2) the maximum time step is achieved if end_game or (t == args.max_time - 1): info = str(e) + "\tepisode\t%s\tlength\t%s\ttotal_reward\t%s\taverage_reward\t%s" % \ (e, t + 1, episode_reward, 1. * episode_reward / (t + 1)) if e % args.print_every == 0: log.write(info + '\n') print('po ' + info) #log.flush() break ### save model if e > 0 and e % args.save_every_episode == 0 or e == train_episode_num - 1: save_name = args.save_path + 'iteration_' + str(iteration) + '_epoch_'+ str(e) + "_po_model.ckpt" poacher.save(sess=sess, filename=save_name) #print('Save model to ' + save_name) ### test if e == train_episode_num - 1 or ( e > 0 and e % args.test_every_episode == 0): po_utility = 0.0 test_total_reward = np.zeros(len(pa_strategy)) ### test against each patroller strategy in the current strategy set for pa_strat in range(len(pa_strategy)): if pa_strategy[pa_strat] > 1e-10: _, test_total_reward[pa_strat], _ = test_(patrollers[pa_strat], poacher, \ env, sess,args, iteration, e, poacher_type = 'DQN', patroller_type = pa_type[pa_strat]) po_utility += pa_strategy[pa_strat] * test_total_reward[pa_strat] test_utility.append(po_utility) if po_utility > best_utility and (e > min(50000, train_episode_num / 2) or args.row_num == 3): best_utility = po_utility sess.run(po_good_copy_op) final_utility[1] = po_utility info = [str(po_utility)] + [str(x) for x in test_total_reward] info = 'test ' + str(e) + ' ' + '\t'.join(info) + '\n' #print('reward is: ', info) print('po ' + info) test_log.write(info) test_log.flush() test_log.close() log.close()
def calc_pa_best_response_PER(patroller, target_patroller, pa_copy_op, pa_good_copy_op, poachers, po_strategy, po_type, iteration, sess, env, args, final_utility, starting_e, train_episode_num = None, po_locations = None): ''' po_locations: if is purely global mode, then po_locations is None else, it is the local + global retrain mode. each entry of po_locations specify the local mode of that poacher. Other things are basically the same as the function 'calc_po_best_response_PER' ''' po_location = None #print('FIND_patroller_best_response iteration: ' + str(iteration)) if train_episode_num is None: train_episode_num = args.pa_episode_num decrease_time = 1.0 / args.epsilon_decrease epsilon_decrease_every = train_episode_num // decrease_time if not args.PER: replay_buffer = ReplayBuffer(args, args.pa_replay_buffer_size) else: replay_buffer = PERMemory(args) best_utility = -10000.0 test_utility = [] if starting_e == 0: log = open(args.save_path + 'pa_log_train_iter_' + str(iteration) + '.dat', 'w') test_log = open(args.save_path + 'pa_log_test_iter_' + str(iteration) + '.dat', 'w') else: log = open(args.save_path + 'pa_log_train_iter_' + str(iteration) + '.dat', 'a') test_log = open(args.save_path + 'pa_log_test_iter_' + str(iteration) + '.dat', 'a') epsilon = 1.0 learning_rate = args.po_initial_lr global_step = 0 action_id = { 'still': 0, 'up': 1, 'down': 2, 'left': 3, 'right': 4 } sess.run(pa_copy_op) for e in range(starting_e, starting_e + train_episode_num): if e > 0 and e % epsilon_decrease_every == 0: epsilon = max(0.1, epsilon - args.epsilon_decrease) if e % args.mix_every_episode == 0 or e == starting_e: po_chosen_strat = np.argmax(np.random.multinomial(1, po_strategy)) poacher = poachers[po_chosen_strat] type = po_type[po_chosen_strat] if po_locations is not None: # loacl + global mode, needs to change the poacher mode po_location = po_locations[po_chosen_strat] ### reset the environment poacher.reset_snare_num() pa_state, po_state = env.reset_game(po_location) episode_reward = 0.0 pa_action = 'still' for t in range(args.max_time): global_step += 1 ### transition records the (s,a,r,s) tuples transition = [] ### poacher chooses an action ### doing so is because heuristic and DQN agent has different infer_action API if type == 'DQN': if not env.catch_flag and not env.home_flag: # if poacher is not caught, it can still do actions po_state = np.array([po_state]) snare_flag, po_action = poacher.infer_action(sess=sess, states=po_state, policy="greedy", po_loc=env.po_loc, animal_density=env.animal_density) else: ### however, if it is caught, just make it stay still and does nothing snare_flag = 0 po_action = 'still' elif type == 'PARAM': po_loc = env.po_loc if not env.catch_flag and not env.home_flag: snare_flag, po_action = poacher.infer_action(loc=po_loc, local_trace=env.get_local_pa_trace(po_loc), local_snare=env.get_local_snare(po_loc), initial_loc=env.po_initial_loc) else: snare_flag = 0 po_action = 'still' ### transition appends the current state transition.append(pa_state) ### patroller chooses an action pa_state = np.array([pa_state]) pa_action = patroller.infer_action(sess=sess, states=pa_state, policy="epsilon_greedy", epsilon=epsilon, pa_loc=env.pa_loc, animal_density=env.animal_density) ### transition adds action transition.append(action_id[pa_action]) ### the game moves on a step. pa_state, pa_reward, po_state, _, end_game = \ env.step(pa_action, po_action, snare_flag) ### transition adds reward and the next state episode_reward += pa_reward transition.append(pa_reward) transition.append(pa_state) ### Add transition to replay buffer replay_buffer.add_transition(transition) ### Start training ### Sample a minibatch, if the replay buffer has been full if replay_buffer.size >= args.batch_size: if not args.PER: train_state, train_action, train_reward, train_new_state = \ replay_buffer.sample_batch(args.batch_size) else: train_state, train_action, train_reward,train_new_state, \ idx_batch, weight_batch = replay_buffer.sample_batch(args.batch_size) ### Double DQN get target max_index = patroller.get_max_q_index(sess=sess, states=train_new_state) max_q = target_patroller.get_q_by_index(sess=sess, states=train_new_state, index=max_index) q_target = train_reward + args.reward_gamma * max_q if args.PER: q_pred = sess.run(patroller.output, {patroller.input_state: train_state}) q_pred = q_pred[np.arange(args.batch_size), train_action] TD_error_batch = np.abs(q_target - q_pred) replay_buffer.update(idx_batch, TD_error_batch) if not args.PER: weight = np.ones(args.batch_size) else: weight = weight_batch ### Update parameter feed = { patroller.input_state: train_state, patroller.actions: train_action, patroller.q_target: q_target, patroller.learning_rate: learning_rate, patroller.weight_loss: weight } sess.run(patroller.train_op, feed_dict=feed) ### Update target network if global_step % args.target_update_every == 0: sess.run(pa_copy_op) ### game ends: 1) the patroller catches the poacher and removes all the snares; ### 2) the maximum time step is achieved if end_game or (t == args.max_time - 1): info = str(e) + "\tepisode\t%s\tlength\t%s\ttotal_reward\t%s\taverage_reward\t%s" % \ (e, t + 1, episode_reward, 1. * episode_reward / (t + 1)) if e % args.print_every == 0: log.write(info + '\n') print('pa ' + info) # log.flush() break ### save the models, and test if they are good if e > 0 and e % args.save_every_episode == 0 or e == train_episode_num - 1: save_name = args.save_path + 'iteration_' + str(iteration) + '_epoch_' + str(e) + "_pa_model.ckpt" patroller.save(sess=sess, filename=save_name) ### test the agent if e == train_episode_num - 1 or ( e > 0 and e % args.test_every_episode == 0): ### test against each strategy the poacher is using now, compute the expected utility pa_utility = 0.0 test_total_reward = np.zeros(len(po_strategy)) for po_strat in range(len(po_strategy)): if po_strategy[po_strat] > 1e-10: if po_locations is None: ### indicates the purely global mode tmp_po_location = None else: ### indicates the local + global retrain mode, needs to set poacher mode tmp_po_location = po_locations[po_strat] test_total_reward[po_strat], _, _ = test_(patroller, poachers[po_strat], \ env, sess,args, iteration, e, patroller_type='DQN', poacher_type=po_type[po_strat], po_location=tmp_po_location) ### update the expected utility pa_utility += po_strategy[po_strat] * test_total_reward[po_strat] test_utility.append(pa_utility) if pa_utility > best_utility and (e > min(50000, train_episode_num / 2) or args.row_num == 3): best_utility = pa_utility sess.run(pa_good_copy_op) final_utility[0] = pa_utility info = [str(pa_utility)] + [str(x) for x in test_total_reward] info = 'test ' + str(e) + ' ' + '\t'.join(info) + '\n' #print('reward is: ', info) print('pa ' + info) test_log.write(info) test_log.flush() test_log.close() log.close()
class DQN: def __init__(self, state_dim, action_dim, gamma, conf={'lr':0.001, 'bs':64, 'loss':nn.MSELoss, 'hidden_dim':64, 'activation':'relu', 'mem_size':50000, 'epsilon':1., 'eps_scheduler':'exp', 'n_episodes':1000, 'n_cycles':1, 'subtract':0., }): if conf['activation'] == 'relu': activation = torch.relu elif conf['activation'] == 'tanh': activation = torch.tanh self._q = Q(state_dim, action_dim, non_linearity=activation, hidden_dim=conf['hidden_dim'], dropout_rate=conf['dropout_rate']).to(device) self._q_target = Q(state_dim, action_dim, non_linearity=activation, hidden_dim=conf['hidden_dim'], dropout_rate=0.0).to(device) self._gamma = gamma ############################ # exploration exploitation tradeoff self.epsilon = conf['epsilon'] self.n_episodes = conf['n_episodes'] self.n_cycles = conf['n_cycles'] self.eps_scheduler = conf['eps_scheduler'] ############################ # Network self.bs = conf['bs'] self._loss_function = nn.MSELoss() # conf['loss'] self._q_optimizer = optim.Adam(self._q.parameters(), lr=conf['lr']) self._action_dim = action_dim self._replay_buffer = ReplayBuffer(conf['mem_size']) self.scheduler = StepLR(self._q_optimizer, step_size=1, gamma=0.99) ############################ # actions self.action = acton_discrete(action_dim) def get_action(self, x, epsilon): u = np.argmax(self._q(tt(x)).cpu().detach().numpy()) r = np.random.uniform() if r < epsilon: return np.random.randint(self._action_dim) return u def train(self, episodes, time_steps, env, conf): # Statistics for each episode # start_time = time.time() stats = EpisodeStats(episode_lengths=np.zeros(episodes), episode_rewards=np.zeros(episodes), episode_loss=np.zeros(episodes), episode_epsilon=np.zeros(episodes)) ############################ # opt: each 20e continuos ploting (only works without bohb) # cont_plot = continous_plot() ############################ # Loop over episodes eps = self.epsilon for e in range(episodes): # reduce epsilon by decay rate if self.eps_scheduler == 'cos': eps = cos_ann_w_restarts(e, self.n_episodes, self.n_cycles, self.epsilon) elif self.eps_scheduler == 'exp': eps = exponential_decay_w_restarts(e, self.n_episodes, self.n_cycles, conf['epsilon'], 0.03, conf['decay_rate']) stats.episode_epsilon[e] = eps ############################ # opt: each 20e continuos ploting (only works without bohb) #if e % 20 == 0: # cont_plot.plot_stats(stats) ############################ stats.episode_lengths[e] = 0 s = env.reset() for t in range(time_steps): ############################ # opt: render env every 5 episodes #if e % 5 == 0: # env.render() ############################ # act and get results a = self.get_action(s, eps) ns, r, d, _ = env.step(self.action.act(a)) ns[2] = angle_normalize(ns[2]) stats.episode_rewards[e] += r self._replay_buffer.add_transition(s, a, ns, r, d) batch_states, batch_actions, batch_next_states, batch_rewards, batch_terminal_flags = self._replay_buffer.random_next_batch(self.bs) # NOQA # get actions of Target network target = (batch_rewards + (1 - batch_terminal_flags) * self._gamma * self._q_target(batch_next_states)[ torch.arange(conf['bs']).long(), torch.argmax(self._q(batch_next_states), dim=1)]) # get actions of value network current_prediction = self._q(batch_states)[ torch.arange(self.bs).long(), batch_actions.long()] ############################ # Update acting network loss = self._loss_function(current_prediction, target.detach()) stats.episode_loss[e] += loss.cpu().detach() self._q_optimizer.zero_grad() loss.backward() self._q_optimizer.step() # Update target network soft_update(self._q_target, self._q, 0.01) ############################ # stop episode if carte leaves boundaries if d: stats.episode_lengths[e] = t break s = ns ############################ # if episode didn't failed, time is maximal time if stats.episode_lengths[e] == 0: stats.episode_lengths[e] = time_steps self.scheduler.step() return stats