Esempio n. 1
0
def preprocess_MLPerf(model, checkpoint_name, folds, fp16, list_of_lists, output_filenames, preprocessing_folder, num_threads_preprocessing):
    assert len(list_of_lists) == len(output_filenames)

    print("loading parameters for folds", folds)
    trainer, params = load_model_and_checkpoint_files(model, folds, fp16=fp16, checkpoint_name=checkpoint_name)

    print("starting preprocessing generator")
    preprocessing = preprocess_multithreaded(trainer, list_of_lists, output_filenames, num_threads_preprocessing, None)
    print("Preprocessing images...")
    all_output_files = []

    for preprocessed in preprocessing:
        output_filename, (d, dct) = preprocessed

        all_output_files.append(output_filename)
        if isinstance(d, str):
            data = np.load(d)
            os.remove(d)
            d = data

        # Pad to the desired full volume
        d = pad_nd_image(d, trainer.patch_size, "constant", None, False, None)

        with open(os.path.join(preprocessing_folder, output_filename+ ".pkl"), "wb") as f:
            pickle.dump([d, dct], f)
        f.close()

    return  all_output_files
Esempio n. 2
0
 def load_model(self):
     folds = None
     print("loading parameters for folds,", folds)
     trainer, params = load_model_and_checkpoint_files(
         self.model,
         folds,
         fp16=False,
         checkpoint_name="model_final_checkpoint")
     return trainer, params
Esempio n. 3
0
File: run.py Progetto: mbasnet1/lpot
def main():
    class CalibrationDL():
        def __init__(self):
            path = os.path.abspath(
                os.path.expanduser('./brats_cal_images_list.txt'))
            with open(path, 'r') as f:
                self.preprocess_files = [line.rstrip() for line in f]

            self.loaded_files = {}
            self.batch_size = 1

        def __getitem__(self, sample_id):
            file_name = self.preprocess_files[sample_id]
            print("Loading file {:}".format(file_name))
            with open(
                    os.path.join('build/calib_preprocess/',
                                 "{:}.pkl".format(file_name)), "rb") as f:
                self.loaded_files[sample_id] = pickle.load(f)[0]
            return torch.from_numpy(
                self.loaded_files[sample_id][np.newaxis, ...]).float(), None

        def __len__(self):
            self.count = len(self.preprocess_files)
            return self.count

    args = get_args()
    assert args.backend == "pytorch"
    model_path = os.path.join(args.model_dir, "plans.pkl")
    assert os.path.isfile(
        model_path), "Cannot find the model file {:}!".format(model_path)
    trainer, params = load_model_and_checkpoint_files(
        args.model_dir,
        folds=1,
        fp16=False,
        checkpoint_name='model_final_checkpoint')
    trainer.load_checkpoint_ram(params[0], False)
    model = trainer.network

    if args.tune:
        quantizer = Quantization('conf.yaml')
        quantizer.model = common.Model(model)
        quantizer.eval_func = eval_func
        calib_dl = CalibrationDL()
        quantizer.calib_dataloader = calib_dl
        q_model = quantizer()
        q_model.save('./lpot_workspace')
        exit(0)

    if args.benchmark:
        model.eval()
        if args.int8:
            from lpot.utils.pytorch import load
            new_model = load(
                os.path.abspath(os.path.expanduser('./lpot_workspace')), model)
        else:
            new_model = model
        eval_func(new_model)
Esempio n. 4
0
    def __init__(self, model_dir, preprocessed_data_dir, performance_count, folds, checkpoint_name):

        print("Loading PyTorch model...")
        model_path = os.path.join(model_dir, "plans.pkl")
        assert os.path.isfile(model_path), "Cannot find the model file {:}!".format(model_path)
        self.trainer, params = load_model_and_checkpoint_files(model_dir, folds, fp16=False, checkpoint_name=checkpoint_name)
        self.trainer.load_checkpoint_ram(params[0], False)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        print("Constructing SUT...")
        self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries, self.process_latencies)
        print("Finished constructing SUT.")
        self.qsl = get_brats_QSL(preprocessed_data_dir, performance_count)
