コード例 #1
0
ファイル: trainer_td.py プロジェクト: nirbhayjm/temp-1
def repalce_mission_with_omega(obs_input: DictObs,
                               omega: torch.Tensor) -> DictObs:
    # obs = copy.deepcopy(obs_input)
    obs = obs_input.copy()
    if 'mission' in obs.keys():
        obs.pop('mission')
    obs.update({'omega': omega.clone()})
    return obs
コード例 #2
0
ファイル: trainer_td.py プロジェクト: nirbhayjm/temp-1
def replace_goal_vector_with_z(obs_input: DictObs,
                               z_latent: torch.Tensor) -> DictObs:
    # obs = copy.deepcopy(obs_input)
    obs = obs_input.copy()
    if 'goal_vector' in obs.keys():
        obs.pop('omega')
    obs.update({'z_latent': z_latent.clone()})
    return obs
コード例 #3
0
ファイル: trainer_td.py プロジェクト: nirbhayjm/temp-1
    def eval_success_td(self):
        with torch.no_grad():
            self.val_envs.reset_config_rng()
            assert self.val_envs.get_attr('reset_on_done')[0]
            self.actor_critic.train()
            reset_output = self.val_envs.reset()

            obs = reset_output[:, 0]
            info = reset_output[:, 1]
            obs = dict_stack_helper(obs)
            obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \
                for key, obs_i in obs.items()})

            episode_counter = 0
            episode_rewards = np.zeros((self.args.num_eval_episodes))
            episode_mrids = np.zeros((self.args.num_eval_episodes))
            masks = torch.ones(self.num_processes_eff,
                               1).float().to(self.device)
            recurrent_hidden_states = torch.zeros(
                self.args.num_processes,
                self.actor_critic.recurrent_hidden_state_size).to(self.device)

            eval_done = False
            while not eval_done:
                z_latent, z_gauss_dist, value, action, \
                action_log_prob, recurrent_hidden_states = \
                    self.actor_critic.act(
                        inputs=obs,
                        rnn_hxs=recurrent_hidden_states,
                        masks=masks,
                        do_z_sampling=False)

                cpu_actions = action.view(-1).cpu().numpy()
                obs, _, done, info = self.val_envs.step(cpu_actions)
                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in done]).to(self.device)

                obs = dict_stack_helper(obs)
                obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \
                    for key, obs_i in obs.items()})

                for batch_idx, info_item in enumerate(info):
                    if 'prev_episode' in info_item.keys():
                        episode_rewards[episode_counter] = \
                            info_item['prev_episode']['info']['episode_reward']
                        episode_mrids[episode_counter] = \
                            info_item['prev_episode']['info']['max_room_id']

                        episode_counter += 1
                        if episode_counter >= self.args.num_eval_episodes:
                            eval_done = True
                            break

            episode_success = 1.0 * (episode_rewards > 0)
            return episode_rewards, episode_success, episode_mrids
