Exemplo n.º 1
0
class Trainer:
    """Train DQN model."""
    def __init__(self, params: dict):

        self.params = params
        self.device = self.params.device

        self.game = params.game

        if self.game == "health":
            viz_env = "VizdoomHealthGathering-v0"
            self.load_path = "models/health/"
        elif self.game == "defend":
            viz_env = "VizdoomBasic-v0"
            self.load_path = "models/defend/"
        elif self.game == "center":
            viz_env = "VizdoomDefendCenter-v0"
            self.load_path = "models/center"

        # Initialize the environment
        self.env = gym.make(viz_env)
        self.num_actions = self.env.action_space.n

        # Intitialize both deep Q networks
        self.target_net = DQN(60, 80,
                              num_actions=self.num_actions).to(self.device)
        self.pred_net = DQN(60, 80,
                            num_actions=self.num_actions).to(self.device)

        self.optimizer = torch.optim.Adam(self.pred_net.parameters(), lr=2e-5)

        # load a pretrained model
        if self.params.load_model:

            checkpoint = torch.load(self.load_path + "full_model.pk",
                                    map_location=torch.device(self.device))

            self.pred_net.load_state_dict(checkpoint["model_state_dict"])

            self.optimizer.load_state_dict(
                checkpoint["optimizer_state_dict"], )

            self.replay_memory = checkpoint["replay_memory"]
            self.steps = checkpoint["steps"]
            self.learning_steps = checkpoint["learning_steps"]
            self.losses = checkpoint["losses"]
            self.frame_stack = checkpoint["frame_stack"]
            self.params = checkpoint["params"]
            self.params.start_decay = params.start_decay
            self.params.end_decay = params.end_decay
            self.episode = checkpoint["episode"]
            self.epsilon = checkpoint["epsilon"]
            self.stack_size = self.params.stack_size

        # training from scratch
        else:
            # weight init
            self.pred_net.apply(init_weights)

            # init replay memory
            self.replay_memory = ReplayMemory(10000)

            # init frame stack
            self.stack_size = self.params.stack_size
            self.frame_stack = deque(maxlen=self.stack_size)

            # track steps for target network update control
            self.steps = 0
            self.learning_steps = 0

            # loss logs
            self.losses = AverageMeter()

            self.episode = 0

            # epsilon decay parameters
            self.epsilon = self.params.eps_start

        # set target network to prediction network
        self.target_net.load_state_dict(self.pred_net.state_dict())
        self.target_net.eval()

        # move models to GPU
        if self.device == "cuda:0":
            self.target_net = self.target_net.to(self.device)
            self.pred_net = self.pred_net.to(self.device)

        # epsilon decay
        self.epsilon_start = self.params.eps_start

        # tensorboard
        self.writer = SummaryWriter()

    def reset_stack(self):
        """reset frame stack."""

        self.frame_stack = deque(maxlen=self.stack_size)

    def update_stack(self, observation: Tensor):
        """
        Update the frame stack with the an observation
        This function will create stacks of four consequtive
        images for the model to learn from.
        """

        image = self.preprocess(observation)

        self.frame_stack.append(image)

    def preprocess(self, observation: np.ndarray) -> Tensor:
        """
        Preprocess images to be used by DQN.
        This function will take a screen grab of the current time step of the game
        and transform it from a RGB image to greyscale, while resizing the height
        and width.

        Parameters
        ----------
        observation : np.ndarray
            RGB image (buffer) from DOOM env.


        Returns
        ----------
        Tensor
            Scaled down and grayscaled version of buffer image.

        """

        # convert to grayscale, and reshape to be quarter of original size
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Grayscale(),
            transforms.Resize((60, 80))
        ])

        # return the transformed image
        return transform(observation)

    def epsilon_decay(self):
        """Calculate decay of epsilon value over time."""

        # only decay between [100,000, 300,000]
        if (self.learning_steps < self.params.start_decay
                or self.learning_steps > self.params.end_decay):
            return

        decay_rate = (self.params.eps_start - self.params.eps_end) / (
            self.params.end_decay - self.params.start_decay)
        self.epsilon = self.params.eps_start - (
            decay_rate * (self.learning_steps - self.params.start_decay))

    def select_action(self, state: Tensor, num_actions: int) -> Tensor:
        """
        Select_action uses a epsilon greedy method to select and action.
        This involves both greedy action selection and random action
        selection for exploration of our agent.

        Parameters
        ----------
        state : Tensor
            Stacked observations.
        num_actions : int
            Total number of actions available to the agent.

        Returns
        ----------
        Tensor
            Single scalar action value stored in tensor.
        """
        # threshold for exploration
        sample = random.random()

        # update epsilon value
        self.epsilon_decay()

        # exploit
        if sample > self.epsilon:
            with torch.no_grad():
                # choose action with highest q-value
                state = state.unsqueeze(0).to(self.device)
                return self.pred_net(state).max(1)[1]

        # explore
        else:
            # select randomly from set of actions
            return torch.tensor([random.randrange(num_actions)],
                                dtype=torch.long,
                                device=self.device)

    def shape_reward(self, reward: int, action: int, done: bool) -> int:
        """
        Shape the reward returned by environment to facilitate faster learning.
        Large values provided by VizDoom destabilize learning.

        Parameters
        ----------
        reward : int
            Reward from environment.
        action : int
            Value corresponding to agent action.
        done : bool
            Boolean indicating episode end.

        Returns
        ----------
        int
            Reshaped reward value

        """

        # missed shot

        if self.game == "defend":
            if action == 2 and reward < 0:
                reward = -0.1

            # terminal state
            elif done:
                reward = 0

            # movement is small negative
            elif reward < 0:
                reward = -0.01

            # shot hit
            elif reward > 0:
                reward = 1.0

        elif self.game == "health":
            if reward > 0:
                reward = 0.01

            else:
                reward = -1.0

        elif self.game == "center":
            if action == 2 and reward <= 0:
                reward = -0.1

        return reward

    def end_condition(self, reward: int) -> bool:
        """
        Determine end condition of game (shooting monster, dying).

        Parameters
        ----------
        reward : int
            Reward receieved.

        """

        if self.game == "health":
            return reward < 0

        elif self.game == "defend":
            return reward > 0

        elif self.game == "center":
            return reward == -1.0

    def train_dqn(self):
        """
        Perform optimization on the DQN given a batch of randomly
        sampled transitions from the replay buffer.
        """

        # wait until replay buffer full before starting model training
        if len(self.replay_memory) < 5000:
            return

        self.learning_steps += 1

        # sample batch from replay memory
        batch = self.replay_memory.sample(self.params.batch_size)

        # extract new states
        new_states = [x[2] for x in batch]

        # can't transition to None (termination)
        non_terminating_filter = torch.tensor(
            tuple(map(lambda s: s is not None, new_states)),
            device=self.device,
        )
        non_terminating = torch.stack([s for s in new_states
                                       if s is not None]).to(self.device)

        # extract states, actions and rewards
        states = torch.stack([x[0] for x in batch]).to(self.device)
        actions = torch.stack([x[1] for x in batch]).to(self.device)
        rewards = torch.stack([x[3] for x in batch]).to(self.device)

        # network predictions
        predicted = self.pred_net(states)
        action_val = predicted.gather(1, actions)

        # init 0 for terminal transitions
        target_vals = torch.zeros(self.params.batch_size, device=self.device)

        # calculate maximum action for new states
        target_vals[non_terminating_filter] = (
            self.target_net(non_terminating).max(dim=1)[0]).detach()

        target_vals = target_vals.unsqueeze(1)

        # target for TD error
        target_update = (self.params.gamma * target_vals) + rewards

        # huber loss for TD error
        huber_loss = F.smooth_l1_loss(action_val, target_update)

        # compute gradient values
        self.optimizer.zero_grad()
        huber_loss.backward()

        # gradient clipping for stability
        for param in self.pred_net.parameters():
            param.grad.data.clamp_(-1, 1)

        # backprop
        self.optimizer.step()

    def train(self):
        """Run a single epoch of training."""

        self.steps = 0
        self.pred_net.train()
        self.target_net.eval()

        # tqdm is a cool thing
        pbar = tqdm(range(self.episode, self.params.episodes), unit="episodes")
        for episode in pbar:
            # tracking number of steps
            pbar.set_description("eps: {} ls: {}, steps: {}".format(
                self.epsilon, self.learning_steps, self.steps))

            episode_steps = 0
            episode_sum = 0
            done = False
            self.env.reset()

            # frame skipping vars
            num_skipped = 0
            skipped_rewards = 0

            # first action selected randomly
            action = torch.tensor([self.env.action_space.sample()],
                                  device=self.device)

            self.reset_stack()

            # until episode termination
            while not done:
                # take action
                action_val = action.detach().clone().item()
                observation, reward, done, _ = self.env.step(action_val)

                reward = self.shape_reward(reward, action_val, done)

                # cumulative sum of skipped frame rewards
                skipped_rewards += reward

                episode_sum += reward
                episode_steps += 1

                # only want to stack every four frames
                if num_skipped == self.params.skip_frames or reward < 0 or done:

                    # reset counter
                    num_skipped = 0

                    # get old stack, and update stack with current observation
                    if len(self.frame_stack) > 0:
                        old_stack = torch.cat(tuple(self.frame_stack),
                                              axis=0).to(self.device)
                        curr_size, _, _ = old_stack.shape

                    else:
                        curr_size = 0

                    self.update_stack(observation)

                    # frame stack
                    if not done:
                        updated_stack = torch.cat(tuple(self.frame_stack),
                                                  axis=0).to(self.device)
                    else:
                        # when we've reached a terminal state
                        updated_stack = None

                    # need two stacks for transition
                    if curr_size == self.params.stack_size:
                        self.replay_memory.add_memory(
                            old_stack,
                            action,
                            updated_stack,
                            torch.tensor([skipped_rewards],
                                         device=self.device),
                        )

                    skipped_rewards = 0

                    # if we can select action using frame stack
                    if len(self.frame_stack) == 4:
                        action = self.select_action(
                            torch.cat(tuple(self.frame_stack),
                                      axis=0).to(self.device),
                            self.num_actions,
                        )
                        self.steps += 1

                    self.train_dqn()

                    # update target network
                    if self.steps % 2000 == 0:
                        self.target_net.load_state_dict(
                            self.pred_net.state_dict())
                        self.target_net.eval()

                    # full parameter saving (expensive)
                    if self.learning_steps > 0 and self.learning_steps % 10000 == 0:

                        torch.save(self.pred_net.state_dict(),
                                   self.load_path + "model.pk")

                        # save full model in case of restart
                        torch.save(
                            {
                                "episode": episode,
                                "steps": self.steps,
                                "learning_steps": self.learning_steps,
                                "model_state_dict": self.pred_net.state_dict(),
                                "optimizer_state_dict":
                                self.optimizer.state_dict(),
                                "replay_memory": self.replay_memory,
                                "losses": self.losses,
                                "params": self.params,
                                "frame_stack": self.frame_stack,
                                "epsilon": self.epsilon,
                            },
                            self.load_path + "full_model.pk",
                        )

                else:
                    num_skipped += 1

            self.writer.add_scalar("Average Reward",
                                   episode_sum / episode_steps, episode)

        self.env.close()

    def evaluate(self):
        """
        Visually evalate model performance by having it follow a greedy
        policy defined by the DQN.
        """

        # load pretrained model
        self.pred_net.load_state_dict(
            torch.load(self.load_path + "model.pk",
                       map_location=torch.device(self.device)))
        self.pred_net.eval()

        self.eps = 0.1

        steps = 0
        for episode in tqdm(range(self.params.episodes),
                            desc="episodes",
                            unit="episodes"):

            episode_sum = 0
            episode_steps = 0
            done = False
            self.env.reset()

            # For frame skipping
            num_skipped = 0

            action = self.env.action_space.sample()

            self.reset_stack()

            while not done:
                self.env.render()
                observation, reward, done, _ = self.env.step(action)

                # only want to stack every four frames
                if num_skipped == self.params.skip_frames - 1:

                    # reset counter
                    num_skipped = 0

                    # get old stack, and update stack with current observation
                    if len(self.frame_stack) > 0:
                        old_stack = torch.cat(tuple(self.frame_stack), axis=0)
                        curr_size, _, _ = old_stack.shape

                    else:
                        curr_size = 0

                    self.update_stack(observation)

                    # if we can select action using frame stack
                    if len(self.frame_stack) == 4:
                        action = self.select_action(
                            torch.cat(tuple(self.frame_stack)),
                            self.num_actions)

                else:
                    num_skipped += 1
