Exemple #1
0
def predict_central_focus(image, out_file):

    in_channels = 3
    out_channels = 1
    init_features = 32
    image_size = 224
    model_path = "./pytorch_unet_models/unet_tumors_only2.pt"
    preprocess = [
        resample, stack_channels_valid, normalization, pad_and_resize
    ]

    device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
    ])

    unet = UNet(in_channels=in_channels,
                out_channels=out_channels,
                init_features=init_features)
    unet.to(device)
    unet.load_state_dict(torch.load(model_path))
    unet.eval()

    epi_image = nib.load(image)
    epi_image_data = epi_image.get_fdata()
    dummy = epi_image_data.copy()
    h, w = epi_image_data.shape[0], epi_image_data.shape[1]

    pixdims = (epi_image.header["pixdim"], epi_image.header["pixdim"])
    for preprocess_ in preprocess:
        if preprocess_ == pad_and_resize:
            epi_image_data, dummy = preprocess_(epi_image_data,
                                                dummy,
                                                image_size=image_size)
        elif preprocess_ == resample:
            epi_image_data, dummy = preprocess_(epi_image_data,
                                                dummy,
                                                pixdims=pixdims)
        else:
            epi_image_data, dummy = preprocess_(epi_image_data, dummy)

    epi_label_pred = []
    for n_slice in trange(epi_image_data.shape[-1]):

        if n_slice > epi_image_data.shape[
                -1] * 0.05 and n_slice < epi_image_data.shape[-1] * 0.95:

            input_image = epi_image_data[..., n_slice]

            x = transform(Image.fromarray(input_image.transpose(1, 0, 2)))

            x = torch.unsqueeze(x, 0)
            y_pred = unet(x.to(device))
            y_pred_np = torch.squeeze(y_pred).detach().cpu().numpy()
            y_pred_np = y_pred_np.transpose(1, 0)

            y_pred_np = cv2.resize(y_pred_np, (h, w), cv2.INTER_CUBIC)
            #y_pred_np = np.round(y_pred_np).astype(np.uint8)

        else:

            y_pred_np = np.zeros((h, w), dtype=np.uint8)

        epi_label_pred.append(np.expand_dims(y_pred_np, axis=-1))

    epi_label_pred = np.concatenate(epi_label_pred, axis=-1)
    epi_label_pred = np.round(epi_label_pred)

    img = nib.Nifti1Image(epi_label_pred, affine=None)
    img.to_filename(out_file)
             "2-stage-224-2d-unet32": 0.5475687 * 0.2,
             "2-stage-256-2d-unet64": 0.5316030 * 0.2,
             "2-stage-256-2d-rc-unet64": 0.4642530 * 0.2,
             "2-stage-256-2d-n4rc-unet64": 0.4782162 * 0.2,
             "2-stage-224-2d-n4-unet32": 0.4700288 * 0.2,
             
             "128-3d-unet16": 0.5915776 * 0.5,
             "128-3d-unet16-pretrained": 0.6051299 * 0.5
            }
    
    model_list = os.listdir(feature_dir)
    
    device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
    
    in_channels = len(model_list)
    unet = UNet(in_channels=in_channels, out_channels=out_channels, init_features=init_features)
    unet.load_state_dict(torch.load(weights))
    unet.to(device, dtype=torch.float)
    unet.eval()
    
    filenames = glob.glob(os.path.join(feature_dir, model_list[0], "*.nii.gz"))
    filenames = [os.path.basename(filename) for filename in filenames]
    
    for filename in tqdm(filenames):
        
        feature_paths = {}
        for model in model_list:
            feature_paths[model] = os.path.join(feature_dir, model, filename)

        features = {key: nib.load(value).get_fdata() for key, value in feature_paths.items()}
        
