예제 #1
0
class Evaluate(object):
  def __init__(self):
    self.action_size = Environment.get_action_size(flags.env_type, flags.env_name)
    self.objective_size = Environment.get_objective_size(flags.env_type, flags.env_name)
    self.global_network = UnrealModel(self.action_size,
                                      self.objective_size,
                                      -1,
                                      flags.use_lstm,
                                      flags.use_pixel_change,
                                      flags.use_value_replay,
                                      flags.use_reward_prediction,
                                      0.0,
                                      0.0,
                                      "/cpu:0",
                                      for_display=True)
    self.environment = Environment.create_environment(flags.env_type, flags.env_name,
                                                      env_args={'episode_schedule': flags.split,
                                                                'log_action_trace': flags.log_action_trace,
                                                                'max_states_per_scene': flags.episodes_per_scene,
                                                                'episodes_per_scene_test': flags.episodes_per_scene})
    self.episode_reward = 0

  def update(self, sess):
    self.process(sess)

  def is_done(self):
    return self.environment.is_all_scheduled_episodes_done()

  def choose_action(self, pi_values):
    return np.random.choice(range(len(pi_values)), p=pi_values)

  def process(self, sess):
    last_action = self.environment.last_action
    last_reward = np.clip(self.environment.last_reward, -1, 1)
    last_action_reward = ExperienceFrame.concat_action_and_reward(last_action, self.action_size,
                                                                  last_reward, self.environment.last_state)
    
    if not flags.use_pixel_change:
      pi_values, v_value = self.global_network.run_base_policy_and_value(sess,
                                                                         self.environment.last_state,
                                                                         last_action_reward)
    else:
      pi_values, v_value, pc_q = self.global_network.run_base_policy_value_pc_q(sess,
                                                                                self.environment.last_state,
                                                                                last_action_reward)
    action = self.choose_action(pi_values)
    state, reward, terminal, pixel_change = self.environment.process(action)
    self.episode_reward += reward
  
    if terminal:
      self.environment.reset()
      self.episode_reward = 0
예제 #2
0
파일: display.py 프로젝트: mcimpoi/unreal
class Display(object):
    def __init__(self, display_size):
        pygame.init()

        self.surface = pygame.display.set_mode(display_size, 0, 24)
        pygame.display.set_caption('UNREAL')

        self.action_size = Environment.get_action_size(flags.env_type,
                                                       flags.env_name)
        self.objective_size = Environment.get_objective_size(
            flags.env_type, flags.env_name)
        self.global_network = UnrealModel(self.action_size,
                                          self.objective_size,
                                          -1,
                                          flags.use_lstm,
                                          flags.use_pixel_change,
                                          flags.use_value_replay,
                                          flags.use_reward_prediction,
                                          0.0,
                                          0.0,
                                          "/cpu:0",
                                          for_display=True)
        self.environment = Environment.create_environment(
            flags.env_type,
            flags.env_name,
            env_args={
                'episode_schedule': flags.split,
                'log_action_trace': flags.log_action_trace,
                'max_states_per_scene': flags.episodes_per_scene,
                'episodes_per_scene_test': flags.episodes_per_scene
            })
        self.font = pygame.font.SysFont(None, 20)
        self.value_history = ValueHistory()
        self.state_history = StateHistory()
        self.episode_reward = 0

    def update(self, sess):
        self.surface.fill(BLACK)
        self.process(sess)
        pygame.display.update()

    def choose_action(self, pi_values):
        return np.random.choice(range(len(pi_values)), p=pi_values)

    def scale_image(self, image, scale):
        return image.repeat(scale, axis=0).repeat(scale, axis=1)

    def draw_text(self, str, left, top, color=WHITE):
        text = self.font.render(str, True, color, BLACK)
        text_rect = text.get_rect()
        text_rect.left = left
        text_rect.top = top
        self.surface.blit(text, text_rect)

    def draw_center_text(self, str, center_x, top):
        text = self.font.render(str, True, WHITE, BLACK)
        text_rect = text.get_rect()
        text_rect.centerx = center_x
        text_rect.top = top
        self.surface.blit(text, text_rect)

    def show_pixel_change(self, pixel_change, left, top, rate, label):
        """
    Show pixel change
    """
        pixel_change_ = np.clip(pixel_change * 255.0 * rate, 0.0, 255.0)
        data = pixel_change_.astype(np.uint8)
        data = np.stack([data for _ in range(3)], axis=2)
        data = self.scale_image(data, 4)
        image = pygame.image.frombuffer(data, (20 * 4, 20 * 4), 'RGB')
        self.surface.blit(image, (left + 8 + 4, top + 8 + 4))
        self.draw_center_text(label, left + 100 / 2, top + 100)

    def show_policy(self, pi):
        """
    Show action probability.
    """
        start_x = 10

        y = 150

        for i in range(len(pi)):
            width = pi[i] * 100
            pygame.draw.rect(self.surface, WHITE, (start_x, y, width, 10))
            y += 20
        self.draw_center_text("PI", 50, y)

    def show_image(self, state):
        """
    Show input image
    """
        state_ = state * 255.0
        data = state_.astype(np.uint8)
        image = pygame.image.frombuffer(data, (84, 84), 'RGB')
        self.surface.blit(image, (8, 8))
        self.draw_center_text("input", 50, 100)

    def show_value(self):
        if self.value_history.is_empty:
            return

        min_v = float("inf")
        max_v = float("-inf")

        values = self.value_history.values

        for v in values:
            min_v = min(min_v, v)
            max_v = max(max_v, v)

        top = 150
        left = 150
        width = 100
        height = 100
        bottom = top + width
        right = left + height

        d = max_v - min_v
        last_r = 0.0
        for i, v in enumerate(values):
            r = (v - min_v) / d
            if i > 0:
                x0 = i - 1 + left
                x1 = i + left
                y0 = bottom - last_r * height
                y1 = bottom - r * height
                pygame.draw.line(self.surface, BLUE, (x0, y0), (x1, y1), 1)
            last_r = r

        pygame.draw.line(self.surface, WHITE, (left, top), (left, bottom), 1)
        pygame.draw.line(self.surface, WHITE, (right, top), (right, bottom), 1)
        pygame.draw.line(self.surface, WHITE, (left, top), (right, top), 1)
        pygame.draw.line(self.surface, WHITE, (left, bottom), (right, bottom),
                         1)

        self.draw_center_text("V", left + width / 2, bottom + 10)

    def show_reward_prediction(self, rp_c, reward):
        start_x = 310
        reward_index = 0
        if reward == 0:
            reward_index = 0
        elif reward > 0:
            reward_index = 1
        elif reward < 0:
            reward_index = 2

        y = 150

        labels = ["0", "+", "-"]

        for i in range(len(rp_c)):
            width = rp_c[i] * 100

            if i == reward_index:
                color = RED
            else:
                color = WHITE
            pygame.draw.rect(self.surface, color, (start_x + 15, y, width, 10))
            self.draw_text(labels[i], start_x, y - 1, color)
            y += 20

        self.draw_center_text("RP", start_x + 100 / 2, y)

    def show_reward(self):
        self.draw_text("REWARD: {}".format(int(self.episode_reward)), 310, 10)

    def process(self, sess):
        last_action = self.environment.last_action
        last_reward = np.clip(self.environment.last_reward, -1, 1)
        last_action_reward = ExperienceFrame.concat_action_and_reward(
            last_action, self.action_size, last_reward,
            self.environment.last_state)

        if not flags.use_pixel_change:
            pi_values, v_value = self.global_network.run_base_policy_and_value(
                sess, self.environment.last_state, last_action_reward)
        else:
            pi_values, v_value, pc_q = self.global_network.run_base_policy_value_pc_q(
                sess, self.environment.last_state, last_action_reward)
        self.value_history.add_value(v_value)

        action = self.choose_action(pi_values)
        state, reward, terminal, pixel_change = self.environment.process(
            action)
        self.episode_reward += reward

        if terminal:
            self.environment.reset()
            self.episode_reward = 0

        self.show_image(state['image'])
        self.show_policy(pi_values)
        self.show_value()
        self.show_reward()

        if flags.use_pixel_change:
            self.show_pixel_change(pixel_change, 100, 0, 3.0, "PC")
            self.show_pixel_change(pc_q[:, :, action], 200, 0, 0.4, "PC Q")

        if flags.use_reward_prediction:
            if self.state_history.is_full:
                rp_c = self.global_network.run_rp_c(sess,
                                                    self.state_history.states)
                self.show_reward_prediction(rp_c, reward)

        self.state_history.add_state(state)

    def get_frame(self):
        data = self.surface.get_buffer().raw
        return data
