Ejemplo n.º 1
0
    def __init__(self, dataset,
                 root_dir='~/data/ARVC/',
                 transform=None, limited_load=False,
                 rescale=True,
                 resample=False):

        self._root_dir = os.path.expanduser(root_dir)
        self.transform = transform
        self._resample = resample
        self._rescale = rescale
        self.dta_settings = get_config('ARVC')
        assert dataset in self.dta_settings.datasets
        files_to_load = get_arvc_datasets()[dataset]
        if limited_load:
            files_to_load = do_limit_load(files_to_load)
        images = list()
        references = list()
        ids = list()
        patient_data_idx = {}
        allidcs = np.empty((0, 2), dtype=int)

        idx = 0
        for filename in tqdm(files_to_load, desc="Load {} set".format(dataset)):

            filename_ref_labels = filename.replace(self.dta_settings.short_axis_dir, self.dta_settings.ref_label_dir)
            img = ARVCImage(filename, filename_ref_labels=filename_ref_labels, rescale=self._rescale,
                            resample=self._resample)
            if img.has_labels:
                patient_data_idx[img.patient_id] = []
                for sample_img, sample_lbl in img.ed():
                    images.append(sample_img)
                    references.append(sample_lbl)
                    ids.append(idx)

                    patient_data_idx[img.patient_id].append(idx)
                    num_slices = len(sample_lbl['labels'])
                    allidcs = np.vstack((allidcs, np.vstack((np.ones(num_slices) * idx, np.arange(num_slices))).T))
                    idx += 1
                for sample_img, sample_lbl in img.es():
                    images.append(sample_img)
                    references.append(sample_lbl)
                    ids.append(idx)
                    patient_data_idx[img.patient_id].append(idx)
                    num_slices = len(sample_lbl['labels'])
                    allidcs = np.vstack((allidcs, np.vstack((np.ones(num_slices) * idx, np.arange(num_slices))).T))
                    idx += 1
        self._idcs = allidcs.astype(int)
        self._images = images
        self._references = references
        self._ids = ids
        self._patient_data_idx = patient_data_idx
    def __init__(self,
                 src_data_path,
                 cardiac_phase,
                 num_of_thresholds=30,
                 verbose=False,
                 patients=None,
                 mc_dropout=False,
                 dataset="ACDC",
                 dt_config_id=None):

        self.src_data_path = src_data_path
        self.pred_labels, self.umaps = None, None
        self.mc_dropout = mc_dropout
        self.type_of_map = 'bmap' if mc_dropout else "emap"
        # TODO IMPORTANT: if dt_config_id is not None then we actually use the detection labels (filtered segmentation
        # errors) as constrained for the mis-predicted seg errors that we "correct" based on the voxel uncertainties
        self.dt_config_id = dt_config_id
        # Important: how many measurements are we performing between rejection 0 voxels and all voxels
        self.num_of_thresholds = num_of_thresholds + 1
        self.x_coverages = None
        self.patient_coverages = None
        self.mean_errors = None
        self.seg_errors = None
        self.use_cropped = False  # whether or not to crop the image to the ground truth region (with padding=10)
        self.patient_dsc = None
        self.patient_hd = None
        self.verbose = verbose
        self.patients = patients
        self.cardiac_phase = cardiac_phase
        self.dta_settings = get_config(dataset)

        # for each patient data, store the optimal C-R curve that could have been achieved
        # based on Geifman paper "Boosting uncertainty estimation..."
        self.optimal_curve = []
        self.save_output_dir = os.path.expanduser(self.src_data_path)
        self._prepare()
Ejemplo n.º 3
0
import numpy as np
import os
from datasets.ACDC.get_data import load_data
from datasets.data_config import get_config
from datasets.ACDC.detection.distance_transforms import generate_adjusted_dt_map, FilterSegErrors

DTA_SETTINGS = get_config('ACDC')


