コード例 #1
0
    def forward(self, observations, rnn_hidden_states, prev_actions, masks):
        r"""
        instruction_embedding: [batch_size x INSTRUCTION_ENCODER.output_size]
        depth_embedding: [batch_size x DEPTH_ENCODER.output_size]
        rgb_embedding: [batch_size x RGB_ENCODER.output_size]
        """
        ### instruction
        # instruction_embedding = self.instruction_encoder(observations)
        instruction_embedding = self._get_bert_embedding(observations)

        depth_embedding = self.depth_encoder(observations)
        rgb_embedding = self.rgb_encoder(observations)
        # print("depth_embedding: ", depth_embedding)
        # print("depth_embedding: ", depth_embedding.size())

        if self.model_config.ablate_instruction:
            instruction_embedding = instruction_embedding * 0
        if self.model_config.ablate_depth:
            depth_embedding = depth_embedding * 0
        if self.model_config.ablate_rgb:
            rgb_embedding = rgb_embedding * 0

        x = torch.cat([instruction_embedding, depth_embedding, rgb_embedding], dim=1)

        if self.model_config.SEQ2SEQ.use_prev_action:
            prev_actions_embedding = self.prev_action_embedding(
                ((prev_actions.float() + 1) * masks).long().view(-1)
            )
            x = torch.cat([x, prev_actions_embedding], dim=1)

        x, rnn_hidden_states = self.state_encoder(x, rnn_hidden_states, masks)

        if self.model_config.PROGRESS_MONITOR.use and AuxLosses.is_active():
            progress_hat = torch.tanh(self.progress_monitor(x))
            progress_loss = F.mse_loss(
                progress_hat.squeeze(1), observations["progress"], reduction="none"
            )
            AuxLosses.register_loss(
                "progress_monitor",
                progress_loss,
                self.model_config.PROGRESS_MONITOR.alpha,
            )

        return x, rnn_hidden_states
コード例 #2
0
    def _update_agent(
        self, observations, prev_actions, not_done_masks, corrected_actions, weights
    ):
        T, N = corrected_actions.size()
        self.optimizer.zero_grad()

        recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            N,
            self.config.MODEL.STATE_ENCODER.hidden_size,
            device=self.device,
        )

        AuxLosses.clear()

        distribution = self.actor_critic.build_distribution(
            observations, recurrent_hidden_states, prev_actions, not_done_masks
        )

        logits = distribution.logits
        logits = logits.view(T, N, -1)

        action_loss = F.cross_entropy(
            logits.permute(0, 2, 1), corrected_actions, reduction="none"
        )
        action_loss = ((weights * action_loss).sum(0) / weights.sum(0)).mean()

        aux_mask = (weights > 0).view(-1)
        aux_loss = AuxLosses.reduce(aux_mask)

        loss = action_loss + aux_loss
        loss.backward()

        self.optimizer.step()

        if isinstance(aux_loss, torch.Tensor):
            return loss.item(), action_loss.item(), aux_loss.item()
        else:
            return loss.item(), action_loss.item(), aux_loss