コード例 #4
0
ファイル: trainer_td.py プロジェクト: nirbhayjm/temp-1
    def train_infobot_supervised(self, total_training_steps, start_iter):
        """Train loop"""

        print("=" * 36)
        print("Trainer initialized! Training information:")
        print("\t# of total_training_steps: {}".format(total_training_steps))
        # print("\t# of train envs: {}".format(len(self.train_envs)))
        print("\tnum_processes: {}".format(self.args.num_processes))
        print("\tnum_agents: {}".format(self.args.num_agents))
        # print("\tIterations per epoch: {}".format(self.num_batches_per_epoch))
        print("=" * 36)

        self.save_checkpoint(0)

        if self.args.model == 'hier':
            self.do_sampling = True
        elif self.args.model == 'cond':
            self.do_sampling = False
        next_save_on = 1 * self.args.save_interval

        self.actor_critic.train()
        # self.agent_pos = np.zeros(
        #     [self.args.num_steps + 1, self.num_processes_eff, 2], dtype='int')
        # self.visit_count = [np.ones(self.num_processes_eff)]
        # self.visit_count = np.ones(
        #     [self.args.num_steps, self.num_processes_eff], dtype='int')
        # self.heuristic_ds = np.zeros(
        #     [self.args.num_steps, self.num_processes_eff], dtype='int')

        reset_output = self.train_envs.reset()

        obs = reset_output[:, 0]
        info = reset_output[:, 1]
        obs = dict_stack_helper(obs)
        # info = dict_stack_helper(info)
        # curr_pos = np.stack([item['agent_pos'] for item in info], 0)
        # self.agent_pos[0] = curr_pos
        # [obs] = flatten_batch_dims(obs)
        obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \
            for key, obs_i in obs.items()})

        self.rollouts.obs[0].copy_(obs)
        self.rollouts.to(self.device)

        # time_steps = torch.zeros(self.num_processes_eff, 1).long().to(self.device)
        # episode_rewards = torch.zeros(self.num_processes_eff, 1).to(self.device)
        episode_counter = 0
        episode_rewards = deque(maxlen=300)
        episode_mrids = deque(maxlen=300)
        ep_len = deque(maxlen=300)

        zz_kld = deque(maxlen=self.args.log_interval)
        zz_kl_loss = deque(maxlen=self.args.log_interval)
        effective_return = deque(maxlen=self.args.log_interval)

        masks = torch.ones(self.num_processes_eff, 1).float().to(self.device)
        recurrent_hidden_states = torch.zeros(
            self.args.num_steps + 1, self.args.num_processes,
            self.actor_critic.recurrent_hidden_state_size)

        num_updates = int(total_training_steps) // \
            (self.num_processes_eff * self.args.num_steps)

        def batch_iterator(start_idx):
            idx = start_idx
            for _ in range(start_idx, num_updates + self.args.log_interval):
                yield idx
                idx += 1

        start = time.time()

        for iter_id in batch_iterator(start_iter):
            self.actor_critic.train()
            self.rollouts.prev_final_mask.fill_(0)

            for step in range(self.args.num_steps):
                with torch.no_grad():
                    z_latent, z_gauss_dist, value, action, \
                    action_log_prob, recurrent_hidden_states = \
                        self.actor_critic.act(
                            inputs=obs,
                            rnn_hxs=self.rollouts.recurrent_hidden_states[step],
                            masks=self.rollouts.masks[step],
                            do_z_sampling=self.args.z_stochastic)
                    z_eps = (z_latent - z_gauss_dist.loc) / z_gauss_dist.scale

                cpu_actions = action.view(-1).cpu().numpy()

                obs, reward, done, info = self.train_envs.step(cpu_actions)

                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in done]).to(self.device)
                episode_counter += done.sum()

                obs = dict_stack_helper(obs)
                obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \
                    for key, obs_i in obs.items()})

                # curr_pos = np.stack([item['agent_pos'] for item in info], 0)
                # curr_dir = np.stack([item['agent_dir'] for item in info], 0)
                # visit_count = np.stack([item['visit_count'] for item in info], 0)

                # self.agent_pos[step + 1] = curr_pos
                # self.visit_count[step] = visit_count
                # if 'is_heuristic_ds' in info[0].keys():
                #     is_heuristic_ds = np.stack(
                #         [item['is_heuristic_ds'] for item in info], 0)
                #     self.heuristic_ds[step] = is_heuristic_ds

                for batch_idx, info_item in enumerate(info):
                    if 'prev_episode' in info_item.keys():
                        # prev_final_obs = info_item['prev_episode']['obs']
                        # prev_final_obs = DictObs(
                        #     {key:torch.from_numpy(obs_i).to(self.device) \
                        #         for key, obs_i in prev_final_obs.items()})
                        # self.rollouts.prev_final_mask[step, batch_idx] = 1
                        # self.rollouts.prev_final_visit_count[step, batch_idx] = \
                        #     info_item['visit_count']
                        # self.rollouts.prev_final_heur_ds[step, batch_idx] = \
                        #     float(info_item['is_heuristic_ds'])
                        # self.rollouts.prev_final_obs[step, batch_idx].copy_(
                        #     prev_final_obs)
                        episode_rewards.append(info_item['prev_episode']
                                               ['info']['episode_reward'])
                        episode_mrids.append(
                            info_item['prev_episode']['info']['max_room_id'])
                        ep_len.append(
                            info_item['prev_episode']['info']['step_count'])

                reward = torch.from_numpy(reward[:, np.newaxis]).float()
                reward = reward.to(self.device)

                self.rollouts.insert(
                    obs=obs,
                    recurrent_hidden_states=recurrent_hidden_states,
                    actions=action,
                    action_log_probs=action_log_prob,
                    value_preds=value,
                    z_eps=z_eps,
                    rewards=reward,
                    masks=masks,
                )

            with torch.no_grad():
                next_value = self.actor_critic.get_value(
                    inputs=self.rollouts.obs[-1],
                    rnn_hxs=self.rollouts.recurrent_hidden_states[-1],
                    masks=self.rollouts.masks[-1],
                ).detach()


            total_num_steps = (iter_id + 1) * \
                self.num_processes_eff * self.args.num_steps

            anneal_coeff = utils.kl_coefficient_curriculum(
                iter_id=total_num_steps,
                iters_per_epoch=1,
                start_after_epochs=self.args.kl_anneal_start_epochs,
                linear_growth_epochs=self.args.kl_anneal_growth_epochs,
            )

            q_start_flag = utils.q_start_curriculum(
                iter_id=total_num_steps,
                iters_per_epoch=1,
                start_after_epochs=self.args.q_start_epochs,
            )

            if not self.args.z_stochastic:
                infobot_coeff = 0
            else:
                infobot_coeff = utils.kl_coefficient_curriculum(
                    iter_id=total_num_steps,
                    iters_per_epoch=1,
                    start_after_epochs=self.args.infobot_kl_start,
                    linear_growth_epochs=self.args.infobot_kl_growth,
                )

            min_ib_coeff = min(self.args.infobot_beta_min,
                               self.args.infobot_beta)
            if self.args.infobot_beta > 0:
                infobot_coeff = max(infobot_coeff,
                                    min_ib_coeff / self.args.infobot_beta)
                if not self.args.z_stochastic:
                    infobot_coeff = 0

            if self.args.algo == 'a2c' or self.args.algo == 'acktr':
                # Conditional model
                value_loss, action_loss, dist_entropy,\
                action_log_probs_mean, ic_info = \
                    self.agent.update_infobot_supervised(
                        rollouts=self.rollouts,
                        infobot_beta=self.args.infobot_beta,
                        next_value=next_value,
                        anneal_coeff=infobot_coeff,
                    )

                ic_info.update({
                    'anneal_coeff': infobot_coeff,
                    'q_start_flag': q_start_flag,
                })
                zz_kld.append(ic_info['zz_kld'])
                zz_kl_loss.append(ic_info['zz_kl_loss'])
                effective_return.append(ic_info['effective_return'])

            else:
                raise ValueError("Unknown algo: {}".format(self.args.algo))

            self.rollouts.after_update()

            if iter_id % self.args.log_interval == 0:
                if len(episode_rewards) > 1:
                    # cpu_rewards = episode_rewards.cpu().numpy()
                    cpu_rewards = episode_rewards
                    mrids = episode_mrids
                    episode_length = ep_len
                else:
                    cpu_rewards = np.array([0])
                    mrids = np.array([0])
                    episode_length = np.array([0])
                end = time.time()
                FPS = int(total_num_steps / (end - start))

                print(
                    f"Updates {iter_id}, num timesteps {total_num_steps}, FPS {FPS}, episodes: {episode_counter} \n Last {len(cpu_rewards)} training episodes: mean/median reward {np.mean(cpu_rewards):.1f}/{np.median(cpu_rewards):.1f}, min/max reward {np.min(cpu_rewards):.1f}/{np.max(cpu_rewards):.1f}"
                )

                print(
                    f" Max room id mean/median: {np.mean(mrids):.1f}/{np.median(mrids):.1f}, min/max: {np.min(mrids)}/{np.max(mrids)}"
                )

                train_success = 1.0 * (np.array(cpu_rewards) > 0)

                self.logger.plot_success(
                    prefix="train_",
                    total_num_steps=total_num_steps,
                    rewards=cpu_rewards,
                    success=train_success,
                    mrids=mrids,
                )
                self.logger.viz.line(total_num_steps,
                                     FPS,
                                     "FPS",
                                     "FPS",
                                     xlabel="time_steps")
                self.logger.plot_quad_stats(x_val=total_num_steps,
                                            array=episode_length,
                                            plot_title="episode_length")
                self.logger.viz.line(total_num_steps,
                                     np.mean(effective_return),
                                     "effective_return",
                                     "mean",
                                     xlabel="time_steps")
                self.logger.viz.line(total_num_steps,
                                     np.mean(zz_kld),
                                     "zz_kl",
                                     "zz_kld",
                                     xlabel="time_steps")
                self.logger.viz.line(total_num_steps,
                                     np.mean(zz_kl_loss),
                                     "zz_kl",
                                     "zz_kl_loss",
                                     xlabel="time_steps")
                self.logger.viz.line(total_num_steps,
                                     infobot_coeff,
                                     "zz_kl",
                                     "anneal_coeff",
                                     xlabel="time_steps")
                self.logger.viz.line(total_num_steps,
                                     np.mean(dist_entropy),
                                     "policy_entropy",
                                     "entropy",
                                     xlabel="time_steps")

            if total_num_steps > self.next_val_after:
                print(f"Evaluating success at {total_num_steps} steps")
                self.next_val_after += self.args.val_interval
                val_rewards, val_success, val_mrids = self.eval_success_td()
                best_success_achieved = self.logger.plot_success(
                    prefix="val_",
                    total_num_steps=total_num_steps,
                    rewards=val_rewards,
                    success=val_success,
                    mrids=val_mrids,
                    track_best=True,
                )
                self.save_checkpoint(total_num_steps,
                                     fname="best_val_success.vd")

            if total_num_steps > next_save_on:
                next_save_on += self.args.save_interval
                self.save_checkpoint(total_num_steps)
