Example #1
0
    def test_resunet(self):
        import BrainMaGe
        from BrainMaGe.models.networks import fetch_model

        weights = Path(BrainMaGe.__file__).parent / 'weights' / 'resunet_ma.pt'
        pt_model = fetch_model(modelname="resunet",
                               num_channels=1,
                               num_classes=2,
                               num_filters=16)
        checkpoint = torch.load(weights, map_location=torch.device('cpu'))
        pt_model.load_state_dict(checkpoint["model_state_dict"])
        pt_model.eval()

        # Get reference output
        inp = torch.randn([1, 1, 128, 128, 128])
        ref = pt_model(inp).detach().numpy()

        # Perform multiple runs with other inputs to make sure that InstanceNorm layer does not stuck
        for _ in range(2):
            dummy_inp = torch.randn(inp.shape)
            pt_model(dummy_inp)

        # Generate OpenVINO IR
        mo_pytorch.convert(pt_model,
                           input_shape=list(inp.shape),
                           model_name='model')

        # Run model with OpenVINO and compare outputs
        net = self.ie.read_network('model.xml', 'model.bin')
        exec_net = self.ie.load_network(net, 'CPU')
        out = exec_net.infer({'input': inp.detach().numpy()})
        out = next(iter(out.values()))

        diff = np.max(np.abs(out - ref))
        self.assertLessEqual(diff, 5e-4)
 def __init__(self, params):
     super(SkullStripper, self).__init__()
     self.params = params
     self.model = fetch_model(params['model'],
                              int(self.params['num_modalities']),
                              int(self.params['num_classes']),
                              int(self.params['base_filters']))
def bench_pytorch_fp32(num_cores):

    pytorch_model_path = brainmage_root / 'BrainMaGe/weights/resunet_ma.pt'

    ### Load PyTorch model
    pt_model = fetch_model(modelname="resunet", num_channels=1, num_classes=2, num_filters=16)
    checkpoint = torch.load(pytorch_model_path, map_location=torch.device('cpu'))
    pt_model.load_state_dict(checkpoint["model_state_dict"])

    ### Run PyTorch Inference
    print (f"\n Starting PyTorch inference with {pytorch_model_path} using {num_cores} cores...")

    _ = pt_model.eval()

    pt_stats =[]

    torch.set_num_threads(int(num_cores))

    with torch.no_grad():
        for i, row in tqdm(dataset_df.iterrows()):
            sub_id = row[sub_idx]
            input_path = row[input_path_idx]
            mask_path = row[mask_path_idx]

            try:
                mask_image = get_mask_image(mask_path)
                input_image, patient_nib = get_input_image(input_path)

                i_start = timer()
                pt_output = pt_model(input_image)
                i_end = timer()

                p_start = timer()
                pt_output = pt_output.cpu().numpy()[0][0]
                pt_to_save = postprocess_output(pt_output, patient_nib.shape)
                pt_dice_score = dice(pt_to_save, mask_image)
                p_end = timer()

                pt_stat = [i, sub_id, pt_dice_score, i_end-i_start, p_end-p_start]
                pt_stats.append(pt_stat)
            except:
                print (f" Inference Failed: {sub_id} ")

    print (f"Done PyTorch inference with {pytorch_model_path} ...")
    pt_stats_df = pd.DataFrame(pt_stats)

    date_time_str = datetime.now().strftime("%b-%d-%Y_%H-%M-%S")
    csv_name = f"pt_fp32_stats_nc_{num_cores}_{date_time_str}.csv"
    pt_stats_df.to_csv(csv_name, sep=',', header=False, index=False)
    print (f"Saved {csv_name} ...")

    print (f"\n PyTorch Dice Mean: {pt_stats_df[:][2].mean():.5f}")
    print (f"PyTorch Total Inf Time: {pt_stats_df[:][3].sum():.2f} sec, Mean: {pt_stats_df[:][3].mean():.2f} sec")

    return pt_stats_df
