예제 #1
0
파일: display.py 프로젝트: mcimpoi/unreal
class Display(object):
    def __init__(self, display_size):
        pygame.init()

        self.surface = pygame.display.set_mode(display_size, 0, 24)
        pygame.display.set_caption('UNREAL')

        self.action_size = Environment.get_action_size(flags.env_type,
                                                       flags.env_name)
        self.objective_size = Environment.get_objective_size(
            flags.env_type, flags.env_name)
        self.global_network = UnrealModel(self.action_size,
                                          self.objective_size,
                                          -1,
                                          flags.use_lstm,
                                          flags.use_pixel_change,
                                          flags.use_value_replay,
                                          flags.use_reward_prediction,
                                          0.0,
                                          0.0,
                                          "/cpu:0",
                                          for_display=True)
        self.environment = Environment.create_environment(
            flags.env_type,
            flags.env_name,
            env_args={
                'episode_schedule': flags.split,
                'log_action_trace': flags.log_action_trace,
                'max_states_per_scene': flags.episodes_per_scene,
                'episodes_per_scene_test': flags.episodes_per_scene
            })
        self.font = pygame.font.SysFont(None, 20)
        self.value_history = ValueHistory()
        self.state_history = StateHistory()
        self.episode_reward = 0

    def update(self, sess):
        self.surface.fill(BLACK)
        self.process(sess)
        pygame.display.update()

    def choose_action(self, pi_values):
        return np.random.choice(range(len(pi_values)), p=pi_values)

    def scale_image(self, image, scale):
        return image.repeat(scale, axis=0).repeat(scale, axis=1)

    def draw_text(self, str, left, top, color=WHITE):
        text = self.font.render(str, True, color, BLACK)
        text_rect = text.get_rect()
        text_rect.left = left
        text_rect.top = top
        self.surface.blit(text, text_rect)

    def draw_center_text(self, str, center_x, top):
        text = self.font.render(str, True, WHITE, BLACK)
        text_rect = text.get_rect()
        text_rect.centerx = center_x
        text_rect.top = top
        self.surface.blit(text, text_rect)

    def show_pixel_change(self, pixel_change, left, top, rate, label):
        """
    Show pixel change
    """
        pixel_change_ = np.clip(pixel_change * 255.0 * rate, 0.0, 255.0)
        data = pixel_change_.astype(np.uint8)
        data = np.stack([data for _ in range(3)], axis=2)
        data = self.scale_image(data, 4)
        image = pygame.image.frombuffer(data, (20 * 4, 20 * 4), 'RGB')
        self.surface.blit(image, (left + 8 + 4, top + 8 + 4))
        self.draw_center_text(label, left + 100 / 2, top + 100)

    def show_policy(self, pi):
        """
    Show action probability.
    """
        start_x = 10

        y = 150

        for i in range(len(pi)):
            width = pi[i] * 100
            pygame.draw.rect(self.surface, WHITE, (start_x, y, width, 10))
            y += 20
        self.draw_center_text("PI", 50, y)

    def show_image(self, state):
        """
    Show input image
    """
        state_ = state * 255.0
        data = state_.astype(np.uint8)
        image = pygame.image.frombuffer(data, (84, 84), 'RGB')
        self.surface.blit(image, (8, 8))
        self.draw_center_text("input", 50, 100)

    def show_value(self):
        if self.value_history.is_empty:
            return

        min_v = float("inf")
        max_v = float("-inf")

        values = self.value_history.values

        for v in values:
            min_v = min(min_v, v)
            max_v = max(max_v, v)

        top = 150
        left = 150
        width = 100
        height = 100
        bottom = top + width
        right = left + height

        d = max_v - min_v
        last_r = 0.0
        for i, v in enumerate(values):
            r = (v - min_v) / d
            if i > 0:
                x0 = i - 1 + left
                x1 = i + left
                y0 = bottom - last_r * height
                y1 = bottom - r * height
                pygame.draw.line(self.surface, BLUE, (x0, y0), (x1, y1), 1)
            last_r = r

        pygame.draw.line(self.surface, WHITE, (left, top), (left, bottom), 1)
        pygame.draw.line(self.surface, WHITE, (right, top), (right, bottom), 1)
        pygame.draw.line(self.surface, WHITE, (left, top), (right, top), 1)
        pygame.draw.line(self.surface, WHITE, (left, bottom), (right, bottom),
                         1)

        self.draw_center_text("V", left + width / 2, bottom + 10)

    def show_reward_prediction(self, rp_c, reward):
        start_x = 310
        reward_index = 0
        if reward == 0:
            reward_index = 0
        elif reward > 0:
            reward_index = 1
        elif reward < 0:
            reward_index = 2

        y = 150

        labels = ["0", "+", "-"]

        for i in range(len(rp_c)):
            width = rp_c[i] * 100

            if i == reward_index:
                color = RED
            else:
                color = WHITE
            pygame.draw.rect(self.surface, color, (start_x + 15, y, width, 10))
            self.draw_text(labels[i], start_x, y - 1, color)
            y += 20

        self.draw_center_text("RP", start_x + 100 / 2, y)

    def show_reward(self):
        self.draw_text("REWARD: {}".format(int(self.episode_reward)), 310, 10)

    def process(self, sess):
        last_action = self.environment.last_action
        last_reward = np.clip(self.environment.last_reward, -1, 1)
        last_action_reward = ExperienceFrame.concat_action_and_reward(
            last_action, self.action_size, last_reward,
            self.environment.last_state)

        if not flags.use_pixel_change:
            pi_values, v_value = self.global_network.run_base_policy_and_value(
                sess, self.environment.last_state, last_action_reward)
        else:
            pi_values, v_value, pc_q = self.global_network.run_base_policy_value_pc_q(
                sess, self.environment.last_state, last_action_reward)
        self.value_history.add_value(v_value)

        action = self.choose_action(pi_values)
        state, reward, terminal, pixel_change = self.environment.process(
            action)
        self.episode_reward += reward

        if terminal:
            self.environment.reset()
            self.episode_reward = 0

        self.show_image(state['image'])
        self.show_policy(pi_values)
        self.show_value()
        self.show_reward()

        if flags.use_pixel_change:
            self.show_pixel_change(pixel_change, 100, 0, 3.0, "PC")
            self.show_pixel_change(pc_q[:, :, action], 200, 0, 0.4, "PC Q")

        if flags.use_reward_prediction:
            if self.state_history.is_full:
                rp_c = self.global_network.run_rp_c(sess,
                                                    self.state_history.states)
                self.show_reward_prediction(rp_c, reward)

        self.state_history.add_state(state)

    def get_frame(self):
        data = self.surface.get_buffer().raw
        return data
