def train_epoch(self, epoch) -> None: self.model.train() dataloader = self.task.labeled_dataloader.dataloaders['train'] dataloader_meta_info = DataloaderMetaInfo(dataloader) with self.task.loss.new_epoch(epoch, "train", dataloader_meta_info=dataloader_meta_info): progress_bar = Bar(f"Train epoch {epoch} of task {self.task.uid}", max=len(dataloader)) for batch_idx, data in enumerate(dataloader): self.optimizer.zero_grad() data = self.dict_to_device(data) batch_size = data[ChannelEnum.GT_DEM].size(0) output = self.model(data) if torch.isnan(output[ChannelEnum.REC_DEM]).sum() > 0: raise RuntimeError("We detected NaNs in the model outputs which means " "that the training is diverging") loss_dict = self.model.loss_function(loss_config=self.task.config["loss"], output=output, data=data, dataloader_meta_info=dataloader_meta_info, feature_extractor=self.feature_extractor) loss = loss_dict[LossEnum.LOSS] self.task.loss(batch_size=batch_size, loss_dict=loss_dict) loss.backward() self.optimizer.step() progress_bar.next() progress_bar.finish()
def validate_epoch(self, epoch: int): self.model.eval() dataloader = self.task.labeled_dataloader.dataloaders['val'] dataloader_meta_info = DataloaderMetaInfo(dataloader) with self.task.loss.new_epoch( epoch, "val", dataloader_meta_info=dataloader_meta_info), torch.no_grad(): progress_bar = Bar( f"Validate epoch {epoch} of task {self.task.uid}", max=len(dataloader)) for batch_idx, data in enumerate(dataloader): data = self.dict_to_device(data) batch_size = data[ChannelEnum.GT_DEM].size(0) output = self.model.forward_pass(data) loss_dict = self.model.loss_function( loss_config=self.task.config["loss"], output=output, data=data, dataloader_meta_info=dataloader_meta_info) self.task.loss(batch_size=batch_size, loss_dict=loss_dict) progress_bar.next() progress_bar.finish() self.controller.add_state(epoch, self.task.loss.get_epoch_loss(), self.model.state_dict())
def infer(self): hdf5_group_prefix = f"/task_{self.task.uid}/inference" data_hdf5_group = self.results_hdf5_file.create_group( f"/{hdf5_group_prefix}/data") self.model.eval() if self.task.type in [ TaskTypeEnum.SUPERVISED_LEARNING, TaskTypeEnum.INFERENCE ]: dataloader = self.task.labeled_dataloader.dataloaders['test'] else: raise NotImplementedError( f"The following task type is not implemented: {self.task.type}" ) dataloader_meta_info = DataloaderMetaInfo(dataloader) subgrid_size = self.task.labeled_dataloader.config.get("subgrid_size") prof = None if not isinstance(self.model, LsqPlaneFitBaseline): prof = profiler.profile() prof.__enter__() start_idx = 0 progress_bar = Bar(f"Inference for task {self.task.uid}", max=len(dataloader)) for batch_idx, data in enumerate(dataloader): data = self.dict_to_device(data) batch_size = data[ChannelEnum.OCC_DEM].size(0) grid_size = list(data[ChannelEnum.OCC_DEM].size()[1:3]) grid_data = data if subgrid_size is not None: data = self.split_subgrids(subgrid_size, data) if isinstance(self.model, LsqPlaneFitBaseline): # the profiler somehow has issues with the scipy lsq solver output = self.model.forward_pass(data) else: with profiler.record_function("model_inference"): output = self.model.forward_pass(data) if subgrid_size is not None: # max occlusion ratio threshold for COMP_DEM where we accept reconstruction # instead of just taking all OCC_DEM subgrid_max_occ_ratio_thresh = self.task.config.get( "subgrid_max_occ_ratio_thresh", 1.0) if subgrid_max_occ_ratio_thresh < 1.0: occ_dem = data[ChannelEnum.OCC_DEM] occ_ratio = torch.isnan(occ_dem).sum( dim=(1, 2)) / (occ_dem.size(1) * occ_dem.size(2)) occ_ratio_selector = occ_ratio > subgrid_max_occ_ratio_thresh comp_dem = output[ChannelEnum.COMP_DEM] comp_dem[occ_ratio_selector, :, :] = occ_dem[ occ_ratio_selector, :, :] output[ChannelEnum.COMP_DEM] = comp_dem if ChannelEnum.OCC_DATA_UM in data and ChannelEnum.COMP_DATA_UM in output: occ_data_um = output[ChannelEnum.OCC_DATA_UM] comp_data_um = output[ChannelEnum.COMP_DATA_UM] comp_data_um[occ_ratio_selector, :, :] = occ_data_um[ occ_ratio_selector, :, :] output[ChannelEnum.COMP_DATA_UM] = comp_dem output = self.unsplit_subgrids(grid_size, output) data = grid_data self.add_batch_data_to_hdf5_results(data_hdf5_group, data, start_idx, dataloader_meta_info.length) self.add_batch_data_to_hdf5_results(data_hdf5_group, output, start_idx, dataloader_meta_info.length) start_idx += batch_size progress_bar.next() progress_bar.finish() if not isinstance(self.model, LsqPlaneFitBaseline): prof.__exit__(0, None, None) with open(str(self.task.logdir / "inference_cputime.txt"), "a") as f: f.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) if self.task.config.get("profiler_export_chrome_trace", False): prof.export_chrome_trace( str(self.task.logdir / "inference_cputime_chrome_trace.json"))
def test(self): hdf5_group_prefix = f"/task_{self.task.uid}/test" test_data_hdf5_group = self.results_hdf5_file.create_group( f"/{hdf5_group_prefix}/data") test_loss_hdf5_group = self.results_hdf5_file.create_group( f"/{hdf5_group_prefix}/loss") traversability_assessment = None if self.task.config.get("traversability_assessment", {}).get("active", False): traversability_config = self.task.config.get( "traversability_assessment", {}) traversability_assessment = TraversabilityAssessment( **traversability_config) self.model.eval() if self.task.type == TaskTypeEnum.SUPERVISED_LEARNING: dataloader = self.task.labeled_dataloader.dataloaders['test'] else: raise NotImplementedError( f"The following task type is not implemented: {self.task.type}" ) dataloader_meta_info = DataloaderMetaInfo(dataloader) with self.task.loss.new_epoch( 0, "test", dataloader_meta_info=dataloader_meta_info), torch.no_grad(): prof = None if not isinstance(self.model, LsqPlaneFitBaseline): prof = profiler.profile() prof.__enter__() start_idx = 0 progress_bar = Bar(f"Test inference for task {self.task.uid}", max=len(dataloader)) for batch_idx, data in enumerate(dataloader): data = self.dict_to_device(data) batch_size = data[ChannelEnum.GT_DEM].size(0) if isinstance(self.model, LsqPlaneFitBaseline): # the profiler somehow has issues with the scipy lsq solver output = self.model.forward_pass(data) else: with profiler.record_function("model_inference"): output = self.model.forward_pass(data) if traversability_assessment is not None: output = traversability_assessment(output=output, data=data) self.add_batch_data_to_hdf5_results( test_data_hdf5_group, data, start_idx, dataloader_meta_info.length) self.add_batch_data_to_hdf5_results( test_data_hdf5_group, output, start_idx, dataloader_meta_info.length) loss_dict = self.model.loss_function( loss_config=self.task.config["loss"], output=output, data=data, dataloader_meta_info=dataloader_meta_info, reduction="mean_per_sample") aggregated_loss_dict = self.task.loss.aggregate_mean_loss_dict( loss_dict) self.task.loss(batch_size=batch_size, loss_dict=aggregated_loss_dict) self.add_batch_data_to_hdf5_results( test_loss_hdf5_group, loss_dict, start_idx, dataloader_meta_info.length) start_idx += batch_size progress_bar.next() progress_bar.finish() if not isinstance(self.model, LsqPlaneFitBaseline): prof.__exit__(0, None, None) with open(str(self.task.logdir / "test_cputime.txt"), "a") as f: f.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) if self.task.config.get("profiler_export_chrome_trace", False): prof.export_chrome_trace( str(self.task.logdir / "test_cputime_chrome_trace.json"))