コード例 #3
0
ファイル: dagger_trainer.py プロジェクト: roy860328/VLN-CE
    def train(self) -> None:
        r"""Main method for training DAgger.

        Returns:
            None
        """
        os.makedirs(self.lmdb_features_dir, exist_ok=True)
        os.makedirs(self.config.CHECKPOINT_FOLDER, exist_ok=True)

        if self.config.DAGGER.PRELOAD_LMDB_FEATURES:
            try:
                lmdb.open(self.lmdb_features_dir, readonly=True)
            except lmdb.Error as err:
                logger.error(
                    "Cannot open database for teacher forcing preload.")
                raise err
        else:
            with lmdb.open(self.lmdb_features_dir,
                           map_size=int(self.config.DAGGER.LMDB_MAP_SIZE)
                           ) as lmdb_env, lmdb_env.begin(write=True) as txn:
                txn.drop(lmdb_env.open_db())

        split = self.config.TASK_CONFIG.DATASET.SPLIT
        self.config.defrost()
        self.config.TASK_CONFIG.TASK.NDTW.SPLIT = split
        self.config.TASK_CONFIG.TASK.SDTW.SPLIT = split

        # if doing teacher forcing, don't switch the scene until it is complete
        if self.config.DAGGER.P == 1.0:
            self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = (
                -1)
        self.config.freeze()

        if self.config.DAGGER.PRELOAD_LMDB_FEATURES:
            # when preloadeding features, its quicker to just load one env as we just
            # need the observation space from it.
            single_proc_config = self.config.clone()
            single_proc_config.defrost()
            single_proc_config.NUM_PROCESSES = 1
            single_proc_config.freeze()
            self.envs = construct_envs(single_proc_config,
                                       get_env_class(self.config.ENV_NAME))
        else:
            self.envs = construct_envs(self.config,
                                       get_env_class(self.config.ENV_NAME))

        self._setup_actor_critic_agent(
            self.config.MODEL,
            self.config.DAGGER.LOAD_FROM_CKPT,
            self.config.DAGGER.CKPT_TO_LOAD,
        )

        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.actor_critic.parameters())))
        logger.info("agent number of trainable parameters: {}".format(
            sum(p.numel() for p in self.actor_critic.parameters()
                if p.requires_grad)))

        if self.config.DAGGER.PRELOAD_LMDB_FEATURES:
            self.envs.close()
            del self.envs
            self.envs = None

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs,
                               purge_step=0) as writer:
            for dagger_it in range(self.config.DAGGER.ITERATIONS):
                step_id = 0
                if not self.config.DAGGER.PRELOAD_LMDB_FEATURES:
                    self._update_dataset(dagger_it + (
                        1 if self.config.DAGGER.LOAD_FROM_CKPT else 0))

                if torch.cuda.is_available():
                    with torch.cuda.device(self.device):
                        torch.cuda.empty_cache()
                gc.collect()

                dataset = IWTrajectoryDataset(
                    self.lmdb_features_dir,
                    self.config.DAGGER.USE_IW,
                    inflection_weight_coef=self.config.MODEL.
                    inflection_weight_coef,
                    lmdb_map_size=self.config.DAGGER.LMDB_MAP_SIZE,
                    batch_size=self.config.DAGGER.BATCH_SIZE,
                )

                AuxLosses.activate()
                for epoch in tqdm.trange(self.config.DAGGER.EPOCHS):
                    diter = torch.utils.data.DataLoader(
                        dataset,
                        batch_size=self.config.DAGGER.BATCH_SIZE,
                        shuffle=False,
                        collate_fn=collate_fn,
                        pin_memory=False,
                        drop_last=True,  # drop last batch if smaller
                        num_workers=0,
                    )
                    for batch in tqdm.tqdm(diter,
                                           total=dataset.length //
                                           dataset.batch_size,
                                           leave=False):
                        (
                            observations_batch,
                            prev_actions_batch,
                            not_done_masks,
                            corrected_actions_batch,
                            weights_batch,
                        ) = batch
                        observations_batch = {
                            k: v.to(device=self.device, non_blocking=True)
                            for k, v in observations_batch.items()
                        }
                        try:
                            loss, action_loss, aux_loss = self._update_agent(
                                observations_batch,
                                prev_actions_batch.to(device=self.device,
                                                      non_blocking=True),
                                not_done_masks.to(device=self.device,
                                                  non_blocking=True),
                                corrected_actions_batch.to(device=self.device,
                                                           non_blocking=True),
                                weights_batch.to(device=self.device,
                                                 non_blocking=True),
                            )
                        except:
                            logger.info(
                                "ERROR: failed to update agent. Updating agent with batch size of 1."
                            )
                            loss, action_loss, aux_loss = 0, 0, 0
                            prev_actions_batch = prev_actions_batch.cpu()
                            not_done_masks = not_done_masks.cpu()
                            corrected_actions_batch = corrected_actions_batch.cpu(
                            )
                            weights_batch = weights_batch.cpu()
                            observations_batch = {
                                k: v.cpu()
                                for k, v in observations_batch.items()
                            }
                            for i in range(not_done_masks.size(0)):
                                output = self._update_agent(
                                    {
                                        k: v[i].to(device=self.device,
                                                   non_blocking=True)
                                        for k, v in observations_batch.items()
                                    },
                                    prev_actions_batch[i].to(
                                        device=self.device, non_blocking=True),
                                    not_done_masks[i].to(device=self.device,
                                                         non_blocking=True),
                                    corrected_actions_batch[i].to(
                                        device=self.device, non_blocking=True),
                                    weights_batch[i].to(device=self.device,
                                                        non_blocking=True),
                                )
                                loss += output[0]
                                action_loss += output[1]
                                aux_loss += output[2]

                        logger.info(f"train_loss: {loss}")
                        logger.info(f"train_action_loss: {action_loss}")
                        logger.info(f"train_aux_loss: {aux_loss}")
                        logger.info(f"Batches processed: {step_id}.")
                        logger.info(
                            f"On DAgger iter {dagger_it}, Epoch {epoch}.")
                        writer.add_scalar(f"train_loss_iter_{dagger_it}", loss,
                                          step_id)
                        writer.add_scalar(
                            f"train_action_loss_iter_{dagger_it}", action_loss,
                            step_id)
                        writer.add_scalar(f"train_aux_loss_iter_{dagger_it}",
                                          aux_loss, step_id)
                        step_id += 1

                    self.save_checkpoint(
                        f"ckpt.{dagger_it * self.config.DAGGER.EPOCHS + epoch}.pth"
                    )
                AuxLosses.deactivate()
