コード例 #1
0
def main():
    args = get_args()
    log_file = args.log_file
    preprocessed_data_dir = args.preprocessed_data_dir
    output_folder = args.postprocessed_data_dir
    ground_truths = args.label_data_dir
    output_dtype = dtype_map[args.output_dtype]
    num_threads_nifti_save = args.num_threads_nifti_save
    all_in_gpu = "None"
    force_separate_z = None
    interp_order = 3
    interp_order_z = 0

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Load necessary metadata.
    print("Loading necessary metadata...")
    with open(os.path.join(preprocessed_data_dir, "preprocessed_files.pkl"),
              "rb") as f:
        preprocessed_files = pickle.load(f)
    dictionaries = []
    for preprocessed_file in preprocessed_files:
        with open(
                os.path.join(preprocessed_data_dir,
                             preprocessed_file + ".pkl"), "rb") as f:
            dct = pickle.load(f)[1]
            dictionaries.append(dct)

    # Load predictions from loadgen accuracy log.
    print("Loading loadgen accuracy log...")
    predictions = load_loadgen_log(log_file, output_dtype, dictionaries)

    # Save predictions
    # This runs in multiprocess
    print("Running postprocessing with multiple threads...")
    save_predictions_MLPerf(predictions, output_folder, preprocessed_files,
                            dictionaries, num_threads_nifti_save, all_in_gpu,
                            force_separate_z, interp_order, interp_order_z)

    # Run evaluation
    print("Running evaluation...")
    evaluate_regions(output_folder, ground_truths, get_brats_regions())

    # Load evaluation summary
    print("Loading evaluation summary...")
    with open(os.path.join(output_folder, "summary.csv")) as f:
        for line in f:
            words = line.split(",")
            if words[0] == "mean":
                whole = float(words[1])
                core = float(words[2])
                enhancing = float(words[3])
                mean = (whole + core + enhancing) / 3
                print(
                    "Accuracy: mean = {:.5f}, whole tumor = {:.4f}, tumor core = {:.4f}, enhancing tumor = {:.4f}"
                    .format(mean, whole, core, enhancing))
                break

    print("Done!")
コード例 #2
0
 def validate(self,
              do_mirroring: bool = True,
              use_sliding_window: bool = True,
              step_size: int = 0.5,
              save_softmax: bool = True,
              use_gaussian: bool = True,
              overwrite: bool = True,
              validation_folder_name: str = 'validation_raw',
              debug: bool = False,
              all_in_gpu: bool = False,
              segmentation_export_kwargs: dict = None,
              run_postprocessing_on_folds: bool = True):
     super().validate(
         do_mirroring=do_mirroring,
         use_sliding_window=use_sliding_window,
         step_size=step_size,
         save_softmax=save_softmax,
         use_gaussian=use_gaussian,
         overwrite=overwrite,
         validation_folder_name=validation_folder_name,
         debug=debug,
         all_in_gpu=all_in_gpu,
         segmentation_export_kwargs=segmentation_export_kwargs,
         run_postprocessing_on_folds=run_postprocessing_on_folds)
     # run brats specific validation
     output_folder = join(self.output_folder, validation_folder_name)
     evaluate_regions(output_folder, self.gt_niftis_folder, self.regions)