def bench_pytorch_fp32():

    ### Load PyTorch model

    pt_model = fetch_model(modelname="resunet",
                           num_channels=1,
                           num_classes=2,
                           num_filters=16)
    checkpoint = torch.load(pytorch_model_path,
                            map_location=torch.device('cpu'))
    pt_model.load_state_dict(checkpoint["model_state_dict"])

    ### Run PyTorch Inference

    print(f"Starting PyTorch inference with {pytorch_model_path} ...")

    _ = pt_model.eval()

    pt_stats = []

    with torch.no_grad():
        for i, row in tqdm(nfbs_dataset_df.iterrows()):
            sub_id = row[0]
            input_path = row[2]
            mask_path = row[3]

            try:
                mask_image = get_mask_image(mask_path)
                input_image, patient_nib = get_input_image(input_path)

                i_start = timer()
                pt_output = pt_model(input_image)
                i_end = timer()

                p_start = timer()
                pt_output = pt_output.cpu().numpy()[0][0]
                pt_to_save = postprocess_output(pt_output, patient_nib.shape)
                pt_dice_score = dice(pt_to_save, mask_image)
                p_end = timer()

                pt_stat = [
                    i, sub_id, pt_dice_score, i_end - i_start, p_end - p_start
                ]
                pt_stats.append(pt_stat)
            except:
                print(f" Inference Failed: {sub_id} ")

    print(f"Done PyTorch inference with {pytorch_model_path} ...")
    pt_stats_df = pd.DataFrame(pt_stats)

    pt_stats_df.to_csv('pt_stats.csv', sep=',', header=False, index=False)
    print(f"Saved pt_stats.csv ...")

    return pt_stats_df
Example #5
0
def infer_single_multi_4(input_paths,
                         output_path,
                         weights,
                         mask_path=None,
                         device="cpu"):
    """
    Inference using multi modality network

    Parameters [TODO]
    ----------
    input_paths : list
        path to all input images following T1_path,T2_path,T1ce_path,Flair_path
    output_path : str
        path of the mask to be generated (prediction)
    weights : str
        path to the weights of the model used
    device : int/str
        device to be run on

    Returns
    -------
    None.

    """
    assert all([os.path.exists(image_path) for image_path in input_paths])

    start = time.asctime()
    startstamp = time.time()
    print("\nHostname   :" + str(os.getenv("HOSTNAME")))
    print("\nStart Time :" + str(start))
    print("\nStart Stamp:" + str(startstamp))
    sys.stdout.flush()

    # default config for multi-4 as from config/test_params_multi_4.cfg
    model = fetch_model(
        modelname="resunet",
        num_channels=4,
        num_classes=2,
        num_filters=16,
    )

    checkpoint = torch.load(str(weights), map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint["model_state_dict"])

    if device != "cpu":
        model.cuda()
    model.eval()

    stack = np.zeros([4, 128, 128, 128], dtype=np.float32)
    for i, image_path in enumerate(input_paths):
        patient_nib = nib.load(image_path)
        image = patient_nib.get_fdata()
        image = preprocess_image(patient_nib)
        stack[i] = image
    stack = stack[np.newaxis, ...]
    image = torch.FloatTensor(stack)

    if device != "cpu":
        image = image.cuda()

    with torch.no_grad():
        output = model(image)
        output = output.cpu().numpy()[0][0]
        to_save = interpolate_image(output, (240, 240, 160))
        to_save = unpad_image(to_save)
        to_save[to_save >= 0.9] = 1
        to_save[to_save < 0.9] = 0
        for i in range(to_save.shape[2]):
            if np.any(to_save[:, :, i]):
                to_save[:, :, i] = binary_fill_holes(to_save[:, :, i])
        to_save = postprocess_prediction(to_save).astype(np.uint8)
        to_save_nib = nib.Nifti1Image(to_save, patient_nib.affine)
        nib.save(to_save_nib, os.path.join(output_path))

    print("Done with running the model.")

    if mask_path is not None:
        raise NotImplementedError('Sorry, masking is not implemented (yet).')

    print("Final output stored in : %s" % (output_path))
    print("Thank you for using BrainMaGe")
    print("*" * 60)