Exemplo n.º 2
0
class Agent:

    def __init__(self, max_memory, batch_size, action_size, atom_size, input_size, kernel_size):
        self.z = np.linspace(V_MIN, V_MAX, ATOM_SIZE)
        self.action_size = action_size
        self.epsilon = EPSILON
        self.batch_size = batch_size
        self.atom_size = atom_size
        self.memory = ReplayMemory(max_memory)
        self.brain = RainbowDQN(action_size=action_size, atom_size=atom_size,
                                input_size=input_size, kernel_size=kernel_size)
        self.target_brain = RainbowDQN(action_size=action_size, atom_size=atom_size,
                                       input_size=input_size, kernel_size=kernel_size)
        self.target_brain.load_state_dict(self.brain.state_dict())
        self.optim = optim.Adam(self.brain.parameters(), lr=0.001)

    def step(self, state_input):
        probs = self.brain(state_input)
        best_action = self.select_best_action(probs)
        return best_action

    def select_best_action(self, probs):
        numpy_probs = self.variable_to_numpy(probs)
        z_probs = np.multiply(numpy_probs, self.z)
        best_action = np.sum(z_probs, axis=1).argmax()
        # best_action = np.argmax(numpy_probs, axis=1)
        return best_action

    def store_states(self, states, best_action, reward, done, next_states):
        td = self.calculate_td(states, best_action, reward, done, next_states)
        self.memory.add_memory(states, best_action, reward, done, next_states, td=td)

    def variable_to_numpy(self, probs):
        # probs is a list of softmax prob
        numpy_probs = probs.data.numpy()
        return numpy_probs

    #TODO find out why td does not get -100 reward
    def calculate_td(self, states, best_action, reward, done, next_states):
        probs = self.brain(states)
        numpy_probs = self.variable_to_numpy(probs)
        # states_prob = np.multiply(numpy_probs, self.z)
        # states_q_value = np.sum(states_prob, axis=1)[best_action]
        states_q_value = numpy_probs[0][best_action]

        next_probs = self.brain(next_states)
        numpy_next_probs = self.variable_to_numpy(next_probs)
        # next_states_prob = np.multiply(numpy_next_probs, self.z)
        # max_next_states_q_value = np.sum(next_states_prob, axis=1).max()
        max_next_states_q_value = np.max(numpy_next_probs, axis=1)[0]

        if done:
            td = reward - states_q_value
        else:
            td = (reward + gamma * max_next_states_q_value) - states_q_value

        return abs(td)

    def learn(self):
        # make sure that there is at least an amount of batch_size before training it
        if self.memory.count < self.batch_size:
            return

        tree_indexes, tds, batches = self.memory.get_memory(self.batch_size)
        total_loss = None
        for index, batch in enumerate(batches):

            # fixme fix this None type
            if batch is None:
                continue

            state_input = batch[0]
            best_action = batch[1]
            reward = batch[2]
            done = batch[3]
            next_state_input = batch[4]

            current_q = self.brain(state_input)
            next_best_action = self.step(next_state_input)
            # max_current_q = torch.max(current_q)

            next_z_prob = self.target_brain(next_state_input)
            next_z_prob = self.variable_to_numpy(next_z_prob)

            # target = reward + (1 - done) * gamma * next_z_prob.data[0][next_best_action]
            # target = Variable(torch.FloatTensor([target]))

            #TODO finish single dqn with per

            target_z_prob = np.zeros([self.action_size, ATOM_SIZE], dtype=np.float32)
            if done:
                Tz = min(V_MAX, max(V_MIN, reward))
                b = (Tz - V_MIN) / (self.z[1] - self.z[0])
                m_l = math.floor(b)
                m_u = math.ceil(b)
                target_z_prob[best_action][m_l] += (m_u - b)
                target_z_prob[best_action][m_u] += (b - m_l)
            else:
                for z_index in range(len(next_z_prob)):
                    Tz = min(V_MAX, max(V_MIN, reward + gamma * self.z[z_index]))
                    b = (Tz - V_MIN) / (self.z[1] - self.z[0])
                    m_l = math.floor(b)
                    m_u = math.ceil(b)

                    target_z_prob[best_action][m_l] += next_z_prob[next_best_action][z_index] * (m_u - b)
                    target_z_prob[best_action][m_u] += next_z_prob[next_best_action][z_index] * (b - m_l)
            target_z_prob = Variable(torch.from_numpy(target_z_prob))

            # backward propagate
            output_prob = self.brain(state_input)[0]
            loss = -torch.sum(target_z_prob * torch.log(output_prob + 1e-8))

            # loss = F.mse_loss(max_current_q, target)
            total_loss = loss if total_loss is None else total_loss + loss

            # update td
            td = self.calculate_td(state_input, best_action, reward, done, next_state_input)
            tds[index] = td

        self.optim.zero_grad()
        total_loss.backward()
        self.optim.step()

        # load brain to target brain
        self.target_brain.load_state_dict(self.brain.state_dict())

        self.memory.update_memory(tree_indexes, tds)