コード例 #3
0
    def validate(self,
                 do_mirroring: bool = True,
                 use_sliding_window: bool = True,
                 step_size: float = 0.5,
                 save_softmax: bool = True,
                 use_gaussian: bool = True,
                 overwrite: bool = True,
                 validation_folder_name: str = 'validation_raw',
                 debug: bool = False,
                 all_in_gpu: bool = False,
                 force_separate_z: bool = None,
                 interpolation_order: int = 3,
                 interpolation_order_z=0):
        """
        disable nnunet postprocessing. this would just waste computation time and does not benefit brats

        !!!We run this with use_sliding_window=False per default (see on_epoch_end). This triggers fully convolutional
        inference. THIS ONLY MAKES SENSE WHEN TRAINING ON FULL IMAGES! Make sure use_sliding_window=True when running
        with default patch size (128x128x128)!!!

        per default this does not use test time data augmentation (mirroring). The reference implementation, however,
        does. I disabled it here because this eats up a lot of computation time

        """
        validation_start = time()

        current_mode = self.network.training
        self.network.eval()

        assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
        if self.dataset_val is None:
            self.load_dataset()
            self.do_split()

        # predictions as they come from the network go here
        output_folder = join(self.output_folder, validation_folder_name)
        maybe_mkdir_p(output_folder)

        # this is for debug purposes
        my_input_args = {
            'do_mirroring': do_mirroring,
            'use_sliding_window': use_sliding_window,
            'step_size': step_size,
            'save_softmax': save_softmax,
            'use_gaussian': use_gaussian,
            'overwrite': overwrite,
            'validation_folder_name': validation_folder_name,
            'debug': debug,
            'all_in_gpu': all_in_gpu,
            'force_separate_z': force_separate_z,
            'interpolation_order': interpolation_order,
            'interpolation_order_z': interpolation_order_z,
        }
        save_json(my_input_args, join(output_folder, "validation_args.json"))

        if do_mirroring:
            if not self.data_aug_params['do_mirror']:
                raise RuntimeError(
                    "We did not train with mirroring so you cannot do inference with mirroring enabled"
                )
            mirror_axes = self.data_aug_params['mirror_axes']
        else:
            mirror_axes = ()

        export_pool = Pool(default_num_threads)
        results = []

        for k in self.dataset_val.keys():
            properties = load_pickle(self.dataset[k]['properties_file'])
            fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
            if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
                    (save_softmax and not isfile(join(output_folder, fname + ".npz"))):
                data = np.load(self.dataset[k]['data_file'])['data']

                #print(k, data.shape)

                softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(
                    data[:-1],
                    do_mirroring=do_mirroring,
                    mirror_axes=mirror_axes,
                    use_sliding_window=use_sliding_window,
                    step_size=step_size,
                    use_gaussian=use_gaussian,
                    all_in_gpu=all_in_gpu,
                    verbose=False,
                    mixed_precision=self.fp16)[1]

                # this does not do anything in brats -> remove this line
                # softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward])

                if save_softmax:
                    softmax_fname = join(output_folder, fname + ".npz")
                else:
                    softmax_fname = None

                results.append(
                    export_pool.starmap_async(
                        save_segmentation_nifti_from_softmax,
                        ((softmax_pred, join(output_folder, fname + ".nii.gz"),
                          properties, interpolation_order, None, None, None,
                          softmax_fname, None, force_separate_z,
                          interpolation_order_z, False), )))

        _ = [i.get() for i in results]
        self.print_to_log_file("finished prediction")

        # evaluate raw predictions
        self.print_to_log_file("evaluation of raw predictions")

        # this writes a csv file into output_folder
        evaluate_regions(output_folder, self.gt_niftis_folder,
                         self.evaluation_regions)
        csv_file = np.loadtxt(join(output_folder, 'summary.csv'),
                              skiprows=1,
                              dtype=str,
                              delimiter=',')[:, 1:]

        # these are the values that are compute with np.nanmean aggregation
        whole, core, enhancing = csv_file[-4, :].astype(float)

        # do some cleanup
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        self.network.train(current_mode)
        validation_end = time()
        self.print_to_log_file('Running the validation took %f seconds' %
                               (validation_end - validation_start))
        self.print_to_log_file(
            '(the time needed for validation is included in the total epoch time!)'
        )

        return whole, core, enhancing
