def train_epoch(self, epoch) -> None:
        self.model.train()

        dataloader = self.task.labeled_dataloader.dataloaders['train']
        dataloader_meta_info = DataloaderMetaInfo(dataloader)
        with self.task.loss.new_epoch(epoch, "train", dataloader_meta_info=dataloader_meta_info):
            progress_bar = Bar(f"Train epoch {epoch} of task {self.task.uid}", max=len(dataloader))
            for batch_idx, data in enumerate(dataloader):
                self.optimizer.zero_grad()

                data = self.dict_to_device(data)
                batch_size = data[ChannelEnum.GT_DEM].size(0)

                output = self.model(data)

                if torch.isnan(output[ChannelEnum.REC_DEM]).sum() > 0:
                    raise RuntimeError("We detected NaNs in the model outputs which means "
                                       "that the training is diverging")

                loss_dict = self.model.loss_function(loss_config=self.task.config["loss"],
                                                     output=output,
                                                     data=data,
                                                     dataloader_meta_info=dataloader_meta_info,
                                                     feature_extractor=self.feature_extractor)
                loss = loss_dict[LossEnum.LOSS]
                self.task.loss(batch_size=batch_size, loss_dict=loss_dict)

                loss.backward()
                self.optimizer.step()
                progress_bar.next()
            progress_bar.finish()
