Пример #1
0
    def _get_gt_map(self, full_map_size):
        self.scene_name = self.habitat_env.sim.config.SCENE
        logger.error('Computing map for %s', self.scene_name)

        # Get map in habitat simulator coordinates
        self.map_obj = HabitatMaps(self.habitat_env)
        if self.map_obj.size[0] < 1 or self.map_obj.size[1] < 1:
            logger.error("Invalid map: {}/{}".format(self.scene_name,
                                                     self.episode_no))
            return None

        agent_y = self._env.sim.get_agent_state().position.tolist()[1] * 100.
        sim_map = self.map_obj.get_map(agent_y, -50., 50.0)

        sim_map[sim_map > 0] = 1.

        # Transform the map to align with the agent
        min_x, min_y = self.map_obj.origin / 100.0
        x, y, o = self.get_sim_location()
        x, y = -x - min_x, -y - min_y
        range_x, range_y = self.map_obj.max / 100. - self.map_obj.origin / 100.

        map_size = sim_map.shape
        scale = 2.
        grid_size = int(scale * max(map_size))
        grid_map = np.zeros((grid_size, grid_size))

        grid_map[(grid_size - map_size[0]) //
                 2:(grid_size - map_size[0]) // 2 + map_size[0],
                 (grid_size - map_size[1]) //
                 2:(grid_size - map_size[1]) // 2 + map_size[1]] = sim_map

        if map_size[0] > map_size[1]:
            st = torch.tensor([[
                    (x - range_x/2.) * 2. / (range_x * scale) \
                             * map_size[1] * 1. / map_size[0],
                    (y - range_y/2.) * 2. / (range_y * scale),
                    180.0 + np.rad2deg(o)
                ]])

        else:
            st = torch.tensor([[
                    (x - range_x/2.) * 2. / (range_x * scale),
                    (y - range_y/2.) * 2. / (range_y * scale) \
                            * map_size[0] * 1. / map_size[1],
                    180.0 + np.rad2deg(o)
                ]])

        rot_mat, trans_mat = get_grid(st, (1, 1, grid_size, grid_size),
                                      torch.device("cpu"))

        grid_map = torch.from_numpy(grid_map).float()
        grid_map = grid_map.unsqueeze(0).unsqueeze(0)
        translated = F.grid_sample(grid_map, trans_mat)
        rotated = F.grid_sample(translated, rot_mat)

        episode_map = torch.zeros((full_map_size, full_map_size)).float()
        if full_map_size > grid_size:
            episode_map[(full_map_size - grid_size)//2:
                        (full_map_size - grid_size)//2 + grid_size,
                        (full_map_size - grid_size)//2:
                        (full_map_size - grid_size)//2 + grid_size] = \
                                rotated[0,0]
        else:
            episode_map = rotated[0, 0, (grid_size - full_map_size) //
                                  2:(grid_size - full_map_size) // 2 +
                                  full_map_size, (grid_size - full_map_size) //
                                  2:(grid_size - full_map_size) // 2 +
                                  full_map_size]

        episode_map = episode_map.numpy()
        episode_map[episode_map > 0] = 1.

        return episode_map
Пример #2
0
    def train(self) -> None:
        r"""Main method for training DAgger.

        Returns:
            None
        """
        os.makedirs(self.lmdb_features_dir, exist_ok=True)
        os.makedirs(self.config.CHECKPOINT_FOLDER, exist_ok=True)

        if self.config.DAGGER.PRELOAD_LMDB_FEATURES:
            try:
                lmdb.open(self.lmdb_features_dir, readonly=True)
            except lmdb.Error as err:
                logger.error(
                    "Cannot open database for teacher forcing preload.")
                raise err
        else:
            with lmdb.open(self.lmdb_features_dir,
                           map_size=int(self.config.DAGGER.LMDB_MAP_SIZE)
                           ) as lmdb_env, lmdb_env.begin(write=True) as txn:
                txn.drop(lmdb_env.open_db())

        split = self.config.TASK_CONFIG.DATASET.SPLIT
        self.config.defrost()
        self.config.TASK_CONFIG.TASK.NDTW.SPLIT = split
        self.config.TASK_CONFIG.TASK.SDTW.SPLIT = split

        # if doing teacher forcing, don't switch the scene until it is complete
        if self.config.DAGGER.P == 1.0:
            self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = (
                -1)
        self.config.freeze()

        if self.config.DAGGER.PRELOAD_LMDB_FEATURES:
            # when preloadeding features, its quicker to just load one env as we just
            # need the observation space from it.
            single_proc_config = self.config.clone()
            single_proc_config.defrost()
            single_proc_config.NUM_PROCESSES = 1
            single_proc_config.freeze()
            self.envs = construct_envs(single_proc_config,
                                       get_env_class(self.config.ENV_NAME))
        else:
            self.envs = construct_envs(self.config,
                                       get_env_class(self.config.ENV_NAME))

        self._setup_actor_critic_agent(
            self.config.MODEL,
            self.config.DAGGER.LOAD_FROM_CKPT,
            self.config.DAGGER.CKPT_TO_LOAD,
        )

        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.actor_critic.parameters())))
        logger.info("agent number of trainable parameters: {}".format(
            sum(p.numel() for p in self.actor_critic.parameters()
                if p.requires_grad)))

        if self.config.DAGGER.PRELOAD_LMDB_FEATURES:
            self.envs.close()
            del self.envs
            self.envs = None

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs,
                               purge_step=0) as writer:
            for dagger_it in range(self.config.DAGGER.ITERATIONS):
                step_id = 0
                if not self.config.DAGGER.PRELOAD_LMDB_FEATURES:
                    self._update_dataset(dagger_it + (
                        1 if self.config.DAGGER.LOAD_FROM_CKPT else 0))

                if torch.cuda.is_available():
                    with torch.cuda.device(self.device):
                        torch.cuda.empty_cache()
                gc.collect()

                dataset = IWTrajectoryDataset(
                    self.lmdb_features_dir,
                    self.config.DAGGER.USE_IW,
                    inflection_weight_coef=self.config.MODEL.
                    inflection_weight_coef,
                    lmdb_map_size=self.config.DAGGER.LMDB_MAP_SIZE,
                    batch_size=self.config.DAGGER.BATCH_SIZE,
                )

                AuxLosses.activate()
                for epoch in tqdm.trange(self.config.DAGGER.EPOCHS):
                    diter = torch.utils.data.DataLoader(
                        dataset,
                        batch_size=self.config.DAGGER.BATCH_SIZE,
                        shuffle=False,
                        collate_fn=collate_fn,
                        pin_memory=False,
                        drop_last=True,  # drop last batch if smaller
                        num_workers=0,
                    )
                    for batch in tqdm.tqdm(diter,
                                           total=dataset.length //
                                           dataset.batch_size,
                                           leave=False):
                        (
                            observations_batch,
                            prev_actions_batch,
                            not_done_masks,
                            corrected_actions_batch,
                            weights_batch,
                        ) = batch
                        observations_batch = {
                            k: v.to(device=self.device, non_blocking=True)
                            for k, v in observations_batch.items()
                        }
                        try:
                            loss, action_loss, aux_loss = self._update_agent(
                                observations_batch,
                                prev_actions_batch.to(device=self.device,
                                                      non_blocking=True),
                                not_done_masks.to(device=self.device,
                                                  non_blocking=True),
                                corrected_actions_batch.to(device=self.device,
                                                           non_blocking=True),
                                weights_batch.to(device=self.device,
                                                 non_blocking=True),
                            )
                        except:
                            logger.info(
                                "ERROR: failed to update agent. Updating agent with batch size of 1."
                            )
                            loss, action_loss, aux_loss = 0, 0, 0
                            prev_actions_batch = prev_actions_batch.cpu()
                            not_done_masks = not_done_masks.cpu()
                            corrected_actions_batch = corrected_actions_batch.cpu(
                            )
                            weights_batch = weights_batch.cpu()
                            observations_batch = {
                                k: v.cpu()
                                for k, v in observations_batch.items()
                            }
                            for i in range(not_done_masks.size(0)):
                                output = self._update_agent(
                                    {
                                        k: v[i].to(device=self.device,
                                                   non_blocking=True)
                                        for k, v in observations_batch.items()
                                    },
                                    prev_actions_batch[i].to(
                                        device=self.device, non_blocking=True),
                                    not_done_masks[i].to(device=self.device,
                                                         non_blocking=True),
                                    corrected_actions_batch[i].to(
                                        device=self.device, non_blocking=True),
                                    weights_batch[i].to(device=self.device,
                                                        non_blocking=True),
                                )
                                loss += output[0]
                                action_loss += output[1]
                                aux_loss += output[2]

                        logger.info(f"train_loss: {loss}")
                        logger.info(f"train_action_loss: {action_loss}")
                        logger.info(f"train_aux_loss: {aux_loss}")
                        logger.info(f"Batches processed: {step_id}.")
                        logger.info(
                            f"On DAgger iter {dagger_it}, Epoch {epoch}.")
                        writer.add_scalar(f"train_loss_iter_{dagger_it}", loss,
                                          step_id)
                        writer.add_scalar(
                            f"train_action_loss_iter_{dagger_it}", action_loss,
                            step_id)
                        writer.add_scalar(f"train_aux_loss_iter_{dagger_it}",
                                          aux_loss, step_id)
                        step_id += 1

                    self.save_checkpoint(
                        f"ckpt.{dagger_it * self.config.DAGGER.EPOCHS + epoch}.pth"
                    )
                AuxLosses.deactivate()
    def __init__(self, config, mode="train"):
        """
        Args:
            env (habitat.Env): Habitat environment
            config: Config
            mode: 'train'/'val'
        """
        self.config = config.TASK_CONFIG
        self.dataset_path = config.DATASET_PATH.format(split=mode)

        if not self.cache_exists():
            """
            for each scene > load scene in memory > save frames for each
            episode corresponding to that scene
            """
            self.env = habitat.Env(config=self.config)
            self.episodes = self.env._dataset.episodes

            logger.info(
                "Dataset cache not found. Saving rgb, seg, depth scene images")
            logger.info("Number of {} episodes: {}".format(
                mode, len(self.episodes)))

            self.scene_ids = []
            self.scene_episode_dict = {}

            # dict for storing list of episodes for each scene
            for episode in self.episodes:
                if episode.scene_id not in self.scene_ids:
                    self.scene_ids.append(episode.scene_id)
                    self.scene_episode_dict[episode.scene_id] = [episode]
                else:
                    self.scene_episode_dict[episode.scene_id].append(episode)

            self.lmdb_env = lmdb.open(
                self.dataset_path,
                map_size=int(1e11),
                writemap=True,
            )

            self.count = 0

            for scene in tqdm(list(self.scene_episode_dict.keys())):
                self.load_scene(scene)
                for episode in tqdm(self.scene_episode_dict[scene]):
                    try:
                        # TODO: Consider alternative for shortest_paths
                        pos_queue = episode.shortest_paths[0]
                    except AttributeError as e:
                        logger.error(e)

                    random_pos = random.sample(pos_queue, 9)
                    self.save_frames(random_pos)

            logger.info("EQA-CNN-PRETRAIN database ready!")
            self.env.close()

        else:
            logger.info("Dataset cache found.")
            self.lmdb_env = lmdb.open(
                self.dataset_path,
                readonly=True,
                lock=False,
            )

        self.dataset_length = int(self.lmdb_env.begin().stat()["entries"] / 3)
        self.lmdb_env.close()
        self.lmdb_env = None