Exemplo n.º 1
0
    def try_inference(self, padding_mode):
        for n in 17, 27:
            patch_size = 10, 15, n
            patch_overlap = 4, 6, 8
            batch_size = 6

            grid_sampler = GridSampler(
                self.sample,
                patch_size,
                patch_overlap,
                padding_mode=padding_mode,
            )
            aggregator = GridAggregator(grid_sampler)
            patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
            with torch.no_grad():
                for patches_batch in tqdm(patch_loader):
                    input_tensor = patches_batch['t1'][DATA]
                    locations = patches_batch[LOCATION]
                    logits = model(input_tensor)  # some model
                    outputs = logits
                    aggregator.add_batch(outputs, locations)

            output = aggregator.get_output_tensor()
            assert (output == -5).all()
            assert output.shape == self.sample.t1.shape
Exemplo n.º 2
0
def predict_validation_region(model, validation_dataset, validation_batch_size,
                              thresh_val, data_path):
    """Performs model prediction of segmentation given some data. 

    Args:
        model: U-net model to segment the data.
        validation_dataset (torchio.ImagesDataset): The dataset containing the validation volume.
        validation_batch_size (int): The batch size to use for prediction.
        thresh_val (float): Value to threshold binary prediction.
        data_path (pathlib.Path): Path to determine volume output location.
    """
    sample = validation_dataset[0]
    patch_overlap = 32

    grid_sampler = GridSampler(
        sample,
        PATCH_SIZE,
        patch_overlap,
        padding_mode='reflect'
    )
    patch_loader = DataLoader(
        grid_sampler, batch_size=validation_batch_size)
    aggregator = torchio.data.inference.GridAggregator(grid_sampler)

    model.eval()
    with torch.no_grad():
        for patches_batch in patch_loader:
            inputs = patches_batch['data'][DATA].to(DEVICE_NUM)
            locations = patches_batch[torchio.LOCATION]
            logits = model(inputs)
            if (hasattr(model, 'final_activation')
                    and model.final_activation is not None):
                logits = model.final_activation(logits)
            aggregator.add_batch(logits, locations)
    predicted_vol = aggregator.get_output_tensor()  # output is 4D
    print(f"Shape of the predicted Volume is: {predicted_vol.shape}")
    predicted_vol = predicted_vol.numpy().squeeze()  # remove first dimension
    if multilabel:
        predicted_vol = np.argmax(predicted_vol, axis=0)
    h5_out_path = data_path/f"{MODEL_OUT_FN[:-8]}_validation_vol_predicted.h5"
    print(f"Outputting prediction of the validation volume to {h5_out_path}")
    with h5.File(h5_out_path, 'w') as f:
        f['/data'] = predicted_vol.astype(np.uint8)
    # Threshold if binary
    if not multilabel:
        predicted_vol[predicted_vol >= thresh_val] = 1
        predicted_vol[predicted_vol < thresh_val] = 0
        predicted_vol = predicted_vol.astype(np.uint8)
        h5_out_path = data_path/f"{MODEL_OUT_FN[:-8]}_validation_vol_pred_thresh.h5"
        print(f"Outputting prediction of the thresholded validation volume "
            f"to {h5_out_path}")
        with h5.File(h5_out_path, 'w') as f:
            f['/data'] = predicted_vol
    plot_validation_slices(predicted_vol, data_path, MODEL_OUT_FN)
Exemplo n.º 3
0
    def test_inference(self):
        model = nn.Conv3d(1, 1, 3)
        patch_size = 10, 15, 27
        patch_overlap = 4, 5, 8
        batch_size = 6
        CHANNELS_DIMENSION = 1

        # Let's create a dummy volume
        input_array = torch.rand((10, 20, 30)).numpy()
        grid_sampler = GridSampler(input_array, patch_size, patch_overlap)
        patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
        aggregator = GridAggregator(input_array, patch_overlap)

        with torch.no_grad():
            for patches_batch in tqdm(patch_loader):
                input_tensor = patches_batch[IMAGE]
                locations = patches_batch[LOCATION]
                logits = model(input_tensor)  # some model
                labels = logits.argmax(dim=CHANNELS_DIMENSION, keepdim=True)
                outputs = labels
                aggregator.add_batch(outputs, locations)