Esempio n. 2
0
    def validate_epoch(self, epoch: int):
        self.model.eval()

        dataloader = self.task.labeled_dataloader.dataloaders['val']
        dataloader_meta_info = DataloaderMetaInfo(dataloader)
        with self.task.loss.new_epoch(
                epoch, "val",
                dataloader_meta_info=dataloader_meta_info), torch.no_grad():
            progress_bar = Bar(
                f"Validate epoch {epoch} of task {self.task.uid}",
                max=len(dataloader))
            for batch_idx, data in enumerate(dataloader):
                data = self.dict_to_device(data)
                batch_size = data[ChannelEnum.GT_DEM].size(0)

                output = self.model.forward_pass(data)

                loss_dict = self.model.loss_function(
                    loss_config=self.task.config["loss"],
                    output=output,
                    data=data,
                    dataloader_meta_info=dataloader_meta_info)
                self.task.loss(batch_size=batch_size, loss_dict=loss_dict)
                progress_bar.next()
            progress_bar.finish()

        self.controller.add_state(epoch, self.task.loss.get_epoch_loss(),
                                  self.model.state_dict())
    def infer(self):
        hdf5_group_prefix = f"/task_{self.task.uid}/inference"
        data_hdf5_group = self.results_hdf5_file.create_group(
            f"/{hdf5_group_prefix}/data")

        self.model.eval()
        if self.task.type in [
                TaskTypeEnum.SUPERVISED_LEARNING, TaskTypeEnum.INFERENCE
        ]:
            dataloader = self.task.labeled_dataloader.dataloaders['test']
        else:
            raise NotImplementedError(
                f"The following task type is not implemented: {self.task.type}"
            )

        dataloader_meta_info = DataloaderMetaInfo(dataloader)

        subgrid_size = self.task.labeled_dataloader.config.get("subgrid_size")

        prof = None
        if not isinstance(self.model, LsqPlaneFitBaseline):
            prof = profiler.profile()
            prof.__enter__()

        start_idx = 0
        progress_bar = Bar(f"Inference for task {self.task.uid}",
                           max=len(dataloader))
        for batch_idx, data in enumerate(dataloader):
            data = self.dict_to_device(data)
            batch_size = data[ChannelEnum.OCC_DEM].size(0)
            grid_size = list(data[ChannelEnum.OCC_DEM].size()[1:3])

            grid_data = data
            if subgrid_size is not None:
                data = self.split_subgrids(subgrid_size, data)

            if isinstance(self.model, LsqPlaneFitBaseline):
                # the profiler somehow has issues with the scipy lsq solver
                output = self.model.forward_pass(data)
            else:
                with profiler.record_function("model_inference"):
                    output = self.model.forward_pass(data)

            if subgrid_size is not None:
                # max occlusion ratio threshold for COMP_DEM where we accept reconstruction
                # instead of just taking all OCC_DEM
                subgrid_max_occ_ratio_thresh = self.task.config.get(
                    "subgrid_max_occ_ratio_thresh", 1.0)
                if subgrid_max_occ_ratio_thresh < 1.0:
                    occ_dem = data[ChannelEnum.OCC_DEM]
                    occ_ratio = torch.isnan(occ_dem).sum(
                        dim=(1, 2)) / (occ_dem.size(1) * occ_dem.size(2))
                    occ_ratio_selector = occ_ratio > subgrid_max_occ_ratio_thresh

                    comp_dem = output[ChannelEnum.COMP_DEM]
                    comp_dem[occ_ratio_selector, :, :] = occ_dem[
                        occ_ratio_selector, :, :]
                    output[ChannelEnum.COMP_DEM] = comp_dem

                    if ChannelEnum.OCC_DATA_UM in data and ChannelEnum.COMP_DATA_UM in output:
                        occ_data_um = output[ChannelEnum.OCC_DATA_UM]
                        comp_data_um = output[ChannelEnum.COMP_DATA_UM]
                        comp_data_um[occ_ratio_selector, :, :] = occ_data_um[
                            occ_ratio_selector, :, :]
                        output[ChannelEnum.COMP_DATA_UM] = comp_dem

                output = self.unsplit_subgrids(grid_size, output)
                data = grid_data

            self.add_batch_data_to_hdf5_results(data_hdf5_group, data,
                                                start_idx,
                                                dataloader_meta_info.length)
            self.add_batch_data_to_hdf5_results(data_hdf5_group, output,
                                                start_idx,
                                                dataloader_meta_info.length)

            start_idx += batch_size
            progress_bar.next()
        progress_bar.finish()

        if not isinstance(self.model, LsqPlaneFitBaseline):
            prof.__exit__(0, None, None)
            with open(str(self.task.logdir / "inference_cputime.txt"),
                      "a") as f:
                f.write(prof.key_averages().table(sort_by="cpu_time_total",
                                                  row_limit=20))
            if self.task.config.get("profiler_export_chrome_trace", False):
                prof.export_chrome_trace(
                    str(self.task.logdir /
                        "inference_cputime_chrome_trace.json"))
    def test(self):
        hdf5_group_prefix = f"/task_{self.task.uid}/test"
        test_data_hdf5_group = self.results_hdf5_file.create_group(
            f"/{hdf5_group_prefix}/data")
        test_loss_hdf5_group = self.results_hdf5_file.create_group(
            f"/{hdf5_group_prefix}/loss")

        traversability_assessment = None
        if self.task.config.get("traversability_assessment",
                                {}).get("active", False):
            traversability_config = self.task.config.get(
                "traversability_assessment", {})
            traversability_assessment = TraversabilityAssessment(
                **traversability_config)

        self.model.eval()

        if self.task.type == TaskTypeEnum.SUPERVISED_LEARNING:
            dataloader = self.task.labeled_dataloader.dataloaders['test']
        else:
            raise NotImplementedError(
                f"The following task type is not implemented: {self.task.type}"
            )

        dataloader_meta_info = DataloaderMetaInfo(dataloader)
        with self.task.loss.new_epoch(
                0, "test",
                dataloader_meta_info=dataloader_meta_info), torch.no_grad():
            prof = None
            if not isinstance(self.model, LsqPlaneFitBaseline):
                prof = profiler.profile()
                prof.__enter__()

            start_idx = 0
            progress_bar = Bar(f"Test inference for task {self.task.uid}",
                               max=len(dataloader))
            for batch_idx, data in enumerate(dataloader):
                data = self.dict_to_device(data)
                batch_size = data[ChannelEnum.GT_DEM].size(0)

                if isinstance(self.model, LsqPlaneFitBaseline):
                    # the profiler somehow has issues with the scipy lsq solver
                    output = self.model.forward_pass(data)
                else:
                    with profiler.record_function("model_inference"):
                        output = self.model.forward_pass(data)

                if traversability_assessment is not None:
                    output = traversability_assessment(output=output,
                                                       data=data)

                self.add_batch_data_to_hdf5_results(
                    test_data_hdf5_group, data, start_idx,
                    dataloader_meta_info.length)
                self.add_batch_data_to_hdf5_results(
                    test_data_hdf5_group, output, start_idx,
                    dataloader_meta_info.length)

                loss_dict = self.model.loss_function(
                    loss_config=self.task.config["loss"],
                    output=output,
                    data=data,
                    dataloader_meta_info=dataloader_meta_info,
                    reduction="mean_per_sample")
                aggregated_loss_dict = self.task.loss.aggregate_mean_loss_dict(
                    loss_dict)
                self.task.loss(batch_size=batch_size,
                               loss_dict=aggregated_loss_dict)
                self.add_batch_data_to_hdf5_results(
                    test_loss_hdf5_group, loss_dict, start_idx,
                    dataloader_meta_info.length)

                start_idx += batch_size
                progress_bar.next()

            progress_bar.finish()

        if not isinstance(self.model, LsqPlaneFitBaseline):
            prof.__exit__(0, None, None)
            with open(str(self.task.logdir / "test_cputime.txt"), "a") as f:
                f.write(prof.key_averages().table(sort_by="cpu_time_total",
                                                  row_limit=20))
            if self.task.config.get("profiler_export_chrome_trace", False):
                prof.export_chrome_trace(
                    str(self.task.logdir / "test_cputime_chrome_trace.json"))