예제 #3
0
class Trainer(object):
    def __init__(self, thread_index, global_network, initial_learning_rate,
                 learning_rate_input, grad_applier, env_type, env_name,
                 use_pixel_change, use_value_replay, use_reward_prediction,
                 pixel_change_lambda, entropy_beta, local_t_max, gamma,
                 gamma_pc, experience_history_size, max_global_time_step,
                 device):

        self.thread_index = thread_index
        self.learning_rate_input = learning_rate_input
        self.env_type = env_type
        self.env_name = env_name
        self.use_pixel_change = use_pixel_change
        self.use_value_replay = use_value_replay
        self.use_reward_prediction = use_reward_prediction
        self.local_t_max = local_t_max
        self.gamma = gamma
        self.gamma_pc = gamma_pc
        self.experience_history_size = experience_history_size
        self.max_global_time_step = max_global_time_step
        self.action_size = Environment.get_action_size(env_type, env_name)

        self.local_network = UnrealModel(self.action_size, thread_index,
                                         use_pixel_change, use_value_replay,
                                         use_reward_prediction,
                                         pixel_change_lambda, entropy_beta,
                                         device)
        self.local_network.prepare_loss()

        self.apply_gradients = grad_applier.minimize_local(
            self.local_network.total_loss, global_network.get_vars(),
            self.local_network.get_vars())

        self.sync = self.local_network.sync_from(global_network)
        self.experience = Experience(self.experience_history_size)
        self.local_t = 0
        self.initial_learning_rate = initial_learning_rate
        self.episode_reward = 0
        # For log output
        self.prev_local_t = 0

    def prepare(self):
        self.environment = Environment.create_environment(
            self.env_type, self.env_name)

    def stop(self):
        self.environment.stop()

    def _anneal_learning_rate(self, global_time_step):
        learning_rate = self.initial_learning_rate * (
            self.max_global_time_step -
            global_time_step) / self.max_global_time_step
        if learning_rate < 0.0:
            learning_rate = 0.0
        return learning_rate

    def choose_action(self, pi_values):
        return np.random.choice(range(len(pi_values)), p=pi_values)

    def _record_score(self, sess, summary_writer, summary_op, score_input,
                      score, global_t):
        summary_str = sess.run(summary_op, feed_dict={score_input: score})
        summary_writer.add_summary(summary_str, global_t)
        summary_writer.flush()

    def set_start_time(self, start_time):
        self.start_time = start_time

    def _fill_experience(self, sess):
        """
    Fill experience buffer until buffer is full.
    """
        prev_state = self.environment.last_state
        last_action = self.environment.last_action
        last_reward = self.environment.last_reward
        last_action_reward = ExperienceFrame.concat_action_and_reward(
            last_action, self.action_size, last_reward)

        pi_, _ = self.local_network.run_base_policy_and_value(
            sess, self.environment.last_state, last_action_reward)
        action = self.choose_action(pi_)

        new_state, reward, terminal, pixel_change = self.environment.process(
            action)

        frame = ExperienceFrame(prev_state, reward, action, terminal,
                                pixel_change, last_action, last_reward)
        self.experience.add_frame(frame)

        if terminal:
            self.environment.reset()
        if self.experience.is_full():
            self.environment.reset()
            print("Replay buffer filled")

    def _print_log(self, global_t):
        if (self.thread_index == 0) and (self.local_t - self.prev_local_t >=
                                         PERFORMANCE_LOG_INTERVAL):
            self.prev_local_t += PERFORMANCE_LOG_INTERVAL
            elapsed_time = time.time() - self.start_time
            steps_per_sec = global_t / elapsed_time
            print(
                "### Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour"
                .format(global_t, elapsed_time, steps_per_sec,
                        steps_per_sec * 3600 / 1000000.))

    def _process_base(self, sess, global_t, summary_writer, summary_op,
                      score_input):
        # [Base A3C]
        states = []
        last_action_rewards = []
        actions = []
        rewards = []
        values = []

        terminal_end = False

        start_lstm_state = self.local_network.base_lstm_state_out

        # t_max times loop
        for _ in range(self.local_t_max):
            # Prepare last action reward
            last_action = self.environment.last_action
            last_reward = self.environment.last_reward
            last_action_reward = ExperienceFrame.concat_action_and_reward(
                last_action, self.action_size, last_reward)
            #Modify Last State - with attention
            pi_, value_ = self.local_network.run_base_policy_and_value(
                sess, self.environment.last_state, last_action_reward)

            action = self.choose_action(pi_)

            states.append(self.environment.last_state)
            last_action_rewards.append(last_action_reward)
            actions.append(action)
            values.append(value_)

            if (self.thread_index == 0) and (self.local_t % LOG_INTERVAL == 0):
                print("pi={}".format(pi_))
                print(" V={}".format(value_))

            prev_state = self.environment.last_state

            # Process game
            new_state, reward, terminal, pixel_change = self.environment.process(
                action)  #Modify New State - with attention
            frame = ExperienceFrame(prev_state, reward, action, terminal,
                                    pixel_change, last_action, last_reward)

            # Store to experience
            self.experience.add_frame(frame)

            self.episode_reward += reward

            rewards.append(reward)

            self.local_t += 1

            if terminal:
                terminal_end = True
                print("score={}".format(self.episode_reward))

                self._record_score(sess, summary_writer, summary_op,
                                   score_input, self.episode_reward, global_t)

                self.episode_reward = 0
                self.environment.reset()
                self.local_network.reset_state()
                break

        R = 0.0
        if not terminal_end:
            R = self.local_network.run_base_value(
                sess, new_state, frame.get_action_reward(self.action_size))

        actions.reverse()
        states.reverse()
        rewards.reverse()
        values.reverse()

        batch_si = []
        batch_a = []
        batch_adv = []
        batch_R = []

        for (ai, ri, si, Vi) in zip(actions, rewards, states, values):
            R = ri + self.gamma * R
            adv = R - Vi
            a = np.zeros([self.action_size])
            a[ai] = 1.0

            batch_si.append(si)
            batch_a.append(a)
            batch_adv.append(adv)
            batch_R.append(R)

        batch_si.reverse()
        batch_a.reverse()
        batch_adv.reverse()
        batch_R.reverse()

        return batch_si, last_action_rewards, batch_a, batch_adv, batch_R, start_lstm_state

    def _process_pc(self, sess):
        # [pixel change]
        # Sample 20+1 frame (+1 for last next state)
        pc_experience_frames = self.experience.sample_sequence(
            self.local_t_max + 1)
        # Reverse sequence to calculate from the last
        pc_experience_frames.reverse()

        batch_pc_si = []
        batch_pc_a = []
        batch_pc_R = []
        batch_pc_last_action_reward = []

        pc_R = np.zeros([20, 20], dtype=np.float32)
        if not pc_experience_frames[1].terminal:
            pc_R = self.local_network.run_pc_q_max(
                sess, pc_experience_frames[0].state,
                pc_experience_frames[0].get_last_action_reward(
                    self.action_size))

        for frame in pc_experience_frames[1:]:
            pc_R = frame.pixel_change + self.gamma_pc * pc_R
            a = np.zeros([self.action_size])
            a[frame.action] = 1.0
            last_action_reward = frame.get_last_action_reward(self.action_size)

            batch_pc_si.append(frame.state)
            batch_pc_a.append(a)
            batch_pc_R.append(pc_R)
            batch_pc_last_action_reward.append(last_action_reward)

        batch_pc_si.reverse()
        batch_pc_a.reverse()
        batch_pc_R.reverse()
        batch_pc_last_action_reward.reverse()

        return batch_pc_si, batch_pc_last_action_reward, batch_pc_a, batch_pc_R

    def _process_vr(self, sess):
        # [Value replay]
        # Sample 20+1 frame (+1 for last next state)
        vr_experience_frames = self.experience.sample_sequence(
            self.local_t_max + 1)
        # Reverse sequence to calculate from the last
        vr_experience_frames.reverse()

        batch_vr_si = []
        batch_vr_R = []
        batch_vr_last_action_reward = []

        vr_R = 0.0
        if not vr_experience_frames[1].terminal:
            vr_R = self.local_network.run_vr_value(
                sess, vr_experience_frames[0].state,
                vr_experience_frames[0].get_last_action_reward(
                    self.action_size))

        # t_max times loop
        for frame in vr_experience_frames[1:]:
            vr_R = frame.reward + self.gamma * vr_R
            batch_vr_si.append(frame.state)
            batch_vr_R.append(vr_R)
            last_action_reward = frame.get_last_action_reward(self.action_size)
            batch_vr_last_action_reward.append(last_action_reward)

        batch_vr_si.reverse()
        batch_vr_R.reverse()
        batch_vr_last_action_reward.reverse()

        return batch_vr_si, batch_vr_last_action_reward, batch_vr_R

    def _process_rp(self):
        # [Reward prediction]
        rp_experience_frames = self.experience.sample_rp_sequence()
        # 4 frames

        batch_rp_si = []
        batch_rp_c = []

        for i in range(3):
            batch_rp_si.append(rp_experience_frames[i].state)

        # one hot vector for target reward
        r = rp_experience_frames[3].reward
        rp_c = [0.0, 0.0, 0.0]
        if r == 0:
            rp_c[0] = 1.0  # zero
        elif r > 0:
            rp_c[1] = 1.0  # positive
        else:
            rp_c[2] = 1.0  # negative
        batch_rp_c.append(rp_c)
        return batch_rp_si, batch_rp_c

    def process(self, sess, global_t, summary_writer, summary_op, score_input):
        # Fill experience replay buffer
        if not self.experience.is_full():
            self._fill_experience(sess)
            return 0

        start_local_t = self.local_t

        cur_learning_rate = self._anneal_learning_rate(global_t)

        # Copy weights from shared to local
        sess.run(self.sync)

        # [Base]
        batch_si, batch_last_action_rewards, batch_a, batch_adv, batch_R, start_lstm_state = \
              self._process_base(sess,
                                 global_t,
                                 summary_writer,
                                 summary_op,
                                 score_input)
        feed_dict = {
            self.local_network.base_input: batch_si,
            self.local_network.base_last_action_reward_input:
            batch_last_action_rewards,
            self.local_network.base_a: batch_a,
            self.local_network.base_adv: batch_adv,
            self.local_network.base_r: batch_R,
            self.local_network.base_initial_lstm_state: start_lstm_state,
            # [common]
            self.learning_rate_input: cur_learning_rate
        }

        # [Pixel change]
        if self.use_pixel_change:
            batch_pc_si, batch_pc_last_action_reward, batch_pc_a, batch_pc_R = self._process_pc(
                sess)

            pc_feed_dict = {
                self.local_network.pc_input: batch_pc_si,
                self.local_network.pc_last_action_reward_input:
                batch_pc_last_action_reward,
                self.local_network.pc_a: batch_pc_a,
                self.local_network.pc_r: batch_pc_R
            }
            feed_dict.update(pc_feed_dict)

        # [Value replay]
        if self.use_value_replay:
            batch_vr_si, batch_vr_last_action_reward, batch_vr_R = self._process_vr(
                sess)

            vr_feed_dict = {
                self.local_network.vr_input: batch_vr_si,
                self.local_network.vr_last_action_reward_input:
                batch_vr_last_action_reward,
                self.local_network.vr_r: batch_vr_R
            }
            feed_dict.update(vr_feed_dict)

        # [Reward prediction]
        if self.use_reward_prediction:
            batch_rp_si, batch_rp_c = self._process_rp()
            rp_feed_dict = {
                self.local_network.rp_input: batch_rp_si,
                self.local_network.rp_c_target: batch_rp_c
            }
            feed_dict.update(rp_feed_dict)

        # Calculate gradients and copy them to global network.
        sess.run(self.apply_gradients, feed_dict=feed_dict)

        self._print_log(global_t)

        # Return advanced local step size
        diff_local_t = self.local_t - start_local_t
        return diff_local_t