Example #6
0
def infer_single_ma(input_path,
                    output_path,
                    weights,
                    mask_path=None,
                    device="cpu"):
    start = time.asctime()
    startstamp = time.time()
    print("\nHostname   :" + str(os.getenv("HOSTNAME")))
    print("\nStart Time :" + str(start))
    print("\nStart Stamp:" + str(startstamp))
    sys.stdout.flush()
    print("Generating Test csv")

    model = fetch_model(modelname="resunet",
                        num_channels=1,
                        num_classes=2,
                        num_filters=16)

    checkpoint = torch.load(weights, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint["model_state_dict"])

    if device != "cpu":
        model.cuda()
    model.eval()

    patient_nib = nib.load(input_path)
    image_data = patient_nib.get_fdata()
    old_shape = patient_nib.shape
    image = process_image(image_data)
    image = resize(image, (128, 128, 128),
                   order=3,
                   mode="edge",
                   cval=0,
                   anti_aliasing=False)
    image = image[np.newaxis, np.newaxis, ...]
    image = torch.FloatTensor(image)
    if device != "cpu":
        image = image.cuda()
    with torch.no_grad():
        output = model(image)
        output = output.cpu().numpy()[0][0]
        to_save = interpolate_image(output, patient_nib.shape)
        to_save[to_save >= 0.9] = 1
        to_save[to_save < 0.9] = 0
        to_save = postprocess_prediction(to_save)
        to_save_nib = nib.Nifti1Image(to_save, patient_nib.affine)
        nib.save(to_save_nib, os.path.join(output_path))

    print("Done with running the model.")

    if mask_path is not None:
        print(
            "You chose to save the brain. We are now saving it with the masks."
        )
        patient_nib_write = nib.load(input_path)
        image_data_write = patient_nib_write.get_fdata()
        image_data_write[to_save == 0] = 0
        to_save_brain = nib.Nifti1Image(image_data_write,
                                        patient_nib_write.affine)
        nib.save(to_save_brain, os.path.join(mask_path))

    print("Thank you for using BrainMaGe")
    print("*" * 60)