コード例 #5
0
ファイル: trainer_td.py プロジェクト: nirbhayjm/temp-1
    def train(self, total_training_steps, start_iter):
        """Train loop"""

        print("=" * 36)
        print("Trainer initialized! Training information:")
        print("\t# of total_training_steps: {}".format(total_training_steps))
        # print("\t# of train envs: {}".format(len(self.train_envs)))
        print("\tnum_processes: {}".format(self.args.num_processes))
        print("\tnum_agents: {}".format(self.args.num_agents))
        # print("\tIterations per epoch: {}".format(self.num_batches_per_epoch))
        print("=" * 36)

        self.save_checkpoint(0)

        if self.args.model == 'hier':
            self.do_sampling = True
        elif self.args.model == 'cond':
            self.do_sampling = False

        self.actor_critic.train()
        self.agent_pos = np.zeros(
            [self.args.num_steps + 1, self.num_processes_eff, 2], dtype='int')
        # self.visit_count = [np.ones(self.num_processes_eff)]
        self.visit_count = np.ones(
            [self.args.num_steps, self.num_processes_eff], dtype='int')
        self.heuristic_ds = np.zeros(
            [self.args.num_steps, self.num_processes_eff], dtype='int')

        reset_output = self.train_envs.reset()

        obs = reset_output[:, 0]
        info = reset_output[:, 1]
        obs = dict_stack_helper(obs)
        # info = dict_stack_helper(info)
        curr_pos = np.stack([item['agent_pos'] for item in info], 0)
        self.agent_pos[0] = curr_pos
        # [obs] = flatten_batch_dims(obs)
        obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \
            for key, obs_i in obs.items()})

        self.rollouts.obs[0].copy_(obs)
        self.rollouts.to(self.device)

        # time_steps = torch.zeros(self.num_processes_eff, 1).long().to(self.device)
        # episode_rewards = torch.zeros(self.num_processes_eff, 1).to(self.device)
        episode_counter = 0
        episode_rewards = deque(maxlen=300)
        episode_mrids = deque(maxlen=300)
        masks = torch.ones(self.num_processes_eff, 1).float().to(self.device)
        recurrent_hidden_states = torch.zeros(
            self.args.num_steps + 1, self.args.num_processes,
            self.actor_critic.recurrent_hidden_state_size)

        num_updates = int(total_training_steps) // \
            (self.num_processes_eff * self.args.num_steps)

        def batch_iterator(start_idx):
            idx = start_idx
            for _ in range(start_idx, num_updates + self.args.log_interval):
                yield idx
                idx += 1

        start = time.time()

        for iter_id in batch_iterator(start_iter):
            self.actor_critic.train()
            self.rollouts.prev_final_mask.fill_(0)

            for step in range(self.args.num_steps):
                with torch.no_grad():
                    value, action, action_log_prob, recurrent_hidden_states = \
                        self.actor_critic.act(
                            inputs=obs,
                            rnn_hxs=self.rollouts.recurrent_hidden_states[step],
                            masks=self.rollouts.masks[step])

                cpu_actions = action.view(-1).cpu().numpy()

                obs, reward, done, info = self.train_envs.step(cpu_actions)

                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in done]).to(self.device)
                episode_counter += done.sum()

                obs = dict_stack_helper(obs)
                obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \
                    for key, obs_i in obs.items()})

                curr_pos = np.stack([item['agent_pos'] for item in info], 0)
                curr_dir = np.stack([item['agent_dir'] for item in info], 0)
                visit_count = np.stack([item['visit_count'] for item in info],
                                       0)

                self.agent_pos[step + 1] = curr_pos
                self.visit_count[step] = visit_count
                if 'is_heuristic_ds' in info[0].keys():
                    is_heuristic_ds = np.stack(
                        [item['is_heuristic_ds'] for item in info], 0)
                    self.heuristic_ds[step] = is_heuristic_ds

                for batch_idx, info_item in enumerate(info):
                    if 'prev_episode' in info_item.keys():
                        prev_final_obs = info_item['prev_episode']['obs']
                        prev_final_obs = DictObs(
                            {key:torch.from_numpy(obs_i).to(self.device) \
                                for key, obs_i in prev_final_obs.items()})
                        self.rollouts.prev_final_mask[step, batch_idx] = 1
                        self.rollouts.prev_final_visit_count[step, batch_idx] = \
                            info_item['visit_count']
                        self.rollouts.prev_final_heur_ds[step, batch_idx] = \
                            float(info_item['is_heuristic_ds'])
                        self.rollouts.prev_final_obs[step, batch_idx].copy_(
                            prev_final_obs)
                        episode_rewards.append(info_item['prev_episode']
                                               ['info']['episode_reward'])
                        episode_mrids.append(
                            info_item['prev_episode']['info']['max_room_id'])

                reward = torch.from_numpy(reward[:, np.newaxis]).float()
                reward = reward.to(self.device)
                # episode_rewards += reward
                # reward = torch.from_numpy(reward).float()

                # not_done = np.logical_not(done)

                # masks = torch.from_numpy(not_done.astype('float32')).unsqueeze(1)
                # masks = masks.to(self.device)

                self.rollouts.insert(
                    obs=obs,
                    recurrent_hidden_states=recurrent_hidden_states,
                    actions=action,
                    action_log_probs=action_log_prob,
                    value_preds=value,
                    rewards=reward,
                    masks=masks,
                )

            with torch.no_grad():
                next_value = self.actor_critic.get_value(
                    inputs=self.rollouts.obs[-1],
                    rnn_hxs=self.rollouts.recurrent_hidden_states[-1],
                    masks=self.rollouts.masks[-1],
                ).detach()

            anneal_coeff = utils.kl_coefficient_curriculum(
                iter_id=iter_id,
                iters_per_epoch=self.num_batches_per_epoch,
                start_after_epochs=self.args.kl_anneal_start_epochs,
                linear_growth_epochs=self.args.kl_anneal_growth_epochs,
            )

            q_start_flag = utils.q_start_curriculum(
                iter_id=iter_id,
                iters_per_epoch=self.num_batches_per_epoch,
                start_after_epochs=self.args.q_start_epochs,
            )

            if self.args.algo == 'a2c' or self.args.algo == 'acktr':
                # Conditional model
                value_loss, action_loss, dist_entropy,\
                action_log_probs_mean, ic_info, option_info = \
                    self.agent.update(
                        rollouts=self.rollouts,
                        hier_mode=self.args.hier_mode,
                        use_intrinsic_control=False,
                        next_value=next_value,
                        option_space=self.args.option_space,
                        use_ib=self.args.use_infobot,
                        agent_pos=self.agent_pos,
                        bonus_z_encoder=self.z_encoder,
                        b_args=self.b_args,
                        bonus_type=self.args.bonus_type,
                        bonus_normalization=self.args.bonus_normalization,
                        heuristic_ds=self.heuristic_ds,
                        heuristic_coeff=self.args.bonus_heuristic_beta,
                        visit_count=self.visit_count,
                    )

                ic_info.update({
                    'anneal_coeff': anneal_coeff,
                    # 'infobot_coeff': infobot_coeff,
                    'q_start_flag': q_start_flag,
                })

                # if 'traj_ce_loss' in ic_info:
                #     traj_ce_loss.extend(ic_info['traj_ce_loss'])
            else:
                raise ValueError("Unknown algo: {}".format(self.args.algo))

            self.rollouts.after_update()

            total_num_steps = (iter_id + 1) * \
                self.num_processes_eff * self.args.num_steps
            if iter_id % self.args.log_interval == 0:
                if len(episode_rewards) > 1:
                    # cpu_rewards = episode_rewards.cpu().numpy()
                    cpu_rewards = episode_rewards
                    mrids = episode_mrids
                else:
                    cpu_rewards = np.array([0])
                    mrids = np.array([-1])
                end = time.time()
                FPS = int(total_num_steps / (end - start))

                print(
                    f"Updates {iter_id}, num timesteps {total_num_steps}, FPS {FPS}, episodes: {episode_counter} \n Last {len(cpu_rewards)} training episodes: mean/median reward {np.mean(cpu_rewards):.1f}/{np.median(cpu_rewards):.1f}, min/max reward {np.min(cpu_rewards):.1f}/{np.max(cpu_rewards):.1f}"
                )

                print(
                    f" Max room id mean/median: {np.mean(mrids):.1f}/{np.median(mrids):.1f}, min/max: {np.min(mrids)}/{np.max(mrids)}"
                )

                train_success = 1.0 * (np.array(cpu_rewards) > 0)

                self.logger.plot_success(
                    prefix="train_",
                    total_num_steps=total_num_steps,
                    rewards=cpu_rewards,
                    success=train_success,
                    mrids=mrids,
                )
                self.logger.viz.line(total_num_steps,
                                     FPS,
                                     "FPS",
                                     "FPS",
                                     xlabel="time_steps")

            if total_num_steps > self.next_val_after:
                print(f"Evaluating success at {total_num_steps} steps")
                self.next_val_after += self.args.val_interval
                val_rewards, val_success, val_mrids = self.eval_success()
                self.logger.plot_success(
                    prefix="val_",
                    total_num_steps=total_num_steps,
                    rewards=val_rewards,
                    success=val_success,
                    mrids=val_mrids,
                    track_best=True,
                )