Esempio n. 5
0
def main():
    args = get_args()

    print("Converting PyTorch model to ONNX...")

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    output_path = "./{}/{}".format(args.output_dir, args.output_name)
    dynamic_bs_output_path = "./{}/{}".format(args.output_dir,
                                              args.dynamic_bs_output_name)

    print("Loading Pytorch model...")
    checkpoint_name = "model_final_checkpoint"
    folds = 1
    trainer, params = load_model_and_checkpoint_files(
        args.model_dir, folds, fp16=False, checkpoint_name=checkpoint_name)
    trainer.load_checkpoint_ram(params[0], False)
    height = 224
    width = 224
    depth = 160
    channels = 4
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dummy_input = torch.rand([1, channels, height, width,
                              depth]).float().to(device)
    torch.onnx.export(trainer.network,
                      dummy_input,
                      output_path,
                      opset_version=11,
                      input_names=['input'],
                      output_names=['output'])
    torch.onnx.export(trainer.network,
                      dummy_input,
                      dynamic_bs_output_path,
                      opset_version=11,
                      input_names=['input'],
                      output_names=['output'],
                      dynamic_axes=({
                          "input": {
                              0: "batch_size"
                          },
                          "output": {
                              0: "batch_size"
                          }
                      }))

    print("Successfully exported model {} and {}".format(
        output_path, dynamic_bs_output_path))
Esempio n. 6
0
def predict_cases_fastest(model,
                          list_of_lists,
                          output_filenames,
                          folds,
                          num_threads_preprocessing,
                          num_threads_nifti_save,
                          segs_from_prev_stage=None,
                          do_tta=True,
                          mixed_precision=True,
                          overwrite_existing=False,
                          all_in_gpu=True,
                          step_size=0.5,
                          checkpoint_name="model_final_checkpoint"):
    assert len(list_of_lists) == len(output_filenames)
    if segs_from_prev_stage is not None:
        assert len(segs_from_prev_stage) == len(output_filenames)

    pool = Pool(num_threads_nifti_save)
    results = []

    cleaned_output_files = []
    for o in output_filenames:
        dr, f = os.path.split(o)
        if len(dr) > 0:
            maybe_mkdir_p(dr)
        if not f.endswith(".nii.gz"):
            f, _ = os.path.splitext(f)
            f = f + ".nii.gz"
        cleaned_output_files.append(join(dr, f))

    if not overwrite_existing:
        print("number of cases:", len(list_of_lists))
        not_done_idx = [
            i for i, j in enumerate(cleaned_output_files) if not isfile(j)
        ]

        cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx]
        list_of_lists = [list_of_lists[i] for i in not_done_idx]
        if segs_from_prev_stage is not None:
            segs_from_prev_stage = [
                segs_from_prev_stage[i] for i in not_done_idx
            ]

        print("number of cases that still need to be predicted:",
              len(cleaned_output_files))

    print("emptying cuda cache")
    torch.cuda.empty_cache()

    print("loading parameters for folds,", folds)
    trainer, params = load_model_and_checkpoint_files(
        model,
        folds,
        mixed_precision=mixed_precision,
        checkpoint_name=checkpoint_name)

    print("starting preprocessing generator")
    preprocessing = preprocess_multithreaded(trainer, list_of_lists,
                                             cleaned_output_files,
                                             num_threads_preprocessing,
                                             segs_from_prev_stage)

    print("starting prediction...")
    for preprocessed in preprocessing:
        print("getting data from preprocessor")
        output_filename, (d, dct) = preprocessed
        print("got something")
        if isinstance(d, str):
            print("what I got is a string, so I need to load a file")
            data = np.load(d)
            os.remove(d)
            d = data

        # preallocate the output arrays
        # same dtype as the return value in predict_preprocessed_data_return_seg_and_softmax (saves time)
        all_softmax_outputs = np.zeros(
            (len(params), trainer.num_classes, *d.shape[1:]), dtype=np.float16)
        all_seg_outputs = np.zeros((len(params), *d.shape[1:]), dtype=int)
        print("predicting", output_filename)

        for i, p in enumerate(params):
            trainer.load_checkpoint_ram(p, False)
            res = trainer.predict_preprocessed_data_return_seg_and_softmax(
                d,
                do_mirroring=do_tta,
                mirror_axes=trainer.data_aug_params['mirror_axes'],
                use_sliding_window=True,
                step_size=step_size,
                use_gaussian=True,
                all_in_gpu=all_in_gpu,
                mixed_precision=mixed_precision)
            if len(params) > 1:
                # otherwise we dont need this and we can save ourselves the time it takes to copy that
                all_softmax_outputs[i] = res[1]
            all_seg_outputs[i] = res[0]

        print("aggregating predictions")
        if len(params) > 1:
            softmax_mean = np.mean(all_softmax_outputs, 0)
            seg = softmax_mean.argmax(0)
        else:
            seg = all_seg_outputs[0]

        print("applying transpose_backward")
        transpose_forward = trainer.plans.get('transpose_forward')
        if transpose_forward is not None:
            transpose_backward = trainer.plans.get('transpose_backward')
            seg = seg.transpose([i for i in transpose_backward])

        print("initializing segmentation export")
        results.append(
            pool.starmap_async(save_segmentation_nifti,
                               ((seg, output_filename, dct, 0, None), )))
        print("done")

    print(
        "inference done. Now waiting for the segmentation export to finish...")
    _ = [i.get() for i in results]
    # now apply postprocessing
    # first load the postprocessing properties if they are present. Else raise a well visible warning
    results = []
    pp_file = join(model, "postprocessing.json")
    if isfile(pp_file):
        print("postprocessing...")
        shutil.copy(pp_file, os.path.dirname(output_filenames[0]))
        # for_which_classes stores for which of the classes everything but the largest connected component needs to be
        # removed
        for_which_classes, min_valid_obj_size = load_postprocessing(pp_file)
        results.append(
            pool.starmap_async(
                load_remove_save,
                zip(output_filenames, output_filenames,
                    [for_which_classes] * len(output_filenames),
                    [min_valid_obj_size] * len(output_filenames))))
        _ = [i.get() for i in results]
    else:
        print(
            "WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run "
            "consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is "
            "%s" % model)

    pool.close()
    pool.join()
