Exemplo n.º 1
0
def aggregate_scores(test_ref_pairs,
                     evaluator=NiftiEvaluator,
                     labels=None,
                     use_label=None,
                     nanmean=True,
                     json_output_file=None,
                     json_name="",
                     json_description="",
                     json_author="Fabian",
                     json_task="",
                     num_threads=2,
                     **metric_kwargs):
    """
    test = predicted image
    :param test_ref_pairs:
    :param evaluator:
    :param labels: must be a dict of int-> str or a list of int
    :param nanmean:
    :param json_output_file:
    :param json_name:
    :param json_description:
    :param json_author:
    :param json_task:
    :param metric_kwargs:
    :return:
    """

    if type(evaluator) == type:
        evaluator = evaluator()

    if labels is not None:
        evaluator.set_labels(labels)

    if use_label is not None:
        evaluator.set_use_label(use_label)

    all_scores = OrderedDict()
    all_scores["all"] = []
    all_scores["mean"] = OrderedDict()

    test = [i[0] for i in test_ref_pairs]
    ref = [i[1] for i in test_ref_pairs]
    p = Pool(num_threads)
    all_res = p.map(
        run_evaluation,
        zip(test, ref, [evaluator] * len(ref), [metric_kwargs] * len(ref)))
    p.close()
    p.join()

    for i in range(len(all_res)):
        all_scores["all"].append(all_res[i])

        # append score list for mean
        for label, score_dict in all_res[i].items():
            if label in ("test", "reference"):
                continue
            if label not in all_scores["mean"]:
                all_scores["mean"][label] = OrderedDict()
            for score, value in score_dict.items():
                if score not in all_scores["mean"][label]:
                    all_scores["mean"][label][score] = []
                all_scores["mean"][label][score].append(value)

    for label in all_scores["mean"]:
        for score in all_scores["mean"][label]:
            if nanmean:
                all_scores["mean"][label][score] = float(
                    np.nanmean(all_scores["mean"][label][score]))
            else:
                all_scores["mean"][label][score] = float(
                    np.mean(all_scores["mean"][label][score]))

    # save to file if desired
    # we create a hopefully unique id by hashing the entire output dictionary
    if json_output_file is not None:
        json_dict = OrderedDict()
        json_dict["name"] = json_name
        json_dict["description"] = json_description
        timestamp = datetime.today()
        json_dict["timestamp"] = str(timestamp)
        json_dict["task"] = json_task
        json_dict["author"] = json_author
        json_dict["results"] = all_scores
        json_dict["id"] = hashlib.md5(
            json.dumps(json_dict).encode("utf-8")).hexdigest()[:12]
        save_json(json_dict, json_output_file)

    return all_scores
