Beispiel #1
0
    def hir_relabel(self, non_null_future_idx, episode_experience, current_t,
                    replay_buffer, env):
        """Relabeling trajectories.

    Args:
      non_null_future_idx: list of time step where something happens
      episode_experience: the RL environment
      current_t: time time step at which the experience is relabeled
      replay_buffer:  the experience replay buffer
      env: the RL environment

    Returns:
      the reset state of the environment
    """
        ep_len = len(episode_experience)
        s, a, _, s_tp1, _, ag = episode_experience[current_t]
        if ag:
            # TODO(ydjiang): k_immediate logic needs improvement
            for _ in range(self.cfg.k_immediate):
                ag_text_single = random.choice(ag)
                g_type = instruction_type(ag_text_single)
                if self.cfg.paraphrase and g_type != 'unary':
                    ag_text_single = paraphrase_sentence(
                        ag_text_single,
                        delete_color=self.cfg.diverse_scene_content)
                replay_buffer.add((s, a, env.reward_scale, s_tp1,
                                   self.encode_fn(ag_text_single)))
                if g_type == 'unary' and self.cfg.negate_unary:
                    negative_ag = negate_unary_sentence(ag_text_single)
                    if negative_ag:
                        replay_buffer.add(
                            (s, a, 0.0, s_tp1, self.encode_fn(negative_ag)))
        # TODO(ydjiang): repeat logit needs improvement
        goal_count, repeat = 0, 0
        while goal_count < self.cfg.future_k and repeat < (ep_len -
                                                           current_t) * 4:
            repeat += 1
            future = np.random.randint(current_t, ep_len)
            _, _, _, _, _, ag_future = episode_experience[future]
            if not ag_future:
                continue
            random.shuffle(ag_future)
            for single_g in ag_future:
                if instruction_type(single_g) != 'unary':
                    discount = self.cfg.discount**(future - current_t)
                    if self.cfg.paraphrase:
                        single_g = paraphrase_sentence(
                            single_g,
                            delete_color=self.cfg.diverse_scene_content)
                    replay_buffer.add(
                        (s, a, discount * env.reward_scale, s_tp1,
                         self.encode_fn(single_g)))
                    goal_count += 1
                    break
    def hir_relabel(self, non_null_future_idx, episode_experience, current_t,
                    replay_buffer, env):
        """Relabeling trajectories.

    Args:
      non_null_future_idx: list of time step where something happens
      episode_experience: the RL environment
      current_t: time time step at which the experience is relabeled
      replay_buffer:  the experience replay buffer
      env: the RL environment

    Returns:
      the reset state of the environment
    """
        s, a, _, s_tp1, g, ag = episode_experience[current_t]
        if not ag:
            return
        for _ in range(min(self.cfg.k_immediate, len(ag) + 1)):
            ag_text_single = random.choice(ag)
            g_type = instruction_type(ag_text_single)
            if self.cfg.paraphrase and g_type != 'unary':
                ag_text_single = paraphrase_sentence(
                    ag_text_single,
                    delete_color=self.cfg.diverse_scene_content)
            replay_buffer.add((s, a, env.reward_scale, s_tp1,
                               self.encode_fn(ag_text_single), ag))