Example #7
0
def infer_ma(hparams):
    global patients, patients_path, n_processes, model_dir, preprocessed_output_dir, temp_output_dir
    model_dir = hparams.model_dir
    training_start_time = time.asctime()
    startstamp = time.time()
    print("\nHostname   :" + str(os.getenv("HOSTNAME")))
    print("\nStart Time :" + str(training_start_time))
    print("\nStart Stamp:" + str(startstamp))
    sys.stdout.flush()
    # Parsing the number of CPU's used
    n_processes = int(hparams.threads)
    print("Number of CPU's used : ", n_processes)

    # PRINT PARSED HPARAMS
    print("\n\n")
    print("Model Dir               :", hparams.model_dir)
    print("Test CSV                :", hparams.test_csv)
    print("Number of channels      :", hparams.num_channels)
    print("Model Name              :", hparams.model)
    print("Modalities              :", hparams.modalities)
    print("Number of classes       :", hparams.num_classes)
    print("Base Filters            :", hparams.base_filters)
    print("Load Weights            :", hparams.weights)

    print("Generating Test csv")
    if not os.path.exists(hparams["results_dir"]):
        os.mkdir(hparams.results_dir)
    if not hparams.csv_provided == "True":
        print("Since CSV were not provided, we are gonna create for you")
        csv_creator_adv.generate_csv(
            hparams.test_dir,
            to_save=hparams.results_dir,
            mode=hparams.mode,
            ftype="test",
            modalities=hparams.modalities,
        )
        test_csv = os.path.join(hparams.results_dir, "test.csv")
    else:
        test_csv = hparams.test_csv

    n_processes = int(hparams.threads)
    model = fetch_model(
        hparams.model,
        int(hparams.num_modalities),
        int(hparams.num_classes),
        int(hparams.base_filters),
    )
    checkpoint = torch.load(str(hparams.weights))
    model.load_state_dict(checkpoint.model_state_dict)

    if hparams.device != "cpu":
        model.cuda()
    model.eval()

    test_df = pd.read_csv(test_csv)
    preprocessed_output_dir = os.path.join(hparams.model_dir, "preprocessed")
    os.makedirs(preprocessed_output_dir, exist_ok=True)
    patients = test_df.iloc[:, 0].astype(str)
    patients_path = test_df.iloc[:, 1]
    n_processes = int(hparams.threads)
    if len(patients) < n_processes:
        print("\n*********** WARNING ***********")
        print(
            "You are processing less number of patients as compared to the\n" +
            "threads provided, which means you are asking for more resources than \n"
            +
            "necessary which is not a great practice. Anyway, we have accounted for that \n"
            +
            "and reduced the number of threads to the maximum number of patients for \n"
            + "better resource management!\n")
        n_processes = len(patients)
    print("*" * 80)
    print("Intializing preprocessing")
    print("*" * 80)
    print("Initiating the CPU workload on %d threads.\n\n" % n_processes)
    print("Currently processing the following patients : ")
    START = time.time()
    pool = Pool(processes=n_processes)
    pool.map(preprocess_batch_works, range(n_processes))
    END = time.time()
    print("\n\n Preprocessing time taken : {} seconds".format(END - START))

    # Load the preprocessed patients to the dataloader
    print("*" * 80)
    print("Intializing Deep Neural Network")
    print("*" * 80)
    START = time.time()
    print("Initiating the GPU workload on CUDA threads.\n\n")
    print("Currently processing the following patients : ")
    preprocessed_data_dir = os.path.join(hparams.model_dir, "preprocessed")
    temp_output_dir = os.path.join(hparams.model_dir, "temp_output")
    os.makedirs(temp_output_dir, exist_ok=True)
    dataset_infer = VolSegDatasetInfer(preprocessed_data_dir)
    infer_loader = DataLoader(
        dataset_infer,
        batch_size=int(hparams.batch_size),
        shuffle=False,
        num_workers=int(hparams.threads),
        pin_memory=False,
    )