예제 #4
0
파일: display.py 프로젝트: kvas7andy/unreal
class Display(object):
    def __init__(self, display_size):
        pygame.init()

        self.surface = pygame.display.set_mode(display_size, 0, 24)
        name = 'UNREAL' if flags.segnet == 0 else "A3C ErfNet"
        pygame.display.set_caption(name)

        env_config = sim_config.get(flags.env_name)
        self.image_shape = [
            env_config.get('height', 88),
            env_config.get('width', 88)
        ]
        segnet_param_dict = {'segnet_mode': flags.segnet}
        is_training = tf.placeholder(tf.bool, name="training")
        map_file = env_config.get('objecttypes_file', '../../objectTypes.csv')
        self.label_mapping = pd.read_csv(map_file, sep=',', header=0)
        self.get_col_index()

        self.action_size = Environment.get_action_size(flags.env_type,
                                                       flags.env_name)
        self.objective_size = Environment.get_objective_size(
            flags.env_type, flags.env_name)
        self.global_network = UnrealModel(self.action_size,
                                          self.objective_size,
                                          -1,
                                          flags.use_lstm,
                                          flags.use_pixel_change,
                                          flags.use_value_replay,
                                          flags.use_reward_prediction,
                                          0.0,
                                          0.0,
                                          "/gpu:0",
                                          segnet_param_dict=segnet_param_dict,
                                          image_shape=self.image_shape,
                                          is_training=is_training,
                                          n_classes=flags.n_classes,
                                          segnet_lambda=flags.segnet_lambda,
                                          dropout=flags.dropout,
                                          for_display=True)
        self.environment = Environment.create_environment(
            flags.env_type,
            flags.env_name,
            flags.termination_time_sec,
            env_args={
                'episode_schedule': flags.split,
                'log_action_trace': flags.log_action_trace,
                'max_states_per_scene': flags.episodes_per_scene,
                'episodes_per_scene_test': flags.episodes_per_scene
            })
        self.font = pygame.font.SysFont(None, 20)
        self.value_history = ValueHistory()
        self.state_history = StateHistory()
        self.episode_reward = 0

    def update(self, sess):
        self.surface.fill(BLACK)
        self.process(sess)
        pygame.display.update()

    def choose_action(self, pi_values):
        return np.random.choice(range(len(pi_values)), p=pi_values)

    def scale_image(self, image, scale):
        return image.repeat(scale, axis=0).repeat(scale, axis=1)

    def draw_text(self, str, left, top, color=WHITE):
        text = self.font.render(str, True, color, BLACK)
        text_rect = text.get_rect()
        text_rect.left = left
        text_rect.top = top
        self.surface.blit(text, text_rect)

    def draw_center_text(self, str, center_x, top):
        text = self.font.render(str, True, WHITE, BLACK)
        text_rect = text.get_rect()
        text_rect.centerx = center_x
        text_rect.top = top
        self.surface.blit(text, text_rect)

    def show_pixel_change(self, pixel_change, left, top, rate, label):
        """
    Show pixel change
    """
        if "PC" in label:
            pixel_change_ = np.clip(pixel_change * 255.0 * rate, 0.0, 255.0)
            data = pixel_change_.astype(np.uint8)
            data = np.stack([data for _ in range(3)], axis=2)
            data = self.scale_image(data, 4)
            #print("PC shape", data.shape)
            image = pygame.image.frombuffer(data, (20 * 4, 20 * 4), 'RGB')
        else:
            pixel_change = self.scale_image(pixel_change, 2)
            #print("Preds shape", pixel_change.shape)
            image = pygame.image.frombuffer(pixel_change.astype(
                np.uint8), (self.image_shape[0] * 2, self.image_shape[1] * 2),
                                            'RGB')
        self.surface.blit(image, (2 * left + 16 + 8, 2 * top + 16 + 8))
        self.draw_center_text(label, 2 * left + 200 / 2, 2 * top + 200)

    def show_policy(self, pi):
        """
    Show action probability.
    """
        start_x = 10

        y = 150

        for i in range(len(pi)):
            width = pi[i] * 100
            pygame.draw.rect(self.surface, WHITE,
                             (2 * start_x, 2 * y, 2 * width, 2 * 10))
            y += 20
        self.draw_center_text("PI", 2 * 50, 2 * y)

    def show_image(self, state):
        """
    Show input image
    """
        state_ = state * 255.0
        data = state_.astype(np.uint8)
        data = self.scale_image(data, 2)
        image = pygame.image.frombuffer(
            data, (self.image_shape[0] * 2, self.image_shape[1] * 2), 'RGB')
        self.surface.blit(image, (8 * 2, 8 * 2))
        self.draw_center_text("input", 2 * 50, 2 * 100)

    def show_value(self):
        if self.value_history.is_empty:
            return

        min_v = float("inf")
        max_v = float("-inf")

        values = self.value_history.values

        for v in values:
            min_v = min(min_v, v)
            max_v = max(max_v, v)

        top = 150 * 2
        left = 150 * 2
        width = 100 * 2
        height = 100 * 2
        bottom = top + width
        right = left + height

        d = max_v - min_v
        last_r = 0.0
        for i, v in enumerate(values):
            r = (v - min_v) / d
            if i > 0:
                x0 = i - 1 + left
                x1 = i + left
                y0 = bottom - last_r * height
                y1 = bottom - r * height
                pygame.draw.line(self.surface, BLUE, (x0, y0), (x1, y1), 1)
            last_r = r

        pygame.draw.line(self.surface, WHITE, (left, top), (left, bottom), 1)
        pygame.draw.line(self.surface, WHITE, (right, top), (right, bottom), 1)
        pygame.draw.line(self.surface, WHITE, (left, top), (right, top), 1)
        pygame.draw.line(self.surface, WHITE, (left, bottom), (right, bottom),
                         1)

        self.draw_center_text("V", left + width / 2, bottom + 10)

    def show_reward_prediction(self, rp_c, reward):
        start_x = 310
        reward_index = 0
        if reward == 0:
            reward_index = 0
        elif reward > 0:
            reward_index = 1
        elif reward < 0:
            reward_index = 2

        y = 150

        labels = ["0", "+", "-"]

        for i in range(len(rp_c)):
            width = rp_c[i] * 100

            if i == reward_index:
                color = RED
            else:
                color = WHITE
            pygame.draw.rect(self.surface, color,
                             (2 * start_x + 2 * 15, 2 * y, 2 * width, 2 * 10))
            self.draw_text(labels[i], 2 * start_x, 2 * y - 2 * 1, color)
            y += 20

        self.draw_center_text("RP", 2 * start_x + 2 * 100 / 2, y)

    def show_reward(self):
        self.draw_text("REWARD: {:.4}".format(float(self.episode_reward)), 300,
                       2 * 10)

    def process(self, sess):
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])
        #sess.run(tf.initialize_all_variables())

        last_action = self.environment.last_action
        last_reward = self.environment.last_reward
        last_action_reward = ExperienceFrame.concat_action_and_reward(
            last_action, self.action_size, last_reward,
            self.environment.last_state)
        preds = None
        mode = "segnet" if flags.segnet >= 2 else ""
        mode = ""  #don't want preds
        if not flags.use_pixel_change:
            pi_values, v_value, preds = self.global_network.run_base_policy_and_value(
                sess,
                self.environment.last_state,
                last_action_reward,
                mode=mode)
        else:
            pi_values, v_value, pc_q = self.global_network.run_base_policy_value_pc_q(
                sess, self.environment.last_state, last_action_reward)

        #print(preds)
        self.value_history.add_value(v_value)

        prev_state = self.environment.last_state

        action = self.choose_action(pi_values)
        state, reward, terminal, pixel_change = self.environment.process(
            action)
        self.episode_reward += reward

        if terminal:
            self.environment.reset()
            self.episode_reward = 0

        self.show_image(state['image'])
        self.show_policy(pi_values)
        self.show_value()
        self.show_reward()

        if not flags.use_pixel_change:
            if preds is not None:
                self.show_pixel_change(self.label_to_rgb(preds), 100, 0, 3.0,
                                       "Preds")
                self.show_pixel_change(self.label_to_rgb(state['objectType']),
                                       200, 0, 0.4, "Segm Mask")
        else:
            self.show_pixel_change(pixel_change, 100, 0, 3.0, "PC")
            self.show_pixel_change(pc_q[:, :, action], 200, 0, 0.4, "PC Q")

        if flags.use_reward_prediction:
            if self.state_history.is_full:
                rp_c = self.global_network.run_rp_c(sess,
                                                    self.state_history.states)
                self.show_reward_prediction(rp_c, reward)

        self.state_history.add_state(state)

    def get_frame(self):
        data = self.surface.get_buffer().raw
        return data

    def get_col_index(self):
        ind_col = self.label_mapping[["index", "color"]].values
        index = ind_col[:, 0].astype(np.int)
        self.index, ind = np.unique(index, return_index=True)
        self.col = np.array([[int(x) for x in col.split('_')]
                             for col in ind_col[ind, 1]])

    def label_to_rgb(self, labels):
        #print(self.col)
        rgb_img = self.col[np.where(self.index[np.newaxis, :] == labels.ravel(
        )[:, np.newaxis])[1]].reshape(labels.shape + (3, ))
        return rgb_img