def predict(image, out_file):

    in_channels = 3
    out_channels = 1
    init_features = 64
    origin_image_size = 512
    image_size = 256
    labels_map = {"brain": 0, "nobrain": 1}
    unet_model_path = "./pytorch_unet_models/rc-unet64.pt"
    efficientnet_model_path = "./pytorch_efficientnet_models/nobrainer.pt"
    model_name = "efficientnet-b0"
    preprocess = [
        resample, stack_channels_valid, normalization, pad_and_resize
    ]

    device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
    #device = torch.device("cpu")

    nobrainer_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
    ])
    unet_transform = transforms.Compose([
        transforms.FiveCrop(image_size),
        transforms.Lambda(lambda crops: torch.stack(
            [transforms.ToTensor()(crop) for crop in crops]))
    ])

    #efn = EfficientNet.from_name(model_name, num_classes=len(labels_map))
    efn = EfficientNet.from_name(
        model_name, override_params={'num_classes': len(labels_map)})
    efn.load_state_dict(torch.load(efficientnet_model_path))
    efn.to(device)

    unet = UNet(in_channels=in_channels,
                out_channels=out_channels,
                init_features=init_features)
    unet.to(device)
    unet.load_state_dict(torch.load(unet_model_path))
    unet.eval()

    epi_image = nib.load(image)
    epi_image_data = epi_image.get_fdata()
    dummy = epi_image_data.copy()
    h, w = epi_image_data.shape[0], epi_image_data.shape[1]

    pixdims = (epi_image.header["pixdim"], epi_image.header["pixdim"])
    for preprocess_ in preprocess:
        if preprocess_ == pad_and_resize:
            epi_image_data, dummy = preprocess_(epi_image_data,
                                                dummy,
                                                image_size=origin_image_size)
        elif preprocess_ == resample:
            epi_image_data, dummy = preprocess_(epi_image_data,
                                                dummy,
                                                pixdims=pixdims)
        else:
            epi_image_data, dummy = preprocess_(epi_image_data, dummy)

    epi_label_pred = []
    no_brain = []

    # stage one
    for n_slice in trange(epi_image_data.shape[-1]):
        input_image = epi_image_data[..., n_slice]

        x = nobrainer_transform(Image.fromarray(input_image.transpose(1, 0,
                                                                      2)))

        x = torch.unsqueeze(x, 0)

        y_nobrain = efn(x.to(device))
        y_nobrain_np = torch.squeeze(y_nobrain).detach().cpu().numpy()

        no_brain.append(np.argmax(softmax(y_nobrain_np)))

    # expand from center
    n_center_slice = int(epi_image_data.shape[-1] / 2)
    for n_slice in range(1, n_center_slice):
        prev, current = no_brain[n_slice - 1], no_brain[n_slice]
        if current == labels_map["brain"] and prev == labels_map["nobrain"]:
            no_brain[n_slice:n_center_slice] = [labels_map["brain"]
                                                ] * (n_center_slice - n_slice)
            break

    for n_slice in range(n_center_slice, epi_image_data.shape[-1]):
        prev, current = no_brain[n_slice - 1], no_brain[n_slice]
        if current == labels_map["nobrain"] and prev == labels_map["brain"]:
            no_brain[n_slice:epi_image_data.shape[-1]] = [
                labels_map["nobrain"]
            ] * (epi_image_data.shape[-1] - n_slice)
            break

    # stage two
    for n_slice in trange(epi_image_data.shape[-1]):

        if no_brain[n_slice] == labels_map["brain"]:

            input_image = epi_image_data[..., n_slice]

            x = unet_transform(Image.fromarray(input_image.transpose(1, 0, 2)))

            y_pred = unet(x.to(device))
            y_pred_np = torch.squeeze(y_pred).detach().cpu().numpy()

            y_pred_np = reconstruct_label(y_pred_np, h, w)
            y_pred_np = y_pred_np.transpose(1, 0)

            #y_pred_np = np.round(y_pred_np).astype(np.uint8)

        elif no_brain[n_slice] == labels_map["nobrain"]:

            y_pred_np = np.zeros((h, w), dtype=np.uint8)

        epi_label_pred.append(np.expand_dims(y_pred_np, axis=-1))

    epi_label_pred = np.concatenate(epi_label_pred, axis=-1)
    epi_label_pred = np.round(epi_label_pred)

    print("nobrain count: {:d}".format(np.sum(no_brain)))
    img = nib.Nifti1Image(epi_label_pred, affine=None)
    img.to_filename(out_file)
