def run(self): make_empty_dir(self.results) print(f"Preprocessing {self.data_path}") try: self.target_spacing = spacings[self.task_code] except: self.collect_spacings() if self.verbose: print(f"Target spacing {self.target_spacing}") if self.modality == "CT": try: self.ct_min = ct_min[self.task] self.ct_max = ct_max[self.task] self.ct_mean = ct_mean[self.task] self.ct_std = ct_std[self.task] except: self.collect_intensities() _mean = round(self.ct_mean, 2) _std = round(self.ct_std, 2) if self.verbose: print(f"[CT] min: {self.ct_min}, max: {self.ct_max}, mean: {_mean}, std: {_std}") self.run_parallel(self.preprocess_pair, self.args.exec_mode) pickle.dump( { "patch_size": self.patch_size, "spacings": self.target_spacing, "n_class": len(self.metadata["labels"]), "in_channels": len(self.metadata["modality"]) + int(self.args.ohe), }, open(os.path.join(self.results, "config.pkl"), "wb"), )
def create_idx_files(self, tfrecords, save_dir): make_empty_dir(save_dir) tfrecords_idx = [] for tfrec in tfrecords: fname = os.path.basename(tfrec).split(".")[0] tfrecords_idx.append(os.path.join(save_dir, f"{fname}.idx")) Parallel(n_jobs=self.args.n_jobs)( delayed(self.create_idx)(tr, ti) for tr, ti in tqdm(zip(tfrecords, tfrecords_idx), total=len(tfrecords)))
def prepare_data(self): if self.args.create_idx: tfrecords_train, tfrecords_val, tfrecords_test = self.load_tfrecords() make_empty_dir("train_idx") make_empty_dir("val_idx") make_empty_dir("test_idx") self.create_idx("train_idx", tfrecords_train) self.create_idx("val_idx", tfrecords_val) self.create_idx("test_idx", tfrecords_test)
trainer.fit(model, train_dataloader=data_module.train_dataloader()) else: # warmup trainer.test(model, test_dataloaders=data_module.test_dataloader()) # benchmark run trainer.current_epoch = 1 trainer.test(model, test_dataloaders=data_module.test_dataloader()) elif args.exec_mode == "train": trainer.fit(model, data_module) if is_main_process(): logname = args.logname if args.logname is not None else "train_log.json" log(logname, torch.tensor(model.best_mean_dice), results=args.results) elif args.exec_mode == "evaluate": model.args = args trainer.test(model, test_dataloaders=data_module.val_dataloader()) if is_main_process(): logname = args.logname if args.logname is not None else "eval_log.json" log(logname, model.eval_dice, results=args.results) elif args.exec_mode == "predict": if args.save_preds: ckpt_name = "_".join(args.ckpt_path.split("/")[-1].split(".")[:-1]) dir_name = f"predictions_{ckpt_name}" dir_name += f"_task={model.args.task}_fold={model.args.fold}" if args.tta: dir_name += "_tta" save_dir = os.path.join(args.results, dir_name) model.save_dir = save_dir make_empty_dir(save_dir) model.args = args trainer.test(model, test_dataloaders=data_module.test_dataloader())