Esempio n. 7
0
def predict_cases(model,
                  list_of_lists,
                  output_filenames,
                  folds,
                  save_npz,
                  num_threads_preprocessing,
                  num_threads_nifti_save,
                  segs_from_prev_stage=None,
                  do_tta=True,
                  mixed_precision=True,
                  overwrite_existing=False,
                  all_in_gpu=False,
                  step_size=0.5,
                  checkpoint_name="model_final_checkpoint",
                  segmentation_export_kwargs: dict = None,
                  disable_sliding_window: bool = False):
    """
    :param segmentation_export_kwargs:
    :param model: folder where the model is saved, must contain fold_x subfolders
    :param list_of_lists: [[case0_0000.nii.gz, case0_0001.nii.gz], [case1_0000.nii.gz, case1_0001.nii.gz], ...]
    :param output_filenames: [output_file_case0.nii.gz, output_file_case1.nii.gz, ...]
    :param folds: default: (0, 1, 2, 3, 4) (but can also be 'all' or a subset of the five folds, for example use (0, )
    for using only fold_0
    :param save_npz: default: False
    :param num_threads_preprocessing:
    :param num_threads_nifti_save:
    :param segs_from_prev_stage:
    :param do_tta: default: True, can be set to False for a 8x speedup at the cost of a reduced segmentation quality
    :param overwrite_existing: default: True
    :param mixed_precision: if None then we take no action. If True/False we overwrite what the model has in its init
    :return:
    """
    assert len(list_of_lists) == len(output_filenames)
    if segs_from_prev_stage is not None:
        assert len(segs_from_prev_stage) == len(output_filenames)

    pool = Pool(num_threads_nifti_save)
    results = []

    cleaned_output_files = []
    for o in output_filenames:
        dr, f = os.path.split(o)
        if len(dr) > 0:
            maybe_mkdir_p(dr)
        if not f.endswith(".nii.gz"):
            f, _ = os.path.splitext(f)
            f = f + ".nii.gz"
        cleaned_output_files.append(join(dr, f))

    if not overwrite_existing:
        print("number of cases:", len(list_of_lists))
        # if save_npz=True then we should also check for missing npz files
        not_done_idx = [
            i for i, j in enumerate(cleaned_output_files)
            if (not isfile(j)) or (save_npz and not isfile(j[:-7] + '.npz'))
        ]

        cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx]
        list_of_lists = [list_of_lists[i] for i in not_done_idx]
        if segs_from_prev_stage is not None:
            segs_from_prev_stage = [
                segs_from_prev_stage[i] for i in not_done_idx
            ]

        print("number of cases that still need to be predicted:",
              len(cleaned_output_files))

    print("emptying cuda cache")
    torch.cuda.empty_cache()

    print("loading parameters for folds,", folds)
    trainer, params = load_model_and_checkpoint_files(
        model,
        folds,
        mixed_precision=mixed_precision,
        checkpoint_name=checkpoint_name)

    if segmentation_export_kwargs is None:
        if 'segmentation_export_params' in trainer.plans.keys():
            force_separate_z = trainer.plans['segmentation_export_params'][
                'force_separate_z']
            interpolation_order = trainer.plans['segmentation_export_params'][
                'interpolation_order']
            interpolation_order_z = trainer.plans[
                'segmentation_export_params']['interpolation_order_z']
        else:
            force_separate_z = None
            interpolation_order = 1
            interpolation_order_z = 0
    else:
        force_separate_z = segmentation_export_kwargs['force_separate_z']
        interpolation_order = segmentation_export_kwargs['interpolation_order']
        interpolation_order_z = segmentation_export_kwargs[
            'interpolation_order_z']

    print("starting preprocessing generator")
    preprocessing = preprocess_multithreaded(trainer, list_of_lists,
                                             cleaned_output_files,
                                             num_threads_preprocessing,
                                             segs_from_prev_stage)
    print("starting prediction...")
    all_output_files = []
    for preprocessed in preprocessing:
        output_filename, (d, dct) = preprocessed
        all_output_files.append(all_output_files)
        if isinstance(d, str):
            data = np.load(d)
            os.remove(d)
            d = data

        print("predicting", output_filename)
        softmax = []
        for p in params:
            trainer.load_checkpoint_ram(p, False)
            softmax.append(
                trainer.predict_preprocessed_data_return_seg_and_softmax(
                    d,
                    do_mirroring=do_tta,
                    mirror_axes=trainer.data_aug_params['mirror_axes'],
                    use_sliding_window=not disable_sliding_window,
                    step_size=step_size,
                    use_gaussian=True,
                    all_in_gpu=all_in_gpu,
                    mixed_precision=mixed_precision)[1][None])

        softmax = np.vstack(softmax)
        softmax_mean = np.mean(softmax, 0)

        transpose_forward = trainer.plans.get('transpose_forward')
        if transpose_forward is not None:
            transpose_backward = trainer.plans.get('transpose_backward')
            softmax_mean = softmax_mean.transpose(
                [0] + [i + 1 for i in transpose_backward])

        if save_npz:
            npz_file = output_filename[:-7] + ".npz"
        else:
            npz_file = None

        if hasattr(trainer, 'regions_class_order'):
            region_class_order = trainer.regions_class_order
        else:
            region_class_order = None
        """There is a problem with python process communication that prevents us from communicating obejcts 
        larger than 2 GB between processes (basically when the length of the pickle string that will be sent is 
        communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long 
        enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually 
        patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will 
        then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either 
        filename or np.ndarray and will handle this automatically"""
        bytes_per_voxel = 4
        if all_in_gpu:
            bytes_per_voxel = 2  # if all_in_gpu then the return value is half (float16)
        if np.prod(softmax_mean.shape) > (2e9 / bytes_per_voxel *
                                          0.85):  # * 0.85 just to be save
            print(
                "This output is too large for python process-process communication. Saving output temporarily to disk"
            )
            np.save(output_filename[:-7] + ".npy", softmax_mean)
            softmax_mean = output_filename[:-7] + ".npy"

        results.append(
            pool.starmap_async(
                save_segmentation_nifti_from_softmax,
                ((softmax_mean, output_filename, dct, interpolation_order,
                  region_class_order, None, None, npz_file, None,
                  force_separate_z, interpolation_order_z), )))

    print(
        "inference done. Now waiting for the segmentation export to finish...")
    _ = [i.get() for i in results]
    # now apply postprocessing
    # first load the postprocessing properties if they are present. Else raise a well visible warning
    results = []
    pp_file = join(model, "postprocessing.json")
    if isfile(pp_file):
        print("postprocessing...")
        shutil.copy(pp_file,
                    os.path.abspath(os.path.dirname(output_filenames[0])))
        # for_which_classes stores for which of the classes everything but the largest connected component needs to be
        # removed
        for_which_classes, min_valid_obj_size = load_postprocessing(pp_file)
        results.append(
            pool.starmap_async(
                load_remove_save,
                zip(output_filenames, output_filenames,
                    [for_which_classes] * len(output_filenames),
                    [min_valid_obj_size] * len(output_filenames))))
        _ = [i.get() for i in results]
    else:
        print(
            "WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run "
            "consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is "
            "%s" % model)

    pool.close()
    pool.join()
