Exemplo n.º 1
0
class AtariGame(Game):
    def __init__(self,
                 rom_path=_default_rom_path,
                 frame_skip=4, history_length=4,
                 resize_mode='scale', resized_rows=84, resized_cols=84, crop_offset=8,
                 display_screen=False, max_null_op=30,
                 replay_memory_size=1000000,
                 replay_start_size=100,
                 death_end_episode=True):
        super(AtariGame, self).__init__()
        self.rng = get_numpy_rng()
        self.ale = ale_load_from_rom(rom_path=rom_path, display_screen=display_screen)
        self.start_lives = self.ale.lives()
        self.action_set = self.ale.getMinimalActionSet()
        self.resize_mode = resize_mode
        self.resized_rows = resized_rows
        self.resized_cols = resized_cols
        self.crop_offset = crop_offset
        self.frame_skip = frame_skip
        self.history_length = history_length
        self.max_null_op = max_null_op
        self.death_end_episode = death_end_episode
        self.screen_buffer_length = 2
        self.screen_buffer = numpy.empty((self.screen_buffer_length,
                                          self.ale.getScreenDims()[1], self.ale.getScreenDims()[0]),
                                         dtype='uint8')
        self.replay_memory = ReplayMemory(state_dim=(resized_rows, resized_cols),
                                          history_length=history_length,
                                          memory_size=replay_memory_size,
                                          replay_start_size=replay_start_size)
        self.start()

    def start(self):
        self.ale.reset_game()
        null_op_num = self.rng.randint(self.screen_buffer_length,
                                       max(self.max_null_op + 1, self.screen_buffer_length + 1))
        for i in range(null_op_num):
            self.ale.act(0)
            self.ale.getScreenGrayscale(self.screen_buffer[i % self.screen_buffer_length, :, :])
        self.total_reward = 0
        self.episode_reward = 0
        self.episode_step = 0
        self.max_episode_step = DEFAULT_MAX_EPISODE_STEP
        self.start_lives = self.ale.lives()

    def force_restart(self):
        self.start()
        self.replay_memory.clear()


    def begin_episode(self, max_episode_step=DEFAULT_MAX_EPISODE_STEP):
        """
            Begin an episode of a game instance. We can play the game for a maximum of
            `max_episode_step` and after that, we are forced to restart
        """
        if self.episode_step > self.max_episode_step or self.ale.game_over():
            self.start()
        else:
            for i in range(self.screen_buffer_length):
                self.ale.act(0)
                self.ale.getScreenGrayscale(self.screen_buffer[i % self.screen_buffer_length, :, :])
        self.max_episode_step = max_episode_step
        self.start_lives = self.ale.lives()
        self.episode_reward = 0
        self.episode_step = 0

    @property
    def episode_terminate(self):
        termination_flag = self.ale.game_over() or self.episode_step >= self.max_episode_step
        if self.death_end_episode:
            return (self.ale.lives() < self.start_lives) or termination_flag
        else:
            return termination_flag

    @property
    def state_enabled(self):
        return self.replay_memory.size >= self.replay_memory.history_length

    def get_observation(self):
        image = self.screen_buffer.max(axis=0)
        if 'crop' == self.resize_mode:
            original_rows, original_cols = image.shape
            new_resized_rows = int(round(
                float(original_rows) * self.resized_cols / original_cols))
            resized = cv2.resize(image, (self.resized_cols, new_resized_rows),
                                 interpolation=cv2.INTER_LINEAR)
            crop_y_cutoff = new_resized_rows - self.crop_offset - self.resized_rows
            img = resized[crop_y_cutoff:
            crop_y_cutoff + self.resized_rows, :]
            return img
        else:
            return cv2.resize(image, (self.resized_cols, self.resized_rows),
                              interpolation=cv2.INTER_LINEAR)

    def play(self, a):
        assert not self.episode_terminate,\
            "Warning, the episode seems to have terminated. " \
            "We need to call either game.begin_episode(max_episode_step) to continue a new " \
            "episode or game.start() to force restart."
        self.episode_step += 1
        reward = 0.0
        action = self.action_set[a]
        for i in range(self.frame_skip):
            reward += self.ale.act(action)
            self.ale.getScreenGrayscale(self.screen_buffer[i % self.screen_buffer_length, :, :])
        self.total_reward += reward
        self.episode_reward += reward
        ob = self.get_observation()
        terminate_flag = self.episode_terminate
        self.replay_memory.append(ob, a, numpy.clip(reward, -1, 1), terminate_flag)
        return reward, terminate_flag
