def save_metric_output(scores, curves, base_dir, name): """ Helper function to save output of the function in a nice format """ scores_string = {str(key): str(item) for key, item in scores.items()} save_json(scores_string, base_dir / f"{name}.json") save_pickle({"scores": scores, "curves": curves}, base_dir / f"{name}.pkl")
def reformat_labels(target: Path): for p in subfiles(target, identifier="*json", join=True): label = load_json(Path(p)) mal_labels = label["scores"] instance_classes = { key: int(item >= 3) for key, item in mal_labels.items() } save_json({ "instances": instance_classes, "scores": mal_labels }, Path(p))
def _create_mask(source, target, centers, rads): try: logger.info(f"Processing {source.stem}") data = sitk.ReadImage(str(source)) mask = create_circle_mask_itk(data, centers, rads, ndim=3) sitk.WriteImage( mask, str(target / f"{source.stem.replace('.', '_')}.nii.gz")) save_json({"instances": {str(k + 1): 0 for k in range(len(centers))}}, target / f"{source.stem.replace('.', '_')}.json") except Exception as e: logger.error( f"Case {source.stem} failed with {e} and {traceback.format_exc()}")
def create_masks(source: Path, target: Path, df: pd.DataFrame, num_processes: int): files = [] split = {} for i in range(10): subset_dir = source / f"subset{i}" if not subset_dir.is_dir(): logger.error(f"{subset_dir} is not s valid subset directory!") continue tmp = list((subset_dir.glob('*.mhd'))) files.extend(tmp) for t in tmp: split[t.stem.replace('.', '_')] = i save_json(split, target.parent.parent / "splits.json") centers = [] rads = [] for f in files: c = [] r = [] try: series_df = df.loc[{f.name.rsplit('.', 1)[0]}] except KeyError: pass else: for _, row in series_df.iterrows(): c.append((float(row['coordX']), float(row['coordY']), float(row['coordZ']))) r.append(float(row['diameter_mm']) / 2) centers.append(c) rads.append(r) assert len(files) == len(centers) == len(rads) with Pool(processes=num_processes) as p: p.starmap(_create_mask, zip(files, repeat(target), centers, rads))
def import_nnunet_boxes( # settings nnunet_prediction_dir: Pathlike, save_dir: Pathlike, boxes_gt_dir: Pathlike, classes: Sequence[str], stuff: Optional[Sequence[int]] = None, num_workers: int = 6, ): assert nnunet_prediction_dir.is_dir( ), f"{nnunet_prediction_dir} is not a dir" save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) summary = [] # create sweep dir sweep_dir = Path(nnunet_prediction_dir) postprocessing_settings = {} # optimize min num voxels logger.info("Looking for optimal min voxel size") min_num_voxel_settings = [0, 5, 10, 15, 20] scores = [] for min_num_voxel in min_num_voxel_settings: # create temp dir sweep_prediction = sweep_dir / f"sweep_min_voxel{min_num_voxel}" sweep_prediction.mkdir(parents=True) # import with settings import_dir( nnunet_prediction_dir=nnunet_prediction_dir, target_dir=sweep_prediction, min_num_voxel=min_num_voxel, save_seg=False, save_iseg=False, stuff=stuff, num_workers=num_workers, ) # evaluate _scores, _ = evaluate_box_dir( pred_dir=sweep_prediction, gt_dir=boxes_gt_dir, classes=classes, save_dir=None, ) scores.append(_scores[TARGET_METRIC]) summary.append({f"Min voxel {min_num_voxel}": _scores[TARGET_METRIC]}) logger.info(f"Min voxel {min_num_voxel} :: {_scores[TARGET_METRIC]}") shutil.rmtree(sweep_prediction) idx = int(np.argmax(scores)) postprocessing_settings["min_num_voxel"] = min_num_voxel_settings[idx] logger.info( f"Found min num voxel {min_num_voxel_settings[idx]} with score {scores[idx]}" ) # optimize score threshold logger.info("Looking for optimal min probability threshold") min_threshold_settings = [None, 0.1, 0.2, 0.3, 0.4, 0.5] scores = [] for min_threshold in min_threshold_settings: # create temp dir sweep_prediction = sweep_dir / f"sweep_min_threshold_{min_threshold}" sweep_prediction.mkdir(parents=True) # import with settings import_dir( nnunet_prediction_dir=nnunet_prediction_dir, target_dir=sweep_prediction, min_threshold=min_threshold, save_seg=False, save_iseg=False, stuff=stuff, num_workers=num_workers, **postprocessing_settings, ) # evaluate _scores, _ = evaluate_box_dir( pred_dir=sweep_prediction, gt_dir=boxes_gt_dir, classes=classes, save_dir=None, ) scores.append(_scores[TARGET_METRIC]) summary.append({f"Min score {min_threshold}": _scores[TARGET_METRIC]}) logger.info(f"Min score {min_threshold} :: {_scores[TARGET_METRIC]}") shutil.rmtree(sweep_prediction) idx = int(np.argmax(scores)) postprocessing_settings["min_threshold"] = min_threshold_settings[idx] logger.info( f"Found min threshold {min_threshold_settings[idx]} with score {scores[idx]}" ) logger.info("Looking for best probability aggregation") aggreagtion_settings = ["max", "median", "mean", "percentile95"] scores = [] for aggregation in aggreagtion_settings: # create temp dir sweep_prediction = sweep_dir / f"sweep_aggregation_{aggregation}" sweep_prediction.mkdir(parents=True) # import with settings import_dir( nnunet_prediction_dir=nnunet_prediction_dir, target_dir=sweep_prediction, aggregation=aggregation, save_seg=False, save_iseg=False, stuff=stuff, num_workers=num_workers, **postprocessing_settings, ) # evaluate _scores, _ = evaluate_box_dir( pred_dir=sweep_prediction, gt_dir=boxes_gt_dir, classes=classes, save_dir=None, ) scores.append(_scores[TARGET_METRIC]) summary.append({f"Aggreagtion {aggregation}": _scores[TARGET_METRIC]}) logger.info(f"Aggreagtion {aggregation} :: {_scores[TARGET_METRIC]}") shutil.rmtree(sweep_prediction) idx = int(np.argmax(scores)) postprocessing_settings["aggregation"] = aggreagtion_settings[idx] logger.info( f"Found aggregation {aggreagtion_settings[idx]} with score {scores[idx]}" ) save_pickle(postprocessing_settings, save_dir / "postprocessing.pkl") save_json(summary, save_dir / "summary.json") return postprocessing_settings
def main(): parser = argparse.ArgumentParser() parser.add_argument( 'tasks', type=str, nargs='+', help="One or multiple of: Task003_Liver, Task007_Pancreas, " "Task008_HepaticVessel, Task010_Colon", ) args = parser.parse_args() tasks = args.tasks decathlon_props = { "Task003_Liver": { "seg2det_stuff": [ 1, ], # liver "seg2det_things": [ 2, ], # cancer "min_size": 3., "labels": { "0": "cancer" }, "labels_stuff": { "1": "liver" }, }, "Task007_Pancreas": { "seg2det_stuff": [ 1, ], # pancreas "seg2det_things": [ 2, ], "min_size": 3., "labels": { "0": "cancer" }, "labels_stuff": { "1": "pancreas" }, }, "Task008_HepaticVessel": { "seg2det_stuff": [ 1, ], # vessel "seg2det_things": [ 2, ], "min_size": 3., "labels": { "0": "tumour" }, "labels_stuff": { "1": "vessel" }, }, "Task010_Colon": { "seg2det_stuff": [], "seg2det_things": [ 1, ], "min_size": 3., "labels": { "0": "cancer" }, "labels_stuff": {}, }, } basedir = Path(os.getenv('det_data')) for task in tasks: task_data_dir = basedir / task logger.remove() logger.add(sys.stdout, level="INFO") logger.add(task_data_dir / "prepare.log", level="DEBUG") logger.info(f"Preparing task: {task}") source_raw_dir = task_data_dir / "raw" source_data_dir = source_raw_dir / "imagesTr" source_labels_dir = source_raw_dir / "labelsTr" splitted_dir = task_data_dir / "raw_splitted" if not source_data_dir.is_dir(): raise ValueError(f"Exptected training images at {source_data_dir}") if not source_labels_dir.is_dir(): raise ValueError( f"Exptected training labels at {source_labels_dir}") if not (p := source_raw_dir / "dataset.json").is_file(): raise ValueError(f"Expected dataset json to be located at {p}") target_data_dir = splitted_dir / "imagesTr" target_label_dir = splitted_dir / "labelsTr" target_data_dir.mkdir(parents=True, exist_ok=True) target_label_dir.mkdir(parents=True, exist_ok=True) # preapre meta original_meta = load_json(source_raw_dir / "dataset.json") dataset_info = { "task": task, "name": original_meta["name"], "target_class": None, "test_labels": True, "modalities": original_meta["modality"], "dim": 3, "info": { "original_labels": original_meta["labels"], "original_numTraining": original_meta["numTraining"], }, } dataset_info.update(decathlon_props[task]) save_json(dataset_info, task_data_dir / "dataset.json") # prepare data and labels case_ids = get_case_ids_from_dir(source_data_dir, remove_modality=False) case_ids = sorted([c for c in case_ids if c]) logger.info(f"Found {len(case_ids)} for preparation.") for cid in maybe_verbose_iterable(case_ids): process_case( cid, source_data_dir, source_labels_dir, target_data_dir, target_label_dir, ) # with Pool(processes=6) as p: # p.starmap(process_case, zip(case_ids, # repeat(source_images), # repeat(source_labels), # repeat(target_images), # repeat(target_labels), # )) # create an artificial test split create_test_split( splitted_dir=splitted_dir, num_modalities=1, test_size=0.3, random_state=0, shuffle=True, )
def instances_from_segmentation(source_file: Path, output_folder: Path, rm_classes: Sequence[int] = None, ro_classes: Dict[int, int] = None, subtract_one_of_classes: bool = True, fg_vs_bg: bool = False, file_name: Optional[str] = None ): """ 1. Optionally removes classes from the segmentation ( e.g. organ segmentation's which are not useful for detection) 2. Optionally reorders the segmentation indices 3. Converts semantic segmentation to instance segmentation's via connected components Args: source_file: path to semantic segmentation file output_folder: folder where processed file will be saved rm_classes: classes to remove from semantic segmentation ro_classes: reorder classes before instances are generated subtract_one_of_classes: subtracts one from the classes in the instance mapping (detection networks assume that classes start from 0) fg_vs_bg: map all foreground classes to a single class to run foreground vs background detection task. file_name: name of saved file (without file type!) """ if subtract_one_of_classes and fg_vs_bg: logger.info("subtract_one_of_classes will be ignored because fg_vs_bg is " "active and all foreground classes ill be mapped to 0") seg_itk = sitk.ReadImage(str(source_file)) seg_npy = sitk.GetArrayFromImage(seg_itk) if rm_classes is not None: seg_npy = remove_classes(seg_npy, rm_classes) if ro_classes is not None: seg_npy = reorder_classes(seg_npy, ro_classes) instances, instance_classes = seg2instances(seg_npy) if fg_vs_bg: num_instances_check = len(instance_classes) seg_npy[seg_npy > 0] = 1 instances, instance_classes = seg2instances(seg_npy) num_instances = len(instance_classes) if num_instances != num_instances_check: logger.warning(f"Lost instance: Found {num_instances} instances before " f"fg_vs_bg but {num_instances_check} instances after it") if subtract_one_of_classes: for key in instance_classes.keys(): instance_classes[key] -= 1 if fg_vs_bg: for key in instance_classes.keys(): instance_classes[key] = 0 seg_itk_new = sitk.GetImageFromArray(instances) seg_itk_new = sitk_copy_metadata(seg_itk, seg_itk_new) if file_name is None: suffix_length = sum(map(len, source_file.suffixes)) file_name = source_file.name[:-suffix_length] save_json({"instances": instance_classes}, output_folder / f"{file_name}.json") sitk.WriteImage(seg_itk_new, str(output_folder / f"{file_name}.nii.gz"))
def run_analysis_suite(prediction_dir: Path, gt_dir: Path, save_dir: Path): for iou, score in maybe_verbose_iterable( list(product([0.1, 0.5], [0.1, 0.5]))): _save_dir = save_dir / f"iou_{iou}_score_{score}" _save_dir.mkdir(parents=True, exist_ok=True) found_predictions = list(prediction_dir.glob("*_boxes.pkl")) logger.info(f"Found {len(found_predictions)} predictions for analysis") df, analysis_ids = collect_overview( prediction_dir, gt_dir, iou=iou, score=score, max_num_fp_per_image=5, top_n=10, ) df.to_json(_save_dir / "analysis.json", indent=4, orient='index') df.to_csv(_save_dir / "analysis.csv") save_json(analysis_ids, _save_dir / "analysis_ids.json") all_pred, all_target, all_pred_ious, all_pred_scores = collect_score_iou( prediction_dir, gt_dir, iou=iou, score=score) confusion_ax = plot_confusion_matrix(all_pred, all_target, iou=iou, score=score) plt.savefig(_save_dir / "confusion_matrix.png") plt.close() iou_score_ax = plot_joint_iou_score(all_pred_ious, all_pred_scores) plt.savefig(_save_dir / "joint_iou_score.png") plt.close() all_pred, all_target, all_boxes = collect_boxes(prediction_dir, gt_dir, iou=iou, score=score) sizes_fig, sizes_ax = plot_sizes(all_pred, all_target, all_boxes, iou=iou, score=score) plt.savefig(_save_dir / "sizes.png") with open(str(_save_dir / 'sizes.pkl'), "wb") as fp: pickle.dump(sizes_fig, fp, protocol=4) plt.close() sizes_fig, sizes_ax = plot_sizes_bar(all_pred, all_target, all_boxes, iou=iou, score=score) plt.savefig(_save_dir / "sizes_bar.png") with open(str(_save_dir / 'sizes_bar.pkl'), "wb") as fp: pickle.dump(sizes_fig, fp, protocol=4) plt.close() sizes_fig, sizes_ax = plot_sizes_bar(all_pred, all_target, all_boxes, iou=iou, score=score, max_bin=100) plt.savefig(_save_dir / "sizes_bar_100.png") with open(str(_save_dir / 'sizes_bar_100.pkl'), "wb") as fp: pickle.dump(sizes_fig, fp, protocol=4) plt.close()
def _train( task: str, ov: List[str], do_sweep: bool, ): """ Run training Args: task: task to run training for ov: overwrites for config manager do_sweep: determine best emprical parameters for run """ print(f"Overwrites: {ov}") initialize_config_module(config_module="nndet.conf") cfg = compose(task, "config.yaml", overrides=ov if ov is not None else []) assert cfg.host.parent_data is not None, 'Parent data can not be None' assert cfg.host.parent_results is not None, 'Output dir can not be None' train_dir = init_train_dir(cfg) pl_logger = MLFlowLogger( experiment_name=cfg["task"], tags={ "host": socket.gethostname(), "fold": cfg["exp"]["fold"], "task": cfg["task"], "job_id": os.getenv('LSB_JOBID', 'no_id'), "mlflow.runName": cfg["exp"]["id"], }, save_dir=os.getenv("MLFLOW_TRACKING_URI", "./mlruns"), ) pl_logger.log_hyperparams( flatten_mapping( {"model": OmegaConf.to_container(cfg["model_cfg"], resolve=True)})) pl_logger.log_hyperparams( flatten_mapping({ "trainer": OmegaConf.to_container(cfg["trainer_cfg"], resolve=True) })) logger.remove() logger.add(sys.stdout, format="{level} {message}", level="INFO") log_file = Path(os.getcwd()) / "train.log" logger.add(log_file, level="INFO") logger.info(f"Log file at {log_file}") meta_data = {} meta_data["torch_version"] = str(torch.__version__) meta_data["date"] = str(datetime.now()) meta_data["git"] = log_git(nndet.__path__[0], repo_name="nndet") save_json(meta_data, "./meta.json") try: write_requirements_to_file("requirements.txt") except Exception as e: logger.error(f"Could not log req: {e}") plan_path = Path(str(cfg.host["plan_path"])) plan = load_pickle(plan_path) save_json(create_debug_plan(plan), "./plan_debug.json") data_dir = Path(cfg.host["preprocessed_output_dir"] ) / plan["data_identifier"] / "imagesTr" datamodule = Datamodule( augment_cfg=OmegaConf.to_container(cfg["augment_cfg"], resolve=True), plan=plan, data_dir=data_dir, fold=cfg["exp"]["fold"], ) module = MODULE_REGISTRY[cfg["module"]]( model_cfg=OmegaConf.to_container(cfg["model_cfg"], resolve=True), trainer_cfg=OmegaConf.to_container(cfg["trainer_cfg"], resolve=True), plan=plan, ) callbacks = [] checkpoint_cb = ModelCheckpoint( dirpath=train_dir, filename='model_best', save_last=True, save_top_k=1, monitor=cfg["trainer_cfg"]["monitor_key"], mode=cfg["trainer_cfg"]["monitor_mode"], ) checkpoint_cb.CHECKPOINT_NAME_LAST = 'model_last' callbacks.append(checkpoint_cb) callbacks.append(LearningRateMonitor(logging_interval="epoch")) OmegaConf.save(cfg, str(Path(os.getcwd()) / "config.yaml")) OmegaConf.save(cfg, str(Path(os.getcwd()) / "config_resolved.yaml"), resolve=True) save_pickle(plan, train_dir / "plan.pkl") # backup plan splits = load_pickle( Path(cfg.host.preprocessed_output_dir) / datamodule.splits_file) save_pickle(splits, train_dir / "splits.pkl") trainer_kwargs = {} if cfg["train"]["mode"].lower() == "resume": trainer_kwargs[ "resume_from_checkpoint"] = train_dir / "model_last.ckpt" num_gpus = cfg["trainer_cfg"]["gpus"] logger.info(f"Using {num_gpus} GPUs for training") plugins = cfg["trainer_cfg"].get("plugins", None) logger.info(f"Using {plugins} plugins for training") trainer = pl.Trainer( gpus=list(range(num_gpus)) if num_gpus > 1 else num_gpus, accelerator=cfg["trainer_cfg"]["accelerator"], precision=cfg["trainer_cfg"]["precision"], amp_backend=cfg["trainer_cfg"]["amp_backend"], amp_level=cfg["trainer_cfg"]["amp_level"], benchmark=cfg["trainer_cfg"]["benchmark"], deterministic=cfg["trainer_cfg"]["deterministic"], callbacks=callbacks, logger=pl_logger, max_epochs=module.max_epochs, progress_bar_refresh_rate=None if bool(int(os.getenv("det_verbose", 1))) else 0, reload_dataloaders_every_epoch=False, num_sanity_val_steps=10, weights_summary='full', plugins=plugins, terminate_on_nan=True, # TODO: make modular move_metrics_to_cpu=True, **trainer_kwargs) trainer.fit(module, datamodule=datamodule) if do_sweep: case_ids = splits[cfg["exp"]["fold"]]["val"] if "debug" in cfg and "num_cases_val" in cfg["debug"]: case_ids = case_ids[:cfg["debug"]["num_cases_val"]] inference_plan = module.sweep( cfg=OmegaConf.to_container(cfg, resolve=True), save_dir=train_dir, train_data_dir=data_dir, case_ids=case_ids, run_prediction=True, ) plan["inference_plan"] = inference_plan save_pickle(plan, train_dir / "plan_inference.pkl") ensembler_cls = module.get_ensembler_cls( key="boxes", dim=plan["network_dim"]) # TODO: make this configurable for restore in [True, False]: target_dir = train_dir / "val_predictions" if restore else \ train_dir / "val_predictions_preprocessed" extract_results( source_dir=train_dir / "sweep_predictions", target_dir=target_dir, ensembler_cls=ensembler_cls, restore=restore, **inference_plan, ) _evaluate( task=cfg["task"], model=cfg["exp"]["id"], fold=cfg["exp"]["fold"], test=False, do_boxes_eval=True, # TODO: make this configurable do_analyze_boxes=True, # TODO: make this configurable )
def prepare_case(case_dir: Path, target_dir: Path, df: pd.DataFrame): target_data_dir = target_dir / "imagesTr" target_label_dir = target_dir / "labelsTr" case_id = str(case_dir).split('/')[-1] logger.info(f"Processing case {case_id}") df = df[df.PatientID == case_id] # process data img = sitk.ReadImage(str(case_dir / f"{case_id}_ct_scan.nrrd")) sitk.WriteImage(img, str(target_data_dir / f"{case_id}.nii.gz")) img_arr = sitk.GetArrayFromImage(img) # process mask final_rois = np.zeros_like(img_arr, dtype=np.uint8) mal_labels = {} roi_ids = set([ ii.split('.')[0].split('_')[-1] for ii in os.listdir(case_dir) if '.nii.gz' in ii ]) rix = 1 for rid in roi_ids: roi_id_paths = [ ii for ii in os.listdir(case_dir) if '{}.nii'.format(rid) in ii ] nodule_ids = [ii.split('_')[2].lstrip("0") for ii in roi_id_paths] rater_labels = [ df[df.NoduleID == int(ii)].Malignancy.values[0] for ii in nodule_ids ] rater_labels.extend([0] * (4 - len(rater_labels))) mal_label = np.mean([ii for ii in rater_labels if ii > -1]) roi_rater_list = [] for rp in roi_id_paths: roi = sitk.ReadImage(str(case_dir / rp)) roi_arr = sitk.GetArrayFromImage(roi).astype(np.uint8) assert roi_arr.shape == img_arr.shape, [ roi_arr.shape, img_arr.shape, case_id, roi.GetSpacing() ] for ix in range(len(img_arr.shape)): npt.assert_almost_equal(roi.GetSpacing()[ix], img.GetSpacing()[ix]) roi_rater_list.append(roi_arr) roi_rater_list.extend([np.zeros_like(roi_rater_list[-1])] * (4 - len(roi_id_paths))) roi_raters = np.array(roi_rater_list) roi_raters = np.mean(roi_raters, axis=0) roi_raters[roi_raters < 0.5] = 0 if np.sum(roi_raters) > 0: mal_labels[rix] = mal_label final_rois[roi_raters >= 0.5] = rix rix += 1 else: # indicate rois suppressed by majority voting of raters logger.warning(f'suppressed roi! {roi_id_paths}') mask_itk = sitk.GetImageFromArray(final_rois) sitk.WriteImage(mask_itk, str(target_label_dir / f"{case_id}.nii.gz")) instance_classes = {key: int(item >= 3) for key, item in mal_labels} save_json({ "instances": instance_classes, "scores": mal_labels }, target_label_dir / f"{case_id}")
logger.add(task_data_dir / "prepare.log", level="DEBUG") meta = { "name": "Luna", "task": "Task016_Luna", "target_class": None, "test_labels": False, "labels": { "0": "lesion", }, "modalities": { "0": "CT", }, "dim": 3, } save_json(meta, task_data_dir / "dataset.json") # prepare data and labels csv = source_data_dir / "annotations.csv" convert_data(source_data_dir, target_data_dir, num_processes=num_processes) df = pd.read_csv(csv, index_col='seriesuid') create_masks(source_data_dir, target_label_dir, df, num_processes=num_processes) # generate split logger.info("Generating luna splits... ") saved_original_splits = load_json(task_data_dir / "splits.json") logger.info(
def export_dataset_info(self): """ Export dataset settings (dataset.json for nnunet) """ self.target_dir.mkdir(exist_ok=True, parents=True) dataset_info = {} dataset_info["name"] = self.data_info.get("name", "unknown") dataset_info["description"] = self.data_info.get( "description", "unknown") dataset_info["reference"] = self.data_info.get("reference", "unknown") dataset_info["licence"] = self.data_info.get("licence", "unknown") dataset_info["release"] = self.data_info.get("release", "unknown") min_size = self.data_info.get("min_size", 0) min_vol = self.data_info.get("min_vol", 0) dataset_info["prep_info"] = f"min size: {min_size} ; min vol {min_vol}" dataset_info["tensorImageSize"] = f"{self.data_info.get('dim', 3)}D" # dataset_info["tensorImageSize"] = self.data_info.get("tensorImageSize", "4D") dataset_info["modality"] = self.data_info.get("modalities", {}) if not dataset_info["modality"]: logger.error("Did not find any modalities for dataset") # +1 for seg classes because of background dataset_info["labels"] = {"0": "background"} instance_classes = self.data_info.get("labels", {}) if not instance_classes: logger.error("Did not find any labels of dataset") for _id, _class in instance_classes.items(): seg_id = int(_id) + 1 dataset_info["labels"][str(seg_id)] = _class if self.export_stuff: stuff_classes = self.data_info.get("labels_stuff", {}) num_instance_classes = len(instance_classes) # copy stuff classes into nnuent dataset.json stuff_classes = { str(int(key) + num_instance_classes): item for key, item in stuff_classes.items() if int(key) > 0 } dataset_info["labels_stuff"] = stuff_classes dataset_info["labels"].update(stuff_classes) _case_ids = get_case_ids_from_dir(self.label_dir, remove_modality=False) case_ids_tr = get_case_ids_from_dir(self.tr_image_dir, remove_modality=True) assert len(set(_case_ids).union(case_ids_tr)) == len( _case_ids), "All training images need a label" dataset_info["numTraining"] = len(case_ids_tr) dataset_info["training"] = [{ "image": f"./imagesTr/{cid}.nii.gz", "label": f"./labelsTr/{cid}.nii.gz" } for cid in case_ids_tr] if self.ts_image_dir is not None: case_ids_ts = get_case_ids_from_dir(self.ts_image_dir, remove_modality=True) dataset_info["numTest"] = len(case_ids_ts) dataset_info["test"] = [ f"./imagesTs/{cid}.nii.gz" for cid in case_ids_ts ] else: dataset_info["numTest"] = 0 dataset_info["test"] = [] save_json(dataset_info, self.target_dir / "dataset.json")