class DQN(Base_Agent):
    """A deep Q learning agent"""
    agent_name = "DQN"

    def __init__(self, config):
        Base_Agent.__init__(self, config)
        model_path = self.config.model_path if self.config.model_path else 'Models'
        self.memory = Replay_Buffer(self.hyperparameters["buffer_size"],
                                    self.hyperparameters["batch_size"],
                                    config.seed)
        self.q_network_local = self.create_NN(input_dim=self.state_size,
                                              output_dim=self.action_size)
        self.q_network_local_path = os.path.join(
            model_path, "{}_q_network_local.pt".format(self.agent_name))

        if self.config.load_model: self.locally_load_policy()
        self.q_network_optimizer = optim.Adam(
            self.q_network_local.parameters(),
            lr=self.hyperparameters["learning_rate"],
            eps=1e-4)
        self.exploration_strategy = Epsilon_Greedy_Exploration(config)

    def reset_game(self):
        super(DQN, self).reset_game()
        self.update_learning_rate(self.hyperparameters["learning_rate"],
                                  self.q_network_optimizer)

    def step(self):
        """Runs a step within a game including a learning step if required"""
        while not self.done:
            self.action = self.pick_action()
            self.conduct_action(self.action)
            if self.time_for_q_network_to_learn():
                for _ in range(self.hyperparameters["learning_iterations"]):
                    self.learn()
            self.save_experience()
            self.state = self.next_state  #this is to set the state for the next iteration
            self.global_step_number += 1
        self.episode_number += 1

    def pick_action(self, state=None):
        """Uses the local Q network and an epsilon greedy policy to pick an action"""
        # PyTorch only accepts mini-batches and not single observations so we have to use unsqueeze to add
        # a "fake" dimension to make it a mini-batch rather than a single observation
        if state is None: state = self.state
        if isinstance(state, np.int64) or isinstance(state, int):
            state = np.array([state])
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        if len(state.shape) < 2: state = state.unsqueeze(0)
        self.q_network_local.eval()  #puts network in evaluation mode
        with torch.no_grad():
            action_values = self.q_network_local(state)
        self.q_network_local.train()  #puts network back in training mode
        action = self.exploration_strategy.perturb_action_for_exploration_purposes(
            {
                "action_values": action_values,
                "turn_off_exploration": self.turn_off_exploration,
                "episode_number": self.episode_number
            })
        self.logger.info("Q values {} -- Action chosen {}".format(
            action_values, action))
        return action

    def learn(self, experiences=None):
        """Runs a learning iteration for the Q network"""
        if experiences is None:
            states, actions, rewards, next_states, dones = self.sample_experiences(
            )  #Sample experiences
        else:
            states, actions, rewards, next_states, dones = experiences
        loss = self.compute_loss(states, next_states, rewards, actions, dones)

        actions_list = [action_X.item() for action_X in actions]

        self.logger.info("Action counts {}".format(Counter(actions_list)))
        self.take_optimisation_step(
            self.q_network_optimizer, self.q_network_local, loss,
            self.hyperparameters["gradient_clipping_norm"])

    def compute_loss(self, states, next_states, rewards, actions, dones):
        """Computes the loss required to train the Q network"""
        with torch.no_grad():
            Q_targets = self.compute_q_targets(next_states, rewards, dones)
        Q_expected = self.compute_expected_q_values(states, actions)
        loss = F.mse_loss(Q_expected, Q_targets)
        return loss

    def compute_q_targets(self, next_states, rewards, dones):
        """Computes the q_targets we will compare to predicted q values to create the loss to train the Q network"""
        Q_targets_next = self.compute_q_values_for_next_states(next_states)
        Q_targets = self.compute_q_values_for_current_states(
            rewards, Q_targets_next, dones)
        return Q_targets

    def compute_q_values_for_next_states(self, next_states):
        """Computes the q_values for next state we will use to create the loss to train the Q network"""
        Q_targets_next = self.q_network_local(next_states).detach().max(
            1)[0].unsqueeze(1)
        return Q_targets_next

    def compute_q_values_for_current_states(self, rewards, Q_targets_next,
                                            dones):
        """Computes the q_values for current state we will use to create the loss to train the Q network"""
        Q_targets_current = rewards + (self.hyperparameters["discount_rate"] *
                                       Q_targets_next * (1 - dones))
        return Q_targets_current

    def compute_expected_q_values(self, states, actions):
        """Computes the expected q_values we will use to create the loss to train the Q network"""
        Q_expected = self.q_network_local(states).gather(1, actions.long(
        ))  #must convert actions to long so can be used as index
        return Q_expected

    def time_for_q_network_to_learn(self):
        """Returns boolean indicating whether enough steps have been taken for learning to begin and there are
        enough experiences in the replay buffer to learn from"""
        return self.right_amount_of_steps_taken(
        ) and self.enough_experiences_to_learn_from()

    def right_amount_of_steps_taken(self):
        """Returns boolean indicating whether enough steps have been taken for learning to begin"""
        return self.global_step_number % self.hyperparameters[
            "update_every_n_steps"] == 0

    def sample_experiences(self):
        """Draws a random sample of experience from the memory buffer"""
        experiences = self.memory.sample()
        states, actions, rewards, next_states, dones = experiences
        return states, actions, rewards, next_states, dones

    def locally_save_policy(self):
        """Saves the policy"""
        """保存策略,待添加"""
        torch.save(self.q_network_local.state_dict(),
                   self.q_network_local_path)

    def locally_load_policy(self):
        print("locall_load_policy")
        if os.path.isfile(self.q_network_local_path):
            try:
                self.q_network_local.load_state_dict(
                    torch.load(self.q_network_local_path))
                print("load critic_local_path")
            except:
                pass