Example #8
0
def infer_multi_4(cfg, device, save_brain, weights):
    """
    Inference using multi modality network

    Parameters
    ----------
    cfg : string
        Location of the config file
    device : int/str
        device to be run on
    save_brain : int
        whether to save brain or not

    Returns
    -------
    None.

    """
    cfg = os.path.abspath(cfg)

    if os.path.isfile(cfg):
        params_df = pd.read_csv(
            cfg,
            sep=" = ",
            names=["param_name", "param_value"],
            comment="#",
            skip_blank_lines=True,
            engine="python",
        ).fillna(" ")
    else:
        print("Missing test_params.cfg file? Please give one!")
        sys.exit(0)
    params = {}
    params["weights"] = weights
    for i in range(params_df.shape[0]):
        params[params_df.iloc[i, 0]] = params_df.iloc[i, 1]
    start = time.asctime()
    startstamp = time.time()
    print("\nHostname   :" + str(os.getenv("HOSTNAME")))
    print("\nStart Time :" + str(start))
    print("\nStart Stamp:" + str(startstamp))
    sys.stdout.flush()

    print("Generating Test csv")
    if not os.path.exists(os.path.join(params["results_dir"])):
        os.mkdir(params["results_dir"])
    if not params["csv_provided"] == "True":
        print("Since CSV were not provided, we are gonna create for you")
        csv_creator_adv.generate_csv(
            params["test_dir"],
            to_save=params["results_dir"],
            mode=params["mode"],
            ftype="test",
            modalities=params["modalities"],
        )
        test_csv = os.path.join(params["results_dir"], "test.csv")
    else:
        test_csv = params["test_csv"]

    test_df = pd.read_csv(test_csv)

    model = fetch_model(
        params["model"],
        int(params["num_modalities"]),
        int(params["num_classes"]),
        int(params["base_filters"]),
    )
    if device != "cpu":
        model.cuda()

    checkpoint = torch.load(str(params["weights"]))
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    for patient in tqdm.tqdm(test_df.values):
        os.makedirs(os.path.join(params["results_dir"], patient[0]),
                    exist_ok=True)
        nmods = params["num_modalities"]
        stack = np.zeros([int(nmods), 128, 128, 128], dtype=np.float32)
        for i in range(int(nmods)):
            image_path = patient[i + 1]
            patient_nib = nib.load(image_path)
            image = patient_nib.get_fdata()
            image = preprocess_image(patient_nib)
            stack[i] = image
        stack = stack[np.newaxis, ...]
        stack = torch.FloatTensor(stack)
        if device != "cpu":
            image = stack.cuda()
        with torch.no_grad():
            output = model(image)
            output = output.cpu().numpy()[0][0]
            to_save = interpolate_image(output, (240, 240, 160))
            to_save = unpad_image(to_save)
            to_save[to_save >= 0.9] = 1
            to_save[to_save < 0.9] = 0
            for i in range(to_save.shape[2]):
                if np.any(to_save[:, :, i]):
                    to_save[:, :, i] = binary_fill_holes(to_save[:, :, i])
            to_save = postprocess_prediction(to_save).astype(np.uint8)
            to_save_mask = nib.Nifti1Image(to_save, patient_nib.affine)
            nib.save(
                to_save_mask,
                os.path.join(params["results_dir"], patient[0],
                             patient[0] + "_mask.nii.gz"),
            )
    print("Done with running the model.")
    if save_brain:
        print(
            "You chose to save the brain. We are now saving it with the masks."
        )
        for patient in tqdm.tqdm(test_df.values):
            nmods = params["num_modalities"]
            mask_nib = nib.load(
                os.path.join(params["results_dir"], patient[0],
                             patient[0] + "_mask.nii.gz"))
            mask_data = mask_nib.get_fdata().astype(np.int8)
            for i in range(int(nmods)):
                image_name = os.path.basename(patient[i + 1]).strip(".nii.gz")
                image_path = patient[i + 1]
                patient_nib = nib.load(image_path)
                image_data = patient_nib.get_fdata()
                image_data[mask_data == 0] = 0
                to_save_image = nib.Nifti1Image(image_data, patient_nib.affine)
                nib.save(
                    to_save_image,
                    os.path.join(params["results_dir"], patient[0],
                                 image_name + "_brain.nii.gz"),
                )

    print("Final output stored in : %s" % (params["results_dir"]))
    print("Thank you for using BrainMaGe")
    print("*" * 60)