def generate_detector_labels(dt_map, seg_errors, apex_base_slices, fname):
    """
    Generation of target areas in automatic segmentations (test set) that we want to inspect after prediction.
    These areas must be detected by our "detection model". So there're just for supervised learning
    For each patient study we produce numpy array with shape [num_of_classes, w, h, #slices]

    IMPORTANT: we have different rois for the automatic segmentations produced by single prediction (dropout=False)
               or by a Bayesian network using T (we used 10) samples. In the latter case mc_dropout=True.

    :param dt_map:
    :param seg_errors:
    :param apex_base_slices: dict 'A': scalar, 'B': scalar
    :param fname: absolute file name
    :return:
    """

    segfilter_handler = FilterSegErrors(seg_errors, dt_map, apex_base_slices)
    # disabled 25-3: self.detector_labels[patient_id][frame_id] = determine_target_voxels(auto_pred, labels, dt_map)
    dt_labels = segfilter_handler.get_filtered_errors()
    np.savez(fname, dt_ref_labels=dt_labels)
    return dt_labels
Ejemplo n.º 4
0
def main():
    # first we obtain the user arguments, set random seeds, make directories, and store the experiment settings.
    args = parse_args()
    # Set resample always to True for ACDC
    args = get_network_settings(args)
    # End - overwriting args
    args.patch_size = tuple(args.patch_size)
    torch.manual_seed(5431232439)
    torch.cuda.manual_seed(5431232439)
    rs = np.random.RandomState(78346)
    os.makedirs(args.output_directory, exist_ok=True)
    saveExperimentSettings(args,
                           path.join(args.output_directory, 'settings.yaml'))
    print(args)
    dta_settings = get_config(args.dataset)

    # we create a trainer
    n_classes = len(dta_settings.tissue_structure_labels)
    n_channels_input = 1

    trainer, pad = get_trainer(args, n_classes, n_channels_input)

    # we initialize datasets with augmentations.
    training_augmentations = get_train_augmentations(args, rs, pad)
    validation_augmentations = [
        datasets.augmentations.PadInput(pad, args.patch_size),
        datasets.augmentations.RandomCrop(args.patch_size,
                                          input_padding=pad,
                                          rs=rs),
        datasets.augmentations.BlurImage(sigma=0.9),
        datasets.augmentations.ToTensor()
    ]

    training_set, validation_set = get_datasets(
        args, dta_settings, transforms.Compose(training_augmentations),
        transforms.Compose(validation_augmentations))

    # now we create dataloaders
    tra_sampler = RandomSampler(training_set,
                                replacement=True,
                                num_samples=args.batch_size * args.max_iters)
    val_sampler = RandomSampler(validation_set,
                                replacement=True,
                                num_samples=args.batch_size * args.max_iters)

    data_loader_training = torch.utils.data.DataLoader(
        training_set,
        batch_size=args.batch_size,
        sampler=tra_sampler,
        num_workers=args.number_of_workers,
        collate_fn=None)  # _utils.collate.default_collate

    data_loader_validation = torch.utils.data.DataLoader(
        validation_set,
        batch_size=args.batch_size,
        sampler=val_sampler,
        num_workers=args.number_of_workers,
        collate_fn=None)

    # and finally we initialize something for visualization in visdom
    env_suffix = "f" + str(args.fold) + args.output_directory.split("_")[-1]
    vis = Visualizer(
        'Segmentation{}-{}_{}'.format(args.dataset, args.network, env_suffix),
        args.port, 'Learning curves of fold {}'.format(args.fold),
        ['training', 'validation', 'aleatoric'])
    #
    try:
        for it, (training_batch, validation_batch) in tqdm(
                enumerate(zip(data_loader_training, data_loader_validation)),
                desc='Training',
                total=args.max_iters):

            # store model
            if not trainer._train_iter % args.store_model_every:
                trainer.save(args.output_directory)

            # store learning curves
            if not trainer._train_iter % args.store_curves_every:
                trainer.save_losses(args.output_directory)

                # visualize example from validation set
                if not trainer._train_iter % args.update_visualizer_every and trainer._train_iter > 20:
                    image = validation_batch['image'][0][None]
                    val_output = trainer.predict(image)
                    prediction = val_output['predictions']
                    reference = validation_batch['reference'][0]
                    val_patient_id = validation_batch['patient_id'][0]

                    image = image.detach().numpy()
                    prediction = prediction.detach().numpy().astype(
                        float)  # .transpose(1, 2, 0)
                    reference = reference.detach().numpy().astype(float)
                    if pad > 0:
                        # Note: image has shape [batch, 1, x, y], we get rid off extra padding in last two dimensions
                        vis.image((image[0, 0, pad:-pad, pad:-pad]**.5),
                                  'padded image {}'.format(val_patient_id), 12)
                    else:
                        vis.image((image[0]**.5),
                                  'image {}'.format(val_patient_id), 11)
                    vis.image(reference / 3, 'reference', 13)
                    vis.image(prediction / 3, 'prediction',
                              14)  # used log_softmax values
                    if 'aleatoric' in val_output.keys():
                        vis.image(val_output['aleatoric'] / 0.9, 'aleatoric',
                                  15)  #
                    # vis.image((prediction >= 0.5).astype(float), 'binary prediction', 15)
                    # visualize learning curve
                    vis(trainer.current_training_loss,
                        trainer.current_validation_loss,
                        trainer.current_aleatoric_loss)  # plot learning curve

            # train on training mini-batch
            trainer.train(training_batch['image'].to(device),
                          training_batch['reference'].to(device),
                          ignore_label=None
                          if 'ignore_label' not in training_batch.keys() else
                          training_batch['ignore_label'])
            # evaluate on validation mini-batch
            trainer.evaluate(validation_batch['image'].to(device),
                             validation_batch['reference'].to(device),
                             ignore_label=None
                             if 'ignore_label' not in validation_batch.keys()
                             else validation_batch['ignore_label'])

    except KeyboardInterrupt:
        print('interrupted')

    finally:
        trainer.save(args.output_directory)
        trainer.save_losses(args.output_directory)