Exemple #4
0
def predict(image, out_file):

    in_channels = 3
    out_channels = 1
    init_features = 64
    origin_image_size = 512
    image_size = 256
    labels_map = {"brain": 0, "nobrain": 1}
    unet_model_path = "../weights/unet64/unet_last.pt"

    preprocess = [
        resample, stack_channels_valid, normalization, pad_and_resize
    ]

    device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
    #device = torch.device("cpu")

    unet_transform = transforms.Compose([
        transforms.FiveCrop(image_size),
        transforms.Lambda(lambda crops: torch.stack(
            [transforms.ToTensor()(crop) for crop in crops]))
    ])

    unet = UNet(in_channels=in_channels,
                out_channels=out_channels,
                init_features=init_features)
    unet.to(device)
    unet.load_state_dict(torch.load(unet_model_path))
    unet.eval()

    epi_image = nib.load(image)
    epi_image_data = epi_image.get_fdata()
    dummy = epi_image_data.copy()
    h, w = epi_image_data.shape[0], epi_image_data.shape[1]

    pixdims = (epi_image.header["pixdim"], epi_image.header["pixdim"])
    for preprocess_ in preprocess:
        if preprocess_ == pad_and_resize:
            epi_image_data, dummy = preprocess_(epi_image_data,
                                                dummy,
                                                image_size=origin_image_size)
        elif preprocess_ == resample:
            epi_image_data, dummy = preprocess_(epi_image_data,
                                                dummy,
                                                pixdims=pixdims)
        else:
            epi_image_data, dummy = preprocess_(epi_image_data, dummy)

    epi_label_pred = []

    # stage two
    for n_slice in trange(epi_image_data.shape[-1]):

        input_image = epi_image_data[..., n_slice]

        x = unet_transform(Image.fromarray(input_image.transpose(1, 0, 2)))

        y_pred = unet(x.to(device))
        y_pred_np = torch.squeeze(y_pred).detach().cpu().numpy()

        y_pred_np = reconstruct_label(y_pred_np, h, w)
        y_pred_np = y_pred_np.transpose(1, 0)

        y_pred_np = np.round(y_pred_np).astype(np.uint8)

        epi_label_pred.append(np.expand_dims(y_pred_np, axis=-1))

    epi_label_pred = np.concatenate(epi_label_pred, axis=-1)

    img = nib.Nifti1Image(epi_label_pred, affine=None)
    img.to_filename(out_file)
 fig_dir = "./fig/demo"
 
 if not os.path.exists(fig_dir):
     os.makedirs(fig_dir)
 
 image_size = 224
 
 device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
 
 transforms = transforms.Compose([
                                 transforms.Resize(image_size),
                                 transforms.ToTensor(),
                                 ])
 valid_dataset = SegmentationDataset(valid_folder_path, transform=transforms)
 
 unet = UNet(in_channels=3, out_channels=1)
 unet.to(device)
 
 unet.load_state_dict(torch.load("unet.pt"))
 unet.eval()
 
 count = 0
 #for x,y in tqdm(valid_dataset):
 for x,y in valid_dataset:
     x = torch.unsqueeze(x, 0)
     y_pred = unet(x.to(device))
     y_pred_np = torch.squeeze(y_pred).detach().cpu().numpy()
     y_true_np = torch.squeeze(y).detach().cpu().numpy()
     
     x_np = torch.squeeze(x).detach().cpu().numpy()
     #print(np.unique(y_true_np), (np.min(y_pred_np), np.max(y_pred_np)))
    #device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
    device = torch.device("cpu")

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
    ])

    efn = EfficientNet.from_name(
        model_name, override_params={'num_classes': len(labels_map)})
    efn.load_state_dict(torch.load(efficientnet_model_path))
    efn.to(device)

    unet = UNet(in_channels=in_channels,
                out_channels=out_channels,
                init_features=init_features)
    unet.to(device)
    unet.load_state_dict(torch.load(unet_model_path))
    unet.eval()

    epi_image = nib.load(image)
    epi_image_data = epi_image.get_fdata()
    dummy = epi_image_data.copy()
    h, w = epi_image_data.shape[0], epi_image_data.shape[1]

    pixdims = (epi_image.header["pixdim"], epi_image.header["pixdim"])
    for preprocess_ in preprocess:
        if preprocess_ == pad_and_resize:
            epi_image_data, dummy = preprocess_(epi_image_data,
                                                dummy,