예제 #5
0
class Trainer(object):
    def __init__(self, thread_index, global_network, initial_learning_rate,
                 learning_rate_input, grad_applier, env_type, env_name,
                 use_lstm, use_pixel_change, use_value_replay,
                 use_reward_prediction, pixel_change_lambda, entropy_beta,
                 local_t_max, n_step_TD, gamma, gamma_pc,
                 experience_history_size, max_global_time_step, device,
                 segnet_param_dict, image_shape, is_training, n_classes,
                 random_state, termination_time, segnet_lambda, dropout):

        self.thread_index = thread_index
        self.learning_rate_input = learning_rate_input
        self.env_type = env_type
        self.env_name = env_name
        self.use_lstm = use_lstm
        self.use_pixel_change = use_pixel_change
        self.use_value_replay = use_value_replay
        self.use_reward_prediction = use_reward_prediction
        self.local_t_max = local_t_max
        self.n_step_TD = n_step_TD
        self.gamma = gamma
        self.gamma_pc = gamma_pc
        self.experience_history_size = experience_history_size
        self.max_global_time_step = max_global_time_step
        self.action_size = Environment.get_action_size(env_type, env_name)
        self.objective_size = Environment.get_objective_size(
            env_type, env_name)

        self.segnet_param_dict = segnet_param_dict
        self.segnet_mode = self.segnet_param_dict.get("segnet_mode", None)

        self.is_training = is_training
        self.n_classes = n_classes
        self.segnet_lambda = segnet_lambda

        self.run_metadata = tf.RunMetadata()
        self.many_runs_timeline = TimeLiner()

        self.random_state = random_state
        self.termination_time = termination_time
        self.dropout = dropout

        try:
            self.local_network = UnrealModel(
                self.action_size,
                self.objective_size,
                thread_index,
                use_lstm,
                use_pixel_change,
                use_value_replay,
                use_reward_prediction,
                pixel_change_lambda,
                entropy_beta,
                device,
                segnet_param_dict=self.segnet_param_dict,
                image_shape=image_shape,
                is_training=is_training,
                n_classes=n_classes,
                segnet_lambda=self.segnet_lambda,
                dropout=dropout)

            self.local_network.prepare_loss()

            self.apply_gradients = grad_applier.minimize_local(
                self.local_network.total_loss, global_network.get_vars(),
                self.local_network.get_vars(), self.thread_index)

            self.sync = self.local_network.sync_from(global_network)
            self.experience = Experience(self.experience_history_size,
                                         random_state=self.random_state)
            self.local_t = 0
            self.initial_learning_rate = initial_learning_rate
            self.episode_reward = 0
            # For log output
            self.prev_local_t = -1
            self.prev_local_t_loss = 0
            self.sr_size = 50
            self.success_rates = deque(maxlen=self.sr_size)
        except Exception as e:
            print(str(e))  #, flush=True)
            raise Exception(
                "Problem in Trainer {} initialization".format(thread_index))

    def prepare(self, termination_time=50.0, termination_dist_value=-10.0):
        self.environment = Environment.create_environment(
            self.env_type,
            self.env_name,
            self.termination_time,
            thread_index=self.thread_index)

    def stop(self):
        self.environment.stop()

    def _anneal_learning_rate(self, global_time_step):
        learning_rate = self.initial_learning_rate * (
            self.max_global_time_step -
            global_time_step) / self.max_global_time_step
        if learning_rate < 0.0:
            learning_rate = 0.0
        return learning_rate

    def choose_action(self, pi_values):
        return self.random_state.choice(len(pi_values), p=pi_values)

    def _record_one(self, sess, summary_writer, summary_op, score_input, score,
                    global_t):
        if self.thread_index >= 0:
            summary_str = sess.run(summary_op, feed_dict={score_input: score})
            for sum_wr in summary_writer:
                sum_wr.add_summary(summary_str, global_t)

    def _record_all(self, sess, summary_writer, summary_op, dict_input,
                    dict_eval, global_t):
        if self.thread_index >= 0:
            assert set(dict_input.keys()) == set(dict_eval.keys()), print(
                dict_input.keys(), dict_eval.keys())

            feed_dict = {}
            for key in dict_input.keys():
                feed_dict.update({dict_input[key]: dict_eval[key]})
            summary_str = sess.run(summary_op, feed_dict=feed_dict)
            for sum_wr in summary_writer:
                sum_wr.add_summary(summary_str, global_t)

    def set_start_time(self, start_time):
        self.start_time = start_time

    def _fill_experience(self, sess):
        """
    Fill experience buffer until buffer is full.
    """
        #print("Start experience filling", flush=True)
        prev_state = self.environment.last_state
        last_action = self.environment.last_action
        last_reward = self.environment.last_reward
        last_action_reward = ExperienceFrame.concat_action_and_reward(
            last_action, self.action_size, last_reward, prev_state)

        #print("Local network run base policy, value!", flush=True)
        pi_, _, _ = self.local_network.run_base_policy_and_value(
            sess, self.environment.last_state, last_action_reward)
        action = self.choose_action(pi_)

        new_state, reward, terminal, pixel_change = self.environment.process(
            action, flag=0)

        frame = ExperienceFrame(
            {
                key: val
                for key, val in prev_state.items() if 'objectType' not in key
            }, reward, action, terminal, pixel_change, last_action,
            last_reward)
        self.experience.add_frame(frame)

        if terminal:
            self.environment.reset()
        if self.experience.is_full():
            self.environment.reset()
            print("Replay buffer filled")

    def _print_log(self, global_t):
        if (self.thread_index == 0) and (self.local_t - self.prev_local_t >=
                                         PERFORMANCE_LOG_INTERVAL):
            self.prev_local_t += PERFORMANCE_LOG_INTERVAL
            elapsed_time = time.time() - self.start_time
            steps_per_sec = global_t / elapsed_time
            print(
                "### Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour"
                .format(global_t, elapsed_time, steps_per_sec,
                        steps_per_sec * 3600 / 1000000.))  #, flush=True)
            # print("### Experience : {}".format(self.experience.get_debug_string()))

    def _process_base(self, sess, global_t, summary_writer, summary_op_dict,
                      summary_dict):  #, losses_input):
        # [Base A3C]
        states = []
        last_action_rewards = []
        actions = []
        rewards = []
        values = []

        terminal_end = False

        start_lstm_state = None
        if self.use_lstm:
            start_lstm_state = self.local_network.base_lstm_state_out

        mode = "segnet" if self.segnet_mode >= 2 else ""
        # t_max times loop
        flag = 0
        for _ in range(self.n_step_TD):
            # Prepare last action reward
            last_action = self.environment.last_action
            last_reward = self.environment.last_reward
            last_action_reward = ExperienceFrame.concat_action_and_reward(
                last_action, self.action_size, last_reward,
                self.environment.last_state)

            pi_, value_, losses = self.local_network.run_base_policy_and_value(
                sess, self.environment.last_state, last_action_reward, mode)

            action = self.choose_action(pi_)

            states.append(self.environment.last_state)
            last_action_rewards.append(last_action_reward)
            actions.append(action)
            values.append(value_)

            if (self.thread_index == 0) and (self.local_t % LOG_INTERVAL == 0):
                print("Trainer {}>>> Local step {}:".format(
                    self.thread_index, self.local_t))
                print("Trainer {}>>> pi={}".format(self.thread_index, pi_))
                print("Trainer {}>>> V={}".format(self.thread_index, value_))
                flag = 1

            prev_state = self.environment.last_state

            # Process game
            new_state, reward, terminal, pixel_change = self.environment.process(
                action, flag=flag)
            frame = ExperienceFrame(
                {
                    key: val
                    for key, val in prev_state.items()
                    if 'objectType' not in key
                }, reward, action, terminal, pixel_change, last_action,
                last_reward)

            # Store to experience
            self.experience.add_frame(frame)

            # Use to know about Experience collection
            #print(self.experience.get_debug_string())

            self.episode_reward += reward
            rewards.append(reward)
            self.local_t += 1

            if terminal:
                terminal_end = True
                print("Trainer {}>>> score={}".format(
                    self.thread_index, self.episode_reward))  #, flush=True)

                summary_dict['values'].update(
                    {'score_input': self.episode_reward})

                success = 1 if self.environment._last_full_state[
                    "success"] else 0
                #print("Type:", type(self.environment._last_full_state["success"]), len(self.success_rates), success)
                self.success_rates.append(success)
                summary_dict['values'].update({
                    'sr_input':
                    np.mean(self.success_rates)
                    if len(self.success_rates) == self.sr_size else 0
                })

                self.episode_reward = 0
                self.environment.reset()
                self.local_network.reset_state()
                if flag:
                    flag = 0
                break

        R = 0.0
        if not terminal_end:
            R = self.local_network.run_base_value(
                sess, new_state, frame.get_action_reward(self.action_size))

        actions.reverse()
        states.reverse()
        rewards.reverse()
        values.reverse()

        batch_si = []
        batch_a = []
        batch_adv = []
        batch_R = []
        batch_sobjT = []

        for (ai, ri, si, Vi) in zip(actions, rewards, states, values):
            R = ri + self.gamma * R
            adv = R - Vi
            a = np.zeros([self.action_size])
            a[ai] = 1.0

            batch_si.append(si['image'])
            batch_a.append(a)
            batch_adv.append(adv)
            batch_R.append(R)
            if self.segnet_param_dict["segnet_mode"] >= 2:
                batch_sobjT.append(si['objectType'])

        batch_si.reverse()
        batch_a.reverse()
        batch_adv.reverse()
        batch_R.reverse()
        batch_sobjT.reverse()

        #print(np.unique(batch_sobjT))

        ## HERE Mathematical Error A3C: only last values should be used for base/ or aggregate with last made

        return batch_si, batch_sobjT, last_action_rewards, batch_a, batch_adv, batch_R, start_lstm_state

    def _process_pc(self, sess):
        # [pixel change]
        # Sample 20+1 frame (+1 for last next state)
        #print(">>> Process run!", flush=True)
        pc_experience_frames = self.experience.sample_sequence(
            self.local_t_max + 1)
        # Reverse sequence to calculate from the last
        # pc_experience_frames.reverse()
        pc_experience_frames = pc_experience_frames[::-1]
        #print(">>> Process ran!", flush=True)

        batch_pc_si = []
        batch_pc_a = []
        batch_pc_R = []
        batch_pc_last_action_reward = []

        pc_R = np.zeros([20, 20], dtype=np.float32)
        if not pc_experience_frames[1].terminal:
            pc_R = self.local_network.run_pc_q_max(
                sess, pc_experience_frames[0].state,
                pc_experience_frames[0].get_last_action_reward(
                    self.action_size))

        #print(">>> Process run!", flush=True)

        for frame in pc_experience_frames[1:]:

            pc_R = frame.pixel_change + self.gamma_pc * pc_R
            a = np.zeros([self.action_size])
            a[frame.action] = 1.0
            last_action_reward = frame.get_last_action_reward(self.action_size)

            batch_pc_si.append(frame.state['image'])
            batch_pc_a.append(a)
            batch_pc_R.append(pc_R)
            batch_pc_last_action_reward.append(last_action_reward)

        batch_pc_si.reverse()
        batch_pc_a.reverse()
        batch_pc_R.reverse()
        batch_pc_last_action_reward.reverse()

        #print(">>> Process ended!", flush=True)
        return batch_pc_si, batch_pc_last_action_reward, batch_pc_a, batch_pc_R

    def _process_vr(self, sess):
        # [Value replay]
        # Sample 20+1 frame (+1 for last next state)
        vr_experience_frames = self.experience.sample_sequence(
            self.local_t_max + 1)
        # Reverse sequence to calculate from the last
        vr_experience_frames.reverse()

        batch_vr_si = []
        batch_vr_R = []
        batch_vr_last_action_reward = []

        vr_R = 0.0
        if not vr_experience_frames[1].terminal:
            vr_R = self.local_network.run_vr_value(
                sess, vr_experience_frames[0].state,
                vr_experience_frames[0].get_last_action_reward(
                    self.action_size))

        # t_max times loop
        for frame in vr_experience_frames[1:]:
            vr_R = frame.reward + self.gamma * vr_R
            batch_vr_si.append(frame.state['image'])
            batch_vr_R.append(vr_R)
            last_action_reward = frame.get_last_action_reward(self.action_size)
            batch_vr_last_action_reward.append(last_action_reward)

        batch_vr_si.reverse()
        batch_vr_R.reverse()
        batch_vr_last_action_reward.reverse()

        return batch_vr_si, batch_vr_last_action_reward, batch_vr_R

    def _process_rp(self):
        # [Reward prediction]
        rp_experience_frames = self.experience.sample_rp_sequence()
        # 4 frames

        batch_rp_si = []
        batch_rp_c = []

        for i in range(3):
            batch_rp_si.append(rp_experience_frames[i].state['image'])

        # one hot vector for target reward
        r = rp_experience_frames[3].reward
        rp_c = [0.0, 0.0, 0.0]
        if -1e-10 < r < 1e-10:
            rp_c[0] = 1.0  # zero
        elif r > 0:
            rp_c[1] = 1.0  # positive
        else:
            rp_c[2] = 1.0  # negative
        batch_rp_c.append(rp_c)
        return batch_rp_si, batch_rp_c

    def process(self, sess, global_t, summary_writer, summary_op_dict,
                score_input, sr_input, eval_input, entropy_input,
                term_global_t, losses_input):

        if self.prev_local_t == -1 and self.segnet_mode >= 2:
            self.prev_local_t = 0
            sess.run(self.local_network.reset_evaluation_vars)
        # Fill experience replay buffer
        #print("Inside train process of thread!", flush=True)
        if not self.experience.is_full():
            self._fill_experience(sess)
            return 0, None

        start_local_t = self.local_t
        episode_score = None

        cur_learning_rate = self._anneal_learning_rate(global_t)

        #print("Weights copying!", flush=True)
        # Copy weights from shared to local
        sess.run(self.sync)
        #print("Weights copied successfully!", flush=True)

        summary_dict = {'placeholders': {}, 'values': {}}
        summary_dict['placeholders'].update(losses_input)

        # [Base]
        #print("[Base]", flush=True)
        batch_si, batch_sobjT, batch_last_action_rewards, batch_a, batch_adv, batch_R, start_lstm_state,  = \
              self._process_base(sess,
                                 global_t,
                                 summary_writer,
                                 summary_op_dict,
                                 summary_dict)
        if summary_dict['values'].get('score_input', None) is not None:
            self._record_one(sess, summary_writer,
                             summary_op_dict['score_input'], score_input,
                             summary_dict['values']['score_input'], global_t)
            self._record_one(sess, summary_writer, summary_op_dict['sr_input'],
                             sr_input, summary_dict['values']['sr_input'],
                             global_t)
            #self._record_one(sess, summary_writer, summary_op_dict['term_global_t'], term_global_t,
            #                 global_t, global_t)
            #summary_writer[0].flush()
            # summary_writer[1].flush()
            # Return advanced local step size
            episode_score = summary_dict['values'].get('score_input', None)
            summary_dict['values'] = {}

        feed_dict = {
            self.local_network.base_input: batch_si,
            self.local_network.base_last_action_reward_input:
            batch_last_action_rewards,
            self.local_network.base_a: batch_a,
            self.local_network.base_adv: batch_adv,
            self.local_network.base_r: batch_R,
            # [common]
            self.learning_rate_input: cur_learning_rate,
            self.is_training: True
        }

        if self.use_lstm:
            feed_dict[
                self.local_network.base_initial_lstm_state] = start_lstm_state

        if self.segnet_param_dict["segnet_mode"] >= 2:
            feed_dict[self.local_network.base_segm_mask] = batch_sobjT

        #print("[Pixel change]", flush=True)
        # [Pixel change]
        if self.use_pixel_change:
            batch_pc_si, batch_pc_last_action_reward, batch_pc_a, batch_pc_R = self._process_pc(
                sess)

            pc_feed_dict = {
                self.local_network.pc_input: batch_pc_si,
                self.local_network.pc_last_action_reward_input:
                batch_pc_last_action_reward,
                self.local_network.pc_a: batch_pc_a,
                self.local_network.pc_r: batch_pc_R
            }
            feed_dict.update(pc_feed_dict)

        #print("[Value replay]", flush=True)
        # [Value replay]
        if self.use_value_replay:
            batch_vr_si, batch_vr_last_action_reward, batch_vr_R = self._process_vr(
                sess)

            vr_feed_dict = {
                self.local_network.vr_input: batch_vr_si,
                self.local_network.vr_last_action_reward_input:
                batch_vr_last_action_reward,
                self.local_network.vr_r: batch_vr_R
            }
            feed_dict.update(vr_feed_dict)

        # [Reward prediction]
        #print("[Reward prediction]", flush=True)
        if self.use_reward_prediction:
            batch_rp_si, batch_rp_c = self._process_rp()
            rp_feed_dict = {
                self.local_network.rp_input: batch_rp_si,
                self.local_network.rp_c_target: batch_rp_c
            }
            feed_dict.update(rp_feed_dict)
            #print(len(batch_rp_c), batch_rp_c)

        grad_check = None
        #if self.local_t - self.prev_local_t_loss >= LOSS_AND_EVAL_LOG_INTERVAL:
        #  grad_check = [tf.add_check_numerics_ops()]
        #print("Applying gradients in train!", flush=True)
        # Calculate gradients and copy them to global network.
        out_list = [self.apply_gradients]
        out_list += [
            self.local_network.total_loss, self.local_network.base_loss,
            self.local_network.policy_loss, self.local_network.value_loss,
            self.local_network.entropy
        ]
        if self.segnet_mode >= 2:
            out_list += [self.local_network.decoder_loss]
            out_list += [self.local_network.regul_loss]
        if self.use_pixel_change:
            out_list += [self.local_network.pc_loss]
        if self.use_value_replay:
            out_list += [self.local_network.vr_loss]
        if self.use_reward_prediction:
            out_list += [self.local_network.rp_loss]
        if self.segnet_mode >= 2:
            out_list += [self.local_network.update_evaluation_vars]
            if self.local_t - self.prev_local_t_loss >= LOSS_AND_EVAL_LOG_INTERVAL:
                out_list += [self.local_network.evaluation]

        import time

        now = time.time()
        with tf.control_dependencies(grad_check):
            if GPU_LOG:
                return_list = sess.run(out_list,
                                       feed_dict=feed_dict,
                                       options=run_options,
                                       run_metadata=self.run_metadata)
            else:
                return_list = sess.run(out_list,
                                       feed_dict=feed_dict,
                                       options=run_options)

        if time.time() - now > 30.0:
            print(
                "Too much time on sess.run: check tensorflow")  #, flush=True)
            sys.exit(0)
            raise ValueError("More than 100 seconds update in tensorflow!")  #

        gradients_tuple, total_loss, base_loss, policy_loss, value_loss, entropy = return_list[:
                                                                                               6]
        grad_norm = gradients_tuple[1]
        return_list = return_list[6:]
        return_string = "Trainer {}>>> Total loss: {}, Base loss: {}\n".format(
            self.thread_index, total_loss, base_loss)
        return_string += "\t\tPolicy loss: {}, Value loss: {}, Grad norm: {}\nEntropy: {}\n".format(
            policy_loss, value_loss, grad_norm, entropy)
        losses_eval = {
            'all/total_loss': total_loss,
            'all/base_loss': base_loss,
            'all/policy_loss': policy_loss,
            'all/value_loss': value_loss,
            'all/loss/grad_norm': grad_norm
        }
        if self.segnet_mode >= 2:
            decoder_loss, l2_loss = return_list[:2]
            return_list = return_list[2:]
            return_string += "\t\tDecoder loss: {}, L2 weights loss: {}\n".format(
                decoder_loss, l2_loss)
            losses_eval.update({
                'all/decoder_loss': decoder_loss,
                'all/l2_weights_loss': l2_loss
            })
        if self.use_pixel_change:
            pc_loss = return_list[0]
            return_list = return_list[1:]
            return_string += "\t\tPC loss: {}\n".format(pc_loss)
            losses_eval.update({'all/pc_loss': pc_loss})
        if self.use_value_replay:
            vr_loss = return_list[0]
            return_list = return_list[1:]
            return_string += "\t\tVR loss: {}\n".format(vr_loss)
            losses_eval.update({'all/vr_loss': vr_loss})
        if self.use_reward_prediction:
            rp_loss = return_list[0]
            return_list = return_list[1:]
            return_string += "\t\tRP loss: {}\n".format(rp_loss)
            losses_eval.update({'all/rp_loss': rp_loss})
        if self.local_t - self.prev_local_t_loss >= LOSS_AND_EVAL_LOG_INTERVAL:
            if self.segnet_mode >= 2:
                return_string += "\t\tmIoU: {}\n".format(return_list[-1])

        summary_dict['values'].update(losses_eval)

        # Printing losses
        if self.local_t - self.prev_local_t_loss >= LOSS_AND_EVAL_LOG_INTERVAL:
            if self.segnet_mode >= 2:
                self._record_one(sess, summary_writer,
                                 summary_op_dict['eval_input'], eval_input,
                                 return_list[-1], global_t)
            self._record_one(sess, summary_writer, summary_op_dict['entropy'],
                             entropy_input, entropy, global_t)
            # summary_writer[0].flush()
            # summary_writer[1].flush()
            print(return_string)
            self.prev_local_t_loss += LOSS_AND_EVAL_LOG_INTERVAL

        if GPU_LOG:
            fetched_timeline = timeline.Timeline(self.run_metadata.step_stats)
            chrome_trace = fetched_timeline.generate_chrome_trace_format()
            self.many_runs_timeline.update_timeline(chrome_trace)

        self._print_log(global_t)

        #Recording score and losses
        self._record_all(sess, summary_writer, summary_op_dict['losses_input'],
                         summary_dict['placeholders'], summary_dict['values'],
                         global_t)

        # Return advanced local step size
        diff_local_t = self.local_t - start_local_t
        return diff_local_t, episode_score