示例#2
0
class DDQN_Wrapper(Base_Agent):
    def __init__(self,
                 config,
                 global_action_id_to_primitive_actions,
                 action_length_reward_bonus,
                 end_of_episode_symbol="/"):
        super().__init__(config)
        self.end_of_episode_symbol = end_of_episode_symbol
        self.global_action_id_to_primitive_actions = global_action_id_to_primitive_actions
        self.memory = Replay_Buffer(self.hyperparameters["buffer_size"],
                                    self.hyperparameters["batch_size"],
                                    config.seed)
        self.exploration_strategy = Epsilon_Greedy_Exploration(config)

        self.oracle = self.create_oracle()
        self.oracle_optimizer = optim.Adam(
            self.oracle.parameters(), lr=self.hyperparameters["learning_rate"])

        self.q_network_local = self.create_NN(input_dim=self.state_size + 1,
                                              output_dim=self.action_size)
        self.q_network_local.print_model_summary()
        self.q_network_optimizer = optim.Adam(
            self.q_network_local.parameters(),
            lr=self.hyperparameters["learning_rate"])
        self.q_network_target = self.create_NN(input_dim=self.state_size + 1,
                                               output_dim=self.action_size)
        Base_Agent.copy_model_over(from_model=self.q_network_local,
                                   to_model=self.q_network_target)

        self.action_length_reward_bonus = action_length_reward_bonus
        self.abandon_ship = config.hyperparameters["abandon_ship"]

    def create_oracle(self):
        """Creates the network we will use to predict the next state"""
        oracle_hyperparameters = copy.deepcopy(self.hyperparameters)
        oracle_hyperparameters["columns_of_data_to_be_embedded"] = []
        oracle_hyperparameters["embedding_dimensions"] = []
        oracle_hyperparameters["linear_hidden_units"] = [5, 5]
        oracle_hyperparameters["final_layer_activation"] = [None, "tanh"]
        oracle = self.create_NN(input_dim=self.state_size + 2,
                                output_dim=[self.state_size + 1, 1],
                                hyperparameters=oracle_hyperparameters)
        oracle.print_model_summary()
        return oracle

    def run_n_episodes(self, num_episodes,
                       episodes_to_run_with_no_exploration):
        self.turn_on_any_epsilon_greedy_exploration()
        self.round_of_macro_actions = []
        self.episode_actions_scores_and_exploration_status = []
        num_episodes_to_get_to = self.episode_number + num_episodes
        while self.episode_number < num_episodes_to_get_to:
            self.reset_game()
            self.step()
            self.save_and_print_result()
            if num_episodes_to_get_to - self.episode_number == episodes_to_run_with_no_exploration:
                self.turn_off_any_epsilon_greedy_exploration()
        assert len(self.episode_actions_scores_and_exploration_status
                   ) == num_episodes, "{} vs. {}".format(
                       len(self.episode_actions_scores_and_exploration_status),
                       num_episodes)
        assert len(self.episode_actions_scores_and_exploration_status[0]) == 3
        assert self.episode_actions_scores_and_exploration_status[0][2] in [
            True, False
        ]
        assert isinstance(
            self.episode_actions_scores_and_exploration_status[0][1], list)
        assert isinstance(
            self.episode_actions_scores_and_exploration_status[0][1][0], int)
        assert isinstance(
            self.episode_actions_scores_and_exploration_status[0][0],
            int) or isinstance(
                self.episode_actions_scores_and_exploration_status[0][0],
                float)
        return self.episode_actions_scores_and_exploration_status, self.round_of_macro_actions

    def step(self):
        """Runs a step within a game including a learning step if required"""
        step_number = 0.0
        self.state = np.append(
            self.state, step_number /
            200.0)  #Divide by 200 because there are 200 steps in cart pole

        self.total_episode_score_so_far = 0
        episode_macro_actions = []
        while not self.done:
            surprised = False
            macro_action = self.pick_action()
            primitive_actions = self.global_action_id_to_primitive_actions[
                macro_action]
            primitive_actions_conducted = 0
            for ix, action in enumerate(primitive_actions):
                if self.abandon_ship and primitive_actions_conducted > 0:
                    if self.abandon_macro_action(action):
                        break

                step_number += 1
                self.action = action
                self.next_state, self.reward, self.done, _ = self.environment.step(
                    action)
                self.next_state = np.append(
                    self.next_state, step_number / 200.0
                )  #Divide by 200 because there are 200 steps in cart pole

                self.total_episode_score_so_far += self.reward
                if self.hyperparameters["clip_rewards"]:
                    self.reward = max(min(self.reward, 1.0), -1.0)
                primitive_actions_conducted += 1
                self.track_episodes_data()
                self.save_experience()

                if len(primitive_actions) > 1:

                    surprised = self.am_i_surprised()

                self.state = self.next_state
                if self.time_for_q_network_to_learn():
                    for _ in range(
                            self.hyperparameters["learning_iterations"]):
                        self.q_network_learn()
                        self.oracle_learn()
                if self.done or surprised: break
            episode_macro_actions.append(macro_action)
            self.round_of_macro_actions.append(macro_action)
        if random.random() < 0.1: print(Counter(episode_macro_actions))
        self.save_episode_actions_with_score()
        self.episode_number += 1
        self.logger.info("END OF EPISODE")

    def am_i_surprised(self):
        """Returns boolean indicating whether the next_state was a surprise or not"""
        with torch.no_grad():
            state = torch.from_numpy(self.state).float().unsqueeze(0).to(
                self.device)
            action = torch.Tensor([[self.action]])

            states_and_actions = torch.cat(
                (state, action),
                dim=1)  #must change this for all games besides cart pole
            predictions = self.oracle(states_and_actions)
            predicted_next_state = predictions[0, :-1]

            difference = F.mse_loss(predicted_next_state,
                                    torch.Tensor(self.next_state))
            if difference > 0.5:
                print("Surprise! Loss {} -- {} vs. {}".format(
                    difference, predicted_next_state, self.next_state))
                return True
            else:
                return False

    def abandon_macro_action(self, action):
        """Returns boolean indicating whether to abandon macro action or not"""
        state = torch.from_numpy(self.state).float().unsqueeze(0).to(
            self.device)
        with torch.no_grad():
            primitive_q_values = self.calculate_q_values(
                state, local=True, primitive_actions_only=True)
        q_value_highest = torch.max(primitive_q_values)
        q_values_action = primitive_q_values[:, action]
        if q_value_highest > 0.0: multiplier = 0.7
        else: multiplier = 1.3
        if q_values_action < multiplier * q_value_highest:
            print("BREAKING Action {} -- Q Values {}".format(
                action, primitive_q_values))
            return True
        else:
            return False

    def pick_action(self, state=None):
        """Uses the local Q network and an epsilon greedy policy to pick an action"""
        if state is None: state = self.state
        if isinstance(state, np.int64) or isinstance(state, int):
            state = np.array([state])
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        if len(state.shape) < 2: state = state.unsqueeze(0)
        self.q_network_local.eval()  #puts network in evaluation mode
        with torch.no_grad():
            action_values = self.calculate_q_values(
                state, local=True, primitive_actions_only=False)
        self.q_network_local.train()  #puts network back in training mode
        action = self.exploration_strategy.perturb_action_for_exploration_purposes(
            {
                "action_values": action_values,
                "turn_off_exploration": self.turn_off_exploration,
                "episode_number": self.episode_number
            })
        self.logger.info("Q values {} -- Action chosen {}".format(
            action_values, action))
        return action

    def calculate_q_values(self, states, local, primitive_actions_only):
        """Calculates the q values using the local q network"""
        if local:
            primitive_q_values = self.q_network_local(states)
        else:
            primitive_q_values = self.q_network_target(states)

        num_actions = len(self.global_action_id_to_primitive_actions)
        if primitive_actions_only or num_actions <= self.action_size:
            return primitive_q_values

        extra_q_values = self.calculate_macro_action_q_values(
            states, num_actions)
        extra_q_values = torch.Tensor([extra_q_values])
        all_q_values = torch.cat((primitive_q_values, extra_q_values), dim=1)

        return all_q_values

    def calculate_macro_action_q_values(self, state, num_actions):
        assert state.shape[0] == 1
        q_values = []
        for action_id in range(self.action_size, num_actions):
            macro_action = self.global_action_id_to_primitive_actions[
                action_id]
            predicted_next_state = state
            cumulated_reward = 0
            action_ix = 0
            for action in macro_action[:-1]:
                predictions = self.oracle(
                    torch.cat((predicted_next_state, torch.Tensor([[action]])),
                              dim=1))
                rewards = predictions[:, -1]
                predicted_next_state = predictions[:, :-1]
                cumulated_reward += (
                    rewards.item() + self.action_length_reward_bonus
                ) * self.hyperparameters["discount_rate"]**(action_ix)
                action_ix += 1
            final_action = macro_action[-1]
            final_q_value = self.q_network_local(predicted_next_state)[
                0, final_action]
            total_q_value = cumulated_reward + final_q_value * self.hyperparameters[
                "discount_rate"]**(action_ix)
            q_values.append(total_q_value)
        return q_values

    def time_for_q_network_to_learn(self):
        """Returns boolean indicating whether enough steps have been taken for learning to begin and there are
        enough experiences in the replay buffer to learn from"""
        return self.right_amount_of_steps_taken(
        ) and self.enough_experiences_to_learn_from()

    def right_amount_of_steps_taken(self):
        """Returns boolean indicating whether enough steps have been taken for learning to begin"""
        return self.global_step_number % self.hyperparameters[
            "update_every_n_steps"] == 0

    def q_network_learn(self, experiences=None):
        """Runs a learning iteration for the Q network"""
        if experiences is None:
            states, actions, rewards, next_states, dones = self.sample_experiences(
            )  #Sample experiences
        else:
            states, actions, rewards, next_states, dones = experiences
        loss = self.compute_loss(states, next_states, rewards, actions, dones)
        self.take_optimisation_step(
            self.q_network_optimizer, self.q_network_local, loss,
            self.hyperparameters["gradient_clipping_norm"])
        self.soft_update_of_target_network(self.q_network_local,
                                           self.q_network_target,
                                           self.hyperparameters["tau"])

    def sample_experiences(self):
        """Draws a random sample of experience from the memory buffer"""
        experiences = self.memory.sample()
        states, actions, rewards, next_states, dones = experiences
        return states, actions, rewards, next_states, dones

    def compute_loss(self, states, next_states, rewards, actions, dones):
        """Computes the loss required to train the Q network"""
        with torch.no_grad():
            max_action_indexes = self.calculate_q_values(
                next_states, local=True,
                primitive_actions_only=True).detach().argmax(1)
            Q_targets_next = self.calculate_q_values(
                next_states, local=False, primitive_actions_only=True).gather(
                    1, max_action_indexes.unsqueeze(1))
            Q_targets = rewards + (self.hyperparameters["discount_rate"] *
                                   Q_targets_next * (1 - dones))
        Q_expected = self.calculate_q_values(
            states, local=True,
            primitive_actions_only=True).gather(1, actions.long(
            ))  # must convert actions to long so can be used as index
        loss = F.mse_loss(Q_expected, Q_targets)
        return loss

    def save_episode_actions_with_score(self):

        self.episode_actions_scores_and_exploration_status.append([
            self.total_episode_score_so_far,
            self.episode_actions + [self.end_of_episode_symbol],
            self.turn_off_exploration
        ])

    def oracle_learn(self):
        states, actions, rewards, next_states, _ = self.sample_experiences(
        )  # Sample experiences
        states_and_actions = torch.cat(
            (states, actions),
            dim=1)  #must change this for all games besides cart pole
        predictions = self.oracle(states_and_actions)
        loss = F.mse_loss(torch.cat((next_states, rewards), dim=1),
                          predictions) / float(next_states.shape[1] + 1.0)
        self.take_optimisation_step(
            self.oracle_optimizer, self.oracle, loss,
            self.hyperparameters["gradient_clipping_norm"])
        self.logger.info("Oracle Loss {}".format(loss))