コード例 #6
0
ファイル: trainer.py プロジェクト: nirbhayjm/temp-1
    def forward_step(self, step, omega_option, obs_base, ib_rnn_hxs,
                     options_rhx):
        # Sample options if applicable
        if self.args.hier_mode == 'transfer':
            with torch.no_grad():
                if step % self.args.num_option_steps == 0:
                    omega_option = None
                    previous_options_rhx = options_rhx
                    option_value, omega_option, option_log_probs, options_rhx = \
                        self.options_policy.act(
                            inputs=obs_base,
                            rnn_hxs=options_rhx,
                            masks=self.rollouts.masks[step])
                    if self.args.option_space == 'discrete':
                        omega_option = omega_option.squeeze(-1)
                        omega_option = torch.eye(self.args.omega_option_dims)\
                            .to(self.device)[omega_option]
                    self.rollouts.insert_option_t(
                        step=step,
                        omega_option_t=omega_option,
                        option_log_probs=option_log_probs,
                        option_value=option_value,
                        options_rhx=previous_options_rhx)
                obs_base = repalce_mission_with_omega(obs_base, omega_option)

        # Sample actions
        with torch.no_grad():
            value, action, action_log_prob, recurrent_hidden_states = \
                self.actor_critic.act(
                    inputs=obs_base,
                    rnn_hxs=self.rollouts.recurrent_hidden_states[step],
                    masks=self.rollouts.masks[step])

        # Take actions, observe reward and next obs
        # cpu_actions = action.view(
        #     (self.args.num_processes, self.args.num_agents)).cpu().numpy()
        cpu_actions = action.view(-1).cpu().numpy()

        # obs, reward, _, info = self.train_envs.step(cpu_actions + 1)
        obs, reward, _, info = self.train_envs.step(cpu_actions)

        obs = dict_stack_helper(obs)
        obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \
            for key, obs_i in obs.items()})

        if self.args.hier_mode == 'transfer' or self.args.model == 'cond':
            obs_base = obs
        else:
            if self.args.use_infobot:
                if self.args.hier_mode == 'infobot-supervised':
                    z_latent, z_log_prob, z_dist, ib_rnn_hxs = \
                        self.actor_critic.encoder_forward(
                            obs=obs,
                            rnn_hxs=ib_rnn_hxs,
                            masks=self.rollouts.masks[step],
                            do_z_sampling=True)

                    obs_base = replace_goal_vector_with_z(obs, z_latent)
                else:
                    # Sample next z_t
                    obs_omega = repalce_mission_with_omega(obs, omega_option)
                    z_latent, z_log_prob, z_dist, ib_rnn_hxs = \
                        self.actor_critic.encoder_forward(
                            obs=obs_omega,
                            rnn_hxs=ib_rnn_hxs,
                            masks=self.rollouts.masks[step],
                            do_z_sampling=self.do_z_sampling)

                    obs_base = repalce_omega_with_z(obs, z_latent)
                self.rollouts.insert_z_latent(z_latent=z_latent,
                                              z_logprobs=z_log_prob,
                                              z_dist=z_dist,
                                              ib_enc_hidden_states=ib_rnn_hxs)
            else:
                obs_base = repalce_mission_with_omega(obs, omega_option)

        done = np.stack([item['done'] for item in info], 0)
        if 'is_heuristic_ds' in info[0].keys():
            is_heuristic_ds = np.stack(
                [item['is_heuristic_ds'] for item in info], 0)
            self.heuristic_ds[step + 1] = is_heuristic_ds

        if not self.continuous_state_space:
            curr_pos = np.stack([item['agent_pos'] for item in info], 0)
            curr_dir = np.stack([item['agent_dir'] for item in info], 0)
            visit_count = np.stack([item['visit_count'] for item in info], 0)
            self.agent_pos[step + 1] = curr_pos

            # if 'current_room' in info[0]:
            #     self.current_room = np.stack(
            #         [item['current_room'] for item in info], 0)
            self.visit_count.append(visit_count)
            pos_velocity = None
        else:
            curr_pos = None
            curr_dir = None
            if self.args.env_name == 'mountain-car':
                pos_velocity = obs['pos-velocity']
            else:
                pos_velocity = np.zeros((self.num_processes_eff, 2))

        # [obs, reward] = utils.flatten_batch_dims(obs,reward)
        # print(step, done)
        # Extract the done flag from the info
        # done = np.concatenate([info_['done'] for info_ in info],0)

        # if step == self.args.num_steps - 1:
        #     s_extract = lambda key_: np.array(
        #         [item[key_] for item in info])
        #     success_train = s_extract('success').astype('float')
        #     goal_index = s_extract('goal_index')
        #     success_0 = success_train[goal_index == 0]
        #     success_1 = success_train[goal_index == 1]
        #     # spl_train = s_extract('spl_values')
        #     # Shape Assertions

        reward = torch.from_numpy(reward[:, np.newaxis]).float()
        # episode_rewards += reward
        cpu_reward = reward
        reward = reward.to(self.device)
        # reward = torch.from_numpy(reward).float()

        not_done = np.logical_not(done)
        self.total_time_steps += not_done.sum()

        masks = torch.from_numpy(not_done.astype('float32')).unsqueeze(1)
        masks = masks.to(self.device)

        for key in obs.keys():
            if obs[key].dim() == 5:
                obs[key] *= masks.type_as(obs[key])\
                    .unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            elif obs[key].dim() == 4:
                obs[key] *= masks.type_as(obs[key]).unsqueeze(-1).unsqueeze(-1)
            elif obs[key].dim() == 1:
                obs[key] *= masks.type_as(obs[key]).squeeze(1)
            else:
                obs[key] *= masks.type_as(obs[key])

        self.rollouts.insert(obs, recurrent_hidden_states, action,
                             action_log_prob, value, reward, masks)

        return obs_base, omega_option, ib_rnn_hxs, options_rhx, cpu_reward, \
            curr_pos, curr_dir, pos_velocity, not_done