예제 #6
0
class Tester(object):
    def __init__(self):
        self.img = np.zeros(shape=(HEIGHT, WIDTH, 3), dtype=np.uint8)
        self.action_size = Environment.get_action_size()
        self.global_network = UnrealModel(self.action_size,
                                          -1,
                                          "/cpu:0",
                                          for_display=True)
        self.env = Environment.create_environment()
        self.value_history = ValueHistory()
        self.state_history = StateHistory()
        self.ep_reward = 0
        self.mazemap = MazeMap()

    def process(self, sess):
        self.img = np.zeros(shape=(HEIGHT, WIDTH, 3), dtype=np.uint8)
        last_action = self.env.last_action
        last_reward = np.clip(self.env.last_reward, -1, 1)
        last_action_reward = ExperienceFrame.concat_action_and_reward(
            last_action, self.action_size, last_reward)
        if not USE_PIXEL_CHANGE:
            pi_values, v_value = self.global_network.run_base_policy_and_value(
                sess, self.env.last_state, last_action_reward)
        else:
            pi_values, v_value, pc_q = self.global_network.run_base_policy_value_pc_q(
                sess, self.env.last_state, last_action_reward)
        self.value_history.add_value(v_value)
        action = self.choose_action(pi_values)
        state, reward, terminal, pc, vtrans, vrot = self.env.process(action)
        self.state_history.add_state(state)
        self.ep_reward += reward
        self.mazemap.update(vtrans, vrot)
        if reward > 9:  # agent到达迷宫终点时,reward为10,地图需要重置
            self.mazemap.reset()
        if terminal:  # lab环境默认3600帧为一个episode而不是到达迷宫终点时给terminal信号
            self.env.reset()
            self.ep_reward = 0
            self.mazemap.reset()

        self.show_ob(state, 3, 3, "Observation")
        self.show_pc(pc, 100, 3, 3.0, "Pixel Change")
        self.show_pc(pc_q[:, :, action], 200, 3, 0.4, "PC Q")
        self.show_map(300, 3, "Maze Map")
        self.show_pi(pi_values)
        self.show_reward()
        self.show_rp()
        self.show_value()

    def choose_action(self, pi_values):
        return np.random.choice(range(len(pi_values)), p=pi_values)

    def scale_image(self, image, scale):
        return image.repeat(scale, axis=0).repeat(scale, axis=1)

    def draw_text(self, text, left, bottom, color=WHITE):
        font = cv2.FONT_HERSHEY_COMPLEX
        cv2.putText(self.img, text, (left, bottom), font, 0.35, color)

    def show_pc(self, pc, left, top, rate, label):
        pc = np.clip(pc * 255.0 * rate, 0.0, 255.0)
        data = pc.astype(np.uint8)
        data = np.stack([data for _ in range(3)], axis=2)
        data = self.scale_image(data, 4)
        h = data.shape[0]
        w = data.shape[1]
        self.img[top:top + h, left:left + w, :] = data
        self.draw_text(label, (left + 2), (top + h + 15))

    def show_map(self, left, top, label):
        maze = self.mazemap.get_map(84, 84)
        maze = (maze * 255).astype(np.uint8)
        h = maze.shape[0]
        w = maze.shape[1]
        self.img[top:top + h, left:left + w, :] = maze
        self.draw_text(label, (left + 2), (top + h + 5))

    def show_pi(self, pi):
        for i in range(len(pi)):
            width = int(pi[i] * 100)
            cv2.rectangle(self.img, (3, 113 + 15 * i), (width, 120 + 15 * i),
                          WHITE)
        self.draw_text("Policy", 20, 120 + 15 * len(pi))

    def show_ob(self, state, left, top, label):
        state = (state * 255.0).astype(np.uint8)
        h = state.shape[0]
        w = state.shape[1]
        self.img[top:top + h, left:left + w, :] = state
        self.draw_text(label, (left + 2), (top + h + 15))

    def show_value(self, left, top, height, width):
        if self.value_history.is_empty:
            return

        min_v = float("inf")
        max_v = float("-inf")
        values = self.value_history.values

        for v in values:
            min_v = min(min_v, v)
            max_v = max(max_v, v)

        bottom = top + height
        right = left + width
        d = max_v - min_v
        last_r = 0.0
        for i, v in enumerate(values):
            r = (v - min_v) / d
            if i > 0:
                x0 = i - 1 + left
                x1 = i + left
                y0 = bottom - last_r * height
                y1 = bottom - r * height
                cv2.line(self.img, (y0, x0), (y1, x1), BLUE, 2)
            last_r = r

        cv2.line(self.img, (top, left), (bottom, left), WHITE, 1)
        cv2.line(self.img, (top, right), (bottom, right), WHITE, 1)
        cv2.line(self.img, (top, left), (top, right), WHITE, 1)
        cv2.line(self.img, (bottom, left), (bottom, right), WHITE, 1)
        self.draw_text("Q Value", 120, 215)

    def show_rp(self):
        pass

    def show_reward(self):
        self.draw_text("Reward: {}".format(int(self.ep_reward)), 10, 230)

    def get_frame(self):
        return self.img
