state = torch.Tensor([env.reset()]).to(device) while True: if args.render_train: env.render() action = agent.calc_action(state, ou_noise) next_state, reward, done, _ = env.step(action.cpu().numpy()[0]) timestep += 1 epoch_return += reward mask = torch.Tensor([done]).to(device) reward = torch.Tensor([reward]).to(device) next_state = torch.Tensor([next_state]).to(device) memory.push(state, action, mask, next_state, reward) state = next_state epoch_value_loss = 0 epoch_policy_loss = 0 if len(memory) > args.batch_size: transitions = memory.sample(args.batch_size) # Transpose the batch # (see http://stackoverflow.com/a/19343/3343043 for detailed explanation). batch = Transition(*zip(*transitions)) # Update actor and critic according to the batch value_loss, policy_loss = agent.update_params(batch)
class DQNDoubleQAgent(BaseAgent): def __init__(self): super(DQNDoubleQAgent, self).__init__() self.training = False self.max_frames = 2000000 self._epsilon = Epsilon(start=1.0, end=0.1, update_increment=0.0001) self.gamma = 0.99 self.train_q_per_step = 4 self.train_q_batch_size = 256 self.steps_before_training = 10000 self.target_q_update_frequency = 50000 self._Q_weights_path = "./data/SC2DoubleQAgent" self._Q = DQNCNN() if os.path.isfile(self._Q_weights_path): self._Q.load_state_dict(torch.load(self._Q_weights_path)) print("Loading weights:", self._Q_weights_path) self._Qt = copy.deepcopy(self._Q) self._Q.cuda() self._Qt.cuda() self._optimizer = optim.Adam(self._Q.parameters(), lr=1e-8) self._criterion = nn.MSELoss() self._memory = ReplayMemory(100000) self._loss = deque(maxlen=1000) self._max_q = deque(maxlen=1000) self._action = None self._screen = None self._fig = plt.figure() self._plot = [plt.subplot(2, 2, i + 1) for i in range(4)] self._screen_size = 28 def get_env_action(self, action, obs): action = np.unravel_index(action, [1, self._screen_size, self._screen_size]) target = [action[2], action[1]] command = _MOVE_SCREEN #action[0] # removing unit selection out of the equation # if command == 0: # command = _SELECT_POINT # else: # command = _MOVE_SCREEN if command in obs.observation["available_actions"]: return actions.FunctionCall(command, [[0], target]) else: return actions.FunctionCall(_NO_OP, []) ''' :param s = obs.observation["screen"] :returns action = argmax action ''' def get_action(self, s): # greedy if np.random.rand() > self._epsilon.value(): # print("greedy action") s = Variable(torch.from_numpy(s).cuda()) s = s.unsqueeze(0).float() self._action = self._Q(s).squeeze().cpu().data.numpy() return self._action.argmax() # explore else: # print("random choice") # action = np.random.choice([0, 1]) action = 0 target = np.random.randint(0, self._screen_size, size=2) return action * self._screen_size * self._screen_size + target[ 0] * self._screen_size + target[1] def select_friendly_action(self, obs): player_relative = obs.observation["screen"][_PLAYER_RELATIVE] friendly_y, friendly_x = ( player_relative == _PLAYER_FRIENDLY).nonzero() target = [int(friendly_x.mean()), int(friendly_y.mean())] return actions.FunctionCall(_SELECT_POINT, [[0], target]) def train(self, env, training=True): self._epsilon.isTraining = training self.run_loop(env, self.max_frames) if self._epsilon.isTraining: torch.save(self._Q.state_dict(), self._Q_weights_path) def run_loop(self, env, max_frames=0): """A run loop to have agents and an environment interact.""" total_frames = 0 start_time = time.time() action_spec = env.action_spec() observation_spec = env.observation_spec() self.setup(observation_spec, action_spec) try: while True: obs = env.reset()[0] # remove unit selection from the equation by selecting the friendly on every new game. select_friendly = self.select_friendly_action(obs) obs = env.step([select_friendly])[0] # distance = self.get_reward(obs.observation["screen"]) self.reset() while True: total_frames += 1 self._screen = obs.observation["screen"][5] s = np.expand_dims(obs.observation["screen"][5], 0) # plt.imshow(s[5]) # plt.pause(0.00001) if max_frames and total_frames >= max_frames: print("max frames reached") return if obs.last(): print("total frames:", total_frames, "Epsilon:", self._epsilon.value()) self._epsilon.increment() break action = self.get_action(s) env_actions = self.get_env_action(action, obs) obs = env.step([env_actions])[0] r = obs.reward s1 = np.expand_dims(obs.observation["screen"][5], 0) done = r > 0 if self._epsilon.isTraining: transition = Transition(s, action, s1, r, done) self._memory.push(transition) if total_frames % self.train_q_per_step == 0 and total_frames > self.steps_before_training and self._epsilon.isTraining: self.train_q() # pass if total_frames % self.target_q_update_frequency == 0 and total_frames > self.steps_before_training and self._epsilon.isTraining: self._Qt = copy.deepcopy(self._Q) self.show_chart() if total_frames % 1000 == 0 and total_frames > self.steps_before_training and self._epsilon.isTraining: self.show_chart() if not self._epsilon.isTraining and total_frames % 3 == 0: self.show_chart() except KeyboardInterrupt: pass finally: print("finished") elapsed_time = time.time() - start_time print("Took %.3f seconds for %s steps: %.3f fps" % (elapsed_time, total_frames, total_frames / elapsed_time)) def get_reward(self, s): player_relative = s[_PLAYER_RELATIVE] neutral_y, neutral_x = (player_relative == _PLAYER_NEUTRAL).nonzero() neutral_target = [int(neutral_x.mean()), int(neutral_y.mean())] friendly_y, friendly_x = ( player_relative == _PLAYER_FRIENDLY).nonzero() if len(friendly_y) == 0 or len(friendly_x) == 0: # this is shit return 0 friendly_target = [int(friendly_x.mean()), int(friendly_y.mean())] distance_2 = (neutral_target[0] - friendly_target[0])**2 + ( neutral_target[1] - friendly_target[1])**2 distance = math.sqrt(distance_2) return -distance def show_chart(self): self._plot[0].clear() self._plot[0].set_xlabel('Last 1000 Training Cycles') self._plot[0].set_ylabel('Loss') self._plot[0].plot(list(self._loss)) self._plot[1].clear() self._plot[1].set_xlabel('Last 1000 Training Cycles') self._plot[1].set_ylabel('Max Q') self._plot[1].plot(list(self._max_q)) self._plot[2].clear() self._plot[2].set_title("screen") self._plot[2].imshow(self._screen) self._plot[3].clear() self._plot[3].set_title("action") self._plot[3].imshow(self._action) plt.pause(0.00001) def train_q(self): if self.train_q_batch_size >= len(self._memory): return s, a, s_1, r, done = self._memory.sample(self.train_q_batch_size) s = Variable(torch.from_numpy(s).cuda()).float() a = Variable(torch.from_numpy(a).cuda()).long() s_1 = Variable(torch.from_numpy(s_1).cuda(), volatile=True).float() r = Variable(torch.from_numpy(r).cuda()).float() done = Variable(torch.from_numpy(1 - done).cuda()).float() # Q_sa = r + gamma * max(Q_s'a') Q = self._Q(s) Q = Q.view(self.train_q_batch_size, -1) Q = Q.gather(1, a) Qt = self._Qt(s_1).view(self.train_q_batch_size, -1) # double Q best_action = self._Q(s_1).view(self.train_q_batch_size, -1).max(dim=1, keepdim=True)[1] y = r + done * self.gamma * Qt.gather(1, best_action) # Q # y = r + done * self.gamma * Qt.max(dim=1)[0].unsqueeze(1) y.volatile = False loss = self._criterion(Q, y) self._loss.append(loss.sum().cpu().data.numpy()) self._max_q.append(Q.max().cpu().data.numpy()[0]) self._optimizer.zero_grad() # zero the gradient buffers loss.backward() self._optimizer.step()
class TestReplayMemory(unittest.TestCase): def setUp(self): self.memory = ReplayMemory(capacity=10) def test_append(self): for i in range(20): a = Transition([0, 1, 2, 3], 0, [4, 5, 6, 7], 0, True) self.memory.push(a) self.assertEqual(len(self.memory.memory), 10) def test_sample(self): for i in range(10): a = Transition([0, 1, 2, i], 0, [4, 5, 6, i*i], 0, True) self.memory.push(a) s, a, s1, r, done = self.memory.sample(2) self.assertEqual(s.shape, (2, 4)) self.assertEqual(a.shape, (2, 1)) self.assertEqual(s1.shape, (2, 4)) self.assertEqual(r.shape, (2, 1)) self.assertEqual(done.shape, (2, 1)) def test_multi_step(self): self.memory = ReplayMemory(capacity=10, multi_step_n=2) for i in range(5): a = Transition([0, 1, 2, i], 0, [4, 5, 6, i*i], 1, False) self.memory.push(a) final = Transition([0, 1, 2, 10], 0, [4, 5, 6, 100], 10, True) self.memory.push(final) self.assertEqual(self.memory.memory[0].r, 2.9701) self.assertEqual(self.memory.memory[3].r, 11.791) self.assertEqual(self.memory.memory[4].r, 10.9) self.assertEqual(self.memory.memory[5].r, 10) def test_zero_step(self): self.memory = ReplayMemory(capacity=10, multi_step_n=0) for i in range(5): a = Transition([0, 1, 2, i], 0, [4, 5, 6, i*i], 1, False) self.memory.push(a) final = Transition([0, 1, 2, 10], 0, [4, 5, 6, 100], 10, True) self.memory.push(final) self.assertEqual(self.memory.memory[0].r, 1) self.assertEqual(self.memory.memory[3].r, 1) self.assertEqual(self.memory.memory[4].r, 1) self.assertEqual(self.memory.memory[5].r, 10)
class NEC: def __init__(self, env, args, device='cpu'): """ Instantiate an NEC Agent ---------- env: gym.Env gym environment to train on args: args class from argparser args are from from train.py: see train.py for help with each arg device: string 'cpu' or 'cuda:0' depending on use_cuda flag from train.py """ self.environment_type = args.environment_type self.env = env self.device = device # Hyperparameters self.epsilon = args.initial_epsilon self.final_epsilon = args.final_epsilon self.epsilon_decay = args.epsilon_decay self.gamma = args.gamma self.N = args.N # Transition queue and replay memory self.transition_queue = [] self.replay_every = args.replay_every self.replay_buffer_size = args.replay_buffer_size self.replay_memory = ReplayMemory(self.replay_buffer_size) # CNN for state embedding network self.frames_to_stack = args.frames_to_stack self.embedding_size = args.embedding_size self.in_height = args.in_height self.in_width = args.in_width self.cnn = CNN(self.frames_to_stack, self.embedding_size, self.in_height, self.in_width).to(self.device) # Differentiable Neural Dictionary (DND): one for each action self.kernel = inverse_distance self.num_neighbors = args.num_neighbors self.max_memory = args.max_memory self.lr = args.lr self.dnd_list = [] for i in range(env.action_space.n): self.dnd_list.append( DND(self.kernel, self.num_neighbors, self.max_memory, args.optimizer, self.lr)) # Optimizer for state embedding CNN self.q_lr = args.q_lr self.batch_size = args.batch_size self.optimizer = get_optimizer(args.optimizer, self.cnn.parameters(), self.lr) def choose_action(self, state_embedding): """ Choose epsilon-greedy policy according to Q-estimates from DNDs """ if random.uniform(0, 1) < self.epsilon: return random.randint(0, self.env.action_space.n - 1) else: qs = [dnd.lookup(state_embedding) for dnd in self.dnd_list] action = torch.argmax(torch.cat(qs)) return action def Q_lookahead(self, t, warmup=False): """ Return the N-step Q-value lookahead from time t in the transition queue """ if warmup or len(self.transition_queue) <= t + self.N: lookahead = [tr.reward for tr in self.transition_queue[t:]] discounted = discount(lookahead, self.gamma) Q_N = torch.tensor([discounted], requires_grad=True) return Q_N else: lookahead = [ tr.reward for tr in self.transition_queue[t:t + self.N] ] discounted = discount(lookahead, self.gamma) state = self.transition_queue[t + self.N].state state = torch.tensor(state).permute(2, 0, 1).unsqueeze(0) # (N,C,H,W) state = state.to(self.device) state_embedding = self.cnn(state) Q_a = [dnd.lookup(state_embedding) for dnd in self.dnd_list] maxQ = torch.cat(Q_a).max() Q_N = discounted + (self.gamma**self.N) * maxQ Q_N = torch.tensor([Q_N], requires_grad=True) return Q_N def Q_update(self, Q, Q_N): """ Return the Q-update for DND updates """ return Q + self.q_lr * (Q_N - Q) def update(self): """ Iterate through the transition queue and make NEC updates """ # Insert transitions into DNDs for t in range(len(self.transition_queue)): tr = self.transition_queue[t] action = tr.action tr = self.transition_queue[t] state = torch.tensor(tr.state).permute(2, 0, 1) # (C,H,W) state = state.unsqueeze(0).to(self.device) # (N,C,H,W) state_embedding = self.cnn(state) dnd = self.dnd_list[action] Q_N = self.Q_lookahead(t).to(self.device) embedding_index = dnd.get_index(state_embedding) if embedding_index is None: dnd.insert(state_embedding.detach(), Q_N.detach().unsqueeze(0)) else: Q = self.Q_update(dnd.values[embedding_index], Q_N) dnd.update(Q.detach(), embedding_index) Q_N = Q_N.detach().to(self.device) self.replay_memory.push(tr.state, action, Q_N) # Commit inserts for dnd in self.dnd_list: dnd.commit_insert() # Train CNN on minibatch for t in range(len(self.transition_queue)): if t % self.replay_every == 0 or t == len( self.transition_queue) - 1: # Train on random mini-batch from self.replay_memory batch = self.replay_memory.sample(self.batch_size) actual_Qs = torch.cat([sample.Q_N for sample in batch]) predicted_Qs = [] for sample in batch: state = torch.tensor(sample.state).permute(2, 0, 1) # (C,H,W) state = state.unsqueeze(0).to(self.device) # (N,C,H,W) state_embedding = self.cnn(state) dnd = self.dnd_list[sample.action] predicted_Q = dnd.lookup(state_embedding, update_flag=True) predicted_Qs.append(predicted_Q) predicted_Qs = torch.cat(predicted_Qs).to(self.device) loss = torch.dist(actual_Qs, predicted_Qs) self.optimizer.zero_grad() loss.backward() self.optimizer.step() for dnd in self.dnd_list: dnd.update_params() # Clear out transition queue self.transition_queue = [] def run_episode(self): """ Train an NEC agent for a single episode: Interact with environment Append (state, action, reward) transitions to transition queue Call update at the end of the episode """ if self.epsilon > self.final_epsilon: self.epsilon = self.epsilon * self.epsilon_decay state = self.env.reset() if self.environment_type == 'fourrooms': fewest_steps = self.env.shortest_path_length(self.env.state) total_steps = 0 total_reward = 0 total_frames = 0 done = False while not done: state_embedding = torch.tensor(state).permute(2, 0, 1) # (C,H,W) state_embedding = state_embedding.unsqueeze(0).to(self.device) state_embedding = self.cnn(state_embedding) action = self.choose_action(state_embedding) next_state, reward, done, _ = self.env.step(action) self.transition_queue.append(Transition(state, action, reward)) total_reward += reward total_frames += self.env.skip total_steps += 1 state = next_state self.update() if self.environment_type == 'fourrooms': n_extra_steps = total_steps - fewest_steps return n_extra_steps, total_frames, total_reward else: return total_frames, total_reward def warmup(self): """ Warmup the DND with values from an episode with a random policy """ state = self.env.reset() total_reward = 0 total_frames = 0 done = False while not done: action = random.randint(0, self.env.action_space.n - 1) next_state, reward, done, _ = self.env.step(action) total_reward += reward total_frames += self.env.skip self.transition_queue.append(Transition(state, action, reward)) state = next_state for t in range(len(self.transition_queue)): tr = self.transition_queue[t] state_embedding = torch.tensor(tr.state).permute(2, 0, 1) # (C,H,W) state_embedding = state_embedding.unsqueeze(0).to(self.device) state_embedding = self.cnn(state_embedding) action = tr.action dnd = self.dnd_list[action] Q_N = self.Q_lookahead(t, True).to(self.device) if dnd.keys_to_be_inserted is None and dnd.keys is None: dnd.insert(state_embedding, Q_N.detach().unsqueeze(0)) else: embedding_index = dnd.get_index(state_embedding) if embedding_index is None: state_embedding = state_embedding.detach() dnd.insert(state_embedding, Q_N.detach().unsqueeze(0)) else: Q = self.Q_update(dnd.values[embedding_index], Q_N) dnd.update(Q.detach(), embedding_index) self.replay_memory.push(tr.state, action, Q_N.detach()) for dnd in self.dnd_list: dnd.commit_insert() # Clear out transition queue self.transition_queue = [] return total_frames, total_reward
class DQN_Agent(): ''' Regular Q-Learning Agent One deep network. DQN - to predict Q of a given action, value a state. i.e. Q(s,a) and Q(s', a') for loss calculation. ''' def __init__( self, state_size, n_actions, args, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")): self.device = device # Exploration / Exploitation params. self.steps_done = 0 self.eps_threshold = 1 self.eps_start = args.eps_start self.eps_end = args.eps_end self.eps_decay = args.eps_decay # RL params self.target_update = args.target_update self.discount = args.discount # Env params self.n_actions = n_actions self.state_size = state_size # Deep q networks params self.layers = args.layers self.batch_size = args.batch_size self.policy_net = DQN(state_size, n_actions, layers=self.layers).to(self.device).float() self.target_net = None self.grad_clip = args.grad_clip if str(args.optimizer).lower() == 'adam': self.optimizer = optim.Adam(self.policy_net.parameters()) if str(args.optimizer).lower() == 'rmsprop': self.optimizer = optim.RMSprop(self.policy_net.parameters()) else: raise NotImplementedError self.memory = ReplayMemory(args.replay_size) # Performance buffers. self.rewards_list = [] def add_to_memory(self, state, action, next_state, reward): self.rewards_list.append(reward) state = torch.from_numpy(state).float() action = torch.tensor([action]) next_state = torch.from_numpy(next_state).float() reward = torch.tensor([reward]) self.memory.push(state, action, next_state, reward) def select_action(self, state): sample = random.random() self.eps_threshold = self.eps_end + (self.eps_start - self.eps_end) * \ math.exp(-1. * self.steps_done / self.eps_decay) self.steps_done += 1 if sample > self.eps_threshold: with torch.no_grad(): # t.max(1) will return largest column value of each row. # second column on max result is index of where max element was # found, so we pick action with the larger expected reward. state = torch.from_numpy(state).float().to( self.device) # Convert to tensor. state = state.unsqueeze(0) # Add batch dimension. return self.policy_net(state).max(1)[1].view(1, 1) else: return torch.tensor([[random.randrange(self.n_actions)]], device=self.device, dtype=torch.long).item() def optimize_model(self): if len(self.memory) < self.batch_size: return transitions = self.memory.sample(self.batch_size) # This converts batch-array of Transitions # to Transition of batch-arrays. batch = Transition(*zip(*transitions)) next_states_batch = torch.cat(batch.next_state).view( self.batch_size, -1).to(self.device) state_batch = torch.cat(batch.state).view(self.batch_size, -1).to(self.device) action_batch = torch.cat(batch.action).view(self.batch_size, -1).to(self.device) reward_batch = torch.cat(batch.reward).view(self.batch_size, -1).to(self.device) # Compute loss loss = self._compute_loss(state_batch, action_batch, next_states_batch, reward_batch) # Optimize the model self.optimizer.zero_grad() loss.backward() # clip grad if self.grad_clip is not None: for param in self.policy_net.parameters(): param.grad.data.clamp_(-self.grad_clip, self.grad_clip) # update Policy net weights self.optimizer.step() # update Target net weights self._update_target() def _compute_loss(self, state_batch, action_batch, next_states_batch, reward_batch): # Compute Q(s_t, a) - the model computes Q(s_t), then we select the # columns of actions taken. These are the actions which would've been taken # for each batch state according to policy_net state_action_values = self.policy_net(state_batch).gather( 1, action_batch) # Compute V(s_{t+1}) for all next states using the same policy net. next_state_values = torch.zeros(self.batch_size, device=self.device) next_state_values = self.policy_net(next_states_batch).max( 1)[0].detach() # Compute the expected Q values expected_state_action_values = (next_state_values.unsqueeze(1) * self.discount) + reward_batch # Compute Huber loss loss = F.smooth_l1_loss(state_action_values, expected_state_action_values) return loss def _update_target(self): if self.target_net is None: # There is nothing to update. return # Update the target network, copying all weights and biases in DQN if self.target_update > 1: # Hard copy of weights. if self.steps_done % self.target_update == 0: self.target_net.load_state_dict(self.policy_net.state_dict()) return elif self.target_update < 1 and self.target_update > 0: # polyak averaging: tau = self.target_update for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()): target_param.data.copy_(tau * param + (1 - tau) * target_param) return else: raise NotImplementedError def save_ckpt(self, ckpt_folder): ''' saves checkpoint of policy net in ckpt_folder :param ckpt_folder: path to a folder. ''' ckpt_path = os.path.join(ckpt_folder, 'policy_net_state_dict.pth') torch.save(self.policy_net.state_dict(), ckpt_path)