class DQN(Base_Agent): """A deep Q learning agent""" agent_name = "DQN" def __init__(self, config): Base_Agent.__init__(self, config) model_path = self.config.model_path if self.config.model_path else 'Models' self.memory = Replay_Buffer(self.hyperparameters["buffer_size"], self.hyperparameters["batch_size"], config.seed) self.q_network_local = self.create_NN(input_dim=self.state_size, output_dim=self.action_size) self.q_network_local_path = os.path.join( model_path, "{}_q_network_local.pt".format(self.agent_name)) if self.config.load_model: self.locally_load_policy() self.q_network_optimizer = optim.Adam( self.q_network_local.parameters(), lr=self.hyperparameters["learning_rate"], eps=1e-4) self.exploration_strategy = Epsilon_Greedy_Exploration(config) def reset_game(self): super(DQN, self).reset_game() self.update_learning_rate(self.hyperparameters["learning_rate"], self.q_network_optimizer) def step(self): """Runs a step within a game including a learning step if required""" while not self.done: self.action = self.pick_action() self.conduct_action(self.action) if self.time_for_q_network_to_learn(): for _ in range(self.hyperparameters["learning_iterations"]): self.learn() self.save_experience() self.state = self.next_state #this is to set the state for the next iteration self.global_step_number += 1 self.episode_number += 1 def pick_action(self, state=None): """Uses the local Q network and an epsilon greedy policy to pick an action""" # PyTorch only accepts mini-batches and not single observations so we have to use unsqueeze to add # a "fake" dimension to make it a mini-batch rather than a single observation if state is None: state = self.state if isinstance(state, np.int64) or isinstance(state, int): state = np.array([state]) state = torch.from_numpy(state).float().unsqueeze(0).to(self.device) if len(state.shape) < 2: state = state.unsqueeze(0) self.q_network_local.eval() #puts network in evaluation mode with torch.no_grad(): action_values = self.q_network_local(state) self.q_network_local.train() #puts network back in training mode action = self.exploration_strategy.perturb_action_for_exploration_purposes( { "action_values": action_values, "turn_off_exploration": self.turn_off_exploration, "episode_number": self.episode_number }) self.logger.info("Q values {} -- Action chosen {}".format( action_values, action)) return action def learn(self, experiences=None): """Runs a learning iteration for the Q network""" if experiences is None: states, actions, rewards, next_states, dones = self.sample_experiences( ) #Sample experiences else: states, actions, rewards, next_states, dones = experiences loss = self.compute_loss(states, next_states, rewards, actions, dones) actions_list = [action_X.item() for action_X in actions] self.logger.info("Action counts {}".format(Counter(actions_list))) self.take_optimisation_step( self.q_network_optimizer, self.q_network_local, loss, self.hyperparameters["gradient_clipping_norm"]) def compute_loss(self, states, next_states, rewards, actions, dones): """Computes the loss required to train the Q network""" with torch.no_grad(): Q_targets = self.compute_q_targets(next_states, rewards, dones) Q_expected = self.compute_expected_q_values(states, actions) loss = F.mse_loss(Q_expected, Q_targets) return loss def compute_q_targets(self, next_states, rewards, dones): """Computes the q_targets we will compare to predicted q values to create the loss to train the Q network""" Q_targets_next = self.compute_q_values_for_next_states(next_states) Q_targets = self.compute_q_values_for_current_states( rewards, Q_targets_next, dones) return Q_targets def compute_q_values_for_next_states(self, next_states): """Computes the q_values for next state we will use to create the loss to train the Q network""" Q_targets_next = self.q_network_local(next_states).detach().max( 1)[0].unsqueeze(1) return Q_targets_next def compute_q_values_for_current_states(self, rewards, Q_targets_next, dones): """Computes the q_values for current state we will use to create the loss to train the Q network""" Q_targets_current = rewards + (self.hyperparameters["discount_rate"] * Q_targets_next * (1 - dones)) return Q_targets_current def compute_expected_q_values(self, states, actions): """Computes the expected q_values we will use to create the loss to train the Q network""" Q_expected = self.q_network_local(states).gather(1, actions.long( )) #must convert actions to long so can be used as index return Q_expected def time_for_q_network_to_learn(self): """Returns boolean indicating whether enough steps have been taken for learning to begin and there are enough experiences in the replay buffer to learn from""" return self.right_amount_of_steps_taken( ) and self.enough_experiences_to_learn_from() def right_amount_of_steps_taken(self): """Returns boolean indicating whether enough steps have been taken for learning to begin""" return self.global_step_number % self.hyperparameters[ "update_every_n_steps"] == 0 def sample_experiences(self): """Draws a random sample of experience from the memory buffer""" experiences = self.memory.sample() states, actions, rewards, next_states, dones = experiences return states, actions, rewards, next_states, dones def locally_save_policy(self): """Saves the policy""" """保存策略,待添加""" torch.save(self.q_network_local.state_dict(), self.q_network_local_path) def locally_load_policy(self): print("locall_load_policy") if os.path.isfile(self.q_network_local_path): try: self.q_network_local.load_state_dict( torch.load(self.q_network_local_path)) print("load critic_local_path") except: pass
class DDQN_Wrapper(Base_Agent): def __init__(self, config, global_action_id_to_primitive_actions, action_length_reward_bonus, end_of_episode_symbol="/"): super().__init__(config) self.end_of_episode_symbol = end_of_episode_symbol self.global_action_id_to_primitive_actions = global_action_id_to_primitive_actions self.memory = Replay_Buffer(self.hyperparameters["buffer_size"], self.hyperparameters["batch_size"], config.seed) self.exploration_strategy = Epsilon_Greedy_Exploration(config) self.oracle = self.create_oracle() self.oracle_optimizer = optim.Adam( self.oracle.parameters(), lr=self.hyperparameters["learning_rate"]) self.q_network_local = self.create_NN(input_dim=self.state_size + 1, output_dim=self.action_size) self.q_network_local.print_model_summary() self.q_network_optimizer = optim.Adam( self.q_network_local.parameters(), lr=self.hyperparameters["learning_rate"]) self.q_network_target = self.create_NN(input_dim=self.state_size + 1, output_dim=self.action_size) Base_Agent.copy_model_over(from_model=self.q_network_local, to_model=self.q_network_target) self.action_length_reward_bonus = action_length_reward_bonus self.abandon_ship = config.hyperparameters["abandon_ship"] def create_oracle(self): """Creates the network we will use to predict the next state""" oracle_hyperparameters = copy.deepcopy(self.hyperparameters) oracle_hyperparameters["columns_of_data_to_be_embedded"] = [] oracle_hyperparameters["embedding_dimensions"] = [] oracle_hyperparameters["linear_hidden_units"] = [5, 5] oracle_hyperparameters["final_layer_activation"] = [None, "tanh"] oracle = self.create_NN(input_dim=self.state_size + 2, output_dim=[self.state_size + 1, 1], hyperparameters=oracle_hyperparameters) oracle.print_model_summary() return oracle def run_n_episodes(self, num_episodes, episodes_to_run_with_no_exploration): self.turn_on_any_epsilon_greedy_exploration() self.round_of_macro_actions = [] self.episode_actions_scores_and_exploration_status = [] num_episodes_to_get_to = self.episode_number + num_episodes while self.episode_number < num_episodes_to_get_to: self.reset_game() self.step() self.save_and_print_result() if num_episodes_to_get_to - self.episode_number == episodes_to_run_with_no_exploration: self.turn_off_any_epsilon_greedy_exploration() assert len(self.episode_actions_scores_and_exploration_status ) == num_episodes, "{} vs. {}".format( len(self.episode_actions_scores_and_exploration_status), num_episodes) assert len(self.episode_actions_scores_and_exploration_status[0]) == 3 assert self.episode_actions_scores_and_exploration_status[0][2] in [ True, False ] assert isinstance( self.episode_actions_scores_and_exploration_status[0][1], list) assert isinstance( self.episode_actions_scores_and_exploration_status[0][1][0], int) assert isinstance( self.episode_actions_scores_and_exploration_status[0][0], int) or isinstance( self.episode_actions_scores_and_exploration_status[0][0], float) return self.episode_actions_scores_and_exploration_status, self.round_of_macro_actions def step(self): """Runs a step within a game including a learning step if required""" step_number = 0.0 self.state = np.append( self.state, step_number / 200.0) #Divide by 200 because there are 200 steps in cart pole self.total_episode_score_so_far = 0 episode_macro_actions = [] while not self.done: surprised = False macro_action = self.pick_action() primitive_actions = self.global_action_id_to_primitive_actions[ macro_action] primitive_actions_conducted = 0 for ix, action in enumerate(primitive_actions): if self.abandon_ship and primitive_actions_conducted > 0: if self.abandon_macro_action(action): break step_number += 1 self.action = action self.next_state, self.reward, self.done, _ = self.environment.step( action) self.next_state = np.append( self.next_state, step_number / 200.0 ) #Divide by 200 because there are 200 steps in cart pole self.total_episode_score_so_far += self.reward if self.hyperparameters["clip_rewards"]: self.reward = max(min(self.reward, 1.0), -1.0) primitive_actions_conducted += 1 self.track_episodes_data() self.save_experience() if len(primitive_actions) > 1: surprised = self.am_i_surprised() self.state = self.next_state if self.time_for_q_network_to_learn(): for _ in range( self.hyperparameters["learning_iterations"]): self.q_network_learn() self.oracle_learn() if self.done or surprised: break episode_macro_actions.append(macro_action) self.round_of_macro_actions.append(macro_action) if random.random() < 0.1: print(Counter(episode_macro_actions)) self.save_episode_actions_with_score() self.episode_number += 1 self.logger.info("END OF EPISODE") def am_i_surprised(self): """Returns boolean indicating whether the next_state was a surprise or not""" with torch.no_grad(): state = torch.from_numpy(self.state).float().unsqueeze(0).to( self.device) action = torch.Tensor([[self.action]]) states_and_actions = torch.cat( (state, action), dim=1) #must change this for all games besides cart pole predictions = self.oracle(states_and_actions) predicted_next_state = predictions[0, :-1] difference = F.mse_loss(predicted_next_state, torch.Tensor(self.next_state)) if difference > 0.5: print("Surprise! Loss {} -- {} vs. {}".format( difference, predicted_next_state, self.next_state)) return True else: return False def abandon_macro_action(self, action): """Returns boolean indicating whether to abandon macro action or not""" state = torch.from_numpy(self.state).float().unsqueeze(0).to( self.device) with torch.no_grad(): primitive_q_values = self.calculate_q_values( state, local=True, primitive_actions_only=True) q_value_highest = torch.max(primitive_q_values) q_values_action = primitive_q_values[:, action] if q_value_highest > 0.0: multiplier = 0.7 else: multiplier = 1.3 if q_values_action < multiplier * q_value_highest: print("BREAKING Action {} -- Q Values {}".format( action, primitive_q_values)) return True else: return False def pick_action(self, state=None): """Uses the local Q network and an epsilon greedy policy to pick an action""" if state is None: state = self.state if isinstance(state, np.int64) or isinstance(state, int): state = np.array([state]) state = torch.from_numpy(state).float().unsqueeze(0).to(self.device) if len(state.shape) < 2: state = state.unsqueeze(0) self.q_network_local.eval() #puts network in evaluation mode with torch.no_grad(): action_values = self.calculate_q_values( state, local=True, primitive_actions_only=False) self.q_network_local.train() #puts network back in training mode action = self.exploration_strategy.perturb_action_for_exploration_purposes( { "action_values": action_values, "turn_off_exploration": self.turn_off_exploration, "episode_number": self.episode_number }) self.logger.info("Q values {} -- Action chosen {}".format( action_values, action)) return action def calculate_q_values(self, states, local, primitive_actions_only): """Calculates the q values using the local q network""" if local: primitive_q_values = self.q_network_local(states) else: primitive_q_values = self.q_network_target(states) num_actions = len(self.global_action_id_to_primitive_actions) if primitive_actions_only or num_actions <= self.action_size: return primitive_q_values extra_q_values = self.calculate_macro_action_q_values( states, num_actions) extra_q_values = torch.Tensor([extra_q_values]) all_q_values = torch.cat((primitive_q_values, extra_q_values), dim=1) return all_q_values def calculate_macro_action_q_values(self, state, num_actions): assert state.shape[0] == 1 q_values = [] for action_id in range(self.action_size, num_actions): macro_action = self.global_action_id_to_primitive_actions[ action_id] predicted_next_state = state cumulated_reward = 0 action_ix = 0 for action in macro_action[:-1]: predictions = self.oracle( torch.cat((predicted_next_state, torch.Tensor([[action]])), dim=1)) rewards = predictions[:, -1] predicted_next_state = predictions[:, :-1] cumulated_reward += ( rewards.item() + self.action_length_reward_bonus ) * self.hyperparameters["discount_rate"]**(action_ix) action_ix += 1 final_action = macro_action[-1] final_q_value = self.q_network_local(predicted_next_state)[ 0, final_action] total_q_value = cumulated_reward + final_q_value * self.hyperparameters[ "discount_rate"]**(action_ix) q_values.append(total_q_value) return q_values def time_for_q_network_to_learn(self): """Returns boolean indicating whether enough steps have been taken for learning to begin and there are enough experiences in the replay buffer to learn from""" return self.right_amount_of_steps_taken( ) and self.enough_experiences_to_learn_from() def right_amount_of_steps_taken(self): """Returns boolean indicating whether enough steps have been taken for learning to begin""" return self.global_step_number % self.hyperparameters[ "update_every_n_steps"] == 0 def q_network_learn(self, experiences=None): """Runs a learning iteration for the Q network""" if experiences is None: states, actions, rewards, next_states, dones = self.sample_experiences( ) #Sample experiences else: states, actions, rewards, next_states, dones = experiences loss = self.compute_loss(states, next_states, rewards, actions, dones) self.take_optimisation_step( self.q_network_optimizer, self.q_network_local, loss, self.hyperparameters["gradient_clipping_norm"]) self.soft_update_of_target_network(self.q_network_local, self.q_network_target, self.hyperparameters["tau"]) def sample_experiences(self): """Draws a random sample of experience from the memory buffer""" experiences = self.memory.sample() states, actions, rewards, next_states, dones = experiences return states, actions, rewards, next_states, dones def compute_loss(self, states, next_states, rewards, actions, dones): """Computes the loss required to train the Q network""" with torch.no_grad(): max_action_indexes = self.calculate_q_values( next_states, local=True, primitive_actions_only=True).detach().argmax(1) Q_targets_next = self.calculate_q_values( next_states, local=False, primitive_actions_only=True).gather( 1, max_action_indexes.unsqueeze(1)) Q_targets = rewards + (self.hyperparameters["discount_rate"] * Q_targets_next * (1 - dones)) Q_expected = self.calculate_q_values( states, local=True, primitive_actions_only=True).gather(1, actions.long( )) # must convert actions to long so can be used as index loss = F.mse_loss(Q_expected, Q_targets) return loss def save_episode_actions_with_score(self): self.episode_actions_scores_and_exploration_status.append([ self.total_episode_score_so_far, self.episode_actions + [self.end_of_episode_symbol], self.turn_off_exploration ]) def oracle_learn(self): states, actions, rewards, next_states, _ = self.sample_experiences( ) # Sample experiences states_and_actions = torch.cat( (states, actions), dim=1) #must change this for all games besides cart pole predictions = self.oracle(states_and_actions) loss = F.mse_loss(torch.cat((next_states, rewards), dim=1), predictions) / float(next_states.shape[1] + 1.0) self.take_optimisation_step( self.oracle_optimizer, self.oracle, loss, self.hyperparameters["gradient_clipping_norm"]) self.logger.info("Oracle Loss {}".format(loss))
class DQN(Base_Agent): """A deep Q learning agent""" agent_name = "DQN" def __init__(self, config): Base_Agent.__init__(self, config) self.memory = Replay_Buffer(self.hyperparameters["buffer_size"], self.hyperparameters["batch_size"], config.seed) self.q_network_local = self.create_NN(input_dim=self.state_size, output_dim=self.action_size) self.q_network_optimizer = optim.SGD( self.q_network_local.parameters(), lr=self.hyperparameters["learning_rate"], weight_decay=5e-4) self.exploration_strategy = Epsilon_Greedy_Exploration(config) def reset_game(self): super(DQN, self).reset_game() self.update_learning_rate(self.hyperparameters["learning_rate"], self.q_network_optimizer) def step(self): """Runs a step within a game including a learning step if required""" while not self.done: # print('state:', self.state) # self.environment.render() self.action = self.pick_action() self.conduct_action(self.action) if self.time_for_q_network_to_learn(): for _ in range(self.hyperparameters["learning_iterations"]): try: self.environment.pause() # print('pause') self.learn() self.environment.resume() # print('resume') except: self.learn() self.save_experience() self.state = self.next_state #this is to set the state for the next iteration self.global_step_number += 1 self.episode_number += 1 def pick_action(self, state=None): """Uses the local Q network and an epsilon greedy policy to pick an action""" # PyTorch only accepts mini-batches and not single observations so we have to use unsqueeze to add # a "fake" dimension to make it a mini-batch rather than a single observation if state is None: state = self.state if isinstance(state, np.int64) or isinstance(state, int): state = np.array([state]) state = torch.from_numpy(state).float().unsqueeze(0).to(self.device) if len(state.shape) < 2: state = state.unsqueeze(0) self.q_network_local.eval() #puts network in evaluation mode with torch.no_grad(): action_values = self.q_network_local(state) self.q_network_local.train() #puts network back in training mode force_explore = self.config.force_explore_mode and self.need_to_force_explore( ) if force_explore: print('explore...') action = self.exploration_strategy.perturb_action_for_exploration_purposes( { "action_values": action_values, "turn_off_exploration": self.turn_off_exploration, "episode_number": self.episode_number, "force_explore": force_explore }) # self.logger.info("Q values {} -- Action chosen {}".format(action_values, action)) return action def learn(self, experiences=None): """Runs a learning iteration for the Q network""" if experiences is None: states, actions, rewards, next_states, dones = self.sample_experiences( ) #Sample experiences else: states, actions, rewards, next_states, dones = experiences loss = self.compute_loss(states, next_states, rewards, actions, dones) actions_list = [action_X.item() for action_X in actions] self.logger.info("Action counts {}".format(Counter(actions_list))) self.take_optimisation_step( self.q_network_optimizer, self.q_network_local, loss, self.hyperparameters["gradient_clipping_norm"]) def compute_loss(self, states, next_states, rewards, actions, dones): """Computes the loss required to train the Q network""" with torch.no_grad(): Q_targets = self.compute_q_targets(next_states, rewards, dones) Q_expected = self.compute_expected_q_values(states, actions) # loss = F.mse_loss(Q_expected, Q_targets) loss = nn.MSELoss(size_average=False)(Q_expected, Q_targets) return loss def compute_q_targets(self, next_states, rewards, dones): """Computes the q_targets we will compare to predicted q values to create the loss to train the Q network""" Q_targets_next = self.compute_q_values_for_next_states(next_states) Q_targets = self.compute_q_values_for_current_states( rewards, Q_targets_next, dones) return Q_targets def compute_q_values_for_next_states(self, next_states): """Computes the q_values for next state we will use to create the loss to train the Q network""" Q_targets_next = self.q_network_local(next_states).detach().max( 1)[0].unsqueeze(1) return Q_targets_next def compute_q_values_for_current_states(self, rewards, Q_targets_next, dones): """Computes the q_values for current state we will use to create the loss to train the Q network""" Q_targets_current = rewards + (self.hyperparameters["discount_rate"] * Q_targets_next * (1 - dones)) return Q_targets_current def compute_expected_q_values(self, states, actions): """Computes the expected q_values we will use to create the loss to train the Q network""" Q_expected = self.q_network_local(states).gather(1, actions.long( )) #must convert actions to long so can be used as index return Q_expected def time_for_q_network_to_learn(self): """Returns boolean indicating whether enough steps have been taken for learning to begin and there are enough experiences in the replay buffer to learn from""" return self.right_amount_of_steps_taken( ) and self.enough_experiences_to_learn_from() def right_amount_of_steps_taken(self): """Returns boolean indicating whether enough steps have been taken for learning to begin""" return self.global_step_number % self.hyperparameters[ "update_every_n_steps"] == 0 def sample_experiences(self): """Draws a random sample of experience from the memory buffer""" experiences = self.memory.sample() states, actions, rewards, next_states, dones = experiences return states, actions, rewards, next_states, dones def locally_save_policy(self, best=True, episode=None): if self.agent_name != "DQN": state = { 'episode': self.episode_number, 'q_network_local': self.q_network_local.state_dict(), 'q_network_target': self.q_network_target.state_dict() } else: state = { 'episode': self.episode_number, 'q_network_local': self.q_network_local.state_dict() } model_root = os.path.join('Models', self.config.env_title, self.agent_name, self.config.log_base) if not os.path.exists(model_root): os.makedirs(model_root) if best: last_best_file = glob.glob( os.path.join(model_root, 'rolling_score*')) if last_best_file: os.remove(last_best_file[0]) save_name = model_root + "/rolling_score_%.4f.model" % ( self.rolling_results[-1]) torch.save(state, save_name) self.logger.info('Model-%s save success...' % (save_name)) else: save_name = model_root + "/%s_%d.model" % (self.agent_name, self.episode_number) torch.save(state, save_name) self.logger.info('Model-%s save success...' % (save_name)) def load_resume(self, resume_path): save = torch.load(resume_path) if self.agent_name != "DQN": q_network_local_dict = save['q_network_local'] q_network_target_dict = save['q_network_target'] self.q_network_local.load_state_dict(q_network_local_dict, strict=True) self.q_network_target.load_state_dict(q_network_target_dict, strict=True) else: q_network_local_dict = save['q_network_local'] self.q_network_local.load_state_dict(q_network_local_dict, strict=True) self.logger.info('load resume model success...') file_name = os.path.basename(resume_path) episode_str = re.findall(r"\d+\.?\d*", file_name)[0] episode_list = episode_str.split('.') if not episode_list[1]: episode = episode_list[0] else: episode = 0 if not self.config.retrain: self.episode_number = episode else: self.episode_number = 0
class DQN(Base_Agent): """A deep Q learning agent""" agent_name = "DQN" def __init__(self, config): Base_Agent.__init__(self, config) self.agent_dic = self.create_agent_dic() self.exploration_strategy = Epsilon_Greedy_Exploration(config) # self.environment.utils.visualize_gat_properties(self.config.GAT) # self.environment.utils.vis_intersec_id_embedding(agent_id='20953772',transform_func=self.get_intersection_id_embedding) def reset_game(self): super(DQN, self).reset_game() # self.update_learning_rate(self.hyperparameters["learning_rate"]) def pick_action(self, states): """Uses the local Q network and an epsilon greedy policy to pick an action""" if len(states) == 0: return [] states_batch = torch.vstack([state['embeding'] for state in states]) network_states_batch = states_batch[:, self.intersection_id_size:] if self.config.does_need_network_state: if self.config.does_need_network_state_embeding: self.config.GAT.eval() # breakpoint() with torch.no_grad(): network_state_embedings=\ self.config.GAT(network_states_batch).view(states_batch.shape[0],-1,self.config.network_embed_size) self.config.GAT.train() else: network_state_embedings = network_states_batch.view( states_batch.shape[0], -1, self.config.network_state_size) else: batch_size = states_batch.size()[0] network_size = self.config.network_state.size()[0] network_state_embedings = torch.empty(batch_size, network_size, 0).to(self.device) # breakpoint() actions = [] for state, network_state_embeding in zip(states, network_state_embedings): agent_id = self.get_agent_id(state) try: intersection_state_embeding = network_state_embeding[ state['agent_idx']] except: breakpoint() destination_id = state['embeding'][0:self.intersection_id_size] destination_id_embeding = self.get_intersection_id_embedding( agent_id, destination_id, eval=True) embeding = torch.cat( (destination_id_embeding, intersection_state_embeding), 0) action_values = self.get_action_values(agent_id, embeding.unsqueeze(0), eval=True) action_data = { "action_values": action_values, "state": state, "turn_off_exploration": self.turn_off_exploration, "episode_number": self.env_episode_number } action = self.exploration_strategy.perturb_action_for_exploration_purposes( action_data) self.logger.info("Q values {} -- Action chosen {}".format( action_values, action)) actions.append(action) return actions def learn(self): """Runs a learning iteration for the Q network on each agent""" for _ in range(self.hyperparameters["learning_iterations"]): agents_losses = [ self.compute_loss(agent_id) for agent_id in self.agent_dic if self.time_for_q_network_to_learn(agent_id) ] try: self.take_optimisation_step( agents_losses, self.hyperparameters["gradient_clipping_norm"], retain_graph=True) except Exception as e: breakpoint() def compute_loss(self, agent_id): """Computes the loss required to train the Q network""" memory = self.agent_dic[agent_id]["memory"] states, actions, rewards, next_states, dones = self.sample_experiences( memory) #Sample experiences with torch.no_grad(): Q_values_next_states = self.compute_q_values_for_next_states( next_states, dones) Q_targets = rewards + (self.hyperparameters["discount_rate"] * Q_values_next_states * (1 - dones)) Q_expected = self.compute_expected_q_values(agent_id, states, actions) loss = F.mse_loss(Q_expected, Q_targets) return (agent_id, loss) def compute_q_values_for_next_states(self, next_states, dones): """Computes the q_values for next state we will use to create the loss to train the Q network""" batch_size = dones.size()[0] Q_targets_next = torch.zeros(batch_size, 1).to(self.device) for state in next_states: # find a dummy embeding to replace for none states! if state != None: dummy_embed = state['embeding'] next_states_embedings = [ state['embeding'] if state != None else dummy_embed for state in next_states ] # not_Non_next_states_batch_index_dic={id(not_Non_next_states[idx]):idx for idx in range(len(not_Non_next_states))} next_states_embedings_batch = torch.vstack(next_states_embedings) network_states_batch = next_states_embedings_batch[:, self. intersection_id_size:] if self.config.does_need_network_state: if self.config.does_need_network_state_embeding: network_state_embeding_batch = self.config.GAT( network_states_batch) else: network_state_embeding_batch = network_states_batch.view( batch_size, -1, self.config.network_state_size) # breakpoint() else: network_size = self.config.network_state.size()[0] network_state_embeding_batch = torch.empty(batch_size, network_size, 0).to(self.device) masks_dic = {} for i in range(0, batch_size): if dones[i] == 1: continue agent_id = self.get_agent_id(next_states[i]) if not agent_id in masks_dic: masks_dic[agent_id] = {} masks_dic[agent_id]["mask"] = [False] * batch_size masks_dic[agent_id]["batch_indexs"] = [] masks_dic[agent_id]["network_index"] = next_states[i][ 'agent_idx'] masks_dic[agent_id]["mask"][i] = True masks_dic[agent_id]["batch_indexs"].append(i) for agent_id in masks_dic: agent_mask = torch.Tensor( masks_dic[agent_id]["mask"]).unsqueeze(1).to(self.device, dtype=torch.bool) agent_states_action_mask = torch.vstack([ agent_state['action_mask'] for agent_state in next_states[ masks_dic[agent_id]["batch_indexs"]] ]) destination_ids = next_states_embedings_batch[ masks_dic[agent_id]["batch_indexs"], 0:self.intersection_id_size] destination_ids_embedings = self.get_intersection_id_embedding( agent_id, destination_ids) intersec_states_embeding = network_state_embeding_batch[ masks_dic[agent_id]["batch_indexs"], masks_dic[agent_id]["network_index"]] agent_states_embedings = torch.cat( (destination_ids_embedings, intersec_states_embeding), 1) try: agent_Q_targets_next = ( self.agent_dic[agent_id]["policy"](agent_states_embedings) + agent_states_action_mask).detach().max(1)[0].unsqueeze(1) except Exception as e: breakpoint() Q_targets_next.masked_scatter_(agent_mask, agent_Q_targets_next) return Q_targets_next # max(1): find the max in every row of the batch # max(0): find the max in every column of the batch # max(1)[0]: value of the max in every row of the batch # max(1)[1]: batch_indexs of the max in every row of the batch def compute_expected_q_values(self, agent_id, states, actions): """Computes the expected q_values we will use to create the loss to train the Q network""" network_index = states[0]['agent_idx'] states_batch = torch.vstack([state['embeding'] for state in states]) network_states_batch = states_batch[:, self.intersection_id_size:] if self.config.does_need_network_state: if self.config.does_need_network_state_embeding: network_state_embeding_batch = self.config.GAT( network_states_batch) else: network_state_embeding_batch = network_states_batch.view( network_states_batch.size()[0], -1, self.config.network_state_size) else: batch_size = actions.size()[0] network_size = self.config.network_state.size()[0] network_state_embeding_batch = torch.empty(batch_size, network_size, 0).to(self.device) destination_ids = states_batch[:, 0:self.intersection_id_size] destination_ids_embedings = self.get_intersection_id_embedding( agent_id, destination_ids) intersec_states_embeding = network_state_embeding_batch[:, network_index] states_embedings = torch.cat( (destination_ids_embedings, intersec_states_embeding), 1) try: Q_expected = self.agent_dic[agent_id]["policy"]( states_embedings).gather( 1, actions.long() ) #must convert actions to long so can be used as batch_indexs except Exception as e: breakpoint() return Q_expected