コード例 #7
0
ファイル: trainer.py プロジェクト: nirbhayjm/temp-1
    def on_episode_start(self):
        self.actor_critic.train()

        # obs = train_envs.reset()
        reset_output = self.train_envs.reset()
        obs = reset_output[:, 0]
        info = reset_output[:, 1]
        obs = dict_stack_helper(obs)
        info = dict_stack_helper(info)

        if not self.continuous_state_space:
            self.visit_count = [np.ones(self.num_processes_eff)]
            self.agent_pos = np.zeros(
                [self.args.num_steps + 1, self.num_processes_eff, 2],
                dtype='int')
            self.agent_pos[0] = info['agent_pos']
            info['pos_velocity'] = None
        else:
            self.agent_pos = None
            if self.args.env_name == 'mountain-car':
                info['pos_velocity'] = obs['pos-velocity']
            else:
                info['pos_velocity'] = np.zeros((self.num_processes_eff, 2))

        self.heuristic_ds = np.zeros(
            [self.args.num_steps + 1, self.num_processes_eff], dtype='int')

        # [obs] = flatten_batch_dims(obs)
        obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \
            for key, obs_i in obs.items()})

        if self.args.model == 'cond':
            omega_option = None
            q_dist_ref = None
            options_rhx = None
            ib_rnn_hx = None
            self.rollouts.obs[0].copy_(obs)
            return omega_option, obs, q_dist_ref, ib_rnn_hx, \
                options_rhx, info

        if self.args.use_infobot:
            ib_rnn_hx = self.rollouts.recurrent_hidden_states.new_zeros(
                self.num_processes_eff,
                self.actor_critic.encoder_recurrent_hidden_state_size)

            if self.args.hier_mode == 'infobot-supervised':
                omega_option = None
                q_dist_ref = None
                options_rhx = None
                self.rollouts.obs[0].copy_(obs)

                z_latent, z_log_prob, z_dist, ib_rnn_hx = \
                self.actor_critic.encoder_forward(
                    obs=obs,
                    rnn_hxs=ib_rnn_hx,
                    masks=self.rollouts.masks[0],
                    do_z_sampling=True,
                )
                self.rollouts.insert_z_latent(z_latent, z_log_prob, z_dist,
                                              ib_rnn_hx)

                obs = replace_goal_vector_with_z(obs, z_latent)

                return omega_option, obs, q_dist_ref, ib_rnn_hx, \
                    options_rhx, info
        else:
            ib_rnn_hx = None

        if self.args.option_space == 'continuous':
            option_log_probs = None
            if self.args.hier_mode == 'default':
                omega_option, q_dist, _, ldj = self.options_decoder(
                    obs, do_sampling=self.do_sampling)

            elif self.args.hier_mode == 'vic':
                ldj = 0.0
                _shape = (self.num_processes_eff, self.args.omega_option_dims)
                _loc = torch.zeros(*_shape).to(self.device)
                _scale = torch.ones(*_shape).to(self.device)
                # if self.omega_dim_current < self.args.omega_option_dims:
                #     _scale[:, self.omega_dim_current:].fill_(1e-3)
                q_dist = ds.normal.Normal(loc=_loc, scale=_scale)
                if self.do_sampling:
                    omega_option = q_dist.rsample()
                else:
                    omega_option = q_dist.mean

                if self.args.ic_mode == 'diyan':
                    _shape_t = (self.args.num_steps + 1, *_shape)
                    _loc_t = torch.zeros(*_shape_t).to(self.device)
                    _scale_t = torch.ones(*_shape_t).to(self.device)
                    # if self.omega_dim_current < self.args.omega_option_dims:
                    #     _scale_t[:, :, self.omega_dim_current:].fill_(1e-3)
                    q_dist_ref = ds.normal.Normal(loc=_loc_t, scale=_scale_t)
                else:
                    q_dist_ref = q_dist

                if self.args.use_infobot:
                    obs_omega = repalce_mission_with_omega(obs, omega_option)
                    z_latent, z_log_prob, z_dist, ib_rnn_hx = \
                        self.actor_critic.encoder_forward(
                            obs=obs_omega,
                            rnn_hxs=ib_rnn_hx,
                            masks=self.rollouts.masks[0],
                            do_z_sampling=self.do_z_sampling)

            elif self.args.hier_mode == 'transfer':
                # omega_option, q_dist, _, ldj = self.options_policy(
                #     obs, do_sampling=self.do_sampling)
                omega_option = None
                q_dist_ref = None
            else:
                raise ValueError

        else:
            ldj = 0.0
            if self.args.hier_mode == 'default':
                with torch.no_grad():
                    option_discrete, q_dist, option_log_probs = self.options_decoder(
                        obs, do_sampling=self.do_sampling)

                    if self.args.use_infobot:
                        raise NotImplementedError

            elif self.args.hier_mode == 'vic':
                with torch.no_grad():
                    _shape = (self.num_processes_eff,
                              self.args.omega_option_dims)
                    uniform_probs = torch.ones(*_shape).to(self.device)
                    if self.omega_dim_current < self.args.omega_option_dims:
                        uniform_probs[:, self.omega_dim_current:].fill_(0)
                    uniform_probs = uniform_probs / uniform_probs.sum(
                        -1, keepdim=True)
                    q_dist = distributions.FixedCategorical(
                        probs=uniform_probs)
                    option_discrete = q_dist.sample()
                    # option_log_probs = q_dist.log_probs(option_discrete)

                    if self.args.ic_mode == 'diyan':
                        _shape_t = (self.args.num_steps + 1, *_shape)
                        uniform_probs = torch.ones(*_shape_t).to(self.device)
                        if self.omega_dim_current < self.args.omega_option_dims:
                            uniform_probs[:, :,
                                          self.omega_dim_current:].fill_(0)
                        uniform_probs = uniform_probs / uniform_probs.sum(
                            -1, keepdim=True)
                        q_dist_ref = distributions.FixedCategorical(
                            probs=uniform_probs)
                    else:
                        q_dist_ref = q_dist

                    if self.args.use_infobot:
                        omega_one_hot = torch.eye(
                            self.args.omega_option_dims)[option_discrete]
                        omega_one_hot = omega_one_hot.float().to(self.device)
                        obs_omega = repalce_mission_with_omega(
                            obs, omega_one_hot)
                        z_latent, z_log_prob, z_dist, ib_rnn_hx = \
                            self.actor_critic.encoder_forward(
                                obs=obs_omega,
                                rnn_hxs=ib_rnn_hx,
                                masks=self.rollouts.masks[0],
                                do_z_sampling=self.do_z_sampling)

            elif self.args.hier_mode in ['transfer', 'bonus']:
                omega_option = None
                q_dist_ref = None

            else:
                raise ValueError

            if self.args.hier_mode != 'transfer':
                option_np = option_discrete.squeeze(-1).cpu().numpy()
                option_one_hot = np.eye(self.args.omega_option_dims)[option_np]
                omega_option = torch.from_numpy(option_one_hot).float().to(
                    self.device)

        if self.args.hier_mode == 'transfer':
            obs_base = obs
            if self.args.use_infobot:
                raise NotImplementedError
            else:
                pass
        else:
            if self.args.use_infobot:
                obs_base = repalce_omega_with_z(obs, z_latent)
                self.rollouts.insert_option(omega_option)
                self.rollouts.insert_z_latent(z_latent, z_log_prob, z_dist,
                                              ib_rnn_hx)
            else:
                obs_base = repalce_mission_with_omega(obs, omega_option)
                self.rollouts.insert_option(omega_option)

        if self.args.hier_mode == 'transfer':
            options_rhx = torch.zeros(
                self.num_processes_eff,
                self.options_policy.recurrent_hidden_state_size).to(
                    self.device)
        else:
            options_rhx = None

        # self.omega_option = omega_option
        # self.obs_base = obs_base
        self.rollouts.obs[0].copy_(obs)

        return omega_option, obs_base, q_dist_ref, ib_rnn_hx, \
            options_rhx, info