class DQN(Base_Agent):
    """A deep Q learning agent"""
    agent_name = "DQN"

    def __init__(self, config):
        Base_Agent.__init__(self, config)
        self.memory = Replay_Buffer(self.hyperparameters["buffer_size"],
                                    self.hyperparameters["batch_size"],
                                    config.seed)
        self.q_network_local = self.create_NN(input_dim=self.state_size,
                                              output_dim=self.action_size)
        self.q_network_optimizer = optim.SGD(
            self.q_network_local.parameters(),
            lr=self.hyperparameters["learning_rate"],
            weight_decay=5e-4)
        self.exploration_strategy = Epsilon_Greedy_Exploration(config)

    def reset_game(self):
        super(DQN, self).reset_game()
        self.update_learning_rate(self.hyperparameters["learning_rate"],
                                  self.q_network_optimizer)

    def step(self):
        """Runs a step within a game including a learning step if required"""
        while not self.done:
            # print('state:', self.state)
            # self.environment.render()
            self.action = self.pick_action()
            self.conduct_action(self.action)
            if self.time_for_q_network_to_learn():
                for _ in range(self.hyperparameters["learning_iterations"]):
                    try:
                        self.environment.pause()
                        # print('pause')
                        self.learn()
                        self.environment.resume()
                        # print('resume')
                    except:
                        self.learn()
            self.save_experience()
            self.state = self.next_state  #this is to set the state for the next iteration
            self.global_step_number += 1
        self.episode_number += 1

    def pick_action(self, state=None):
        """Uses the local Q network and an epsilon greedy policy to pick an action"""
        # PyTorch only accepts mini-batches and not single observations so we have to use unsqueeze to add
        # a "fake" dimension to make it a mini-batch rather than a single observation
        if state is None: state = self.state
        if isinstance(state, np.int64) or isinstance(state, int):
            state = np.array([state])
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        if len(state.shape) < 2: state = state.unsqueeze(0)
        self.q_network_local.eval()  #puts network in evaluation mode
        with torch.no_grad():
            action_values = self.q_network_local(state)
        self.q_network_local.train()  #puts network back in training mode

        force_explore = self.config.force_explore_mode and self.need_to_force_explore(
        )

        if force_explore:
            print('explore...')

        action = self.exploration_strategy.perturb_action_for_exploration_purposes(
            {
                "action_values": action_values,
                "turn_off_exploration": self.turn_off_exploration,
                "episode_number": self.episode_number,
                "force_explore": force_explore
            })
        # self.logger.info("Q values {} -- Action chosen {}".format(action_values, action))
        return action

    def learn(self, experiences=None):
        """Runs a learning iteration for the Q network"""
        if experiences is None:
            states, actions, rewards, next_states, dones = self.sample_experiences(
            )  #Sample experiences
        else:
            states, actions, rewards, next_states, dones = experiences
        loss = self.compute_loss(states, next_states, rewards, actions, dones)

        actions_list = [action_X.item() for action_X in actions]

        self.logger.info("Action counts {}".format(Counter(actions_list)))
        self.take_optimisation_step(
            self.q_network_optimizer, self.q_network_local, loss,
            self.hyperparameters["gradient_clipping_norm"])

    def compute_loss(self, states, next_states, rewards, actions, dones):
        """Computes the loss required to train the Q network"""
        with torch.no_grad():
            Q_targets = self.compute_q_targets(next_states, rewards, dones)
        Q_expected = self.compute_expected_q_values(states, actions)
        # loss = F.mse_loss(Q_expected, Q_targets)

        loss = nn.MSELoss(size_average=False)(Q_expected, Q_targets)
        return loss

    def compute_q_targets(self, next_states, rewards, dones):
        """Computes the q_targets we will compare to predicted q values to create the loss to train the Q network"""
        Q_targets_next = self.compute_q_values_for_next_states(next_states)
        Q_targets = self.compute_q_values_for_current_states(
            rewards, Q_targets_next, dones)
        return Q_targets

    def compute_q_values_for_next_states(self, next_states):
        """Computes the q_values for next state we will use to create the loss to train the Q network"""
        Q_targets_next = self.q_network_local(next_states).detach().max(
            1)[0].unsqueeze(1)
        return Q_targets_next

    def compute_q_values_for_current_states(self, rewards, Q_targets_next,
                                            dones):
        """Computes the q_values for current state we will use to create the loss to train the Q network"""
        Q_targets_current = rewards + (self.hyperparameters["discount_rate"] *
                                       Q_targets_next * (1 - dones))
        return Q_targets_current

    def compute_expected_q_values(self, states, actions):
        """Computes the expected q_values we will use to create the loss to train the Q network"""
        Q_expected = self.q_network_local(states).gather(1, actions.long(
        ))  #must convert actions to long so can be used as index
        return Q_expected

    def time_for_q_network_to_learn(self):
        """Returns boolean indicating whether enough steps have been taken for learning to begin and there are
        enough experiences in the replay buffer to learn from"""
        return self.right_amount_of_steps_taken(
        ) and self.enough_experiences_to_learn_from()

    def right_amount_of_steps_taken(self):
        """Returns boolean indicating whether enough steps have been taken for learning to begin"""
        return self.global_step_number % self.hyperparameters[
            "update_every_n_steps"] == 0

    def sample_experiences(self):
        """Draws a random sample of experience from the memory buffer"""
        experiences = self.memory.sample()
        states, actions, rewards, next_states, dones = experiences
        return states, actions, rewards, next_states, dones

    def locally_save_policy(self, best=True, episode=None):
        if self.agent_name != "DQN":
            state = {
                'episode': self.episode_number,
                'q_network_local': self.q_network_local.state_dict(),
                'q_network_target': self.q_network_target.state_dict()
            }
        else:
            state = {
                'episode': self.episode_number,
                'q_network_local': self.q_network_local.state_dict()
            }

        model_root = os.path.join('Models', self.config.env_title,
                                  self.agent_name, self.config.log_base)
        if not os.path.exists(model_root):
            os.makedirs(model_root)

        if best:
            last_best_file = glob.glob(
                os.path.join(model_root, 'rolling_score*'))
            if last_best_file:
                os.remove(last_best_file[0])

            save_name = model_root + "/rolling_score_%.4f.model" % (
                self.rolling_results[-1])
            torch.save(state, save_name)
            self.logger.info('Model-%s save success...' % (save_name))
        else:
            save_name = model_root + "/%s_%d.model" % (self.agent_name,
                                                       self.episode_number)
            torch.save(state, save_name)
            self.logger.info('Model-%s save success...' % (save_name))

    def load_resume(self, resume_path):
        save = torch.load(resume_path)
        if self.agent_name != "DQN":
            q_network_local_dict = save['q_network_local']
            q_network_target_dict = save['q_network_target']
            self.q_network_local.load_state_dict(q_network_local_dict,
                                                 strict=True)
            self.q_network_target.load_state_dict(q_network_target_dict,
                                                  strict=True)
        else:
            q_network_local_dict = save['q_network_local']
            self.q_network_local.load_state_dict(q_network_local_dict,
                                                 strict=True)
        self.logger.info('load resume model success...')

        file_name = os.path.basename(resume_path)
        episode_str = re.findall(r"\d+\.?\d*", file_name)[0]
        episode_list = episode_str.split('.')
        if not episode_list[1]:
            episode = episode_list[0]
        else:
            episode = 0

        if not self.config.retrain:
            self.episode_number = episode
        else:
            self.episode_number = 0
