def test_RandomAffine(im_seg, transform): im, seg = im_seg metadata_in = [SampleMetadata({}) for _ in im] if isinstance(im, list) else SampleMetadata({}) # Transform on Numpy do_im, metadata_do = transform(im.copy(), metadata_in) do_seg, metadata_do = transform(seg.copy(), metadata_do) if DEBUGGING and len(im[0].shape) == 2: plot_transformed_sample(im[0], do_im[0], ['raw', 'do']) plot_transformed_sample(seg[0], do_seg[0], ['raw', 'do']) # Transform on Numpy undo_im, _ = transform.undo_transform(do_im, metadata_do) undo_seg, _ = transform.undo_transform(do_seg, metadata_do) if DEBUGGING and len(im[0].shape) == 2: # TODO: ERROR for image but not for seg..... plot_transformed_sample(im[0], undo_im[0], ['raw', 'undo']) plot_transformed_sample(seg[0], undo_seg[0], ['raw', 'undo']) # Check data type and shape _check_dtype(im, [do_im, undo_im]) _check_shape(im, [do_im, undo_im]) _check_dtype(seg, [undo_seg, do_seg]) _check_shape(seg, [undo_seg, do_seg]) # Loop and check for idx, i in enumerate(im): # Data consistency assert dice_score(undo_seg[idx], seg[idx]) > 0.85
def _test_Resample(im_seg, resample_transform, native_resolution, is_2D=False): im, seg = im_seg metadata_ = SampleMetadata({ 'zooms': native_resolution, 'data_shape': im[0].shape if len(im[0].shape) == 3 else list(im[0].shape) + [1], 'data_type': 'im' }) metadata_in = [metadata_ for _ in im] if isinstance(im, list) else SampleMetadata({}) # Resample input data do_im, do_metadata = resample_transform(sample=im, metadata=metadata_in) # Undo Resample on input data undo_im, _ = resample_transform.undo_transform(sample=do_im, metadata=do_metadata) # Resampler for label data resample_transform.interpolation_order = 0 metadata_ = SampleMetadata({ 'zooms': native_resolution, 'data_shape': seg[0].shape if len(seg[0].shape) == 3 else list(seg[0].shape) + [1], 'data_type': 'gt' }) metadata_in = [metadata_ for _ in seg] if isinstance( seg, list) else SampleMetadata({}) # Resample label data do_seg, do_metadata = resample_transform(sample=seg, metadata=metadata_in) # Undo Resample on label data undo_seg, _ = resample_transform.undo_transform(sample=do_seg, metadata=do_metadata) # Check data type and shape _check_dtype(im, [undo_im]) _check_shape(im, [undo_im]) _check_dtype(seg, [undo_seg]) _check_shape(seg, [undo_seg]) # Check data content and data shape between input data and undo for idx, i in enumerate(im): # Plot for debugging if DEBUGGING and is_2D: plot_transformed_sample(im[idx], undo_im[idx], ['raw', 'undo']) plot_transformed_sample(seg[idx], undo_seg[idx], ['raw', 'undo']) # Data consistency assert dice_score(undo_seg[idx], seg[idx]) > 0.8
def run_eval(self): """Stores evaluation results in dictionary Returns: dict, ndarray: dictionary containing evaluation results, data with each object painted a different color """ dct = {} data_gt = self.data_gt.copy() data_pred = self.data_pred.copy() for n in range(self.n_classes): self.data_pred = data_pred[..., n] self.data_gt = data_gt[..., n] dct['vol_pred_class' + str(n)] = self.get_vol(self.data_pred) dct['vol_gt_class' + str(n)] = self.get_vol(self.data_gt) dct['rvd_class' + str(n)], dct['avd_class' + str(n)] = self.get_rvd(), self.get_avd() dct['dice_class' + str(n)] = imed_metrics.dice_score( self.data_gt, self.data_pred) dct['recall_class' + str(n)] = imed_metrics.recall_score( self.data_pred, self.data_gt, err_value=np.nan) dct['precision_class' + str(n)] = imed_metrics.precision_score( self.data_pred, self.data_gt, err_value=np.nan) dct['specificity_class' + str(n)] = imed_metrics.specificity_score( self.data_pred, self.data_gt, err_value=np.nan) dct['n_pred_class' + str(n)], dct['n_gt_class' + str(n)] = self.n_pred[n], self.n_gt[n] dct['ltpr_class' + str(n)], _ = self.get_ltpr(class_idx=n) dct['lfdr_class' + str(n)] = self.get_lfdr(class_idx=n) dct['mse_class' + str(n)] = imed_metrics.mse( self.data_gt, self.data_pred) for lb_size, gt_pred in zip(self.label_size_lst[n][0], self.label_size_lst[n][1]): suffix = self.size_suffix_lst[int(lb_size) - 1] if gt_pred == 'gt': dct['ltpr' + suffix + "_class" + str(n)], dct['n' + suffix] = self.get_ltpr( label_size=lb_size, class_idx=n) else: # gt_pred == 'pred' dct['lfdr' + suffix + "_class" + str(n)] = self.get_lfdr( label_size=lb_size, class_idx=n) if self.n_classes == 1: self.data_painted = np.squeeze(self.data_painted, axis=-1) return dct, self.data_painted
def test_image_orientation(download_data_testing_test_files, loader_parameters): device = torch.device("cuda:" + str(GPU_ID) if torch.cuda.is_available() else "cpu") cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.set_device(device) logger.info(f"Using GPU ID {device}") bids_df = BidsDataframe(loader_parameters, __tmp_dir__, derivatives=True) contrast_params = loader_parameters["contrast_params"] target_suffix = loader_parameters["target_suffix"] roi_params = loader_parameters["roi_params"] train_lst = ['sub-unf01_T1w.nii.gz'] training_transform_dict = { "Resample": { "wspace": 1.5, "hspace": 1, "dspace": 3 }, "CenterCrop": { "size": [176, 128, 160] }, "NormalizeInstance": { "applied_to": ['im'] } } tranform_lst, training_undo_transform = imed_transforms.prepare_transforms( training_transform_dict) model_params = { "name": "Modified3DUNet", "dropout_rate": 0.3, "bn_momentum": 0.9, "depth": 2, "in_channel": 1, "out_channel": 1, "length_3D": [176, 128, 160], "stride_3D": [176, 128, 160], "attention": False, "n_filters": 8 } for dim in ['2d', '3d']: for slice_axis in [0, 1, 2]: if dim == '2d': ds = BidsDataset(bids_df=bids_df, subject_file_lst=train_lst, target_suffix=target_suffix, contrast_params=contrast_params, model_params=model_params, metadata_choice=False, slice_axis=slice_axis, transform=tranform_lst, multichannel=False) ds.load_filenames() else: ds = Bids3DDataset(bids_df=bids_df, subject_file_lst=train_lst, target_suffix=target_suffix, model_params=model_params, contrast_params=contrast_params, metadata_choice=False, slice_axis=slice_axis, transform=tranform_lst, multichannel=False) loader = DataLoader(ds, batch_size=1, shuffle=True, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=1) input_filename, gt_filename, roi_filename, metadata = ds.filename_pairs[ 0] segpair = SegmentationPair(input_filename, gt_filename, metadata=metadata, slice_axis=slice_axis) nib_original = nib.load(gt_filename[0]) # Get image with original, ras and hwd orientations input_init = nib_original.get_fdata() input_ras = nib.as_closest_canonical(nib_original).get_fdata() img, gt = segpair.get_pair_data() input_hwd = gt[0] pred_tmp_lst, z_tmp_lst = [], [] for i, batch in enumerate(loader): # batch["input_metadata"] = batch["input_metadata"][0] # Take only metadata from one input # batch["gt_metadata"] = batch["gt_metadata"][0] # Take only metadata from one label for smp_idx in range(len(batch['gt'])): # undo transformations if dim == '2d': preds_idx_undo, metadata_idx = training_undo_transform( batch["gt"][smp_idx], batch["gt_metadata"][smp_idx], data_type='gt') # add new sample to pred_tmp_lst pred_tmp_lst.append(preds_idx_undo[0]) z_tmp_lst.append( int(batch['input_metadata'][smp_idx][0] ['slice_index'])) else: preds_idx_undo, metadata_idx = training_undo_transform( batch["gt"][smp_idx], batch["gt_metadata"][smp_idx], data_type='gt') fname_ref = metadata_idx[0]['gt_filenames'][0] if (pred_tmp_lst and i == len(loader) - 1) or dim == '3d': # save the completely processed file as a nii nib_ref = nib.load(fname_ref) nib_ref_can = nib.as_closest_canonical(nib_ref) if dim == '2d': tmp_lst = [] for z in range(nib_ref_can.header.get_data_shape() [slice_axis]): tmp_lst.append( pred_tmp_lst[z_tmp_lst.index(z)]) arr = np.stack(tmp_lst, axis=-1) else: arr = np.array(preds_idx_undo[0]) # verify image after transform, undo transform and 3D reconstruction input_hwd_2 = imed_postpro.threshold_predictions(arr) # Some difference are generated due to transform and undo transform # (e.i. Resample interpolation) assert imed_metrics.dice_score(input_hwd_2, input_hwd) >= 0.8 input_ras_2 = imed_loader_utils.orient_img_ras( input_hwd_2, slice_axis) assert imed_metrics.dice_score(input_ras_2, input_ras) >= 0.8 input_init_2 = imed_loader_utils.reorient_image( input_hwd_2, slice_axis, nib_ref, nib_ref_can) assert imed_metrics.dice_score(input_init_2, input_init) >= 0.8 # re-init pred_stack_lst pred_tmp_lst, z_tmp_lst = [], []