예제 #2
0
파일: display.py 프로젝트: kvas7andy/unreal
class Display(object):
    def __init__(self, display_size):
        pygame.init()

        self.surface = pygame.display.set_mode(display_size, 0, 24)
        name = 'UNREAL' if flags.segnet == 0 else "A3C ErfNet"
        pygame.display.set_caption(name)

        env_config = sim_config.get(flags.env_name)
        self.image_shape = [
            env_config.get('height', 88),
            env_config.get('width', 88)
        ]
        segnet_param_dict = {'segnet_mode': flags.segnet}
        is_training = tf.placeholder(tf.bool, name="training")
        map_file = env_config.get('objecttypes_file', '../../objectTypes.csv')
        self.label_mapping = pd.read_csv(map_file, sep=',', header=0)
        self.get_col_index()

        self.action_size = Environment.get_action_size(flags.env_type,
                                                       flags.env_name)
        self.objective_size = Environment.get_objective_size(
            flags.env_type, flags.env_name)
        self.global_network = UnrealModel(self.action_size,
                                          self.objective_size,
                                          -1,
                                          flags.use_lstm,
                                          flags.use_pixel_change,
                                          flags.use_value_replay,
                                          flags.use_reward_prediction,
                                          0.0,
                                          0.0,
                                          "/gpu:0",
                                          segnet_param_dict=segnet_param_dict,
                                          image_shape=self.image_shape,
                                          is_training=is_training,
                                          n_classes=flags.n_classes,
                                          segnet_lambda=flags.segnet_lambda,
                                          dropout=flags.dropout,
                                          for_display=True)
        self.environment = Environment.create_environment(
            flags.env_type,
            flags.env_name,
            flags.termination_time_sec,
            env_args={
                'episode_schedule': flags.split,
                'log_action_trace': flags.log_action_trace,
                'max_states_per_scene': flags.episodes_per_scene,
                'episodes_per_scene_test': flags.episodes_per_scene
            })
        self.font = pygame.font.SysFont(None, 20)
        self.value_history = ValueHistory()
        self.state_history = StateHistory()
        self.episode_reward = 0

    def update(self, sess):
        self.surface.fill(BLACK)
        self.process(sess)
        pygame.display.update()

    def choose_action(self, pi_values):
        return np.random.choice(range(len(pi_values)), p=pi_values)

    def scale_image(self, image, scale):
        return image.repeat(scale, axis=0).repeat(scale, axis=1)

    def draw_text(self, str, left, top, color=WHITE):
        text = self.font.render(str, True, color, BLACK)
        text_rect = text.get_rect()
        text_rect.left = left
        text_rect.top = top
        self.surface.blit(text, text_rect)

    def draw_center_text(self, str, center_x, top):
        text = self.font.render(str, True, WHITE, BLACK)
        text_rect = text.get_rect()
        text_rect.centerx = center_x
        text_rect.top = top
        self.surface.blit(text, text_rect)

    def show_pixel_change(self, pixel_change, left, top, rate, label):
        """
    Show pixel change
    """
        if "PC" in label:
            pixel_change_ = np.clip(pixel_change * 255.0 * rate, 0.0, 255.0)
            data = pixel_change_.astype(np.uint8)
            data = np.stack([data for _ in range(3)], axis=2)
            data = self.scale_image(data, 4)
            #print("PC shape", data.shape)
            image = pygame.image.frombuffer(data, (20 * 4, 20 * 4), 'RGB')
        else:
            pixel_change = self.scale_image(pixel_change, 2)
            #print("Preds shape", pixel_change.shape)
            image = pygame.image.frombuffer(pixel_change.astype(
                np.uint8), (self.image_shape[0] * 2, self.image_shape[1] * 2),
                                            'RGB')
        self.surface.blit(image, (2 * left + 16 + 8, 2 * top + 16 + 8))
        self.draw_center_text(label, 2 * left + 200 / 2, 2 * top + 200)

    def show_policy(self, pi):
        """
    Show action probability.
    """
        start_x = 10

        y = 150

        for i in range(len(pi)):
            width = pi[i] * 100
            pygame.draw.rect(self.surface, WHITE,
                             (2 * start_x, 2 * y, 2 * width, 2 * 10))
            y += 20
        self.draw_center_text("PI", 2 * 50, 2 * y)

    def show_image(self, state):
        """
    Show input image
    """
        state_ = state * 255.0
        data = state_.astype(np.uint8)
        data = self.scale_image(data, 2)
        image = pygame.image.frombuffer(
            data, (self.image_shape[0] * 2, self.image_shape[1] * 2), 'RGB')
        self.surface.blit(image, (8 * 2, 8 * 2))
        self.draw_center_text("input", 2 * 50, 2 * 100)

    def show_value(self):
        if self.value_history.is_empty:
            return

        min_v = float("inf")
        max_v = float("-inf")

        values = self.value_history.values

        for v in values:
            min_v = min(min_v, v)
            max_v = max(max_v, v)

        top = 150 * 2
        left = 150 * 2
        width = 100 * 2
        height = 100 * 2
        bottom = top + width
        right = left + height

        d = max_v - min_v
        last_r = 0.0
        for i, v in enumerate(values):
            r = (v - min_v) / d
            if i > 0:
                x0 = i - 1 + left
                x1 = i + left
                y0 = bottom - last_r * height
                y1 = bottom - r * height
                pygame.draw.line(self.surface, BLUE, (x0, y0), (x1, y1), 1)
            last_r = r

        pygame.draw.line(self.surface, WHITE, (left, top), (left, bottom), 1)
        pygame.draw.line(self.surface, WHITE, (right, top), (right, bottom), 1)
        pygame.draw.line(self.surface, WHITE, (left, top), (right, top), 1)
        pygame.draw.line(self.surface, WHITE, (left, bottom), (right, bottom),
                         1)

        self.draw_center_text("V", left + width / 2, bottom + 10)

    def show_reward_prediction(self, rp_c, reward):
        start_x = 310
        reward_index = 0
        if reward == 0:
            reward_index = 0
        elif reward > 0:
            reward_index = 1
        elif reward < 0:
            reward_index = 2

        y = 150

        labels = ["0", "+", "-"]

        for i in range(len(rp_c)):
            width = rp_c[i] * 100

            if i == reward_index:
                color = RED
            else:
                color = WHITE
            pygame.draw.rect(self.surface, color,
                             (2 * start_x + 2 * 15, 2 * y, 2 * width, 2 * 10))
            self.draw_text(labels[i], 2 * start_x, 2 * y - 2 * 1, color)
            y += 20

        self.draw_center_text("RP", 2 * start_x + 2 * 100 / 2, y)

    def show_reward(self):
        self.draw_text("REWARD: {:.4}".format(float(self.episode_reward)), 300,
                       2 * 10)

    def process(self, sess):
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])
        #sess.run(tf.initialize_all_variables())

        last_action = self.environment.last_action
        last_reward = self.environment.last_reward
        last_action_reward = ExperienceFrame.concat_action_and_reward(
            last_action, self.action_size, last_reward,
            self.environment.last_state)
        preds = None
        mode = "segnet" if flags.segnet >= 2 else ""
        mode = ""  #don't want preds
        if not flags.use_pixel_change:
            pi_values, v_value, preds = self.global_network.run_base_policy_and_value(
                sess,
                self.environment.last_state,
                last_action_reward,
                mode=mode)
        else:
            pi_values, v_value, pc_q = self.global_network.run_base_policy_value_pc_q(
                sess, self.environment.last_state, last_action_reward)

        #print(preds)
        self.value_history.add_value(v_value)

        prev_state = self.environment.last_state

        action = self.choose_action(pi_values)
        state, reward, terminal, pixel_change = self.environment.process(
            action)
        self.episode_reward += reward

        if terminal:
            self.environment.reset()
            self.episode_reward = 0

        self.show_image(state['image'])
        self.show_policy(pi_values)
        self.show_value()
        self.show_reward()

        if not flags.use_pixel_change:
            if preds is not None:
                self.show_pixel_change(self.label_to_rgb(preds), 100, 0, 3.0,
                                       "Preds")
                self.show_pixel_change(self.label_to_rgb(state['objectType']),
                                       200, 0, 0.4, "Segm Mask")
        else:
            self.show_pixel_change(pixel_change, 100, 0, 3.0, "PC")
            self.show_pixel_change(pc_q[:, :, action], 200, 0, 0.4, "PC Q")

        if flags.use_reward_prediction:
            if self.state_history.is_full:
                rp_c = self.global_network.run_rp_c(sess,
                                                    self.state_history.states)
                self.show_reward_prediction(rp_c, reward)

        self.state_history.add_state(state)

    def get_frame(self):
        data = self.surface.get_buffer().raw
        return data

    def get_col_index(self):
        ind_col = self.label_mapping[["index", "color"]].values
        index = ind_col[:, 0].astype(np.int)
        self.index, ind = np.unique(index, return_index=True)
        self.col = np.array([[int(x) for x in col.split('_')]
                             for col in ind_col[ind, 1]])

    def label_to_rgb(self, labels):
        #print(self.col)
        rgb_img = self.col[np.where(self.index[np.newaxis, :] == labels.ravel(
        )[:, np.newaxis])[1]].reshape(labels.shape + (3, ))
        return rgb_img