예제 #7
0
class Trainer(object):
    def __init__(self, thread_index, global_network, initial_learning_rate,
                 learning_rate_input, grad_applier, env_type, env_name,
                 use_pixel_change, use_value_replay, use_reward_prediction,
                 use_future_reward_prediction, use_autoencoder, reward_length,
                 pixel_change_lambda, entropy_beta, local_t_max, gamma,
                 gamma_pc, experience_history_size, max_global_time_step,
                 device, log_file, skip_step):

        self.thread_index = thread_index
        self.learning_rate_input = learning_rate_input
        self.env_type = env_type
        self.env_name = env_name
        self.use_pixel_change = use_pixel_change
        self.use_value_replay = use_value_replay
        self.use_reward_prediction = use_reward_prediction
        self.use_future_reward_prediction = use_future_reward_prediction
        self.use_autoencoder = use_autoencoder
        self.local_t_max = local_t_max
        self.gamma = gamma
        self.gamma_pc = gamma_pc
        self.experience_history_size = experience_history_size
        self.max_global_time_step = max_global_time_step
        self.skip_step = skip_step
        self.action_size = Environment.get_action_size(env_type, env_name)

        self.local_network = UnrealModel(self.action_size, thread_index,
                                         use_pixel_change, use_value_replay,
                                         use_reward_prediction,
                                         use_future_reward_prediction,
                                         use_autoencoder, pixel_change_lambda,
                                         entropy_beta, device)
        self.local_network.prepare_loss()

        self.apply_gradients = grad_applier.minimize_local(
            self.local_network.total_loss, global_network.get_vars(),
            self.local_network.get_vars())

        self.sync = self.local_network.sync_from(global_network)
        self.experience = Experience(self.experience_history_size,
                                     reward_length)
        self.local_t = 0
        self.initial_learning_rate = initial_learning_rate
        self.episode_reward = 0
        # For log output
        self.prev_local_t = 0
        self.log_file = log_file
        self.prediction_res_file = log_file + '/' + 'res.pkl'

    def prepare(self):
        self.environment = Environment.create_environment(
            self.env_type, self.env_name, self.skip_step)

    def stop(self):
        self.environment.stop()

    def add_summary(self, step, name, value, writer):
        summary = tf.Summary()
        summary_value = summary.value.add()
        summary_value.simple_value = float(value)
        summary_value.tag = name
        writer.add_summary(summary, step)

    def _anneal_learning_rate(self, global_time_step):
        learning_rate = self.initial_learning_rate * (
            self.max_global_time_step -
            global_time_step) / self.max_global_time_step
        if learning_rate < 0.0:
            learning_rate = 0.0
        return learning_rate

    def choose_action(self, pi_values):
        return np.random.choice(range(len(pi_values)), p=pi_values)

    def _record_score(self, sess, summary_writer, summary_op, score_input,
                      score, global_t):
        summary_str = sess.run(summary_op, feed_dict={score_input: score})
        summary_writer.add_summary(summary_str, global_t)
        summary_writer.flush()

    def set_start_time(self, start_time):
        self.start_time = start_time

    def _fill_experience(self, sess):
        """
    Fill experience buffer until buffer is full.
    """
        prev_state = self.environment.last_state
        last_action = self.environment.last_action
        last_reward = self.environment.last_reward
        last_action_reward = ExperienceFrame.concat_action_and_reward(
            last_action, self.action_size, last_reward)

        pi_, _ = self.local_network.run_base_policy_and_value(
            sess, self.environment.last_state, last_action_reward)
        action = self.choose_action(pi_)

        new_state, reward, terminal, pixel_change = self.environment.process(
            action)

        frame = ExperienceFrame(prev_state, reward, action, terminal,
                                pixel_change, last_action, last_reward)
        self.experience.add_frame(frame)

        if terminal:
            self.environment.reset()
        if self.experience.is_full():
            self.environment.reset()
            print("Replay buffer filled")

    def _print_log(self, global_t):
        if (self.thread_index == 0) and (self.local_t - self.prev_local_t >=
                                         PERFORMANCE_LOG_INTERVAL):
            self.prev_local_t += PERFORMANCE_LOG_INTERVAL
            elapsed_time = time.time() - self.start_time
            steps_per_sec = global_t / elapsed_time
            print(
                "### Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour"
                .format(global_t, elapsed_time, steps_per_sec,
                        steps_per_sec * 3600 / 1000000.))

    def _process_base(self, sess, global_t, summary_writer, summary_op,
                      score_input):
        # [Base A3C]
        states = []
        last_action_rewards = []
        actions = []
        rewards = []
        values = []

        terminal_end = False

        start_lstm_state = self.local_network.base_lstm_state_out

        # t_max times loop
        for _ in range(self.local_t_max):
            # Prepare last action reward
            last_action = self.environment.last_action
            last_reward = self.environment.last_reward
            last_action_reward = ExperienceFrame.concat_action_and_reward(
                last_action, self.action_size, last_reward)

            pi_, value_ = self.local_network.run_base_policy_and_value(
                sess, self.environment.last_state, last_action_reward)

            action = self.choose_action(pi_)

            states.append(self.environment.last_state)
            last_action_rewards.append(last_action_reward)
            actions.append(action)
            values.append(value_)

            if (self.thread_index == 0) and (self.local_t % LOG_INTERVAL == 0):
                print("pi={}".format(pi_))
                print(" V={}".format(value_))

            prev_state = self.environment.last_state

            # Process game
            new_state, reward, terminal, pixel_change = self.environment.process(
                action)
            frame = ExperienceFrame(prev_state, reward, action, terminal,
                                    pixel_change, last_action, last_reward)

            # Store to experience
            self.experience.add_frame(frame)

            self.episode_reward += reward

            rewards.append(reward)

            self.local_t += 1

            if terminal:
                terminal_end = True
                print("score={}".format(self.episode_reward))

                self._record_score(sess, summary_writer, summary_op,
                                   score_input, self.episode_reward, global_t)

                self.episode_reward = 0
                self.environment.reset()
                self.local_network.reset_state()
                break

        R = 0.0
        if not terminal_end:
            R = self.local_network.run_base_value(
                sess, new_state, frame.get_action_reward(self.action_size))

        actions.reverse()
        states.reverse()
        rewards.reverse()
        values.reverse()

        batch_si = []
        batch_a = []
        batch_adv = []
        batch_R = []

        for (ai, ri, si, Vi) in zip(actions, rewards, states, values):
            R = ri + self.gamma * R
            adv = R - Vi
            a = np.zeros([self.action_size])
            a[ai] = 1.0

            batch_si.append(si)
            batch_a.append(a)
            batch_adv.append(adv)
            batch_R.append(R)

        batch_si.reverse()
        batch_a.reverse()
        batch_adv.reverse()
        batch_R.reverse()

        return batch_si, last_action_rewards, batch_a, batch_adv, batch_R, start_lstm_state

    def _process_pc(self, sess):
        # [pixel change]
        # Sample 20+1 frame (+1 for last next state)
        pc_experience_frames = self.experience.sample_sequence(
            self.local_t_max + 1)
        # Reverse sequence to calculate from the last
        pc_experience_frames.reverse()

        batch_pc_si = []
        batch_pc_a = []
        batch_pc_R = []
        batch_pc_last_action_reward = []

        pc_R = np.zeros([20, 20], dtype=np.float32)
        if not pc_experience_frames[1].terminal:
            pc_R = self.local_network.run_pc_q_max(
                sess, pc_experience_frames[0].state,
                pc_experience_frames[0].get_last_action_reward(
                    self.action_size))

        for frame in pc_experience_frames[1:]:
            pc_R = frame.pixel_change + self.gamma_pc * pc_R
            a = np.zeros([self.action_size])
            a[frame.action] = 1.0
            last_action_reward = frame.get_last_action_reward(self.action_size)

            batch_pc_si.append(frame.state)
            batch_pc_a.append(a)
            batch_pc_R.append(pc_R)
            batch_pc_last_action_reward.append(last_action_reward)

        batch_pc_si.reverse()
        batch_pc_a.reverse()
        batch_pc_R.reverse()
        batch_pc_last_action_reward.reverse()

        return batch_pc_si, batch_pc_last_action_reward, batch_pc_a, batch_pc_R

    def _process_vr(self, sess):
        # [Value replay]
        # Sample 20+1 frame (+1 for last next state)
        vr_experience_frames = self.experience.sample_sequence(
            self.local_t_max + 1)
        # Reverse sequence to calculate from the last
        vr_experience_frames.reverse()

        batch_vr_si = []
        batch_vr_R = []
        batch_vr_last_action_reward = []

        vr_R = 0.0
        if not vr_experience_frames[1].terminal:
            vr_R = self.local_network.run_vr_value(
                sess, vr_experience_frames[0].state,
                vr_experience_frames[0].get_last_action_reward(
                    self.action_size))

        # t_max times loop
        for frame in vr_experience_frames[1:]:
            vr_R = frame.reward + self.gamma * vr_R
            batch_vr_si.append(frame.state)
            batch_vr_R.append(vr_R)
            last_action_reward = frame.get_last_action_reward(self.action_size)
            batch_vr_last_action_reward.append(last_action_reward)

        batch_vr_si.reverse()
        batch_vr_R.reverse()
        batch_vr_last_action_reward.reverse()

        return batch_vr_si, batch_vr_last_action_reward, batch_vr_R

    '''
  def _process_rp(self):
    # [Reward prediction]
    rp_experience_frames, total_raw_reward, _ = self.experience.sample_rp_sequence()
    # 4 frames

    batch_rp_si = []
    batch_rp_c = []
    
    for i in range(4):
      batch_rp_si.append(rp_experience_frames[i].state)

    # one hot vector for target reward
    r = total_raw_reward
    rp_c = [0.0, 0.0, 0.0]
    if r == 0:
      rp_c[0] = 1.0 # zero
    elif r > 0:
      rp_c[1] = 1.0 # positive
    else:
      rp_c[2] = 1.0 # negative
    batch_rp_c.append(rp_c)
    return batch_rp_si, batch_rp_c
  '''

    def _process_replay(self, action=False):
        # [Reward prediction]
        rp_experience_frames, total_raw_reward, next_frame = self.experience.sample_rp_sequence(
            flag=True)
        # 4 frames

        batch_rp_si = []
        batch_rp_c = []

        for i in range(4):
            batch_rp_si.append(rp_experience_frames[i].state)

        # one hot vector for target reward
        r = total_raw_reward
        rp_c = [0.0, 0.0, 0.0]
        if r == 0:
            rp_c[0] = 1.0  # zero
        elif r > 0:
            rp_c[1] = 1.0  # positive
        else:
            rp_c[2] = 1.0  # negative
        batch_rp_c.append(rp_c)

        result = [batch_rp_si, batch_rp_c, next_frame]

        if action:
            batch_rp_action = []
            action_index = rp_experience_frames[3].action
            action_one_hot = np.zeros([self.action_size])
            action_one_hot[action_index] = 1.0
            batch_rp_action.append(action_one_hot)
            result.append(batch_rp_action)
        return result

    def process(self, sess, global_t, summary_writer, summary_op, score_input):
        # Fill experience replay buffer
        if not self.experience.is_full():
            self._fill_experience(sess)
            return 0

        start_local_t = self.local_t

        cur_learning_rate = self._anneal_learning_rate(global_t)

        # Copy weights from shared to local
        sess.run(self.sync)

        # [Base]
        batch_si, batch_last_action_rewards, batch_a, batch_adv, batch_R, start_lstm_state = \
              self._process_base(sess,
                                 global_t,
                                 summary_writer,
                                 summary_op,
                                 score_input)
        feed_dict = {
            self.local_network.base_input: batch_si,
            self.local_network.base_last_action_reward_input:
            batch_last_action_rewards,
            self.local_network.base_a: batch_a,
            self.local_network.base_adv: batch_adv,
            self.local_network.base_r: batch_R,
            self.local_network.base_initial_lstm_state: start_lstm_state,
            # [common]
            self.learning_rate_input: cur_learning_rate
        }

        # [Pixel change]
        if self.use_pixel_change:
            batch_pc_si, batch_pc_last_action_reward, batch_pc_a, batch_pc_R = self._process_pc(
                sess)

            pc_feed_dict = {
                self.local_network.pc_input: batch_pc_si,
                self.local_network.pc_last_action_reward_input:
                batch_pc_last_action_reward,
                self.local_network.pc_a: batch_pc_a,
                self.local_network.pc_r: batch_pc_R
            }
            feed_dict.update(pc_feed_dict)

        # [Value replay]
        if self.use_value_replay:
            batch_vr_si, batch_vr_last_action_reward, batch_vr_R = self._process_vr(
                sess)

            vr_feed_dict = {
                self.local_network.vr_input: batch_vr_si,
                self.local_network.vr_last_action_reward_input:
                batch_vr_last_action_reward,
                self.local_network.vr_r: batch_vr_R
            }
            feed_dict.update(vr_feed_dict)

        # [Reward prediction]
        next_frame = None
        if self.use_reward_prediction:
            batch_rp_si, batch_rp_c, next_frame = self._process_replay()
            rp_feed_dict = {
                self.local_network.rp_input: batch_rp_si,
                self.local_network.rp_c_target: batch_rp_c
            }
            feed_dict.update(rp_feed_dict)

        # [Future reward prediction]
        if self.use_future_reward_prediction:
            batch_frp_si, batch_frp_c, next_frame, batch_frp_action = self._process_replay(
                action=True)
            frp_feed_dict = {
                self.local_network.frp_input: batch_frp_si,
                self.local_network.frp_c_target: batch_frp_c,
                self.local_network.frp_action_input: batch_frp_action
            }
            feed_dict.update(frp_feed_dict)

        if next_frame and self.use_autoencoder:
            ae_feed_dict = {
                self.local_network.ground_truth:
                np.expand_dims(next_frame.state, axis=0)
            }
            feed_dict.update(ae_feed_dict)

        # Calculate gradients and copy them to global network.
        #sess.run( self.apply_gradients, feed_dict=feed_dict)
        ln = self.local_network
        if self.use_future_reward_prediction:
            if self.use_autoencoder:
                frp_c, decoder_loss, frp_loss, value_loss, policy_loss, _ = sess.run(
                    [
                        ln.frp_c, ln.decoder_loss, ln.frp_loss, ln.value_loss,
                        ln.policy_loss, self.apply_gradients
                    ],
                    feed_dict=feed_dict)
                self.add_summary(global_t, 'decoder_loss', decoder_loss,
                                 summary_writer)
                self.add_summary(global_t, 'frp_loss', frp_loss,
                                 summary_writer)
            else:
                frp_c, value_loss, policy_loss, _ = sess.run(
                    [
                        ln.frp_c, ln.value_loss, ln.policy_loss,
                        self.apply_gradients
                    ],
                    feed_dict=feed_dict)
            acc = ((frp_c == frp_c.max()) * batch_frp_c).sum()
            self.add_summary(global_t, 'reward prediction accuracy', acc,
                             summary_writer)
        else:
            value_loss, policy_loss, _ = sess.run(
                [ln.value_loss, ln.policy_loss, self.apply_gradients],
                feed_dict=feed_dict)

        self.add_summary(global_t, 'value_loss', value_loss, summary_writer)
        self.add_summary(global_t, 'policy_loss', policy_loss, summary_writer)
        self.add_summary(global_t, 'base_loss', policy_loss + value_loss,
                         summary_writer)

        if self.use_autoencoder and global_t % 25000 == 0:
            current_res = {
                'next_frame_ground_truth': next_frame,
                'step': global_t
            }
            if self.use_reward_prediction:
                predicted_frame, predicted_reward = sess.run(
                    [
                        self.local_network.encoder_output,
                        self.local_network.rp_c
                    ],
                    feed_dict=feed_dict)
                current_res['states'] = batch_rp_si
                current_res['target_reward'] = batch_rp_c
            elif self.use_future_reward_prediction:
                predicted_frame, predicted_reward = sess.run(
                    [
                        self.local_network.encoder_output,
                        self.local_network.frp_c
                    ],
                    feed_dict=feed_dict)
                current_res['states'] = batch_frp_si
                current_res['target_reward'] = batch_frp_c
                current_res['action'] = batch_frp_action
            current_res['next_frame_prediction'] = predicted_frame
            current_res['next_reward_prediction'] = predicted_reward
            if os.path.exists(self.prediction_res_file) and os.path.getsize(
                    self.prediction_res_file) > 0:
                with open(self.prediction_res_file, 'rb') as f:
                    res = pickle.load(f)
            else:
                res = []
            res.append(current_res)
            with open(self.prediction_res_file, 'wb') as f:
                pickle.dump(res, f)

        self._print_log(global_t)

        # Return advanced local step size
        diff_local_t = self.local_t - start_local_t
        return diff_local_t