Ejemplo n.º 5
0
def main():
    # first we obtain the user arguments, set random seeds, make directories, and store the experiment settings.
    args = parse_args()
    if args.samples > 1:
        use_mc = True
    else:
        use_mc = False

    os.makedirs(args.output_directory, exist_ok=True)
    experiment_settings = loadExperimentSettings(
        path.join(args.experiment_directory, 'settings.yaml'))
    # if we pass args.dataset (to evaluate on dataset different than trained on) use it to set input dir for images to be segmented
    dta_settings = get_config(dataset=args.dataset if args.dataset is not None
                              else experiment_settings.dataset)

    model_file = path.join(args.experiment_directory,
                           str(args.checkpoint) + '.model')
    output_dirs = make_dirs(args.output_directory, use_mc)
    # we create a trainer
    if experiment_settings.dataset != dta_settings.dataset:
        n_classes = len(
            get_config(experiment_settings.dataset).tissue_structure_labels)
    else:
        n_classes = len(dta_settings.tissue_structure_labels)
    n_channels_input = 1

    trainer, pad = get_trainer(experiment_settings,
                               n_classes,
                               n_channels_input,
                               model_file=model_file)
    trainer.evaluate_with_dropout = use_mc
    print("WARNING - Rescaling intensities is set to {}".format(args.rescale))
    if args.dataset is None:
        # TODO TEMPORARY !!!!
        root_dir = os.path.expanduser("~/data/ACDC_SR/")
        testset_generator = acdc_validation_fold_image4d(
            experiment_settings.fold,
            root_dir=root_dir,  # dta_settings.short_axis_dir,
            file_suffix="4d_acai.nii.gz",
            resample=experiment_settings.resample,
            patid=args.patid,
            rescale=args.rescale)
    else:
        print("INFO - You passed following arguments")
        print(args)
        testset_generator = get_4dimages_nifti(dta_settings.short_axis_dir,
                                               resample=False,
                                               patid=args.patid,
                                               rescale=True)
    pat_id_saved = None

    for sample in tqdm(testset_generator,
                       desc="Generating 4d segmentation volumes"):
        image, spacing, reference = sample['image'], sample['spacing'], sample[
            'reference']
        pat_id, phase_id, frame_id = sample['patient_id'], sample[
            'cardiac_phase'], sample['frame_id']
        num_of_frames = sample['num_of_frames']
        original_spacing = sample['original_spacing']
        shape_changed = False
        if pat_id_saved is None or pat_id != pat_id_saved:
            # initialize 4d segmentation volume
            segmentation4d = np.empty(
                (0, image.shape[0], image.shape[1], image.shape[2]))

        if pad > 0 or experiment_settings.network[:
                                                  4] == "unet" or experiment_settings.network[:
                                                                                              3] == "drn":
            if experiment_settings.network[:4] == "unet":
                image, shape_changed, yx_padding = fit_unet_image(
                    image, num_downsamplings=4)
            elif experiment_settings.network[:3] == "drn":
                # print("WARNING - adjust image", image.shape)
                image, shape_changed, yx_padding = fit_unet_image(
                    image, num_downsamplings=3)
            else:
                # image has shape [z, y, x] so pad last two dimensions
                image = np.pad(
                    image, ((0, 0), (pad, pad), (pad, pad)),
                    mode="edge")  # "'constant', constant_values=(0,))

        image = image[:, None]  # add extra dim of size 1
        image = torch.from_numpy(image)
        pat_predictions = Predictions()
        with torch.set_grad_enabled(False):
            for s in range(args.samples):
                output = trainer.predict(image)
                if shape_changed:

                    output = restore_original_size(output, yx_padding)
                    # print("WARNING - restore original size", output["softmax"].shape)
                soft_probs = output['softmax'].detach().numpy()
                pat_predictions(soft_probs,
                                cardiac_phase_tag=phase_id,
                                pred_logits=None)
                # torch.cuda.empty_cache()

        # if mc_dropout is true we compute the Bayesian maps (stddev) otherwise Entropy. Method makes sure
        # that in case we're sampling that pred_probs are averaged over samples.
        # 14-8-2019 IMPORTANT: We are computing Bayesian maps with MEAN stddev over classes! Before MAX
        pred_probs, uncertainties = pat_predictions.get_predictions(
            compute_uncertainty=True, mc_dropout=use_mc, agg_func="mean")
        segmentation = np.argmax(pred_probs, axis=1)
        eval_obj = VolumeEvaluation(pat_id,
                                    segmentation,
                                    reference,
                                    voxel_spacing=spacing,
                                    num_of_classes=n_classes,
                                    mc_dropout=use_mc,
                                    cardiac_phase=phase_id)

        eval_obj.post_processing_only()
        # IMPORTANT: in fast_evaluate we post process the predictions only keeping the largest connected components
        segmentation = eval_obj.pred_labels
        # print(segmentation4d.shape, segmentation.shape, segmentation[None].shape, spacing, original_spacing)
        segmentation4d = np.vstack((segmentation4d, segmentation[None]))
        del output

        if args.save_output and num_of_frames == frame_id + 1:
            do_resample = True if experiment_settings.resample or original_spacing[
                -1] < 1. else False
            # IMPORTANT: if frame_id is None (e.g. when processing 4D data) then filename is without suffix frame_id
            save_segmentations(segmentation4d,
                               pat_id,
                               output_dirs,
                               spacing,
                               do_resample,
                               new_spacing=original_spacing,
                               frame_id=None)

        pat_id_saved = pat_id
