def add_classes_in_slice_info(self):
        """
        this speeds up oversampling foreground during training
        :return:
        """
        p = Pool(default_num_threads)

        # if there is more than one my_data_identifier (different brnaches) then this code will run for all of them if
        # they start with the same string. not problematic, but not pretty
        stages = [
            join(self.preprocessed_output_folder,
                 self.data_identifier + "_stage%d" % i)
            for i in range(len(self.plans_per_stage))
        ]

        for s in stages:
            print(s.split("/")[-1])
            list_of_npz_files = subfiles(s, True, None, ".npz", True)
            list_of_pkl_files = [i[:-4] + ".pkl" for i in list_of_npz_files]
            all_classes = []
            for pk in list_of_pkl_files:
                props = load_pickle(pk)
                all_classes_tmp = np.array(props['classes'])
                all_classes.append(all_classes_tmp[all_classes_tmp >= 0])
            p.map(add_classes_in_slice_info,
                  zip(list_of_npz_files, list_of_pkl_files, all_classes))
        p.close()
        p.join()
    def load_my_plans(self):
        self.plans = load_pickle(self.plans_fname)

        self.plans_per_stage = self.plans['plans_per_stage']
        self.dataset_properties = self.plans['dataset_properties']
        self.transpose_forward = self.plans['transpose_forward']
        self.transpose_backward = self.plans['transpose_backward']
 def load_pretrained_plans(self):
     classes = self.plans['num_classes']
     self.plans = load_pickle(self.pretrained_model_plans_file)
     self.plans['num_classes'] = classes
     self.transpose_forward = self.plans['transpose_forward']
     self.preprocessor_name = self.plans['preprocessor_name']
     self.plans_per_stage = self.plans['plans_per_stage']
     self.plans['data_identifier'] = self.data_identifier
     self.save_my_plans()
     print(self.plans['plans_per_stage'])
예제 #4
0
def restore_model(pkl_file, checkpoint=None, train=False):
    """
    This is a utility function to load any nnUNet trainer from a pkl. It will recursively search
    nnunet.trainig.network_training for the file that contains the trainer and instantiate it with the arguments saved in the pkl file. If checkpoint
    is specified, it will furthermore load the checkpoint file in train/test mode (as specified by train).
    The pkl file required here is the one that will be saved automatically when calling nnUNetTrainer.save_checkpoint.
    :param pkl_file:
    :param checkpoint:
    :param train:
    :return:
    """
    info = load_pickle(pkl_file)
    init = info['init']
    name = info['name']
    search_in = os.path.join(nnunet.__path__[0], "training", "network_training")
    tr = recursive_find_trainer([search_in], name, current_module="nnunet.training.network_training")

    if tr is None:
        """
        Fabian only. This will trigger searching for trainer classes in other repositories as well
        """
        try:
            import meddec
            search_in = os.path.join(meddec.__path__[0], "model_training")
            tr = recursive_find_trainer([search_in], name, current_module="meddec.model_training")
        except ImportError:
            pass

    if tr is None:
        raise RuntimeError("Could not find the model trainer specified in checkpoint in nnunet.trainig.network_training. If it "
                           "is not located there, please move it or change the code of restore_model. Your model "
                           "trainer can be located in any directory within nnunet.trainig.network_training (search is recursive)."
                           "\nDebug info: \ncheckpoint file: %s\nName of trainer: %s " % (checkpoint, name))
    assert issubclass(tr, nnUNetTrainer), "The network trainer was found but is not a subclass of nnUNetTrainer. " \
                                          "Please make it so!"

    if len(init) == 7:
        print("warning: this model seems to have been saved with a previous version of nnUNet. Attempting to load it "
              "anyways. Expect the unexpected.")
        print("manually editing init args...")
        init = [init[i] for i in range(len(init)) if i != 2]

    # init[0] is the plans file. This argument needs to be replaced because the original plans file may not exist
    # anymore.
    trainer = tr(*init)
    trainer.process_plans(info['plans'])
    if checkpoint is not None:
        trainer.load_checkpoint(checkpoint, train)
    return trainer