コード例 #8
0
def eval_success_simple(
    num_processes,
    num_steps,
    val_envs,
    actor_critic,
    device,
    num_episodes,
):
    ARGMAX_POLICY = True
    episode_count = 0
    return_list = []
    all_max_room = []
    val_envs.modify_attr('render_rgb', [False] * num_processes)
    val_envs.reset_config_rng()

    while episode_count < num_episodes:
        reward_list = []
        reset_output = val_envs.reset()

        obs = reset_output[:, 0]
        info = reset_output[:, 1]
        obs = dict_stack_helper(obs)
        info = dict_stack_helper(info)
        obs = DictObs({key:torch.from_numpy(obs_i).to(device) \
        for key, obs_i in obs.items()})

        recurrent_hidden_states = torch.zeros(
            num_processes, actor_critic.recurrent_hidden_state_size).to(device)
        masks = torch.ones(num_processes, 1).to(device)

        for step in range(num_steps):
            _, action, _, recurrent_hidden_states = \
                actor_critic.act(
                    inputs=obs,
                    rnn_hxs=recurrent_hidden_states,
                    masks=masks,
                    deterministic=bool(ARGMAX_POLICY))

            cpu_actions = action.view(-1).cpu().numpy()

            obs, reward, _, info = val_envs.step(cpu_actions)
            reward_list.append(reward)

            obs = dict_stack_helper(obs)
            obs = DictObs({key:torch.from_numpy(obs_i).to(device) \
                for key, obs_i in obs.items()})

            done = np.stack([item['done'] for item in info], 0)

            if 'max_room_id' in info[0]:
                max_room = np.stack([item['max_room_id'] for item in info], 0)
            else:
                max_room = np.ones((num_processes)) * -1

        all_max_room.append(max_room)

        episodic_return = np.stack(reward_list, 0).sum(0)
        return_list.append(episodic_return)
        episode_count += num_processes

    all_return = np.concatenate(return_list, 0)
    success = (all_return > 0).astype('float')
    all_max_room = np.concatenate(all_max_room, 0)

    return success, all_max_room, all_return