Ejemplo n.º 6
0
def get_arvc_datasets(split=(0.50, 0.25, 0.25), rs=None,
                      ids_only=False) -> dict:
    """
    Creates three list with absolute file names of short-axis MRI images for ARVC dataset
    training, validation and test based on the specified split percentages.

    IMPORTANT: we first check whether we already created a certain split (split file name exists)
                if true, we load the existing file else we create a new one in data root dir e.g. ~/data/ARVC/

    :param split:
    :param rs:
    :param ids_only
    :return:
    """
    def create_absolute_file_names(rel_file_list, src_path) -> list:
        return [
            os.path.join(src_path, rel_fname) for rel_fname in rel_file_list
        ]

    def get_dataset_files(all_files, file_ids) -> list:
        return [all_files[fid] for fid in file_ids]

    dta_settings = get_config('ARVC')
    if os.path.isfile(dta_settings.split_file):
        # load existing splits
        with open(dta_settings.split_file, 'r') as fp:
            split_config = yaml.load(fp, Loader=yaml.FullLoader)
            training_ids = split_config['training']
            validation_ids = split_config['validation']
            test_ids = split_config['test']
        print("INFO - Load existing split file {}".format(
            dta_settings.split_file))
    else:
        # create new split
        assert sum(split) == 1.
        # get a list with the short-axis image files that we have in total (e.g. in ~/data/ARVC/images/*.nii.gz)
        search_suffix = "*" + dta_settings.img_file_ext
        search_mask_img = os.path.expanduser(
            os.path.join(dta_settings.short_axis_dir, search_suffix))
        # we make a list of relative file names (root data dir is omitted)
        files_to_load = [
            os.path.basename(abs_fname)
            for abs_fname in get_image_file_list(search_mask_img)
        ]
        num_of_patients = len(files_to_load)
        # permute the list of all files, we will separate the permuted list into train, validation and test sets
        if rs is None:
            rs = np.random.RandomState(78346)
        ids = rs.permutation(num_of_patients)
        # create three lists of files
        training_ids = get_dataset_files(files_to_load,
                                         ids[:int(split[0] * num_of_patients)])
        c_size = int(len(training_ids))
        validation_ids = get_dataset_files(
            files_to_load,
            ids[c_size:c_size + int(split[1] * num_of_patients)])
        c_size += len(validation_ids)
        test_ids = get_dataset_files(files_to_load, ids[c_size:])

        # write split configuration
        split_config = {
            'training': training_ids,
            'validation': validation_ids,
            'test': test_ids
        }
        print("INFO - Write split file {}".format(dta_settings.split_file))
        with open(dta_settings.split_file, 'w') as fp:
            yaml.dump(split_config, fp)

    return {
        'training':
        create_absolute_file_names(training_ids, dta_settings.short_axis_dir),
        'validation':
        create_absolute_file_names(validation_ids,
                                   dta_settings.short_axis_dir)[:25],
        'test':
        create_absolute_file_names(test_ids, dta_settings.short_axis_dir)
    }