Example #9
0
def infer_ma(cfg, device, save_brain, weights):

    cfg = os.path.abspath(cfg)

    if os.path.isfile(cfg):
        params_df = pd.read_csv(cfg,
                                sep=' = ',
                                names=['param_name', 'param_value'],
                                comment='#',
                                skip_blank_lines=True,
                                engine='python').fillna(' ')
    else:
        print('Missing test_params.cfg file? Please give one!')
        sys.exit(0)
    params = {}
    for i in range(params_df.shape[0]):
        params[params_df.iloc[i, 0]] = params_df.iloc[i, 1]
    params['weights'] = weights
    start = time.asctime()
    startstamp = time.time()
    print("\nHostname   :" + str(os.getenv("HOSTNAME")))
    print("\nStart Time :" + str(start))
    print("\nStart Stamp:" + str(startstamp))
    sys.stdout.flush()

    print("Generating Test csv")
    if not os.path.exists(os.path.join(params['results_dir'])):
        os.mkdir(params['results_dir'])
    if not params['csv_provided'] == 'True':
        print('Since CSV were not provided, we are gonna create for you')
        csv_creator_adv.generate_csv(params['test_dir'],
                                     to_save=params['results_dir'],
                                     mode=params['mode'],
                                     ftype='test',
                                     modalities=params['modalities'])
        test_csv = os.path.join(params['results_dir'], 'test.csv')
    else:
        test_csv = params['test_csv']

    test_df = pd.read_csv(test_csv)
    test_df.ID = test_df.ID.astype(str)
    temp_dir = os.path.join(params['results_dir'], 'Temp')
    os.makedirs(temp_dir, exist_ok=True)

    patients_dict = {}

    print("Resampling the images to isotropic resolution of 1mm x 1mm x 1mm")
    print("Also Converting the images to RAI and brats for smarter use.")
    for patient in tqdm.tqdm(test_df.values):
        os.makedirs(os.path.join(temp_dir, patient[0]), exist_ok=True)
        patient_path = patient[1]
        image = nib.load(patient_path)
        old_spacing = image.header.get_zooms()
        old_affine = image.affine
        old_shape = image.header.get_data_shape()
        new_spacing = (1, 1, 1)
        new_shape = (int(
            np.round(old_spacing[0] / new_spacing[0] * float(image.shape[0]))),
                     int(
                         np.round(old_spacing[1] / new_spacing[1] *
                                  float(image.shape[1]))),
                     int(
                         np.round(old_spacing[2] / new_spacing[2] *
                                  float(image.shape[2]))))
        image_data = image.get_fdata()
        new_image = resize(image_data,
                           new_shape,
                           order=3,
                           mode='edge',
                           cval=0,
                           anti_aliasing=False)
        new_affine = np.eye(4)
        new_affine = np.array(old_affine)
        for i in range(3):
            for j in range(3):
                if old_affine[i, j] != 0:
                    new_affine[i,
                               j] = old_affine[i, j] * (1 / old_affine[i, j])
                    if old_affine[i, j] <= 0:
                        new_affine[i, j] = -1 * (old_affine[i, j] *
                                                 (1 / old_affine[i, j]))
        temp_image = nib.Nifti1Image(new_image, new_affine)
        nib.save(
            temp_image,
            os.path.join(temp_dir, patient[0],
                         patient[0] + '_resamp111.nii.gz'))

        temp_dict = {}
        temp_dict['name'] = patient[0]
        temp_dict['old_spacing'] = old_spacing
        temp_dict['old_affine'] = old_affine
        temp_dict['old_shape'] = old_shape
        temp_dict['new_spacing'] = new_spacing
        temp_dict['new_affine'] = new_affine
        temp_dict['new_shape'] = new_shape

        patient_path = os.path.join(temp_dir, patient[0],
                                    patient[0] + '_resamp111.nii.gz')
        patient_nib = nib.load(patient_path)
        patient_data = patient_nib.get_fdata()
        patient_data, pad_info = pad_image(patient_data)
        patient_affine = patient_nib.affine
        temp_image = nib.Nifti1Image(patient_data, patient_affine)
        nib.save(
            temp_image,
            os.path.join(temp_dir, patient[0],
                         patient[0] + '_bratsized.nii.gz'))
        temp_dict['pad_info'] = pad_info
        patients_dict[patient[0]] = temp_dict

    model = fetch_model(params['model'], int(params['num_modalities']),
                        int(params['num_classes']),
                        int(params['base_filters']))
    checkpoint = torch.load(str(params['weights']))
    model.load_state_dict(checkpoint['model_state_dict'])

    if device != 'cpu':
        model.cuda()
    model.eval()

    print("Done Resampling the Data.\n")
    print("--" * 30)
    print("Running the model on the subjects")
    for patient in tqdm.tqdm(test_df.values):
        patient_path = os.path.join(temp_dir, patient[0],
                                    patient[0] + '_bratsized.nii.gz')
        patient_nib = nib.load(patient_path)
        image = patient_nib.get_fdata()
        image = process_image(image)
        image = resize(image, (128, 128, 128),
                       order=3,
                       mode='edge',
                       cval=0,
                       anti_aliasing=False)
        image = image[np.newaxis, np.newaxis, ...]
        image = torch.FloatTensor(image)
        if device != 'cpu':
            image = image.cuda()
        with torch.no_grad():
            output = model(image)
            output = output.cpu().numpy()[0][0]
            to_save = interpolate_image(output, patient_nib.shape)
            to_save[to_save >= 0.9] = 1
            to_save[to_save < 0.9] = 0
            to_save_nib = nib.Nifti1Image(to_save, patient_nib.affine)
            nib.save(
                to_save_nib,
                os.path.join(temp_dir, patient[0],
                             patient[0] + '_bratsized_mask.nii.gz'))
            current_patient_dict = patients_dict[patient[0]]
            new_image = padder_and_cropper(to_save,
                                           current_patient_dict['pad_info'])
            to_save_new_nib = nib.Nifti1Image(new_image, patient_nib.affine)
            nib.save(
                to_save_new_nib,
                os.path.join(temp_dir, patient[0],
                             patient[0] + '_resample111_mask.nii.gz'))
            to_save_final = resize(new_image,
                                   current_patient_dict['old_shape'],
                                   order=3,
                                   mode='edge',
                                   cval=0)
            to_save_final[to_save_final > 0.9] = 1
            to_save_final[to_save_final < 0.9] = 0
            for i in range(to_save_final.shape[2]):
                if np.any(to_save_final[:, :, i]):
                    to_save_final[:, :,
                                  i] = binary_fill_holes(to_save_final[:, :,
                                                                       i])
            to_save_final = postprocess_prediction(to_save_final).astype(
                np.uint8)
            to_save_final_nib = nib.Nifti1Image(
                to_save_final, current_patient_dict['old_affine'])

            os.makedirs(os.path.join(params['results_dir'], patient[0]),
                        exist_ok=True)

            nib.save(
                to_save_final_nib,
                os.path.join(params['results_dir'], patient[0],
                             patient[0] + '_mask.nii.gz'))

    print("Done with running the model.")
    if save_brain:
        print(
            "You chose to save the brain. We are now saving it with the masks."
        )
        for patient in tqdm.tqdm(test_df.values):
            image = nib.load(patient[1])
            image_data = image.get_fdata()
            mask = nib.load(
                os.path.join(params['results_dir'], patient[0],
                             patient[0] + '_mask.nii.gz'))
            mask_data = mask.get_fdata().astype(np.int8)
            image_data[mask_data == 0] = 0
            to_save_brain = nib.Nifti1Image(image_data, image.affine)
            nib.save(
                to_save_brain,
                os.path.join(params['results_dir'], patient[0],
                             patient[0] + '_brain.nii.gz'))

    print("Please check the %s folder for the intermediate outputs if you\"+\
          would like to see some intermediate steps." %
          (os.path.join(params['results_dir'], 'Temp')))
    print("Final output stored in : %s" % (params['results_dir']))
    print("Thank you for using BrainMaGe")
    print('*' * 60)
Example #10
0
import torch
from BrainMaGe.models.networks import fetch_model

import mo_pytorch
from openvino.inference_engine import IECore

brainmage_root = Path('../')
pytorch_model_path = brainmage_root / 'BrainMaGe/weights/resunet_ma.pt'
#ov_model_dir = brainmage_root / 'BrainMaGe/weights/ov/fp16'
ov_model_dir = brainmage_root / 'BrainMaGe/weights/ov/fp32'

if not os.path.exists(ov_model_dir):
    os.makedirs(ov_model_dir)

pt_model = fetch_model(modelname="resunet",
                       num_channels=1,
                       num_classes=2,
                       num_filters=16)
checkpoint = torch.load(pytorch_model_path, map_location=torch.device('cpu'))
pt_model.load_state_dict(checkpoint["model_state_dict"])

pt_model.eval()

# Test Accuracy

# Create dummy data
np.random.seed(123)
input_image = torch.Tensor(
    np.random.standard_normal([1, 1, 128, 128, 128]).astype(np.float32))

print("Running PyTorch Inference on random data...")
# Run Pytorch Inference