コード例 #9
0
def eval_ib_kl(args,
               vis_env,
               actor_critic,
               device,
               omega_dim_current,
               num_samples=10):
    assert args.use_infobot != 0
    action_dims = vis_env.action_space.n
    if hasattr(vis_env.actions, 'forward'):
        action_space_type = 'pov'
    elif hasattr(vis_env.actions, 'up'):
        action_space_type = 'cardinal'

    vis_obs, vis_info = vis_env.reset()
    assert 'rgb_grid' in vis_info
    env_rgb_img = vis_info['rgb_grid'].transpose([2, 0, 1])
    env_rgb_img = np.flip(env_rgb_img, 1)

    all_obs = vis_env.enumerate_states()

    _rhx = torch.zeros(num_samples,
                       actor_critic.recurrent_hidden_state_size).to(device)
    _masks = torch.ones(1, num_samples, 1).to(device)

    def repeat_dict_obs(dict_obs, batch_size):
        out = {}
        for key, value in dict_obs.items():
            out[key] = np.broadcast_to(value[np.newaxis, :],
                                       (batch_size, *value.shape))
        return out

    grid_shape = (vis_env.width, vis_env.height)
    # kl_zz_grid = torch.zeros(*grid_shape).to(device)
    kl_zz_opt_grid = [torch.zeros(*grid_shape).to(device) \
        for _ in range(omega_dim_current)]
    kl_pi_def_grid = torch.zeros(*grid_shape).to(device)
    kl_pi_opt_grid = [torch.zeros(*grid_shape).to(device) \
        for _ in range(omega_dim_current)]
    pi_def_grid = torch.zeros((action_dims, *grid_shape)).to(device)
    pi_opt_grid = [torch.zeros((action_dims, *grid_shape)).to(device) \
        for _ in range(omega_dim_current)]

    if args.option_space == 'continuous':
        _shape = (num_samples, args.omega_option_dims)
        _loc = torch.zeros(*_shape).to(device)
        _scale = torch.ones(*_shape).to(device)
        omega_prior = ds.normal.Normal(loc=_loc, scale=_scale)

        _z_shape = (num_samples, args.z_latent_dims)
        _z_loc = torch.zeros(*_z_shape).to(device)
        _z_scale = torch.ones(*_z_shape).to(device)
        z_prior = ds.normal.Normal(loc=_z_loc, scale=_z_scale)

        for key, obs in all_obs.items():
            obs = repeat_dict_obs(obs, num_samples)
            omega_option = omega_prior.rsample()

            obs = DictObs({key:torch.from_numpy(obs_i).to(device) \
                for key, obs_i in obs.items()})

            if 'mission' in obs.keys():
                obs.pop('mission')
            obs.update({'omega': omega_option})

            z_latent, z_log_prob, z_dist, _ = \
                actor_critic.encoder_forward(
                    obs=obs,
                    rnn_hxs=_rhx,
                    masks=_masks,
                    do_z_sampling=True)

            kld_zz = ds.kl.kl_divergence(z_dist, z_prior)
            # kld_zz = kld_zz.view(
            #     num_steps + 1, num_processes, z_latent_dims)
            kld_zz = torch.sum(kld_zz, 1).mean()

            kl_zz_grid[key.x, key.y] = kld_zz

    else:
        # _shape = (omega_dim_current, args.omega_option_dims)
        # uniform_probs = torch.ones(*_shape).to(device)

        _z_shape = (omega_dim_current * num_samples, args.z_latent_dims)
        _z_loc = torch.zeros(*_z_shape).to(device)
        _z_scale = torch.ones(*_z_shape).to(device)
        z_prior = ds.normal.Normal(loc=_z_loc, scale=_z_scale)

        # if omega_dim_current < args.omega_option_dims:
        #     uniform_probs[:, omega_dim_current:].fill_(0)
        # uniform_probs = uniform_probs / uniform_probs.sum(-1, keepdim=True)
        # omega_prior = FixedCategorical(probs=uniform_probs)
        # # option_discrete = omega_prior.sample()

        omega_option = torch.eye(omega_dim_current).to(device)
        if omega_dim_current < args.omega_option_dims:
            _diff = args.omega_option_dims - omega_dim_current
            _pad = omega_option.new_zeros(omega_dim_current, _diff)
            omega_option = torch.cat([omega_option, _pad], 1)
        omega_option = omega_option.unsqueeze(0).repeat(num_samples, 1, 1)
        omega_option = omega_option.view(-1, *omega_option.shape[2:])

        for key, obs in all_obs.items():
            obs = repeat_dict_obs(obs, omega_option.shape[0])
            # omega_option = omega_prior.rsample()

            obs = DictObs({key:torch.from_numpy(obs_i).to(device) \
                for key, obs_i in obs.items()})

            if 'mission' in obs.keys():
                obs.pop('mission')
            obs.update({'omega': omega_option})

            z_latent, z_log_prob, z_dist, _ = \
                actor_critic.encoder_forward(
                    obs=obs,
                    rnn_hxs=_rhx,
                    masks=_masks,
                    do_z_sampling=True)

            kld_zz = ds.kl.kl_divergence(z_dist, z_prior)
            kld_zz = kld_zz.view(num_samples, omega_dim_current,
                                 *kld_zz.shape[1:])
            kld_zz = kld_zz.sum(-1).mean(0)
            for opt_idx in range(omega_dim_current):
                kl_zz_opt_grid[opt_idx][key.x, key.y] = kld_zz[opt_idx]

            obs.pop('omega')
            obs.update({'z_latent': z_latent})

            _, action_dist, _, _ = \
                actor_critic.get_action_dist(
                    inputs=obs, rnn_hxs=_rhx, masks=_masks)

            action_probs = action_dist.probs
            action_probs = action_probs.view(num_samples, omega_dim_current,
                                             *action_probs.shape[1:]).mean(0)

            pi_opt, pi_kl = {}, {}
            for opt_idx in range(omega_dim_current):
                pi_opt[opt_idx] = FixedCategorical(probs=action_probs[opt_idx])
                pi_opt_grid[opt_idx][:, key.x, key.y] = pi_opt[opt_idx].probs
            pi_def = FixedCategorical(probs=action_probs.mean(0))
            pi_def_grid[:, key.x, key.y] = pi_def.probs

            for opt_idx in range(omega_dim_current):
                pi_kl[opt_idx] = ds.kl.kl_divergence(pi_opt[opt_idx], pi_def)
                kl_pi_opt_grid[opt_idx][key.x, key.y] = pi_kl[opt_idx]

            pi_kl_avg = torch.stack(tuple(pi_kl.values()), 0).mean(0)
            kl_pi_def_grid[key.x, key.y] = pi_kl_avg

    pi_opt_grid = torch.stack(pi_opt_grid, 0)
    kl_pi_opt_grid = torch.stack(kl_pi_opt_grid, 0)
    kl_zz_opt_grid = torch.stack(kl_zz_opt_grid, 0)
    kl_zz_grid = kl_zz_opt_grid.mean(0)

    return_dict = {
        'env_rgb_img': env_rgb_img,
        'pi_opt_grid': pi_opt_grid.cpu().numpy().transpose([0, 1, 3, 2]),
        'pi_def_grid': pi_def_grid.cpu().numpy().transpose([0, 2, 1]),
        'kl_zz_grid': kl_zz_grid.cpu().numpy().T,
        'kl_zz_opt_grid': kl_zz_opt_grid.cpu().numpy().transpose([0, 2, 1]),
        'kl_pi_def_grid': kl_pi_def_grid.cpu().numpy().T,
        'kl_pi_opt_grid': kl_pi_opt_grid.cpu().numpy().transpose([0, 2, 1]),
    }
    return return_dict