示例#4
0
class DQN(Base_Agent):
    """A deep Q learning agent"""
    agent_name = "DQN"

    def __init__(self, config):
        Base_Agent.__init__(self, config)
        self.agent_dic = self.create_agent_dic()
        self.exploration_strategy = Epsilon_Greedy_Exploration(config)
        # self.environment.utils.visualize_gat_properties(self.config.GAT)
        # self.environment.utils.vis_intersec_id_embedding(agent_id='20953772',transform_func=self.get_intersection_id_embedding)

    def reset_game(self):
        super(DQN, self).reset_game()
        # self.update_learning_rate(self.hyperparameters["learning_rate"])

    def pick_action(self, states):
        """Uses the local Q network and an epsilon greedy policy to pick an action"""
        if len(states) == 0:
            return []

        states_batch = torch.vstack([state['embeding'] for state in states])
        network_states_batch = states_batch[:, self.intersection_id_size:]
        if self.config.does_need_network_state:
            if self.config.does_need_network_state_embeding:
                self.config.GAT.eval()
                # breakpoint()
                with torch.no_grad():
                    network_state_embedings=\
                    self.config.GAT(network_states_batch).view(states_batch.shape[0],-1,self.config.network_embed_size)
                self.config.GAT.train()
            else:
                network_state_embedings = network_states_batch.view(
                    states_batch.shape[0], -1, self.config.network_state_size)
        else:
            batch_size = states_batch.size()[0]
            network_size = self.config.network_state.size()[0]
            network_state_embedings = torch.empty(batch_size, network_size,
                                                  0).to(self.device)
        # breakpoint()
        actions = []
        for state, network_state_embeding in zip(states,
                                                 network_state_embedings):
            agent_id = self.get_agent_id(state)
            try:
                intersection_state_embeding = network_state_embeding[
                    state['agent_idx']]
            except:
                breakpoint()

            destination_id = state['embeding'][0:self.intersection_id_size]
            destination_id_embeding = self.get_intersection_id_embedding(
                agent_id, destination_id, eval=True)
            embeding = torch.cat(
                (destination_id_embeding, intersection_state_embeding), 0)
            action_values = self.get_action_values(agent_id,
                                                   embeding.unsqueeze(0),
                                                   eval=True)
            action_data = {
                "action_values": action_values,
                "state": state,
                "turn_off_exploration": self.turn_off_exploration,
                "episode_number": self.env_episode_number
            }
            action = self.exploration_strategy.perturb_action_for_exploration_purposes(
                action_data)

            self.logger.info("Q values {} -- Action chosen {}".format(
                action_values, action))
            actions.append(action)

        return actions

    def learn(self):
        """Runs a learning iteration for the Q network on each agent"""
        for _ in range(self.hyperparameters["learning_iterations"]):
            agents_losses = [
                self.compute_loss(agent_id) for agent_id in self.agent_dic
                if self.time_for_q_network_to_learn(agent_id)
            ]
            try:
                self.take_optimisation_step(
                    agents_losses,
                    self.hyperparameters["gradient_clipping_norm"],
                    retain_graph=True)
            except Exception as e:
                breakpoint()

    def compute_loss(self, agent_id):
        """Computes the loss required to train the Q network"""
        memory = self.agent_dic[agent_id]["memory"]
        states, actions, rewards, next_states, dones = self.sample_experiences(
            memory)  #Sample experiences

        with torch.no_grad():
            Q_values_next_states = self.compute_q_values_for_next_states(
                next_states, dones)
            Q_targets = rewards + (self.hyperparameters["discount_rate"] *
                                   Q_values_next_states * (1 - dones))

        Q_expected = self.compute_expected_q_values(agent_id, states, actions)
        loss = F.mse_loss(Q_expected, Q_targets)
        return (agent_id, loss)

    def compute_q_values_for_next_states(self, next_states, dones):
        """Computes the q_values for next state we will use to create the loss to train the Q network"""
        batch_size = dones.size()[0]
        Q_targets_next = torch.zeros(batch_size, 1).to(self.device)

        for state in next_states:
            # find a dummy embeding to replace for none states!
            if state != None:
                dummy_embed = state['embeding']

        next_states_embedings = [
            state['embeding'] if state != None else dummy_embed
            for state in next_states
        ]
        # not_Non_next_states_batch_index_dic={id(not_Non_next_states[idx]):idx for idx in range(len(not_Non_next_states))}
        next_states_embedings_batch = torch.vstack(next_states_embedings)
        network_states_batch = next_states_embedings_batch[:, self.
                                                           intersection_id_size:]

        if self.config.does_need_network_state:
            if self.config.does_need_network_state_embeding:
                network_state_embeding_batch = self.config.GAT(
                    network_states_batch)
            else:
                network_state_embeding_batch = network_states_batch.view(
                    batch_size, -1, self.config.network_state_size)
                # breakpoint()
        else:
            network_size = self.config.network_state.size()[0]
            network_state_embeding_batch = torch.empty(batch_size,
                                                       network_size,
                                                       0).to(self.device)

        masks_dic = {}

        for i in range(0, batch_size):
            if dones[i] == 1:
                continue
            agent_id = self.get_agent_id(next_states[i])
            if not agent_id in masks_dic:
                masks_dic[agent_id] = {}
                masks_dic[agent_id]["mask"] = [False] * batch_size
                masks_dic[agent_id]["batch_indexs"] = []
                masks_dic[agent_id]["network_index"] = next_states[i][
                    'agent_idx']
            masks_dic[agent_id]["mask"][i] = True
            masks_dic[agent_id]["batch_indexs"].append(i)

        for agent_id in masks_dic:
            agent_mask = torch.Tensor(
                masks_dic[agent_id]["mask"]).unsqueeze(1).to(self.device,
                                                             dtype=torch.bool)
            agent_states_action_mask = torch.vstack([
                agent_state['action_mask'] for agent_state in next_states[
                    masks_dic[agent_id]["batch_indexs"]]
            ])
            destination_ids = next_states_embedings_batch[
                masks_dic[agent_id]["batch_indexs"],
                0:self.intersection_id_size]
            destination_ids_embedings = self.get_intersection_id_embedding(
                agent_id, destination_ids)
            intersec_states_embeding = network_state_embeding_batch[
                masks_dic[agent_id]["batch_indexs"],
                masks_dic[agent_id]["network_index"]]
            agent_states_embedings = torch.cat(
                (destination_ids_embedings, intersec_states_embeding), 1)
            try:
                agent_Q_targets_next = (
                    self.agent_dic[agent_id]["policy"](agent_states_embedings)
                    + agent_states_action_mask).detach().max(1)[0].unsqueeze(1)
            except Exception as e:
                breakpoint()
            Q_targets_next.masked_scatter_(agent_mask, agent_Q_targets_next)

        return Q_targets_next

        # max(1): find the max in every row of the batch
        # max(0): find the max in every column of the batch
        # max(1)[0]: value of the max in every row of the batch
        # max(1)[1]: batch_indexs of the max in every row of the batch

    def compute_expected_q_values(self, agent_id, states, actions):
        """Computes the expected q_values we will use to create the loss to train the Q network"""
        network_index = states[0]['agent_idx']
        states_batch = torch.vstack([state['embeding'] for state in states])
        network_states_batch = states_batch[:, self.intersection_id_size:]

        if self.config.does_need_network_state:
            if self.config.does_need_network_state_embeding:
                network_state_embeding_batch = self.config.GAT(
                    network_states_batch)
            else:
                network_state_embeding_batch = network_states_batch.view(
                    network_states_batch.size()[0], -1,
                    self.config.network_state_size)
        else:
            batch_size = actions.size()[0]
            network_size = self.config.network_state.size()[0]
            network_state_embeding_batch = torch.empty(batch_size,
                                                       network_size,
                                                       0).to(self.device)

        destination_ids = states_batch[:, 0:self.intersection_id_size]
        destination_ids_embedings = self.get_intersection_id_embedding(
            agent_id, destination_ids)
        intersec_states_embeding = network_state_embeding_batch[:,
                                                                network_index]
        states_embedings = torch.cat(
            (destination_ids_embedings, intersec_states_embeding), 1)
        try:
            Q_expected = self.agent_dic[agent_id]["policy"](
                states_embedings).gather(
                    1, actions.long()
                )  #must convert actions to long so can be used as batch_indexs
        except Exception as e:
            breakpoint()
        return Q_expected