def _load_filenames(self): """Load preprocessed pair data (input and gt) in handler.""" for subject_id, input_filename, gt_filename, roi_filename, metadata in self.filename_pairs: # Creating/ getting the subject group if str(subject_id) in self.hdf5_file.keys(): grp = self.hdf5_file[str(subject_id)] else: grp = self.hdf5_file.create_group(str(subject_id)) roi_pair = imed_loader.SegmentationPair(input_filename, roi_filename, metadata=metadata, slice_axis=self.slice_axis, cache=False, soft_gt=self.soft_gt) seg_pair = imed_loader.SegmentationPair(input_filename, gt_filename, metadata=metadata, slice_axis=self.slice_axis, cache=False, soft_gt=self.soft_gt) print("gt filename", gt_filename) input_data_shape, _ = seg_pair.get_pair_shapes() useful_slices = [] input_volumes = [] gt_volume = [] roi_volume = [] for idx_pair_slice in range(input_data_shape[-1]): slice_seg_pair = seg_pair.get_pair_slice(idx_pair_slice) self.has_bounding_box = imed_obj_detect.verify_metadata( slice_seg_pair, self.has_bounding_box) if self.has_bounding_box: imed_obj_detect.adjust_transforms(self.prepro_transforms, slice_seg_pair) # keeping idx of slices with gt if self.slice_filter_fn: filter_fn_ret_seg = self.slice_filter_fn(slice_seg_pair) if self.slice_filter_fn and filter_fn_ret_seg: useful_slices.append(idx_pair_slice) roi_pair_slice = roi_pair.get_pair_slice(idx_pair_slice) slice_seg_pair, roi_pair_slice = imed_transforms.apply_preprocessing_transforms( self.prepro_transforms, slice_seg_pair, roi_pair_slice) input_volumes.append(slice_seg_pair["input"][0]) # Handle unlabeled data if not len(slice_seg_pair["gt"]): gt_volume = [] else: gt_volume.append( (slice_seg_pair["gt"][0] * 255).astype(np.uint8) / 255.) # Handle data with no ROI provided if not len(roi_pair_slice["gt"]): roi_volume = [] else: roi_volume.append( (roi_pair_slice["gt"][0] * 255).astype(np.uint8) / 255.) # Getting metadata using the one from the last slice input_metadata = slice_seg_pair['input_metadata'][0] gt_metadata = slice_seg_pair['gt_metadata'][0] roi_metadata = roi_pair_slice['input_metadata'][0] if grp.attrs.__contains__('slices'): grp.attrs['slices'] = list( set(np.concatenate((grp.attrs['slices'], useful_slices)))) else: grp.attrs['slices'] = useful_slices # Creating datasets and metadata contrast = input_metadata['contrast'] # Inputs print(len(input_volumes)) print("grp= ", str(subject_id)) key = "inputs/{}".format(contrast) print("key = ", key) if len(input_volumes) < 1: print("list empty") continue grp.create_dataset(key, data=input_volumes) # Sub-group metadata if grp['inputs'].attrs.__contains__('contrast'): attr = grp['inputs'].attrs['contrast'] new_attr = [c for c in attr] new_attr.append(contrast) grp['inputs'].attrs.create('contrast', new_attr, dtype=self.dt) else: grp['inputs'].attrs.create('contrast', [contrast], dtype=self.dt) # dataset metadata grp[key].attrs['input_filenames'] = input_metadata[ 'input_filenames'] grp[key].attrs['data_type'] = input_metadata['data_type'] if "zooms" in input_metadata.keys(): grp[key].attrs["zooms"] = input_metadata['zooms'] if "data_shape" in input_metadata.keys(): grp[key].attrs["data_shape"] = input_metadata['data_shape'] if "bounding_box" in input_metadata.keys(): grp[key].attrs["bounding_box"] = input_metadata['bounding_box'] # GT key = "gt/{}".format(contrast) grp.create_dataset(key, data=gt_volume) # Sub-group metadata if grp['gt'].attrs.__contains__('contrast'): attr = grp['gt'].attrs['contrast'] new_attr = [c for c in attr] new_attr.append(contrast) grp['gt'].attrs.create('contrast', new_attr, dtype=self.dt) else: grp['gt'].attrs.create('contrast', [contrast], dtype=self.dt) # dataset metadata grp[key].attrs['gt_filenames'] = input_metadata['gt_filenames'] grp[key].attrs['data_type'] = gt_metadata['data_type'] if "zooms" in gt_metadata.keys(): grp[key].attrs["zooms"] = gt_metadata['zooms'] if "data_shape" in gt_metadata.keys(): grp[key].attrs["data_shape"] = gt_metadata['data_shape'] if gt_metadata['bounding_box'] is not None: grp[key].attrs["bounding_box"] = gt_metadata['bounding_box'] # ROI key = "roi/{}".format(contrast) grp.create_dataset(key, data=roi_volume) # Sub-group metadata if grp['roi'].attrs.__contains__('contrast'): attr = grp['roi'].attrs['contrast'] new_attr = [c for c in attr] new_attr.append(contrast) grp['roi'].attrs.create('contrast', new_attr, dtype=self.dt) else: grp['roi'].attrs.create('contrast', [contrast], dtype=self.dt) # dataset metadata grp[key].attrs['roi_filename'] = roi_metadata['gt_filenames'] grp[key].attrs['data_type'] = roi_metadata['data_type'] if "zooms" in roi_metadata.keys(): grp[key].attrs["zooms"] = roi_metadata['zooms'] if "data_shape" in roi_metadata.keys(): grp[key].attrs["data_shape"] = roi_metadata['data_shape'] # Adding contrast to group metadata if grp.attrs.__contains__('contrast'): attr = grp.attrs['contrast'] new_attr = [c for c in attr] new_attr.append(contrast) grp.attrs.create('contrast', new_attr, dtype=self.dt) else: grp.attrs.create('contrast', [contrast], dtype=self.dt)
def _load_filenames(self): """Load preprocessed pair data (input and gt) in handler.""" with h5py.File(self.path_hdf5, "a") as hdf5_file: for subject_id, input_filename, gt_filename, roi_filename, metadata in self.filename_pairs: # Creating/ getting the subject group if str(subject_id) in hdf5_file.keys(): grp = hdf5_file[str(subject_id)] else: grp = hdf5_file.create_group(str(subject_id)) roi_pair = imed_loader.SegmentationPair( input_filename, roi_filename, metadata=metadata, slice_axis=self.slice_axis, cache=False, soft_gt=self.soft_gt) seg_pair = imed_loader.SegmentationPair( input_filename, gt_filename, metadata=metadata, slice_axis=self.slice_axis, cache=False, soft_gt=self.soft_gt) print("gt filename", gt_filename) input_data_shape, _ = seg_pair.get_pair_shapes() useful_slices = [] input_volumes = [] gt_volume = [] roi_volume = [] for idx_pair_slice in range(input_data_shape[-1]): slice_seg_pair, roi_pair_slice = self._slice_seg_pair( idx_pair_slice, seg_pair, roi_pair, useful_slices, input_volumes, gt_volume, roi_volume) # Getting metadata using the one from the last slice input_metadata = slice_seg_pair['input_metadata'][0] gt_metadata = slice_seg_pair['gt_metadata'][0] roi_metadata = roi_pair_slice['input_metadata'][0] if grp.attrs.__contains__('slices'): grp.attrs['slices'] = list( set( np.concatenate( (grp.attrs['slices'], useful_slices)))) else: grp.attrs['slices'] = useful_slices # Creating datasets and metadata contrast = input_metadata['contrast'] # Inputs print(len(input_volumes)) print("grp= ", str(subject_id)) key = "inputs/{}".format(contrast) print("key = ", key) if len(input_volumes) < 1: print("list empty") continue grp.create_dataset(key, data=input_volumes) # Sub-group metadata self.create_subgrp_metadata('inputs', grp, contrast) # dataset metadata grp[key].attrs['input_filenames'] = input_metadata[ 'input_filenames'] self.create_metadata(grp, key, input_metadata) # GT key = "gt/{}".format(contrast) grp.create_dataset(key, data=gt_volume) # Sub-group metadata self.create_subgrp_metadata('gt', grp, contrast) # dataset metadata grp[key].attrs['gt_filenames'] = input_metadata['gt_filenames'] self.create_metadata(grp, key, gt_metadata) # ROI key = "roi/{}".format(contrast) grp.create_dataset(key, data=roi_volume) # Sub-group metadata self.create_subgrp_metadata('roi', grp, contrast) # dataset metadata grp[key].attrs['roi_filename'] = roi_metadata['gt_filenames'] self.create_metadata(grp, key, roi_metadata) # Adding contrast to group metadata self.add_grp_contrast(grp, contrast)
def test_image_orientation(): device = torch.device("cuda:" + str(GPU_NUMBER) if torch.cuda.is_available() else "cpu") cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.set_device(device) print("Using GPU number {}".format(device)) train_lst = ['sub-unf01'] training_transform_dict = { "Resample": { "wspace": 1.5, "hspace": 1, "dspace": 3 }, "CenterCrop": { "size": [176, 128, 160] }, "NumpyToTensor": {}, "NormalizeInstance": {"applied_to": ['im']} } tranform_lst, training_undo_transform = imed_transforms.prepare_transforms(training_transform_dict) model_params = { "name": "Modified3DUNet", "dropout_rate": 0.3, "bn_momentum": 0.9, "depth": 2, "in_channel": 1, "out_channel": 1, "length_3D": [176, 128, 160], "stride_3D": [176, 128, 160], "attention": False, "n_filters": 8 } contrast_params = { "contrast_lst": ['T1w'], "balance": {} } for dim in ['2d', '3d']: for slice_axis in [0, 1, 2]: if dim == '2d': ds = imed_loader.BidsDataset(PATH_BIDS, subject_lst=train_lst, target_suffix=["_seg-manual"], contrast_params=contrast_params, metadata_choice=False, slice_axis=slice_axis, transform=tranform_lst, multichannel=False) ds.load_filenames() else: ds = imed_loader.Bids3DDataset(PATH_BIDS, subject_lst=train_lst, target_suffix=["_seg-manual"], model_params=model_params, contrast_params=contrast_params, metadata_choice=False, slice_axis=slice_axis, transform=tranform_lst, multichannel=False) loader = DataLoader(ds, batch_size=1, shuffle=True, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=1) input_filename, gt_filename, roi_filename, metadata = ds.filename_pairs[0] segpair = imed_loader.SegmentationPair(input_filename, gt_filename, metadata=metadata, slice_axis=slice_axis) nib_original = nib.load(gt_filename[0]) # Get image with original, ras and hwd orientations input_init = nib_original.get_fdata() input_ras = nib.as_closest_canonical(nib_original).get_fdata() img, gt = segpair.get_pair_data() input_hwd = gt[0] pred_tmp_lst, z_tmp_lst = [], [] for i, batch in enumerate(loader): # batch["input_metadata"] = batch["input_metadata"][0] # Take only metadata from one input # batch["gt_metadata"] = batch["gt_metadata"][0] # Take only metadata from one label for smp_idx in range(len(batch['gt'])): # undo transformations if dim == '2d': preds_idx_undo, metadata_idx = training_undo_transform(batch["gt"][smp_idx], batch["gt_metadata"][smp_idx], data_type='gt') # add new sample to pred_tmp_lst pred_tmp_lst.append(preds_idx_undo[0]) z_tmp_lst.append(int(batch['input_metadata'][smp_idx][0]['slice_index'])) else: preds_idx_undo, metadata_idx = training_undo_transform(batch["gt"][smp_idx], batch["gt_metadata"][smp_idx], data_type='gt') fname_ref = metadata_idx[0]['gt_filenames'][0] if (pred_tmp_lst and i == len(loader) - 1) or dim == '3d': # save the completely processed file as a nii nib_ref = nib.load(fname_ref) nib_ref_can = nib.as_closest_canonical(nib_ref) if dim == '2d': tmp_lst = [] for z in range(nib_ref_can.header.get_data_shape()[slice_axis]): tmp_lst.append(pred_tmp_lst[z_tmp_lst.index(z)]) arr = np.stack(tmp_lst, axis=-1) else: arr = np.array(preds_idx_undo[0]) # verify image after transform, undo transform and 3D reconstruction input_hwd_2 = imed_postpro.threshold_predictions(arr) # Some difference are generated due to transform and undo transform # (e.i. Resample interpolation) assert imed_metrics.dice_score(input_hwd_2, input_hwd) >= 0.8 input_ras_2 = imed_loader_utils.orient_img_ras(input_hwd_2, slice_axis) assert imed_metrics.dice_score(input_ras_2, input_ras) >= 0.8 input_init_2 = imed_loader_utils.reorient_image(input_hwd_2, slice_axis, nib_ref, nib_ref_can) assert imed_metrics.dice_score(input_init_2, input_init) >= 0.8 # re-init pred_stack_lst pred_tmp_lst, z_tmp_lst = [], []