コード例 #1
0
ファイル: dagger_trainer.py プロジェクト: roy860328/VLN-CE
    def train(self) -> None:
        r"""Main method for training DAgger.

        Returns:
            None
        """
        os.makedirs(self.lmdb_features_dir, exist_ok=True)
        os.makedirs(self.config.CHECKPOINT_FOLDER, exist_ok=True)

        if self.config.DAGGER.PRELOAD_LMDB_FEATURES:
            try:
                lmdb.open(self.lmdb_features_dir, readonly=True)
            except lmdb.Error as err:
                logger.error(
                    "Cannot open database for teacher forcing preload.")
                raise err
        else:
            with lmdb.open(self.lmdb_features_dir,
                           map_size=int(self.config.DAGGER.LMDB_MAP_SIZE)
                           ) as lmdb_env, lmdb_env.begin(write=True) as txn:
                txn.drop(lmdb_env.open_db())

        split = self.config.TASK_CONFIG.DATASET.SPLIT
        self.config.defrost()
        self.config.TASK_CONFIG.TASK.NDTW.SPLIT = split
        self.config.TASK_CONFIG.TASK.SDTW.SPLIT = split

        # if doing teacher forcing, don't switch the scene until it is complete
        if self.config.DAGGER.P == 1.0:
            self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = (
                -1)
        self.config.freeze()

        if self.config.DAGGER.PRELOAD_LMDB_FEATURES:
            # when preloadeding features, its quicker to just load one env as we just
            # need the observation space from it.
            single_proc_config = self.config.clone()
            single_proc_config.defrost()
            single_proc_config.NUM_PROCESSES = 1
            single_proc_config.freeze()
            self.envs = construct_envs(single_proc_config,
                                       get_env_class(self.config.ENV_NAME))
        else:
            self.envs = construct_envs(self.config,
                                       get_env_class(self.config.ENV_NAME))

        self._setup_actor_critic_agent(
            self.config.MODEL,
            self.config.DAGGER.LOAD_FROM_CKPT,
            self.config.DAGGER.CKPT_TO_LOAD,
        )

        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.actor_critic.parameters())))
        logger.info("agent number of trainable parameters: {}".format(
            sum(p.numel() for p in self.actor_critic.parameters()
                if p.requires_grad)))

        if self.config.DAGGER.PRELOAD_LMDB_FEATURES:
            self.envs.close()
            del self.envs
            self.envs = None

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs,
                               purge_step=0) as writer:
            for dagger_it in range(self.config.DAGGER.ITERATIONS):
                step_id = 0
                if not self.config.DAGGER.PRELOAD_LMDB_FEATURES:
                    self._update_dataset(dagger_it + (
                        1 if self.config.DAGGER.LOAD_FROM_CKPT else 0))

                if torch.cuda.is_available():
                    with torch.cuda.device(self.device):
                        torch.cuda.empty_cache()
                gc.collect()

                dataset = IWTrajectoryDataset(
                    self.lmdb_features_dir,
                    self.config.DAGGER.USE_IW,
                    inflection_weight_coef=self.config.MODEL.
                    inflection_weight_coef,
                    lmdb_map_size=self.config.DAGGER.LMDB_MAP_SIZE,
                    batch_size=self.config.DAGGER.BATCH_SIZE,
                )

                AuxLosses.activate()
                for epoch in tqdm.trange(self.config.DAGGER.EPOCHS):
                    diter = torch.utils.data.DataLoader(
                        dataset,
                        batch_size=self.config.DAGGER.BATCH_SIZE,
                        shuffle=False,
                        collate_fn=collate_fn,
                        pin_memory=False,
                        drop_last=True,  # drop last batch if smaller
                        num_workers=0,
                    )
                    for batch in tqdm.tqdm(diter,
                                           total=dataset.length //
                                           dataset.batch_size,
                                           leave=False):
                        (
                            observations_batch,
                            prev_actions_batch,
                            not_done_masks,
                            corrected_actions_batch,
                            weights_batch,
                        ) = batch
                        observations_batch = {
                            k: v.to(device=self.device, non_blocking=True)
                            for k, v in observations_batch.items()
                        }
                        try:
                            loss, action_loss, aux_loss = self._update_agent(
                                observations_batch,
                                prev_actions_batch.to(device=self.device,
                                                      non_blocking=True),
                                not_done_masks.to(device=self.device,
                                                  non_blocking=True),
                                corrected_actions_batch.to(device=self.device,
                                                           non_blocking=True),
                                weights_batch.to(device=self.device,
                                                 non_blocking=True),
                            )
                        except:
                            logger.info(
                                "ERROR: failed to update agent. Updating agent with batch size of 1."
                            )
                            loss, action_loss, aux_loss = 0, 0, 0
                            prev_actions_batch = prev_actions_batch.cpu()
                            not_done_masks = not_done_masks.cpu()
                            corrected_actions_batch = corrected_actions_batch.cpu(
                            )
                            weights_batch = weights_batch.cpu()
                            observations_batch = {
                                k: v.cpu()
                                for k, v in observations_batch.items()
                            }
                            for i in range(not_done_masks.size(0)):
                                output = self._update_agent(
                                    {
                                        k: v[i].to(device=self.device,
                                                   non_blocking=True)
                                        for k, v in observations_batch.items()
                                    },
                                    prev_actions_batch[i].to(
                                        device=self.device, non_blocking=True),
                                    not_done_masks[i].to(device=self.device,
                                                         non_blocking=True),
                                    corrected_actions_batch[i].to(
                                        device=self.device, non_blocking=True),
                                    weights_batch[i].to(device=self.device,
                                                        non_blocking=True),
                                )
                                loss += output[0]
                                action_loss += output[1]
                                aux_loss += output[2]

                        logger.info(f"train_loss: {loss}")
                        logger.info(f"train_action_loss: {action_loss}")
                        logger.info(f"train_aux_loss: {aux_loss}")
                        logger.info(f"Batches processed: {step_id}.")
                        logger.info(
                            f"On DAgger iter {dagger_it}, Epoch {epoch}.")
                        writer.add_scalar(f"train_loss_iter_{dagger_it}", loss,
                                          step_id)
                        writer.add_scalar(
                            f"train_action_loss_iter_{dagger_it}", action_loss,
                            step_id)
                        writer.add_scalar(f"train_aux_loss_iter_{dagger_it}",
                                          aux_loss, step_id)
                        step_id += 1

                    self.save_checkpoint(
                        f"ckpt.{dagger_it * self.config.DAGGER.EPOCHS + epoch}.pth"
                    )
                AuxLosses.deactivate()
コード例 #2
0
    config.MODEL.TORCH_GPU_ID = config.TORCH_GPU_ID
    config.freeze()

    action_space = spaces.Discrete(4)

    policy = CMAPolicy(observation_space, action_space, config.MODEL).to(device)

    dummy_instruction = torch.randint(1, 4, size=(4 * 2, 8), device=device)
    dummy_instruction[:, 5:] = 0
    dummy_instruction[0, 2:] = 0

    obs = dict(
        rgb=torch.randn(4 * 2, 224, 224, 3, device=device),
        depth=torch.randn(4 * 2, 256, 256, 1, device=device),
        instruction=dummy_instruction,
        progress=torch.randn(4 * 2, 1, device=device),
    )

    hidden_states = torch.randn(
        policy.net.state_encoder.num_recurrent_layers,
        2,
        policy.net._hidden_size,
        device=device,
    )
    prev_actions = torch.randint(0, 3, size=(4 * 2, 1), device=device)
    masks = torch.ones(4 * 2, 1, device=device)

    AuxLosses.activate()

    policy.evaluate_actions(obs, hidden_states, prev_actions, masks, prev_actions)