Esempio n. 8
0
def predict_cases(model,
                  list_of_lists,
                  output_filenames,
                  folds,
                  save_npz,
                  num_threads_preprocessing,
                  num_threads_nifti_save,
                  segs_from_prev_stage=None,
                  do_tta=True,
                  overwrite_existing=False):

    assert len(list_of_lists) == len(output_filenames)
    if segs_from_prev_stage is not None:
        assert len(segs_from_prev_stage) == len(output_filenames)

    prman = Pool(num_threads_nifti_save)
    results = []

    cleaned_output_files = []
    for o in output_filenames:
        dr, f = os.path.split(o)
        if len(dr) > 0:
            maybe_mkdir_p(dr)
        if not f.endswith(".nii.gz"):
            f, _ = os.path.splitext(f)
            f = f + ".nii.gz"
        cleaned_output_files.append(join(dr, f))

    if not overwrite_existing:
        print("number of cases:", len(list_of_lists))
        not_done_idx = [
            i for i, j in enumerate(cleaned_output_files) if not isfile(j)
        ]

        cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx]
        list_of_lists = [list_of_lists[i] for i in not_done_idx]
        if segs_from_prev_stage is not None:
            segs_from_prev_stage = [
                segs_from_prev_stage[i] for i in not_done_idx
            ]

        print("number of cases that still need to be predicted:",
              len(cleaned_output_files))

    print("emptying cuda cache")
    torch.cuda.empty_cache()

    print("loading parameters for folds,", folds)
    trainer, params = load_model_and_checkpoint_files(model, folds)

    print("starting preprocessing generator")
    preprocessing = preprocess_multithreaded(trainer, list_of_lists,
                                             cleaned_output_files,
                                             num_threads_preprocessing,
                                             segs_from_prev_stage)
    print("starting prediction...")
    for preprocessed in preprocessing:
        output_filename, (d, dct) = preprocessed
        if isinstance(d, str):
            data = np.load(d)
            os.remove(d)
            d = data

        print("predicting", output_filename)

        softmax = []
        for p in params:
            trainer.load_checkpoint_ram(p, False)
            softmax.append(
                trainer.predict_preprocessed_data_return_softmax(
                    d, do_tta, 1, False, 1,
                    trainer.data_aug_params['mirror_axes'], True, True, 2,
                    trainer.patch_size, True)[None])

        softmax = np.vstack(softmax)
        softmax_mean = np.mean(softmax, 0)

        transpose_forward = trainer.plans.get('transpose_forward')
        if transpose_forward is not None:
            transpose_backward = trainer.plans.get('transpose_backward')
            softmax_mean = softmax_mean.transpose(
                [0] + [i + 1 for i in transpose_backward])

        if save_npz:
            npz_file = output_filename[:-7] + ".npz"
        else:
            npz_file = None
        """There is a problem with python process communication that prevents us from communicating obejcts 
        larger than 2 GB between processes (basically when the length of the pickle string that will be sent is 
        communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long 
        enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually 
        patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will 
        then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either 
        filename or np.ndarray and will handle this automatically"""
        if np.prod(
                softmax_mean.shape) > (2e9 / 4 * 0.9):  # *0.9 just to be save
            print(
                "This output is too large for python process-process communication. Saving output temporarily to disk"
            )
            np.save(output_filename[:-7] + ".npy", softmax_mean)
            softmax_mean = output_filename[:-7] + ".npy"

        results.append(
            prman.starmap_async(save_segmentation_nifti_from_softmax,
                                ((softmax_mean, output_filename, dct, 1, None,
                                  None, None, npz_file), )))

    _ = [i.get() for i in results]