Ejemplo n.º 7
0
def arvc_get_evaluate_set(dataset, limited_load=False, resample=False, rescale=True, patid=None, all_frames=False):
    """
    We use this function during validation and testing. Different than for the ARVCDataSet object which returns
    slices that are used for training and hence transformed e.g. rotation, mirroring etc.

    :param dataset:
    :param limited_load:
    :param resample:
    :param rescale:
    :param patid: one patient ID to process (string)
    :param all_frames: boolean, if TRUE, we get all time frames of a patient ignoring reference labels
    :return:
    """
    dta_settings = get_config('ARVC')
    assert dataset in dta_settings.datasets
    files_to_load = get_arvc_datasets()[dataset]
    if patid is not None:
        files_to_load = [fname for fname in files_to_load if patid in fname]
        if len(files_to_load) == 0:

            raise ValueError("ERROR - {} is not a valid patient id".format(patid))
    if limited_load:
        files_to_load = do_limit_load(files_to_load)
    patient_data_idx = {}

    idx = 0

    for filename in tqdm(files_to_load, desc="Load {} set".format(dataset)):
        filename_ref_labels = filename.replace(dta_settings.short_axis_dir, dta_settings.ref_label_dir)
        img = ARVCImage(filename, filename_ref_labels=filename_ref_labels, rescale=rescale,
                        resample=resample)
        if all_frames:

            patient_data_idx[img.patient_id] = []
            for sample_img, sample_lbl in img.all():
                sample = {'image': sample_img['image'],
                          'reference': sample_lbl['labels'],
                          'spacing': sample_img['spacing'],
                          'direction': sample_img['direction'],
                          'origin': sample_img['origin'],
                          'frame_id': sample_img['frame_id'],
                          'cardiac_phase': sample_img['cardiac_phase'],
                          'structures': sample_lbl['structures'],
                          'ignore_label': sample_lbl['ignore_label'],
                          'original_spacing': sample_lbl['original_spacing'],
                          'patient_id': sample_lbl['patient_id'],
                          'number_of_frames': sample_img['number_of_frames']}
                patient_data_idx[img.patient_id].append(idx)
                yield sample
                idx += 1

        if img.has_labels and not all_frames:
            patient_data_idx[img.patient_id] = []
            for sample_img, sample_lbl in img.ed():
                sample = {'image': sample_img['image'],
                          'reference': sample_lbl['labels'],
                          'spacing': sample_img['spacing'],
                          'direction': sample_img['direction'],
                          'origin': sample_img['origin'],
                          'frame_id': sample_img['frame_id'],
                          'cardiac_phase': sample_img['cardiac_phase'],
                          'structures': sample_lbl['structures'],
                          'ignore_label': sample_lbl['ignore_label'],
                          'original_spacing': sample_lbl['original_spacing'],
                          'patient_id': sample_lbl['patient_id'],
                          'number_of_frames': sample_img['number_of_frames']}
                patient_data_idx[img.patient_id].append(idx)
                yield sample
                idx += 1

            for sample_img, sample_lbl in img.es():
                sample = {'image': sample_img['image'],
                          'reference': sample_lbl['labels'],
                          'spacing': sample_img['spacing'],
                          'direction': sample_img['direction'],
                          'origin': sample_img['origin'],
                          'frame_id': sample_img['frame_id'],
                          'cardiac_phase': sample_img['cardiac_phase'],
                          'structures': sample_lbl['structures'],
                          'ignore_label': sample_lbl['ignore_label'],
                          'original_spacing': sample_lbl['original_spacing'],
                          'patient_id': sample_lbl['patient_id'],
                          'number_of_frames': sample_img['number_of_frames']}
                patient_data_idx[img.patient_id].append(idx)
                yield sample
                idx += 1