예제 #8
0
class Evaluate(object):
  def __init__(self):
    self.action_size = Environment.get_action_size(flags.env_type, flags.env_name)
    self.objective_size = Environment.get_objective_size(flags.env_type, flags.env_name)

    env_config = sim_config.get(flags.env_name)
    self.image_shape = [env_config['height'], env_config['width']]
    segnet_param_dict = {'segnet_mode': flags.segnet}
    is_training = tf.placeholder(tf.bool, name="training") # for display param in UnrealModel says its value

    self.global_network = UnrealModel(self.action_size,
                                      self.objective_size,
                                      -1,
                                      flags.use_lstm,
                                      flags.use_pixel_change,
                                      flags.use_value_replay,
                                      flags.use_reward_prediction,
                                      0.0, #flags.pixel_change_lambda
                                      0.0, #flags.entropy_beta
                                      device,
                                      segnet_param_dict=segnet_param_dict,
                                      image_shape=self.image_shape,
                                      is_training=is_training,
                                      n_classes=flags.n_classes,
                                      segnet_lambda=flags.segnet_lambda,
                                      dropout=flags.dropout,
                                      for_display=True)
    self.environment = Environment.create_environment(flags.env_type, flags.env_name, flags.termination_time_sec,
                                                      env_args={'episode_schedule': flags.split,
                                                                'log_action_trace': flags.log_action_trace,
                                                                'max_states_per_scene': flags.episodes_per_scene,
                                                                'episodes_per_scene_test': flags.episodes_per_scene})

    self.global_network.prepare_loss()

    self.total_loss = []
    self.segm_loss = []
    self.episode_reward = [0]
    self.episode_roomtype = []
    self.roomType_dict  = {}
    self.segnet_class_dict = {}
    self.success_rate = []
    self.batch_size = 20
    self.batch_cur_num = 0
    self.batch_prev_num = 0
    self.batch_si = []
    self.batch_sobjT = []
    self.batch_a = []
    self.batch_reward = []

  def update(self, sess):
    self.process(sess)

  def is_done(self):
    return self.environment.is_all_scheduled_episodes_done()

  def choose_action(self, pi_values):
    return np.random.choice(range(len(pi_values)), p=pi_values)

  def process(self, sess):
    last_action = self.environment.last_action
    last_reward = self.environment.last_reward
    last_action_reward = ExperienceFrame.concat_action_and_reward(last_action, self.action_size,
                                                                  last_reward, self.environment.last_state)
    if random_policy:
      pi_values = [1/3.0, 1/3.0, 1/3.0]
      action = self.choose_action(pi_values)
      state, reward, terminal, pixel_change = self.environment.process(action)
      self.episode_reward[-1] += reward
    else:
      mode = "segnet" if flags.segnet >= 2 else ""
      segnet_preds = None
      if not flags.use_pixel_change:
        pi_values, v_value, segnet_preds = self.global_network.run_base_policy_and_value(sess,
                                                                           self.environment.last_state,
                                                                           last_action_reward, mode=mode)
      else:
        pi_values, v_value, pc_q = self.global_network.run_base_policy_value_pc_q(sess,
                                                                                  self.environment.last_state,
                                                                                  last_action_reward)

      if segnet_preds is not None:
          mask = self.environment.last_state.get('objectType', None)
          if mask is not None:
              new_classes = np.unique(mask)
              if segnet_preds.shape != mask.shape:
                  print("Predictions have shape {}, but groundtruth mask has shape {}".format(segnet_preds.shape, mask.shape))
              else:
                  similar = segnet_preds == mask
                  for id_class in new_classes:
                      id_list = self.segnet_class_dict.get(id_class, None)
                      if id_list is None:
                          id_list = []
                      id_list += [[np.sum(similar[mask == id_class]), np.sum(mask == id_class)]]
                      self.segnet_class_dict[id_class] = id_list

      self.batch_cur_num += 1
      if flags.segnet == -1: #just not necessary
        if self.batch_cur_num != 0 and self.batch_cur_num - self.batch_prev_num >= self.batch_size:

          #print(np.unique(self.batch_sobjT))
          feed_dict = {self.global_network.base_input: self.batch_si,
                       self.global_network.base_segm_mask: self.batch_sobjT,
                       self.global_network.is_training: not True}

          segm_loss, preds, confusion_mtx = sess.run([self.global_network.decoder_loss,
                                                    self.global_network.preds, self.global_network.update_evaluation_vars],
                                                   feed_dict=feed_dict)
          total_loss = 0
          self.total_loss += [total_loss]
          self.segm_loss += [segm_loss] # TODO: here do something with it, store somwhere?

          #update every_thing else
          self.batch_prev_num = self.batch_cur_num
          self.batch_si = []
          self.batch_sobjT = []
          self.batch_a = []
        else:
          self.batch_si += [self.environment.last_state["image"]]
          self.batch_sobjT += [self.environment.last_state["objectType"]]
          self.batch_a += [self.environment.ACTION_LIST[self.environment.last_action]]

      action = self.choose_action(pi_values)
      state, reward, terminal, pixel_change = self.environment.process(action)
      self.episode_reward[-1] += reward

    if terminal:
      ep_info = self.environment._episode_info
      if ep_info['task'] == 'room_goal':
          one_hot_room = ep_info['goal']['roomTypeEncoded']
          room_type = ep_info['goal']['roomType']
          ind = np.where(one_hot_room)[0][0]
          self.roomType_dict[ind] = room_type
          self.episode_roomtype += [ind]
      self.success_rate += [int(self.environment._last_full_state["success"])]
      self.environment.reset()
      self.episode_reward += [0]