Esempio n. 9
0
def predict_cases_fastest(model,
                          list_of_lists,
                          output_filenames,
                          folds,
                          num_threads_preprocessing,
                          num_threads_nifti_save,
                          segs_from_prev_stage=None,
                          do_tta=True,
                          overwrite_existing=False,
                          all_in_gpu=True):
    assert len(list_of_lists) == len(output_filenames)
    if segs_from_prev_stage is not None:
        assert len(segs_from_prev_stage) == len(output_filenames)

    prman = Pool(num_threads_nifti_save)
    results = []

    cleaned_output_files = []
    for o in output_filenames:
        dr, f = os.path.split(o)
        if len(dr) > 0:
            maybe_mkdir_p(dr)
        if not f.endswith(".nii.gz"):
            f, _ = os.path.splitext(f)
            f = f + ".nii.gz"
        cleaned_output_files.append(join(dr, f))

    if not overwrite_existing:
        print("number of cases:", len(list_of_lists))
        not_done_idx = [
            i for i, j in enumerate(cleaned_output_files) if not isfile(j)
        ]

        cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx]
        list_of_lists = [list_of_lists[i] for i in not_done_idx]
        if segs_from_prev_stage is not None:
            segs_from_prev_stage = [
                segs_from_prev_stage[i] for i in not_done_idx
            ]

        print("number of cases that still need to be predicted:",
              len(cleaned_output_files))

    print("emptying cuda cache")
    torch.cuda.empty_cache()

    print("loading parameters for folds,", folds)
    trainer, params = load_model_and_checkpoint_files(model, folds)

    print("starting preprocessing generator")
    preprocessing = preprocess_multithreaded(trainer, list_of_lists,
                                             cleaned_output_files,
                                             num_threads_preprocessing,
                                             segs_from_prev_stage)
    print("starting prediction...")
    for preprocessed in preprocessing:
        output_filename, (d, dct) = preprocessed
        if isinstance(d, str):
            data = np.load(d)
            os.remove(d)
            d = data

        print("predicting", output_filename)

        softmax = []
        segs = []
        for p in params:
            trainer.load_checkpoint_ram(p, False)
            res = trainer.predict_preprocessed_data_return_softmax_and_seg(
                d,
                do_tta,
                1,
                False,
                1,
                trainer.data_aug_params['mirror_axes'],
                True,
                True,
                2,
                trainer.patch_size,
                True,
                all_in_gpu=all_in_gpu)
            softmax.append(res[1][None])
            segs.append(res[0])

        if len(softmax) > 1:
            softmax = np.vstack(softmax)
            softmax_mean = np.mean(softmax, 0)
            seg = softmax_mean.argmax(0)
        else:
            seg = segs[0]

        transpose_forward = trainer.plans.get('transpose_forward')
        if transpose_forward is not None:
            transpose_backward = trainer.plans.get('transpose_backward')
            seg = seg.transpose([i for i in transpose_backward])

        results.append(
            prman.starmap_async(save_segmentation_nifti,
                                ((seg, output_filename, dct, 0, None), )))

    print(
        "inference done. Now waiting for the segmentation export to finish...")
    _ = [i.get() for i in results]