Ejemplo n.º 8
0
import copy
import numpy as np
import os
import glob
import yaml
import shutil

from tqdm import tqdm
from torch.utils.data import Dataset
from datasets.data_config import get_config
from datasets.common import read_nifty, apply_2d_zoom_3d, get_arvc_datasets

arvc_data_settings = get_config('ARVC')


def do_limit_load(p_dataset_filelist):
    return p_dataset_filelist[:arvc_data_settings.limited_load_max]


def load_data(data_config, limited_load=False, search_mask="*.nii.gz", **kwargs):
    '''

    :return:
    '''
    load_ref_labels = False if "load_ref_labels" not in kwargs.keys() else kwargs['load_ref_labels']
    resample = kwargs.get('resample', False)
    search_mask_img = os.path.expanduser(os.path.join(data_config.short_axis_dir, search_mask))
    files_to_load = glob.glob(search_mask_img)
    files_to_load.sort()
    if len(files_to_load) == 0:
        raise ValueError("ERROR - ARVC data - Can't find any files to load in {}".format(search_mask))
Ejemplo n.º 9
0
def main():
    # first we obtain the user arguments, set random seeds, make directories, and store the experiment settings.
    args = parse_args()
    if args.samples > 1:
        use_mc = True
        type_of_map = "bmap"
        res_suffix = "_mc.npz"
    else:
        use_mc = False
        type_of_map = "emap"
        res_suffix = ".npz"

    os.makedirs(args.output_directory, exist_ok=True)
    experiment_settings = loadExperimentSettings(path.join(args.experiment_directory, 'settings.yaml'))
    dta_settings = get_config("ARVC")

    model_file = path.join(args.experiment_directory, str(args.checkpoint) + '.model')
    output_dirs = make_dirs(args.output_directory, use_mc)
    # we create a trainer
    n_classes = len(dta_settings.tissue_structure_labels)
    n_channels_input = 1
    transfer_learn = False
    # IMPORTANT!!! if necessary enable transfer learning settings
    if experiment_settings.dataset != dta_settings.dataset:
        n_classes = len(get_config(experiment_settings.dataset).tissue_structure_labels)
        transfer_learn = True
        print("INFO - transfer learning: trained on nclasses {}".format(n_classes))

    trainer, pad = get_trainer(experiment_settings, n_classes, n_channels_input, model_file=model_file)
    trainer.evaluate_with_dropout = use_mc
    test_results = ARVCTestResult()
    # if patid is not None e.g. "NL256100_1" then we do a single evaluation
    testset_generator = get_test_set_generator(args, experiment_settings, dta_settings, patid=args.patid, all_frames=args.all_frames)
    # we're evaluating patient by patient. one patient can max have 4 volumes if all tissue structures are
    # annotated in separate phases.
    pat_saved, c_phase_saved, sample_saved, phase_results, result_obj = None, None, None, None, None
    for sample in testset_generator:
        image, reference = sample['image'], sample['reference']
        pat_id, phase_id, frame_id = sample['patient_id'], sample['cardiac_phase'], sample['frame_id']
        spacing, original_spacing, direction = sample['spacing'], sample['original_spacing'], sample['direction']

        if pat_saved != pat_id:
            check_save(pat_saved, sample_saved, result_obj, args, experiment_settings, output_dirs, type_of_map)
        result_obj = prepare_result_obj(pat_saved, sample, result_obj, experiment_settings)
        # get ignore_labels (numpy array of shape [n_classes]. Add batch dim in front with None
        ignore_labels, merge_results, phase_results = prepare_evaluation(sample, pat_saved, c_phase_saved,
                                                                         phase_results)
        if transfer_learn:
            reference, ignore_labels = prepare_transfer_learning(n_classes, reference, dta_settings.cls_translate,
                                                                 ignore_labels)
        if pad > 0:
            # image has shape [z, y, x] so pad last two dimensions
            image = np.pad(image, ((0,0), (pad, pad), (pad, pad)), mode="edge") # "'constant', constant_values=(0,))

        image = image[:, None]  # add extra dim of size 1
        image = torch.from_numpy(image)
        pat_predictions = Predictions()
        with torch.set_grad_enabled(False):
            for s in range(args.samples):
                output = trainer.predict(image)
                if ignore_labels is not None:
                    # ignore_labels if not None is a binary vector of size n_classes (target dataset)
                    pred_mask = get_loss_mask(output['softmax'], ignore_labels[None])
                    soft_probs = output['softmax'].detach().numpy() * pred_mask.detach().numpy()
                else:
                    soft_probs = output['softmax'].detach().numpy()
                aleatoric = None if 'aleatoric' not in output.keys() else np.squeeze(output['aleatoric'])
                pat_predictions(soft_probs, cardiac_phase_tag=phase_id, pred_logits=None)
                # torch.cuda.empty_cache()

        pred_probs, uncertainties = pat_predictions.get_predictions(compute_uncertainty=True, mc_dropout=use_mc)
        segmentation = np.argmax(pred_probs, axis=1)

        eval_obj = VolumeEvaluation(pat_id, segmentation, reference,
                                    voxel_spacing=spacing, num_of_classes=n_classes,
                                    mc_dropout=use_mc, cardiac_phase=phase_id, ignore_labels=ignore_labels)
        if args.all_frames:
            eval_obj.post_processing_only()
        else:
            eval_obj.fast_evaluate(compute_hd=True)
            phase_results = process_volume_results(test_results, pat_id, phase_id, merge_results,
                                                   phase_results)
        # IMPORTANT: in fast_evaluate we post process the predictions only keeping the largest connected components
        segmentation = eval_obj.pred_labels

        result_obj = process_volume(result_obj, sample, segmentation, uncertainties, aleatoric=aleatoric)
        if transfer_learn and not args.all_frames:
            print("{}: RV/LV {:.2f} {:.2f}".format(eval_obj.patient_id, eval_obj.dice[1], eval_obj.dice[3]))
        del output
        # save patient and phase we just processed, we need this in order to know whether or not to merge the results
        pat_saved, c_phase_saved, sample_saved = pat_id, phase_id, sample

    if not args.all_frames:
        test_results.show_results(transfer_learning=transfer_learn)
    check_save(pat_saved, sample_saved, result_obj, args, experiment_settings, output_dirs, type_of_map)

    if args.save_results:
        fname = path.join(args.output_directory, "results_f" + str(args.fold) + "_{}".format(len(test_results.pat_ids)) +
                          res_suffix)
        test_results.save(filename=fname)
        print("INFO - performance results saved to {}".format(fname))