Exemplo n.º 4
0
    def test_inference(self):
        model = nn.Conv3d(1, 1, 3)
        patch_size = 10, 15, 27
        patch_overlap = 4, 5, 8
        batch_size = 6
        CHANNELS_DIMENSION = 1

        grid_sampler = GridSampler(self.sample, patch_size, patch_overlap)
        patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
        aggregator = GridAggregator(self.sample, patch_overlap)

        with torch.no_grad():
            for patches_batch in tqdm(patch_loader):
                input_tensor = patches_batch['t1'][DATA]
                locations = patches_batch[LOCATION]
                logits = model(input_tensor)  # some model
                labels = logits.argmax(dim=CHANNELS_DIMENSION, keepdim=True)
                outputs = labels
                aggregator.add_batch(outputs, locations)

        aggregator.get_output_tensor()
Exemplo n.º 5
0
def predict_agg_3d(
    input_array,
    model3d,
    patch_size=(128, 224, 224),
    patch_overlap=(12, 12, 12),
    nb=True,
    device=0,
    debug_verbose=False,
    fpn=False,
    overlap_mode="crop",
):
    import torchio as tio
    from torchio import IMAGE, LOCATION
    from torchio.data.inference import GridAggregator, GridSampler

    print(input_array.shape)
    img_tens = torch.FloatTensor(input_array[:]).unsqueeze(0)
    print(f"Predict and aggregate on volume of {img_tens.shape}")

    one_subject = tio.Subject(
        img=tio.Image(tensor=img_tens, label=tio.INTENSITY),
        label=tio.Image(tensor=img_tens, label=tio.LABEL),
    )

    img_dataset = tio.SubjectsDataset(
        [
            one_subject,
        ]
    )
    img_sample = img_dataset[-1]

    batch_size = 1

    grid_sampler = GridSampler(img_sample, patch_size, patch_overlap)
    patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
    aggregator1 = GridAggregator(grid_sampler, overlap_mode=overlap_mode)

    input_tensors = []
    output_tensors = []

    if nb:
        from tqdm.notebook import tqdm
    else:
        from tqdm import tqdm

    with torch.no_grad():

        for patches_batch in tqdm(patch_loader):
            input_tensor = patches_batch["img"]["data"]
            locations = patches_batch[LOCATION]
            inputs_t = input_tensor
            inputs_t = inputs_t.to(device)

            if fpn:
                outputs = model3d(inputs_t)[0]
            else:
                outputs = model3d(inputs_t)
            if debug_verbose:
                print(f"inputs_t: {inputs_t.shape}")
                print(f"outputs: {outputs.shape}")

            output = outputs[:, 0:1, :]
            # output = torch.sigmoid(output)

            aggregator1.add_batch(output, locations)

    return aggregator1