예제 #5
0
    def do_split(self):
        """
        This is a suggestion for if your dataset is a dictionary (my personal standard)
        :return:
        """
        splits_file = os.path.join(self.dataset_directory, 'splits_final.pkl')
        if not os.path.isfile(splits_file):
            self.print_to_log_file("Creating new split...")
            splits = []
            all_keys_sorted = np.sort(list(self.dataset.keys()))
            kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
            for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
                train_keys = np.array(all_keys_sorted)[train_idx]
                test_keys = np.array(all_keys_sorted)[test_idx]
                splits.append(OrderedDict())
                splits[-1]['train'] = train_keys
                splits[-1]['val'] = test_keys
            save_pickle(splits, splits_file)

        splits = load_pickle(splits_file)

        if self.fold == "all":
            tr_keys = val_keys = list(self.dataset.keys())
        else:
            tr_keys = splits[self.fold]['train']
            val_keys = splits[self.fold]['val']

        tr_keys.sort()
        val_keys.sort()

        self.dataset_tr = OrderedDict()
        for i in tr_keys:
            self.dataset_tr[i] = self.dataset[i]

        self.dataset_val = OrderedDict()
        for i in val_keys:
            self.dataset_val[i] = self.dataset[i]
예제 #6
0
    def validate(self,
                 do_mirroring: bool = True,
                 use_sliding_window: bool = True,
                 step_size: float = 0.5,
                 save_softmax: bool = True,
                 use_gaussian: bool = True,
                 overwrite: bool = True,
                 validation_folder_name: str = 'validation_raw',
                 debug: bool = False,
                 all_in_gpu: bool = False,
                 segmentation_export_kwargs: dict = None,
                 run_postprocessing_on_folds: bool = True):
        if isinstance(self.network, DDP):
            net = self.network.module
        else:
            net = self.network
        ds = net.do_ds
        net.do_ds = False

        current_mode = self.network.training
        self.network.eval()

        assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
        if self.dataset_val is None:
            self.load_dataset()
            self.do_split()

        if segmentation_export_kwargs is None:
            if 'segmentation_export_params' in self.plans.keys():
                force_separate_z = self.plans['segmentation_export_params'][
                    'force_separate_z']
                interpolation_order = self.plans['segmentation_export_params'][
                    'interpolation_order']
                interpolation_order_z = self.plans[
                    'segmentation_export_params']['interpolation_order_z']
            else:
                force_separate_z = None
                interpolation_order = 1
                interpolation_order_z = 0
        else:
            force_separate_z = segmentation_export_kwargs['force_separate_z']
            interpolation_order = segmentation_export_kwargs[
                'interpolation_order']
            interpolation_order_z = segmentation_export_kwargs[
                'interpolation_order_z']

        # predictions as they come from the network go here
        output_folder = join(self.output_folder, validation_folder_name)
        maybe_mkdir_p(output_folder)
        # this is for debug purposes
        my_input_args = {
            'do_mirroring': do_mirroring,
            'use_sliding_window': use_sliding_window,
            'step_size': step_size,
            'save_softmax': save_softmax,
            'use_gaussian': use_gaussian,
            'overwrite': overwrite,
            'validation_folder_name': validation_folder_name,
            'debug': debug,
            'all_in_gpu': all_in_gpu,
            'segmentation_export_kwargs': segmentation_export_kwargs,
        }
        save_json(my_input_args, join(output_folder, "validation_args.json"))

        if do_mirroring:
            if not self.data_aug_params['do_mirror']:
                raise RuntimeError(
                    "We did not train with mirroring so you cannot do inference with mirroring enabled"
                )
            mirror_axes = self.data_aug_params['mirror_axes']
        else:
            mirror_axes = ()

        pred_gt_tuples = []

        export_pool = Pool(default_num_threads)
        results = []

        all_keys = list(self.dataset_val.keys())
        my_keys = all_keys[self.local_rank::dist.get_world_size()]
        # we cannot simply iterate over all_keys because we need to know pred_gt_tuples and valid_labels of all cases
        # for evaluation (which is done by local rank 0)
        for k in my_keys:
            properties = load_pickle(self.dataset[k]['properties_file'])
            fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
            pred_gt_tuples.append([
                join(output_folder, fname + ".nii.gz"),
                join(self.gt_niftis_folder, fname + ".nii.gz")
            ])
            if k in my_keys:
                if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
                        (save_softmax and not isfile(join(output_folder, fname + ".npz"))):
                    data = np.load(self.dataset[k]['data_file'])['data']

                    print(k, data.shape)
                    data[-1][data[-1] == -1] = 0

                    softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(
                        data[:-1],
                        do_mirroring=do_mirroring,
                        mirror_axes=mirror_axes,
                        use_sliding_window=use_sliding_window,
                        step_size=step_size,
                        use_gaussian=use_gaussian,
                        all_in_gpu=all_in_gpu,
                        mixed_precision=self.fp16)[1]

                    softmax_pred = softmax_pred.transpose(
                        [0] + [i + 1 for i in self.transpose_backward])

                    if save_softmax:
                        softmax_fname = join(output_folder, fname + ".npz")
                    else:
                        softmax_fname = None
                    """There is a problem with python process communication that prevents us from communicating obejcts
                    larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
                    communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
                    enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
                    patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
                    then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
                    filename or np.ndarray and will handle this automatically"""
                    if np.prod(softmax_pred.shape) > (
                            2e9 / 4 * 0.85):  # *0.85 just to be save
                        np.save(join(output_folder, fname + ".npy"),
                                softmax_pred)
                        softmax_pred = join(output_folder, fname + ".npy")

                    results.append(
                        export_pool.starmap_async(
                            save_segmentation_nifti_from_softmax,
                            ((softmax_pred,
                              join(output_folder,
                                   fname + ".nii.gz"), properties,
                              interpolation_order, self.regions_class_order,
                              None, None, softmax_fname, None,
                              force_separate_z, interpolation_order_z), )))

        _ = [i.get() for i in results]
        self.print_to_log_file("finished prediction")

        distributed.barrier()

        if self.local_rank == 0:
            # evaluate raw predictions
            self.print_to_log_file("evaluation of raw predictions")
            task = self.dataset_directory.split("/")[-1]
            job_name = self.experiment_name
            _ = aggregate_scores(pred_gt_tuples,
                                 labels=list(range(self.num_classes)),
                                 json_output_file=join(output_folder,
                                                       "summary.json"),
                                 json_name=job_name + " val tiled %s" %
                                 (str(use_sliding_window)),
                                 json_author="Fabian",
                                 json_task=task,
                                 num_threads=default_num_threads)

            if run_postprocessing_on_folds:
                # in the old nnunet we would stop here. Now we add a postprocessing. This postprocessing can remove everything
                # except the largest connected component for each class. To see if this improves results, we do this for all
                # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
                # have this applied during inference as well
                self.print_to_log_file("determining postprocessing")
                determine_postprocessing(
                    self.output_folder,
                    self.gt_niftis_folder,
                    validation_folder_name,
                    final_subf_name=validation_folder_name + "_postprocessed",
                    debug=debug)
                # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
                # They are always in that folder, even if no postprocessing as applied!

            # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
            # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
            # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
            # be used later
            gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
            maybe_mkdir_p(gt_nifti_folder)
            for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
                success = False
                attempts = 0
                e = None
                while not success and attempts < 10:
                    try:
                        shutil.copy(f, gt_nifti_folder)
                        success = True
                    except OSError as e:
                        attempts += 1
                        sleep(1)
                if not success:
                    print("Could not copy gt nifti file %s into folder %s" %
                          (f, gt_nifti_folder))
                    if e is not None:
                        raise e

        self.network.train(current_mode)
        net.do_ds = ds
from batchgenerators.utilities.file_and_folder_operations import load_pickle
from batchgenerators.utilities.file_and_folder_operations import save_pickle

import numpy as np

from pathlib import Path

plans_file = Path.joinpath(
    Path.home(),
    "Pictures/nnUNet/nnUNet_preprocessed/KiTS2019/nnUNetPlans_plans_3D.pkl")
new_plans_file = Path.joinpath(
    Path.home(),
    "Pictures/nnUNet/nnUNet_preprocessed/KiTS2019/nnUNetPlans_plans_3D.pkl")

plans = load_pickle(plans_file)

plans['plans_per_stage'][0]['patch_size'] = np.array([128, 128, 80])
plans['plans_per_stage'][0]['patch_size_org'] = np.array([160, 128, 80])
#plans['num_classes'] = 1
#plans['all_classes'] = np.array([2])

save_pickle(plans, new_plans_file)