Ejemplo n.º 10
0
def main():
    # first we obtain the user arguments, set random seeds, make directories, and store the experiment settings.
    args = parse_args()
    if args.samples > 1:
        use_mc = True
        type_of_map = "bmap"
        res_suffix = "_mc.npz"
    else:
        use_mc = False
        type_of_map = "emap"
        res_suffix = ".npz"

    print("INFO - Evaluating with super resolution = {}".format(
        args.super_resolution))
    os.makedirs(args.output_directory, exist_ok=True)
    experiment_settings = loadExperimentSettings(
        path.join(args.experiment_directory, 'settings.yaml'))
    dta_settings = get_config(experiment_settings.dataset)

    model_file = path.join(args.experiment_directory,
                           str(args.checkpoint) + '.model')
    output_dirs = make_dirs(args.output_directory, use_mc)
    # we create a trainer
    n_classes = len(dta_settings.tissue_structure_labels)
    n_channels_input = 1

    trainer, pad = get_trainer(experiment_settings,
                               n_classes,
                               n_channels_input,
                               model_file=model_file)
    trainer.evaluate_with_dropout = use_mc
    test_results = ACDCTestResult()
    testset_generator = get_test_set_generator(args,
                                               experiment_settings,
                                               dta_settings,
                                               patid=args.patid)

    for sample in testset_generator:
        image, spacing, reference = sample['image'], sample['spacing'], sample[
            'reference']
        pat_id, phase_id, frame_id = sample['patient_id'], sample[
            'cardiac_phase'], sample['frame_id']
        original_spacing = sample['original_spacing']
        shape_changed = False
        # if pat_id == "patient037" and phase_id == "ES":
        #     print("WARNING - Skip {}".format(pat_id))
        #     continue
        if pad > 0 or experiment_settings.network[:
                                                  4] == "unet" or experiment_settings.network[:
                                                                                              3] == "drn":
            if experiment_settings.network[:4] == "unet":
                image, shape_changed, yx_padding = fit_unet_image(
                    image, num_downsamplings=4)
            elif experiment_settings.network[:3] == "drn":
                # print("WARNING - adjust image", image.shape)
                image, shape_changed, yx_padding = fit_unet_image(
                    image, num_downsamplings=3)
            else:
                # image has shape [z, y, x] so pad last two dimensions
                image = np.pad(
                    image, ((0, 0), (pad, pad), (pad, pad)),
                    mode="edge")  # "'constant', constant_values=(0,))

        image = image[:, None]  # add extra dim of size 1
        image = torch.from_numpy(image)
        pat_predictions = Predictions()
        with torch.set_grad_enabled(False):
            for s in range(args.samples):
                output = trainer.predict(image)
                if shape_changed:

                    output = restore_original_size(output, yx_padding)
                    # print("WARNING - restore original size", output["softmax"].shape)
                soft_probs = output['softmax'].detach().numpy()
                aleatoric = None if 'aleatoric' not in output.keys(
                ) else np.squeeze(output['aleatoric'])
                aleatoric = None if use_mc else aleatoric
                pat_predictions(soft_probs,
                                cardiac_phase_tag=phase_id,
                                pred_logits=None)
                # torch.cuda.empty_cache()

        # if mc_dropout is true we compute the Bayesian maps (stddev) otherwise Entropy. Method makes sure
        # that in case we're sampling that pred_probs are averaged over samples.
        # 14-8-2019 IMPORTANT: We are computing Bayesian maps with MEAN stddev over classes! Before MAX
        pred_probs, uncertainties = pat_predictions.get_predictions(
            compute_uncertainty=True, mc_dropout=use_mc, agg_func="mean")
        segmentation = np.argmax(pred_probs, axis=1)
        eval_obj = VolumeEvaluation(pat_id,
                                    segmentation,
                                    reference,
                                    voxel_spacing=spacing,
                                    num_of_classes=n_classes,
                                    mc_dropout=use_mc,
                                    cardiac_phase=phase_id)

        eval_obj.fast_evaluate(compute_hd=True)
        # IMPORTANT: in fast_evaluate we post process the predictions only keeping the largest connected components
        segmentation = eval_obj.pred_labels
        test_results(eval_obj.dice,
                     hd=eval_obj.hd,
                     cardiac_phase_tag=phase_id,
                     pat_id=pat_id,
                     hd95=eval_obj.hd95,
                     assd=eval_obj.assd)
        eval_obj.show_results()
        if args.save_output:
            # print("INFO - image/reference size ", image.shape, reference.shape)
            do_resample = True if experiment_settings.resample or original_spacing[
                -1] < 1. else False
            save_pat_objects(
                pat_id,
                phase_id,
                segmentation,
                None,
                uncertainties,
                aleatoric,
                type_of_map,
                spacing,
                output_dirs,
                new_spacing=original_spacing,
                do_resample=do_resample,
                pred_probs=pred_probs if args.save_probs else None)
        # Work-around to save predicted probabilities only
        if args.save_probs and not args.save_output:
            do_resample = True if experiment_settings.resample or original_spacing[
                -1] < 1. else False
            save_pred_probs(pat_id,
                            phase_id,
                            pred_probs,
                            spacing,
                            output_dirs,
                            new_spacing=original_spacing,
                            do_resample=do_resample,
                            direction=None,
                            origin=None)

        del output

    test_results.show_results()
    test_results.excel_string()

    if args.save_results:
        fname = path.join(
            args.output_directory,
            "results_f" + str(experiment_settings.fold) +
            "_{}".format(len(test_results.pat_ids)) + res_suffix)
        test_results.save(filename=fname)
        print("INFO - performance results saved to {}".format(fname))