Exemplo n.º 6
0
def gridsampler_pipeline(
        input_array,
        entity_pts,
        patch_size=(64, 64, 64),
        patch_overlap=(0, 0, 0),
        batch_size=1,
):
    import torchio as tio
    from torchio import IMAGE, LOCATION
    from torchio.data.inference import GridAggregator, GridSampler

    logger.debug("Starting up gridsampler pipeline...")
    input_tensors = []
    output_tensors = []

    entity_pts = entity_pts.astype(np.int32)
    img_tens = torch.FloatTensor(input_array)

    one_subject = tio.Subject(
        img=tio.Image(tensor=img_tens, label=tio.INTENSITY),
        label=tio.Image(tensor=img_tens, label=tio.LABEL),
    )

    img_dataset = tio.ImagesDataset([
        one_subject,
    ])
    img_sample = img_dataset[-1]
    grid_sampler = GridSampler(img_sample, patch_size, patch_overlap)
    patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
    aggregator1 = GridAggregator(grid_sampler)
    aggregator2 = GridAggregator(grid_sampler)

    pipeline = Pipeline({
        "p":
        1,
        "ordered_ops": [
            make_masks,
            make_features,
            make_sr,
            make_seg_sr,
            make_seg_cnn,
        ],
    })

    payloads = []

    with torch.no_grad():
        for patches_batch in patch_loader:
            locations = patches_batch[LOCATION]

            loc_arr = np.array(locations[0])
            loc = (loc_arr[0], loc_arr[1], loc_arr[2])
            logger.debug(f"Location: {loc}")

            # Prepare region data (IMG (Float Volume) AND GEOMETRY (3d Point))
            cropped_vol, offset_pts = crop_vol_and_pts_centered(
                input_array,
                entity_pts,
                location=loc,
                patch_size=patch_size,
                offset=True,
                debug_verbose=True,
            )

            plt.figure(figsize=(12, 12))
            plt.imshow(cropped_vol[cropped_vol.shape[0] // 2, :], cmap="gray")
            plt.scatter(offset_pts[:, 1], offset_pts[:, 2])

            logger.debug(f"Number of offset_pts: {offset_pts.shape}")
            logger.debug(
                f"Allocating memory for no. voxels: {cropped_vol.shape[0] * cropped_vol.shape[1] * cropped_vol.shape[2]}"
            )

            # payload = Patch(
            #    {"in_array": cropped_vol},
            #    offset_pts,
            #    None,
            # )

            payload = Patch(
                {"total_mask": np.random.random((4, 4), )},
                {"total_anno": np.random.random((4, 4), )},
                {"points": np.random.random((4, 3), )},
            )
            pipeline.init_payload(payload)

            for step in pipeline:
                logger.debug(step)

            # Aggregation (Output: large volume aggregated from many smaller volumes)
            output_tensor = (torch.FloatTensor(
                payload.annotation_layers["total_mask"]).unsqueeze(
                    0).unsqueeze(1))
            logger.debug(
                f"Aggregating output tensor of shape: {output_tensor.shape}")
            aggregator1.add_batch(output_tensor, locations)

            output_tensor = (torch.FloatTensor(
                payload.annotation_layers["prediction"]).unsqueeze(
                    0).unsqueeze(1))
            logger.debug(
                f"Aggregating output tensor of shape: {output_tensor.shape}")
            aggregator2.add_batch(output_tensor, locations)
            payloads.append(payload)

    output_tensor1 = aggregator1.get_output_tensor()
    logger.debug(output_tensor1.shape)
    output_arr1 = np.array(output_tensor1.squeeze(0))

    output_tensor2 = aggregator2.get_output_tensor()
    logger.debug(output_tensor2.shape)
    output_arr2 = np.array(output_tensor2.squeeze(0))

    return [output_tensor1, output_tensor2], payloads
Exemplo n.º 7
0
    batch_size = 2  # Set to 2 for 32Gb Card
print(f"Patch size is {PATCH_SIZE}")
print(f"Free GPU memory is {free_gpu_mem:0.2f} GB. Batch size will be "
      f"{batch_size}.")

# Load model
print(f"Loading model from {MODEL_FILE}")
model_dict = torch.load(MODEL_FILE, map_location='cpu')
unet = create_unet_on_device(DEVICE_NUM, model_dict['model_struc_dict'])
unet.load_state_dict(model_dict['model_state_dict'])
if model_dict['model_struc_dict']['out_channels'] > 1:
    multilabel = True
# Load the data and create a sampler
print(f"Loading data from {DATA_FILE}")
data_tens = tensor_from_hdf5(DATA_FILE, HDF5_PATH)
data_subject = torchio.Subject(
    data=torchio.Image(tensor=data_tens, label=torchio.INTENSITY))
print(f"Setting up grid sampler with overlap {PATCH_OVERLAP} and padding "
      f"mode: {PADDING_MODE}")
grid_sampler = GridSampler(data_subject,
                           PATCH_SIZE,
                           PATCH_OVERLAP,
                           padding_mode=PADDING_MODE)

pred_vol = predict_volume(unet, grid_sampler, batch_size, DATA_OUT_FN,
                          multilabel)
fig_out_dir = DATA_OUT_DIR / f'{date.today()}_3d_prediction_figs'
print(f"Creating directory for figures: {fig_out_dir}")
os.makedirs(fig_out_dir, exist_ok=True)
plot_predict_figure(pred_vol, data_tens, fig_out_dir)
Exemplo n.º 8
0
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchio import IMAGE, LOCATION
from torchio.data.inference import GridSampler, GridAggregator

patch_size = 128, 128, 128  # or just 128
patch_overlap = 4, 4, 4  # or just 4
batch_size = 6
CHANNELS_DIMENSION = 1

# Let's create a dummy volume
input_array = torch.rand((193, 229, 193)).numpy()

# More info about patch-based inference in NiftyNet docs:
# https://niftynet.readthedocs.io/en/dev/window_sizes.html
grid_sampler = GridSampler(input_array, patch_size, patch_overlap)
patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
aggregator = GridAggregator(input_array, patch_overlap)

model = nn.Module()  # some Pytorch model

with torch.no_grad():
    for patches_batch in tqdm(patch_loader):
        input_tensor = patches_batch[IMAGE]
        locations = patches_batch[LOCATION]
        logits = model(input_tensor)
        labels = logits.argmax(dim=CHANNELS_DIMENSION, keepdim=True)
        outputs = labels
        aggregator.add_batch(outputs, locations)

output_tensor = aggregator.get_output_tensor()