コード例 #1
0
class MyAgent(AbstractAgent):
    def __init__(self, observation_space, action_space):
        """
        Changes the frame to be same input as the PyTorchFrame wrapper frames
        Then creates a replay bugffer and a vanilla DQN which the weights are loaded into
        """
        shape = observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0,
                                                high=1.0,
                                                shape=(shape[-1], shape[0],
                                                       shape[1]),
                                                dtype=np.uint8)
        self.action_space = action_space
        self.memory = ReplayBuffer(int(5e3))
        self.policy_network = DQN(self.observation_space, self.action_space)
        self.policy_network.load_state_dict(
            torch.load("checkpoints/40000.pth",
                       map_location=torch.device(device)))
        self.policy_network.eval()

    def act(self, observation):
        """
        Given a state, choose an action according to learned policy policy_network
        @see dqn.model

        The only major change is changing the input to be in the same shape as PyTorchFrame wrapper frames
        """
        observation = np.rollaxis(observation, 2)
        observation = np.array(observation) / 255.0
        observation = torch.from_numpy(observation).float().unsqueeze(0).to(
            device)
        with torch.no_grad():
            q_values = self.policy_network(observation)
            _, action = q_values.max(1)
            chosen_action = action.item()
            return chosen_action
コード例 #2
0
class DQNAgent:
    def __init__(
        self,
        observation_space: spaces.Box,
        action_space: spaces.Discrete,
        replay_buffer: ReplayBuffer,
        use_double_dqn,
        lr,
        batch_size,
        gamma,
    ):
        """
        Initialise the DQN algorithm using the Adam optimiser
        :param action_space: the action space of the environment
        :param observation_space: the state space of the environment
        :param replay_buffer: storage for experience replay
        :param lr: the learning rate for Adam
        :param batch_size: the batch size
        :param gamma: the discount factor
        """

        self.memory = replay_buffer
        self.batch_size = batch_size
        self.use_double_dqn = use_double_dqn
        self.gamma = gamma
        self.policy_network = DQN(observation_space, action_space).to(device)
        self.target_network = DQN(observation_space, action_space).to(device)
        self.update_target_network()
        self.target_network.eval()
        self.optimiser = torch.optim.Adam(self.policy_network.parameters(),
                                          lr=lr)

    def optimise_td_loss(self):
        """
        Optimise the TD-error over a single minibatch of transitions
        :return: the loss
        """
        states, actions, rewards, next_states, dones = self.memory.sample(
            self.batch_size)
        states = np.array(states) / 255.0
        next_states = np.array(next_states) / 255.0
        states = torch.from_numpy(states).float().to(device)
        actions = torch.from_numpy(actions).long().to(device)
        rewards = torch.from_numpy(rewards).float().to(device)
        next_states = torch.from_numpy(next_states).float().to(device)
        dones = torch.from_numpy(dones).float().to(device)

        with torch.no_grad():
            if self.use_double_dqn:
                _, max_next_action = self.policy_network(next_states).max(1)
                max_next_q_values = self.target_network(next_states).gather(
                    1, max_next_action.unsqueeze(1)).squeeze()
            else:
                next_q_values = self.target_network(next_states)
                max_next_q_values, _ = next_q_values.max(1)
            target_q_values = rewards + (
                1 - dones) * self.gamma * max_next_q_values

        input_q_values = self.policy_network(states)
        input_q_values = input_q_values.gather(1,
                                               actions.unsqueeze(1)).squeeze()

        loss = F.smooth_l1_loss(input_q_values, target_q_values)

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        del states
        del next_states
        return loss.item()

    def update_target_network(self):
        """
        Update the target Q-network by copying the weights from the current Q-network
        """
        self.target_network.load_state_dict(self.policy_network.state_dict())

    def act(self, state: np.ndarray):
        """
        Select an action greedily from the Q-network given the state
        :param state: the current state
        :return: the action to take
        """
        state = np.array(state) / 255.0
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        with torch.no_grad():
            q_values = self.policy_network(state)
            _, action = q_values.max(1)
            return action.item()
コード例 #3
0
ファイル: agent.py プロジェクト: TheWronskians/obstacle_tower
class DQNAgent:
    def __init__(self,
                 observation_space: spaces.Box,
                 action_space: spaces.Discrete,
                 replay_buffer: ReplayBuffer,
                 use_double_dqn=True,
                 lr=1e-4,
                 batch_size=32,
                 gamma=0.99):
        """
        Initialise the DQN algorithm using the Adam optimiser
        :param action_space: the action space of the environment
        :param observation_space: the state space of the environment
        :param replay_buffer: storage for experience replay
        :param lr: the learning rate for Adam
        :param batch_size: the batch size
        :param gamma: the discount factor
        """

        self.memory = replay_buffer
        self.batch_size = batch_size
        self.use_double_dqn = use_double_dqn
        self.gamma = gamma
        self.policy_network = DQN(observation_space, action_space).to(device)
        self.target_network = DQN(observation_space, action_space).to(device)
        self.update_target_network()
        self.target_network.eval()
        self.optimiser = torch.optim.Adam(self.policy_network.parameters(),
                                          lr=lr)

    def optimise_td_loss(self):
        """
        Optimise the TD-error over a single minibatch of transitions
        :return: the loss
        """
        states, actions, rewards, next_states, dones = self.memory.sample(
            self.batch_size)
        states = np.array(states) / 255.0
        next_states = np.array(next_states) / 255.0
        states = torch.from_numpy(states).float().to(device)
        actions = torch.from_numpy(actions).long().to(device)
        rewards = torch.from_numpy(rewards).float().to(device)
        next_states = torch.from_numpy(next_states).float().to(device)
        dones = torch.from_numpy(dones).float().to(device)

        with torch.no_grad():
            if self.use_double_dqn:
                _, max_next_action = self.policy_network(next_states).max(1)
                max_next_q_values = self.target_network(next_states).gather(
                    1, max_next_action.unsqueeze(1)).squeeze()
            else:
                next_q_values = self.target_network(next_states)
                max_next_q_values, _ = next_q_values.max(1)
            target_q_values = rewards + (
                1 - dones) * self.gamma * max_next_q_values

        input_q_values = self.policy_network(states)
        input_q_values = input_q_values.gather(1,
                                               actions.unsqueeze(1)).squeeze()

        loss = F.smooth_l1_loss(input_q_values, target_q_values)

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        del states
        del next_states
        return loss.item()

    def update_target_network(self):
        """
        Update the target Q-network by copying the weights from the current Q-network
        """
        self.target_network.load_state_dict(self.policy_network.state_dict())

    def save_models(self, episode_count):
        """
        saves the target actor and critic models
        :param episode_count: the count of episodes iterated
        :return:
        """
        torch.save(self.policy_network.state_dict(),
                   './Models/' + str(episode_count) + '_policy.pt')
        # torch.save(self.target_network.state_dict(), './Models/' + str(episode_count) + '_target.pt')
        print('Models saved successfully')

# def load_models(self, episode):
#     """
#     loads the target actor and critic models, and copies them onto actor and critic models
#     :param episode: the count of episodes iterated (used to find the file name)
#     :return:
#     """
#     self.policy_network.load_state_dict(torch.load('./Models/' + str(episode) + '_policy.pt'))
#     utils2.hard_update(self.target_network, self.policy_network)
#     print('Models loaded succesfully')

    def act(self, state: np.ndarray):
        """
        Select an action greedily from the Q-network given the state
        :param state: the current state
        :return: the action to take
        """
        state = np.array(state) / 255.0
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        with torch.no_grad():
            q_values = self.policy_network(state)
            _, action = q_values.max(1)
            return action.item()