Esempio n. 10
0
import sys
sys.path.append('..')
import argparse
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.training.model_restore import load_model_and_checkpoint_files
from fvcore.nn.flop_count import _DEFAULT_SUPPORTED_OPS, FlopCountAnalysis, flop_count
import numpy as np
import torch
import os
join = os.path.join

parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', help="2d, 3d_lowres, 3d_fullres or 3d_cascade_fullres. Default: 3d_fullres", default="3d_fullres", required=False)
args = parser.parse_args()
model = args.model

model_path = join('./data/RESULTS_FOLDER/nnUNet', model, 'Task000_FLARE21Baseline/nnUNetTrainerV2__nnUNetPlansv2.1')
trainer, params = load_model_and_checkpoint_files(model_path, folds='all', checkpoint_name='model_final_checkpoint')
pkl_file = join(model_path, "all/model_final_checkpoint.model.pkl")
info = load_pickle(pkl_file)
if model == '2d' or model == '3d_lowres':
    patch_size = info['plans']['plans_per_stage'][0]['patch_size']
else:
    patch_size = info['plans']['plans_per_stage'][1]['patch_size']
patch_size = np.append(np.array(1), patch_size)

inputs = (torch.randn(tuple(np.append(np.array(1),patch_size))).cuda(),)
flops = FlopCountAnalysis(trainer.network, inputs)
print('Total FLOPs:', flops.total())