コード例 #4
0
    def forward(self, observations, rnn_hidden_states, prev_actions, masks):
        r"""
        instruction_embedding: [batch_size x INSTRUCTION_ENCODER.output_size]
        depth_embedding: [batch_size x DEPTH_ENCODER.output_size]
        rgb_embedding: [batch_size x RGB_ENCODER.output_size]
        """
        instruction_embedding = self.instruction_encoder(observations)
        depth_embedding = self.depth_encoder(observations)
        depth_embedding = torch.flatten(depth_embedding, 2)

        rgb_embedding = self.rgb_encoder(observations)
        rgb_embedding = torch.flatten(rgb_embedding, 2)

        prev_actions = self.prev_action_embedding(
            ((prev_actions.float() + 1) * masks).long().view(-1)
        )

        if self.model_config.ablate_instruction:
            instruction_embedding = instruction_embedding * 0
        if self.model_config.ablate_depth:
            depth_embedding = depth_embedding * 0
        if self.model_config.ablate_rgb:
            rgb_embedding = rgb_embedding * 0

        if self.rcm_state_encoder:
            (
                state,
                rnn_hidden_states[0 : self.state_encoder.num_recurrent_layers],
            ) = self.state_encoder(
                rgb_embedding,
                depth_embedding,
                prev_actions,
                rnn_hidden_states[0 : self.state_encoder.num_recurrent_layers],
                masks,
            )
        else:
            rgb_in = self.rgb_linear(rgb_embedding)
            depth_in = self.depth_linear(depth_embedding)

            state_in = torch.cat([rgb_in, depth_in, prev_actions], dim=1)
            (
                state,
                rnn_hidden_states[0 : self.state_encoder.num_recurrent_layers],
            ) = self.state_encoder(
                state_in,
                rnn_hidden_states[0 : self.state_encoder.num_recurrent_layers],
                masks,
            )

        text_state_q = self.state_q(state)
        text_state_k = self.text_k(instruction_embedding)
        text_mask = (instruction_embedding == 0.0).all(dim=1)
        text_embedding = self._attn(
            text_state_q, text_state_k, instruction_embedding, text_mask
        )

        rgb_k, rgb_v = torch.split(
            self.rgb_kv(rgb_embedding), self._hidden_size // 2, dim=1
        )
        depth_k, depth_v = torch.split(
            self.depth_kv(depth_embedding), self._hidden_size // 2, dim=1
        )

        text_q = self.text_q(text_embedding)
        rgb_embedding = self._attn(text_q, rgb_k, rgb_v)
        depth_embedding = self._attn(text_q, depth_k, depth_v)

        x = torch.cat(
            [state, text_embedding, rgb_embedding, depth_embedding, prev_actions], dim=1
        )
        x = self.second_state_compress(x)
        (
            x,
            rnn_hidden_states[self.state_encoder.num_recurrent_layers :],
        ) = self.second_state_encoder(
            x, rnn_hidden_states[self.state_encoder.num_recurrent_layers :], masks
        )

        if self.model_config.PROGRESS_MONITOR.use and AuxLosses.is_active():
            progress_hat = torch.tanh(self.progress_monitor(x))
            progress_loss = F.mse_loss(
                progress_hat.squeeze(1), observations["progress"], reduction="none"
            )
            AuxLosses.register_loss(
                "progress_monitor",
                progress_loss,
                self.model_config.PROGRESS_MONITOR.alpha,
            )

        return x, rnn_hidden_states
コード例 #5
0
    config.MODEL.TORCH_GPU_ID = config.TORCH_GPU_ID
    config.freeze()

    action_space = spaces.Discrete(4)

    policy = CMAPolicy(observation_space, action_space, config.MODEL).to(device)

    dummy_instruction = torch.randint(1, 4, size=(4 * 2, 8), device=device)
    dummy_instruction[:, 5:] = 0
    dummy_instruction[0, 2:] = 0

    obs = dict(
        rgb=torch.randn(4 * 2, 224, 224, 3, device=device),
        depth=torch.randn(4 * 2, 256, 256, 1, device=device),
        instruction=dummy_instruction,
        progress=torch.randn(4 * 2, 1, device=device),
    )

    hidden_states = torch.randn(
        policy.net.state_encoder.num_recurrent_layers,
        2,
        policy.net._hidden_size,
        device=device,
    )
    prev_actions = torch.randint(0, 3, size=(4 * 2, 1), device=device)
    masks = torch.ones(4 * 2, 1, device=device)

    AuxLosses.activate()

    policy.evaluate_actions(obs, hidden_states, prev_actions, masks, prev_actions)