コード例 #4
0
ファイル: Task082_BraTS_2020.py プロジェクト: zzsnow/nnUNet
def collect_and_prepare(base_dir, num_processes = 12, clean=False):
    """
    collect all cv_niftis, compute brats metrics, compute enh tumor thresholds and summarize in csv
    :param base_dir:
    :return:
    """
    out = join(base_dir, 'cv_results')
    out_pp = join(base_dir, 'cv_results_pp')
    experiments = subfolders(base_dir, join=False, prefix='nnUNetTrainer')
    regions = get_brats_regions()
    gt_dir = join(base_dir, 'gt_niftis')
    replace_with = 2

    failed = []
    successful = []
    for e in experiments:
        print(e)
        try:
            o = join(out, e)
            o_p = join(out_pp, e)
            maybe_mkdir_p(o)
            maybe_mkdir_p(o_p)
            collect_cv_niftis(join(base_dir, e), o)
            if clean or not isfile(join(o, 'summary.csv')):
                evaluate_regions(o, gt_dir, regions, num_processes)
            if clean or not isfile(join(o_p, 'threshold.pkl')):
                determine_brats_postprocessing(o, gt_dir, o_p, num_processes, thresholds=list(np.arange(0, 760, 10)), replace_with=replace_with)
            if clean or not isfile(join(o_p, 'summary.csv')):
                evaluate_regions(o_p, gt_dir, regions, num_processes)
            successful.append(e)
        except Exception as ex:
            print("\nERROR\n", e, ex, "\n")
            failed.append(e)

    # we are interested in the mean (nan is 1) column
    with open(join(base_dir, 'cv_summary.csv'), 'w') as f:
        f.write('name,whole,core,enh,mean\n')
        for e in successful:
            expected_nopp = join(out, e, 'summary.csv')
            expected_pp = join(out, out_pp, e, 'summary.csv')
            if isfile(expected_nopp):
                res = np.loadtxt(expected_nopp, dtype=str, skiprows=0, delimiter=',')[-2]
                as_numeric = [float(i) for i in res[1:]]
                f.write(e + '_noPP,')
                f.write("%0.4f," % as_numeric[0])
                f.write("%0.4f," % as_numeric[1])
                f.write("%0.4f," % as_numeric[2])
                f.write("%0.4f\n" % np.mean(as_numeric))
            if isfile(expected_pp):
                res = np.loadtxt(expected_pp, dtype=str, skiprows=0, delimiter=',')[-2]
                as_numeric = [float(i) for i in res[1:]]
                f.write(e + '_PP,')
                f.write("%0.4f," % as_numeric[0])
                f.write("%0.4f," % as_numeric[1])
                f.write("%0.4f," % as_numeric[2])
                f.write("%0.4f\n" % np.mean(as_numeric))

    # this just crawls the folders and evaluates what it finds
    with open(join(base_dir, 'cv_summary2.csv'), 'w') as f:
        for folder in ['cv_results', 'cv_results_pp']:
            for ex in subdirs(join(base_dir, folder), join=False):
                print(folder, ex)
                expected = join(base_dir, folder, ex, 'summary.csv')
                if clean or not isfile(expected):
                    evaluate_regions(join(base_dir, folder, ex), gt_dir, regions, num_processes)
                if isfile(expected):
                    res = np.loadtxt(expected, dtype=str, skiprows=0, delimiter=',')[-2]
                    as_numeric = [float(i) for i in res[1:]]
                    f.write('%s__%s,' % (folder, ex))
                    f.write("%0.4f," % as_numeric[0])
                    f.write("%0.4f," % as_numeric[1])
                    f.write("%0.4f," % as_numeric[2])
                    f.write("%0.4f\n" % np.mean(as_numeric))

        f.write('name,whole,core,enh,mean\n')
        for e in successful:
            expected_nopp = join(out, e, 'summary.csv')
            expected_pp = join(out, out_pp, e, 'summary.csv')
            if isfile(expected_nopp):
                res = np.loadtxt(expected_nopp, dtype=str, skiprows=0, delimiter=',')[-2]
                as_numeric = [float(i) for i in res[1:]]
                f.write(e + '_noPP,')
                f.write("%0.4f," % as_numeric[0])
                f.write("%0.4f," % as_numeric[1])
                f.write("%0.4f," % as_numeric[2])
                f.write("%0.4f\n" % np.mean(as_numeric))
            if isfile(expected_pp):
                res = np.loadtxt(expected_pp, dtype=str, skiprows=0, delimiter=',')[-2]
                as_numeric = [float(i) for i in res[1:]]
                f.write(e + '_PP,')
                f.write("%0.4f," % as_numeric[0])
                f.write("%0.4f," % as_numeric[1])
                f.write("%0.4f," % as_numeric[2])
                f.write("%0.4f\n" % np.mean(as_numeric))

    # apply threshold to val set
    expected_num_cases = 125
    missing_valset = []
    has_val_pred = []
    for e in successful:
        if isdir(join(base_dir, 'predVal', e)):
            currdir = join(base_dir, 'predVal', e)
            files = subfiles(currdir, suffix='.nii.gz', join=False)
            if len(files) != expected_num_cases:
                print(e, 'prediction not done, found %d files, expected %s' % (len(files), expected_num_cases))
                continue
            output_folder = join(base_dir, 'predVal_PP', e)
            maybe_mkdir_p(output_folder)
            threshold = load_pickle(join(out_pp, e, 'threshold.pkl'))[2]
            if threshold > 1000: threshold = 750  # don't make it too big!
            apply_threshold_to_folder(currdir, output_folder, threshold, replace_with, num_processes)
            has_val_pred.append(e)
        else:
            print(e, 'has no valset predictions')
            missing_valset.append(e)

    # 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold' needs special treatment
    e = 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5'
    currdir = join(base_dir, 'predVal', 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold')
    output_folder = join(base_dir, 'predVal_PP', 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold')
    maybe_mkdir_p(output_folder)
    threshold = load_pickle(join(out_pp, e, 'threshold.pkl'))[2]
    if threshold > 1000: threshold = 750  # don't make it too big!
    apply_threshold_to_folder(currdir, output_folder, threshold, replace_with, num_processes)

    # 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold' needs special treatment
    e = 'nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5'
    currdir = join(base_dir, 'predVal', 'nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold')
    output_folder = join(base_dir, 'predVal_PP', 'nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold')
    maybe_mkdir_p(output_folder)
    threshold = load_pickle(join(out_pp, e, 'threshold.pkl'))[2]
    if threshold > 1000: threshold = 750  # don't make it too big!
    apply_threshold_to_folder(currdir, output_folder, threshold, replace_with, num_processes)

    # convert val set to brats labels for submission
    output_converted = join(base_dir, 'converted_valSet')

    for source in ['predVal', 'predVal_PP']:
        for e in has_val_pred + ['nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold', 'nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold']:
            expected_source_folder = join(base_dir, source, e)
            if not isdir(expected_source_folder):
                print(e, 'has no', source)
                raise RuntimeError()
            files = subfiles(expected_source_folder, suffix='.nii.gz', join=False)
            if len(files) != expected_num_cases:
                print(e, 'prediction not done, found %d files, expected %s' % (len(files), expected_num_cases))
                continue
            target_folder = join(output_converted, source, e)
            maybe_mkdir_p(target_folder)
            convert_labels_back_to_BraTS_2018_2019_convention(expected_source_folder, target_folder)

    summarize_validation_set_predictions(output_converted)
コード例 #5
0
ファイル: run_accuracy.py プロジェクト: IntelAI/models
    def run(self):
        print("Run inference for accuracy")
        setup(self.args.data_location, self.args.input_graph)

        graph = tf.Graph()
        with graph.as_default():
            graph_def = tf.compat.v1.GraphDef()
            with open(self.args.input_graph, "rb") as f:
                graph_def.ParseFromString(f.read())
            output_graph = optimize_for_inference(
                graph_def, [INPUTS], [OUTPUTS],
                dtypes.float32.as_datatype_enum, False)
            tf.import_graph_def(output_graph, name="")

        input_tensor = graph.get_tensor_by_name('input:0')
        output_tensor = graph.get_tensor_by_name('Identity:0')

        config = tf.compat.v1.ConfigProto()
        config.intra_op_parallelism_threads = self.args.num_intra_threads
        config.inter_op_parallelism_threads = self.args.num_inter_threads
        config.graph_options.rewrite_options.auto_mixed_precision_mkl = rewriter_config_pb2.RewriterConfig.ON

        sess = tf.compat.v1.Session(graph=graph, config=config)
        if (self.args.accuracy_only):
            print("Inference with real data")
            preprocessed_data_dir = "build/preprocessed_data"
            with open(
                    os.path.join(preprocessed_data_dir,
                                 "preprocessed_files.pkl"), "rb") as f:
                preprocessed_files = pickle.load(f)

            dictionaries = []
            for preprocessed_file in preprocessed_files:
                with open(
                        os.path.join(preprocessed_data_dir,
                                     preprocessed_file + ".pkl"), "rb") as f:
                    dct = pickle.load(f)[1]
                    dictionaries.append(dct)

            count = len(preprocessed_files)
            predictions = [None] * count
            validation_indices = list(range(0, count))
            print("Found {:d} preprocessed files".format(count))
            loaded_files = {}
            batch_size = self.args.batch_size
            # Get the number of steps based on batch size
            steps = count  #math.ceil(count/batch_size)
            for i in range(steps):
                print("Iteration {} ...".format(i))
                test_data_index = validation_indices[
                    i]  #validation_indices[i * batch_size:(i + 1) * batch_size]
                file_name = preprocessed_files[test_data_index]
                with open(
                        os.path.join(preprocessed_data_dir,
                                     "{:}.pkl".format(file_name)), "rb") as f:
                    data = pickle.load(f)[0]
                predictions[i] = sess.run(
                    output_tensor,
                    feed_dict={input_tensor: data[np.newaxis,
                                                  ...]})[0].astype(np.float32)

            output_folder = "build/postprocessed_data"
            output_files = preprocessed_files
            # Post Process
            postprocess_output(predictions, dictionaries, validation_indices,
                               output_folder, output_files)

            ground_truths = "build/raw_data/nnUNet_raw_data/Task043_BraTS2019/labelsTr"
            # Run evaluation
            print("Running evaluation...")
            evaluate_regions(output_folder, ground_truths, get_brats_regions())
            # Load evaluation summary
            print("Loading evaluation summary...")
            with open(os.path.join(output_folder, "summary.csv")) as f:
                for line in f:
                    words = line.split(",")
                    if words[0] == "mean":
                        whole = float(words[1])
                        core = float(words[2])
                        enhancing = float(words[3])
                        mean = (whole + core + enhancing) / 3
                        print(
                            "Accuracy: mean = {:.5f}, whole tumor = {:.4f}, tumor core = {:.4f}, enhancing tumor = {:.4f}"
                            .format(mean, whole, core, enhancing))
                        break

        print("Done!")