Exemplo n.º 2
0
class Trainer():
    def __init__(self,
                 env,
                 network,
                 update_timestep=2000,
                 batch_size=512,
                 gamma=0.99,
                 epsilon=0.2,
                 c1=0.5,
                 c2=0.01,
                 lr=0.01,
                 weight_decay=0.0,
                 min_std=0.1):
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.memory = ReplayMemory(update_timestep)
        self.env = env
        self.policy_net = network
        self.policy_net.to(self.device)

        self.gamma = gamma
        self.epsilon = epsilon
        self.c1 = c1
        self.c2 = c2

        self.update_timestep = update_timestep
        self.min_std = min_std

        self.batch_size = batch_size
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(),
                                          lr=lr,
                                          weight_decay=weight_decay,
                                          betas=(0.9, 0.999))
        self.mse = nn.MSELoss()

        self.reward_log = []
        self.loss_log = []
        self.time_log = []
        self.num_updates = 0

    def ppo_update(self, num_epochs):
        #TODO: Fix wrong calculation of q_values

        self.num_updates += 1

        experience = self.memory.sample()
        #dimension of states need to be squeezed to not cause trouble in the
        #log_probs computation in the evaluate function.
        exp_states = torch.stack(experience.state).squeeze().float()
        exp_actions = torch.stack(experience.action)
        exp_rewards = experience.reward
        exp_dones = experience.done
        exp_log_probs = torch.stack(experience.log_prob).squeeze()

        #calculate q-values
        q_values = []
        discounted_reward = 0
        for reward, done in zip(reversed(exp_rewards), reversed(exp_dones)):
            if done:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            q_values.insert(0, discounted_reward)

        q_values = torch.tensor(q_values, device=self.device)
        q_values = (q_values - q_values.mean()) / (q_values.std() + 1e-5)

        dataset = TensorDataset(exp_states, exp_actions, exp_log_probs,
                                q_values)
        trainloader = DataLoader(dataset,
                                 batch_size=self.batch_size,
                                 shuffle=False)

        train_loss = 0
        num_iterations = 0
        for _ in range(num_epochs):
            for state_batch, action_batch, log_probs_batch, q_value_batch in trainloader:

                #evaluate previous states
                state_values, log_probs, dist_entropy = self.policy_net.evaluate(
                    state_batch, action_batch)

                # Calculate ratio (pi_theta / pi_theta__old):
                ratios = torch.exp(log_probs - log_probs_batch.detach())

                # Calculate Surrogate Loss:
                advantages = q_value_batch - state_values.detach()
                surr1 = ratios * advantages
                surr2 = torch.clamp(ratios, 1 - self.epsilon,
                                    1 + self.epsilon) * advantages

                actor_loss = torch.min(surr1, surr2)
                critic_loss = self.mse(state_values, q_value_batch)

                # - because of gradient ascent
                loss = -actor_loss + self.c1 * critic_loss - self.c2 * dist_entropy

                train_loss += loss.mean()

                # take gradient step
                self.optimizer.zero_grad()
                loss.mean().backward()
                self.optimizer.step()
                num_iterations += 1

        self.loss_log.append(train_loss / num_iterations)

    def train(self, num_episodes, num_epochs, max_timesteps, render=False):
        timestep = 0
        for i_episode in range(1, num_episodes + 1):
            state = self.env.reset()
            running_reward = 0
            for i_timestep in range(max_timesteps):
                timestep += 1

                # compute action
                state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
                prev_state = state
                with torch.no_grad():
                    action, action_log_prob = self.policy_net.act(state)

                state, reward, done, _ = self.env.step(action.cpu().numpy())
                running_reward += reward
                transition = Transition(prev_state, action, reward,
                                        action_log_prob, done)
                self.memory.push(transition)

                #Update policy network
                if timestep % self.update_timestep == 0:
                    self.ppo_update(num_epochs)
                    print("Policy updated")
                    self.memory.clear()
                    timestep = 0

                if render:
                    env.render()

                if done:
                    break

            print('Episode {} Done, \t length: {} \t reward: {}'.format(
                i_episode, i_timestep, running_reward))
            self.reward_log.append(int(running_reward))
            self.time_log.append(i_timestep)

    def plot_rewards(self):
        plt.plot(self.reward_log)

    def plot_loss(self):
        plt.plot(self.loss_log)