Exemplo n.º 2
0
    def validate(self,
                 do_mirroring: bool = True,
                 use_sliding_window: bool = True,
                 step_size: float = 0.5,
                 save_softmax: bool = True,
                 use_gaussian: bool = True,
                 overwrite: bool = True,
                 validation_folder_name: str = 'validation_raw',
                 debug: bool = False,
                 all_in_gpu: bool = False,
                 segmentation_export_kwargs: dict = None,
                 run_postprocessing_on_folds: bool = True):
        if isinstance(self.network, DDP):
            net = self.network.module
        else:
            net = self.network
        ds = net.do_ds
        net.do_ds = False

        current_mode = self.network.training
        self.network.eval()

        assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
        if self.dataset_val is None:
            self.load_dataset()
            self.do_split()

        if segmentation_export_kwargs is None:
            if 'segmentation_export_params' in self.plans.keys():
                force_separate_z = self.plans['segmentation_export_params'][
                    'force_separate_z']
                interpolation_order = self.plans['segmentation_export_params'][
                    'interpolation_order']
                interpolation_order_z = self.plans[
                    'segmentation_export_params']['interpolation_order_z']
            else:
                force_separate_z = None
                interpolation_order = 1
                interpolation_order_z = 0
        else:
            force_separate_z = segmentation_export_kwargs['force_separate_z']
            interpolation_order = segmentation_export_kwargs[
                'interpolation_order']
            interpolation_order_z = segmentation_export_kwargs[
                'interpolation_order_z']

        # predictions as they come from the network go here
        output_folder = join(self.output_folder, validation_folder_name)
        maybe_mkdir_p(output_folder)
        # this is for debug purposes
        my_input_args = {
            'do_mirroring': do_mirroring,
            'use_sliding_window': use_sliding_window,
            'step_size': step_size,
            'save_softmax': save_softmax,
            'use_gaussian': use_gaussian,
            'overwrite': overwrite,
            'validation_folder_name': validation_folder_name,
            'debug': debug,
            'all_in_gpu': all_in_gpu,
            'segmentation_export_kwargs': segmentation_export_kwargs,
        }
        save_json(my_input_args, join(output_folder, "validation_args.json"))

        if do_mirroring:
            if not self.data_aug_params['do_mirror']:
                raise RuntimeError(
                    "We did not train with mirroring so you cannot do inference with mirroring enabled"
                )
            mirror_axes = self.data_aug_params['mirror_axes']
        else:
            mirror_axes = ()

        pred_gt_tuples = []

        export_pool = Pool(default_num_threads)
        results = []

        all_keys = list(self.dataset_val.keys())
        my_keys = all_keys[self.local_rank::dist.get_world_size()]
        # we cannot simply iterate over all_keys because we need to know pred_gt_tuples and valid_labels of all cases
        # for evaluation (which is done by local rank 0)
        for k in my_keys:
            properties = load_pickle(self.dataset[k]['properties_file'])
            fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
            pred_gt_tuples.append([
                join(output_folder, fname + ".nii.gz"),
                join(self.gt_niftis_folder, fname + ".nii.gz")
            ])
            if k in my_keys:
                if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
                        (save_softmax and not isfile(join(output_folder, fname + ".npz"))):
                    data = np.load(self.dataset[k]['data_file'])['data']

                    print(k, data.shape)
                    data[-1][data[-1] == -1] = 0

                    softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(
                        data[:-1],
                        do_mirroring=do_mirroring,
                        mirror_axes=mirror_axes,
                        use_sliding_window=use_sliding_window,
                        step_size=step_size,
                        use_gaussian=use_gaussian,
                        all_in_gpu=all_in_gpu,
                        mixed_precision=self.fp16)[1]

                    softmax_pred = softmax_pred.transpose(
                        [0] + [i + 1 for i in self.transpose_backward])

                    if save_softmax:
                        softmax_fname = join(output_folder, fname + ".npz")
                    else:
                        softmax_fname = None
                    """There is a problem with python process communication that prevents us from communicating obejcts
                    larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
                    communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long
                    enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
                    patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
                    then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
                    filename or np.ndarray and will handle this automatically"""
                    if np.prod(softmax_pred.shape) > (
                            2e9 / 4 * 0.85):  # *0.85 just to be save
                        np.save(join(output_folder, fname + ".npy"),
                                softmax_pred)
                        softmax_pred = join(output_folder, fname + ".npy")

                    results.append(
                        export_pool.starmap_async(
                            save_segmentation_nifti_from_softmax,
                            ((softmax_pred,
                              join(output_folder,
                                   fname + ".nii.gz"), properties,
                              interpolation_order, self.regions_class_order,
                              None, None, softmax_fname, None,
                              force_separate_z, interpolation_order_z), )))

        _ = [i.get() for i in results]
        self.print_to_log_file("finished prediction")

        distributed.barrier()

        if self.local_rank == 0:
            # evaluate raw predictions
            self.print_to_log_file("evaluation of raw predictions")
            task = self.dataset_directory.split("/")[-1]
            job_name = self.experiment_name
            _ = aggregate_scores(pred_gt_tuples,
                                 labels=list(range(self.num_classes)),
                                 json_output_file=join(output_folder,
                                                       "summary.json"),
                                 json_name=job_name + " val tiled %s" %
                                 (str(use_sliding_window)),
                                 json_author="Fabian",
                                 json_task=task,
                                 num_threads=default_num_threads)

            if run_postprocessing_on_folds:
                # in the old nnunet we would stop here. Now we add a postprocessing. This postprocessing can remove everything
                # except the largest connected component for each class. To see if this improves results, we do this for all
                # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
                # have this applied during inference as well
                self.print_to_log_file("determining postprocessing")
                determine_postprocessing(
                    self.output_folder,
                    self.gt_niftis_folder,
                    validation_folder_name,
                    final_subf_name=validation_folder_name + "_postprocessed",
                    debug=debug)
                # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
                # They are always in that folder, even if no postprocessing as applied!

            # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
            # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
            # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
            # be used later
            gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
            maybe_mkdir_p(gt_nifti_folder)
            for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
                success = False
                attempts = 0
                e = None
                while not success and attempts < 10:
                    try:
                        shutil.copy(f, gt_nifti_folder)
                        success = True
                    except OSError as e:
                        attempts += 1
                        sleep(1)
                if not success:
                    print("Could not copy gt nifti file %s into folder %s" %
                          (f, gt_nifti_folder))
                    if e is not None:
                        raise e

        self.network.train(current_mode)
        net.do_ds = ds