コード例 #10
0
def eval_success(
    args,
    val_envs,
    vis_env,
    actor_critic,
    b_args,
    bonus_type,
    bonus_z_encoder,
    bonus_beta,
    bonus_normalization,
    device,
    num_episodes,
):
    ARGMAX_POLICY = True
    episode_count = 0
    return_list = []
    all_max_room = []
    val_envs.modify_attr('render_rgb', [True] * args.num_processes)
    val_envs.reset_config_rng()
    # vis_env.reset_config_rng()
    grid_shape = (vis_env.width, vis_env.height)

    kl_grid = torch.zeros(*grid_shape).to(device)
    bonus_grid = torch.zeros(*grid_shape).to(device)

    while episode_count < num_episodes:
        reward_list = []
        reset_output = val_envs.reset()

        obs = reset_output[:, 0]
        info = reset_output[:, 1]
        obs = dict_stack_helper(obs)
        info = dict_stack_helper(info)
        obs = DictObs({key:torch.from_numpy(obs_i).to(device) \
        for key, obs_i in obs.items()})

        rgb_grid = info['rgb_grid']

        recurrent_hidden_states = torch.zeros(
            args.num_processes,
            actor_critic.recurrent_hidden_state_size).to(device)
        masks = torch.ones(args.num_processes, 1).to(device)
        agent_pos = [val_envs.get_attr('agent_pos')]
        all_masks = [np.array([True] * args.num_processes)]
        all_obs = [obs]
        all_vc = [np.ones(args.num_processes)]

        for step in range(args.num_steps):
            _, action, _, recurrent_hidden_states = \
                actor_critic.act(
                    inputs=obs,
                    rnn_hxs=recurrent_hidden_states,
                    masks=masks,
                    deterministic=bool(ARGMAX_POLICY))

            cpu_actions = action.view(-1).cpu().numpy()

            obs, reward, _, info = val_envs.step(cpu_actions)
            reward_list.append(reward)

            obs = dict_stack_helper(obs)
            obs = DictObs({key:torch.from_numpy(obs_i).to(device) \
                for key, obs_i in obs.items()})
            all_obs.append(obs)

            done = np.stack([item['done'] for item in info], 0)
            curr_pos = np.stack([item['agent_pos'] for item in info], 0)
            # curr_dir = np.stack([item['agent_dir'] for item in info], 0)
            visit_count = np.stack([item['visit_count'] for item in info], 0)
            all_vc.append(visit_count)
            agent_pos.append(curr_pos)
            all_masks.append(done == False)

            if 'max_room_id' in info[0]:
                max_room = np.stack([item['max_room_id'] for item in info], 0)
            else:
                max_room = np.ones((args.num_processes)) * -1

        agent_pos = np.stack(agent_pos, 0)
        all_masks = np.stack(all_masks, 0)
        all_max_room.append(max_room)

        stacked_obs = {}
        for key in all_obs[0].keys():
            stacked_obs[key] = torch.stack([_obs[key] for _obs in all_obs], 0)
        stacked_obs = DictObs(stacked_obs)
        stacked_masks = np.stack(all_masks, 0).astype('float32')
        stacked_masks = torch.from_numpy(stacked_masks).to(device)

        stacked_visit_count = np.stack(all_vc, 0)

        if bonus_type != 'count':
            bonus_kld = bonus_kl_forward(
                bonus_type=bonus_type,
                obs=stacked_obs,
                b_args=b_args,
                bonus_z_encoder=bonus_z_encoder,
                masks=stacked_masks,
                bonus_normalization=bonus_normalization,
            )
        else:
            bonus_kld = stacked_masks.clone() * 0

        episodic_return = np.stack(reward_list, 0).sum(0)
        return_list.append(episodic_return)
        episode_count += args.num_processes

    VIS_COUNT = 1
    VIS_IDX = 0
    agent_pos = agent_pos[:, VIS_IDX]
    episode_length = all_masks[:, VIS_IDX].sum()
    rgb_env_image = rgb_grid[VIS_IDX]
    bonus_kld = bonus_kld[:, VIS_IDX]
    visit_count = stacked_visit_count[:, VIS_IDX]
    rgb_env_image = np.flip(rgb_env_image.transpose([2, 0, 1]), 1)

    vis_info = make_bonus_grid(
        bonus_beta=bonus_beta,
        agent_pos=agent_pos,
        kl_values=bonus_kld.squeeze(-1).cpu().numpy(),
        visit_count=visit_count,
        episode_length=episode_length,
        grid_shape=grid_shape,
    )
    vis_info['rgb_env_image'] = rgb_env_image

    all_return = np.concatenate(return_list, 0)
    success = (all_return > 0).astype('float')
    all_max_room = np.concatenate(all_max_room, 0)

    return success, all_max_room, vis_info