class AtariGame(Game):
    def __init__(self,
                 rom_path=_default_rom_path,
                 frame_skip=4,
                 history_length=4,
                 resize_mode='clip',
                 resized_rows=84,
                 resized_cols=84,
                 crop_offset=8,
                 display_screen=False,
                 max_null_op=30,
                 replay_memory_size=1000000,
                 replay_start_size=100,
                 death_end_episode=True):
        super(AtariGame, self).__init__()
        self.rng = get_numpy_rng()
        self.ale = ale_load_from_rom(rom_path=rom_path,
                                     display_screen=display_screen)
        self.start_lives = self.ale.lives()
        self.action_set = self.ale.getMinimalActionSet()
        self.resize_mode = resize_mode
        self.resized_rows = resized_rows
        self.resized_cols = resized_cols
        self.crop_offset = crop_offset
        self.frame_skip = frame_skip
        self.history_length = history_length
        self.max_null_op = max_null_op
        self.death_end_episode = death_end_episode
        self.screen_buffer_length = 2
        self.screen_buffer = numpy.empty(
            (self.screen_buffer_length, self.ale.getScreenDims()[1],
             self.ale.getScreenDims()[0]),
            dtype='uint8')
        self.replay_memory = ReplayMemory(state_dim=(resized_rows,
                                                     resized_cols),
                                          history_length=history_length,
                                          memory_size=replay_memory_size,
                                          replay_start_size=replay_start_size)
        self.start()

    def start(self):
        self.ale.reset_game()
        null_op_num = self.rng.randint(
            self.screen_buffer_length,
            max(self.max_null_op + 1, self.screen_buffer_length + 1))
        for i in range(null_op_num):
            self.ale.act(0)
            self.ale.getScreenGrayscale(
                self.screen_buffer[i % self.screen_buffer_length, :, :])
        self.total_reward = 0
        self.episode_reward = 0
        self.episode_step = 0
        self.max_episode_step = DEFAULT_MAX_EPISODE_STEP
        self.start_lives = self.ale.lives()

    def force_restart(self):
        self.start()
        self.replay_memory.clear()

    def begin_episode(self, max_episode_step=DEFAULT_MAX_EPISODE_STEP):
        """
            Begin an episode of a game instance. We can play the game for a maximum of
            `max_episode_step` and after that, we are forced to restart
        """
        if self.episode_step > self.max_episode_step or self.ale.game_over():
            self.start()
        else:
            for i in range(self.screen_buffer_length):
                self.ale.act(0)
                self.ale.getScreenGrayscale(
                    self.screen_buffer[i % self.screen_buffer_length, :, :])
        self.max_episode_step = max_episode_step
        self.start_lives = self.ale.lives()
        self.episode_reward = 0
        self.episode_step = 0

    @property
    def episode_terminate(self):
        termination_flag = self.ale.game_over(
        ) or self.episode_step >= self.max_episode_step
        if self.death_end_episode:
            return (self.ale.lives() < self.start_lives) or termination_flag
        else:
            return termination_flag

    @property
    def state_enabled(self):
        return self.replay_memory.size >= self.replay_memory.history_length

    def get_observation(self):
        image = self.screen_buffer.max(axis=0)

        if 'crop' == self.resize_mode:
            original_rows, original_cols = image.shape
            new_resized_rows = int(
                round(
                    float(original_rows) * self.resized_cols / original_cols))
            resized = cv2.resize(image, (self.resized_cols, new_resized_rows),
                                 interpolation=cv2.INTER_LINEAR)
            crop_y_cutoff = new_resized_rows - self.crop_offset - self.resized_rows
            img = resized[crop_y_cutoff:crop_y_cutoff + self.resized_rows, :]
            return img
        else:
            # plt.imshow(image, cmap='gray')
            # plt.show()
            return cv2.resize(image, (self.resized_cols, self.resized_rows),
                              interpolation=cv2.INTER_LINEAR)

    def play(self, a):
        assert not self.episode_terminate,\
            "Warning, the episode seems to have terminated. " \
            "We need to call either game.begin_episode(max_episode_step) to continue a new " \
            "episode or game.start() to force restart."
        self.episode_step += 1
        reward = 0.0
        action = self.action_set[int(a)]
        for i in range(self.frame_skip):
            reward += self.ale.act(action)
            self.ale.getScreenGrayscale(
                self.screen_buffer[i % self.screen_buffer_length, :, :])
        self.total_reward += reward
        self.episode_reward += reward
        ob = self.get_observation()
        # plt.imshow(ob, cmap="gray")
        # plt.show()

        terminate_flag = self.episode_terminate

        self.replay_memory.append(ob, a, numpy.clip(reward, -1, 1),
                                  terminate_flag)
        return reward, terminate_flag