Exemplo n.º 1
0
def infer_one(model,
              mask,
              tile_size=(512, 512),
              tile_step=(256, 256),
              weight='mean'):
    image = mask.cpu().numpy()
    image = np.moveaxis(image, 0, -1)

    with torch.no_grad():
        tiler = ImageSlicer((900, 900),
                            tile_size=tile_size,
                            tile_step=tile_step,
                            weight=weight)
        tiles = [np.moveaxis(tile, -1, 0) for tile in tiler.split(image)]
        merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight)

        for tiles_batch, coords_batch in DataLoader(list(
                zip(tiles, tiler.crops)),
                                                    batch_size=10,
                                                    pin_memory=False):
            tiles_batch = tiles_batch.float().cuda()
            pred_batch = model(tiles_batch)
            tiles_batch.cpu().detach()
            merger.integrate_batch(pred_batch, coords_batch)
    merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1)
    merged_mask = tiler.crop_to_orignal_size(merged_mask)

    m = merged_mask[..., 0].copy()
    return m
def predict(model: nn.Module,
            image: np.ndarray,
            image_size,
            normalize=A.Normalize(),
            batch_size=1) -> np.ndarray:

    tile_step = (image_size[0] // 2, image_size[1] // 2)

    tile_slicer = ImageSlicer(image.shape, image_size, tile_step)
    tile_merger = CudaTileMerger(tile_slicer.target_shape, 1,
                                 tile_slicer.weight)
    patches = tile_slicer.split(image)

    transform = A.Compose([normalize, A.Lambda(image=_tensor_from_rgb_image)])

    data = list({
        "image": patch,
        "coords": np.array(coords, dtype=np.int)
    } for (patch, coords) in zip(patches, tile_slicer.crops))
    for batch in DataLoader(InMemoryDataset(data, transform),
                            pin_memory=True,
                            batch_size=batch_size):
        image = batch["image"].cuda(non_blocking=True)
        coords = batch["coords"]
        mask_batch = model(image)
        tile_merger.integrate_batch(mask_batch, coords)

    mask = tile_merger.merge()

    mask = np.moveaxis(to_numpy(mask), 0, -1)
    mask = tile_slicer.crop_to_orignal_size(mask)

    return mask
Exemplo n.º 3
0
def test_tiles_split_merge_cuda():
    if not torch.cuda.is_available():
        return

    class MaxChannelIntensity(nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, input):
            max_channel, _ = torch.max(input, dim=1, keepdim=True)
            return max_channel

    image = np.random.random((5000, 5000, 3)).astype(np.uint8)
    tiler = ImageSlicer(image.shape,
                        tile_size=(512, 512),
                        tile_step=(256, 256),
                        weight='pyramid')
    tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image)]

    model = MaxChannelIntensity().eval().cuda()

    merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight)
    for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)),
                                                batch_size=8,
                                                pin_memory=True):
        tiles_batch = tiles_batch.float().cuda()
        pred_batch = model(tiles_batch)

        merger.integrate_batch(pred_batch, coords_batch)

    merged = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8)
    merged = tiler.crop_to_orignal_size(merged)

    np.testing.assert_equal(merged, image.max(axis=2, keepdims=True))
Exemplo n.º 4
0
def inference_tiles(inference_model,
                    img_full,
                    device='cuda',
                    shape=(32, 1, 768, 448),
                    weight='mean',
                    mean=88.904434,
                    std=62.048634,
                    plot=False):
    bs = shape[0]
    input_x = shape[2]
    input_y = shape[3]

    # Cut large image into overlapping tiles
    tiler = ImageSlicer(img_full.shape,
                        tile_size=(input_x, input_y),
                        tile_step=(input_x // 2, input_y // 2),
                        weight=weight)

    # HCW -> CHW. Optionally, do normalization here
    tiles = [
        tensor_from_rgb_image(tile)
        for tile in tiler.split(cv2.cvtColor(img_full, cv2.COLOR_GRAY2RGB))
    ]

    # Allocate a CUDA buffer for holding entire mask
    merger = CudaTileMerger(tiler.target_shape,
                            channels=1,
                            weight=tiler.weight)

    # Run predictions for tiles and accumulate them
    for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)),
                                                batch_size=bs,
                                                pin_memory=True):

        # Move tile to GPU
        tiles_batch = ((tiles_batch.float() - mean) / std).to(device)

        # Predict
        pred_batch = inference_model(tiles_batch)

        # Merge on GPU
        merger.integrate_batch(pred_batch, coords_batch)

        if plot:
            for i in range(pred_batch.to('cpu').numpy().shape[0]):
                plt.imshow(tiles_batch.to('cpu').numpy()[i, 0, :, :])
                plt.show()
                plt.imshow(pred_batch.to('cpu').numpy()[i, 0, :, :])
                plt.colorbar()
                plt.show()

    # Normalize accumulated mask and convert back to numpy
    merged_mask = np.moveaxis(to_numpy(merger.merge()), 0,
                              -1).astype('float32')
    merged_mask = tiler.crop_to_orignal_size(merged_mask)

    torch.cuda.empty_cache()

    return merged_mask.squeeze()
Exemplo n.º 5
0
def test_tiles_split_merge_non_dividable():
    image = np.random.random((563, 512, 3)).astype(np.uint8)
    tiler = ImageSlicer(
        image.shape, tile_size=(128, 128), tile_step=(128, 128), weight="mean"
    )
    tiles = tiler.split(image)
    merged = tiler.merge(tiles, dtype=np.uint8)
    np.testing.assert_equal(merged, image)
Exemplo n.º 6
0
def predict_mask(image,
                 model,
                 dims=3,
                 size=394,
                 step=192,
                 batch_size=8,
                 plot_image=False,
                 dstdir=None,
                 img_name='image1.png'):
    if image.ndim == 2:
        image = np.expand_dims(image, 2)
    if image.shape[-1] != dims:
        if image.shape[-1] == 1:
            image = np.repeat(image, 3, axis=2)
        elif image.shape[-1] == 3:
            image = np.expand_dims(image[:, :, 0], 2)
    print(image.shape)

    # Cut large image into overlapping tiles
    tiler = ImageSlicer(image.shape,
                        tile_size=(size, size),
                        tile_step=(step, step),
                        weight='pyramid')

    # HCW -> CHW. Optionally, do normalization here
    tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image)]

    # Allocate a CUDA buffer for holding entire mask
    merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight)

    # Run predictions for tiles and accumulate them
    with torch.no_grad():
        for tiles_batch, coords_batch in DataLoader(list(
                zip(tiles, tiler.crops)),
                                                    batch_size=batch_size,
                                                    pin_memory=True):
            #         print(tiles_batch.shape)
            tiles_batch = tiles_batch.float().cuda()
            pred_batch = model(tiles_batch)
            pred_mask = pred_batch.max(dim=1)[1].float()

            merger.integrate_batch(pred_mask, coords_batch)

    # Normalize accumulated mask and convert back to numpy
    merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8)
    merged_mask = tiler.crop_to_orignal_size(merged_mask)

    if plot_image:
        assert dstdir is not None, 'dstdir should be passed'
        fig, ax = plt.subplots(ncols=2, figsize=(20, 10))
        ax[0].imshow(image[:, :, 0], cmap='gray')
        ax[1].imshow(merged_mask[:, :, 0], alpha=0.3)
        fig.savefig(osp.join(dstdir, img_name),
                    bbox_inches='tight',
                    pad_inches=0)
        print(osp.join(dstdir, img_name))
    return merged_mask
Exemplo n.º 7
0
def test_tiles_split_merge():
    image = np.random.random((5000, 5000, 3)).astype(np.uint8)
    tiler = ImageSlicer(image.shape,
                        tile_size=512,
                        tile_step=256,
                        weight='mean')
    tiles = tiler.split(image)
    merged = tiler.merge(tiles, dtype=np.uint8)
    np.testing.assert_equal(merged, image)
Exemplo n.º 8
0
def test_tiles_split_merge_2():
    image = np.random.random((5000, 5000, 3)).astype(np.uint8)
    tiler = ImageSlicer(
        image.shape, tile_size=(512, 512), tile_step=(256, 256), weight="pyramid"
    )

    np.testing.assert_allclose(tiler.weight, tiler.weight.T)

    tiles = tiler.split(image)
    merged = tiler.merge(tiles, dtype=np.uint8)
    np.testing.assert_equal(merged, image)
Exemplo n.º 9
0
def predict(model: nn.Module,
            image: np.ndarray,
            image_size,
            tta=None,
            normalize=A.Normalize(),
            batch_size=1,
            activation='sigmoid') -> np.ndarray:
    model.eval()
    tile_step = (image_size[0] // 2, image_size[1] // 2)

    tile_slicer = ImageSlicer(image.shape,
                              image_size,
                              tile_step,
                              weight='pyramid')
    tile_merger = CudaTileMerger(tile_slicer.target_shape, 1,
                                 tile_slicer.weight)
    patches = tile_slicer.split(image)

    transform = A.Compose([normalize, A.Lambda(image=_tensor_from_rgb_image)])

    if tta == 'fliplr':
        model = TTAWrapperFlipLR(model)
        print('Using FlipLR TTA')

    if tta == 'd4':
        model = TTAWrapperD4(model)
        print('Using D4 TTA')

    with torch.no_grad():
        data = list({
            'image': patch,
            'coords': np.array(coords, dtype=np.int)
        } for (patch, coords) in zip(patches, tile_slicer.crops))
        for batch in DataLoader(InMemoryDataset(data, transform),
                                pin_memory=True,
                                batch_size=batch_size):
            image = batch['image'].cuda(non_blocking=True)
            coords = batch['coords']
            mask_batch = model(image)
            tile_merger.integrate_batch(mask_batch, coords)

    mask = tile_merger.merge()
    if activation == 'sigmoid':
        mask = mask.sigmoid()

    if isinstance(activation, float):
        mask = F.relu(mask_batch - activation, inplace=True)

    mask = np.moveaxis(to_numpy(mask), 0, -1)
    mask = tile_slicer.crop_to_orignal_size(mask)

    return mask
Exemplo n.º 10
0
def inference(inference_model, img_full, device='cuda'):
    x, y, ch = img_full.shape

    input_x = config['training']['crop_size'][0]
    input_y = config['training']['crop_size'][1]

    # Cut large image into overlapping tiles
    tiler = ImageSlicer(img_full.shape, tile_size=(input_x, input_y),
                        tile_step=(input_x // 2, input_y // 2), weight=args.weight)

    # HCW -> CHW. Optionally, do normalization here
    tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(img_full)]

    # Allocate a CUDA buffer for holding entire mask
    merger = CudaTileMerger(tiler.target_shape, channels=1, weight=tiler.weight)

    # Run predictions for tiles and accumulate them
    for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), batch_size=args.bs, pin_memory=True):
        # Move tile to GPU
        tiles_batch = (tiles_batch.float() / 255.).to(device)
        # Predict and move back to CPU
        pred_batch = inference_model(tiles_batch)

        # Merge on GPU
        merger.integrate_batch(pred_batch, coords_batch)

        # Plot
        if args.plot:
            for i in range(args.bs):
                if args.bs != 1:
                    plt.imshow(pred_batch.cpu().detach().numpy().astype('float32').squeeze()[i, :, :])
                else:
                    plt.imshow(pred_batch.cpu().detach().numpy().astype('float32').squeeze())
                plt.show()

    # Normalize accumulated mask and convert back to numpy
    merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype('float32')
    merged_mask = tiler.crop_to_orignal_size(merged_mask)
    # Plot
    if args.plot:
        for i in range(args.bs):
            if args.bs != 1:
                plt.imshow(merged_mask)
            else:
                plt.imshow(merged_mask.squeeze())
            plt.show()

    torch.cuda.empty_cache()
    gc.collect()

    return merged_mask.squeeze()
def get_slice_heatmap(slide_id,
                      image_dir,
                      model,
                      crop_size=512,
                      tile_step=512,
                      bsize=16,
                      num_workers=4,
                      normalize=True,
                      *args,
                      **kwargs):
    image = openslide.OpenSlide(os.path.join(image_dir, f'{slide_id}.tiff'))
    x, y = image.level_dimensions[0][::-1]
    tiler = ImageSlicer((x, y),
                        tile_size=(crop_size, crop_size),
                        tile_step=(tile_step, tile_step),
                        weight='mean')
    merger = CpuTileMerger(tiler.target_shape,
                           1,
                           tiler.weight,
                           normalize=normalize)
    dataset = InferenceSingleImage(tiler.crops,
                                   slide_id,
                                   image_dir,
                                   crop_size=crop_size,
                                   **kwargs)
    dataloader = DataLoader(dataset,
                            batch_size=bsize,
                            pin_memory=True,
                            num_workers=num_workers)
    pseudo_mask = np.ones((1, crop_size, crop_size))
    with torch.no_grad():
        for data_b in tqdm(dataloader, total=len(dataloader)):
            need_processing = data_b['need_to_process']
            if need_processing.sum() > 0:
                pred_batch = model(data_b['image'][need_processing].cuda())
                pred_batch = torch.nn.Softmax(dim=1)(pred_batch).cpu().numpy()
                pred_batch = np.stack([
                    pseudo_mask * np.argmax(pred_batch[idx])
                    for idx in range(pred_batch.shape[0])
                ])
                merger.integrate_batch(
                    pred_batch,
                    data_b['coords'][need_processing].cpu().numpy())
    merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8)
    merged_mask = tiler.crop_to_orignal_size(merged_mask)
    merged_mask = cv2.resize(merged_mask,
                             image.level_dimensions[-1],
                             interpolation=cv2.INTER_AREA)
    return (merged_mask)
Exemplo n.º 12
0
def split_image(image_fname, output_dir, tile_size, tile_step, image_margin):
    os.makedirs(output_dir, exist_ok=True)
    image = read_image_as_is(image_fname)
    image_id = id_from_fname(image_fname)

    slicer = ImageSlicer(image.shape, tile_size, tile_step, image_margin)
    tiles = slicer.split(image)

    fnames = []
    for i, tile in enumerate(tiles):
        output_fname = os.path.join(output_dir, f"{image_id}_tile_{i}.png")
        cv2.imwrite(output_fname, tile)
        fnames.append(output_fname)

    return fnames
Exemplo n.º 13
0
    def __init__(
        self,
        image_fname: str,
        mask_fname: str,
        image_loader: Callable,
        target_loader: Callable,
        tile_size,
        tile_step,
        image_margin=0,
        transform=None,
        target_shape=None,
        keep_in_mem=False,
    ):
        self.image_fname = image_fname
        self.mask_fname = mask_fname
        self.image_loader = image_loader
        self.mask_loader = target_loader
        self.image = None
        self.mask = None

        if target_shape is None or keep_in_mem:
            image = image_loader(image_fname)
            mask = target_loader(mask_fname)
            if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[
                    1]:
                raise ValueError(
                    f"Image size {image.shape} and mask shape {image.shape} must have equal width and height"
                )

            target_shape = image.shape

        self.slicer = ImageSlicer(target_shape, tile_size, tile_step,
                                  image_margin)

        if keep_in_mem:
            self.images = self.slicer.split(image)
            self.masks = self.slicer.split(mask)
        else:
            self.images = None
            self.masks = None

        self.transform = transform
        self.image_ids = [
            id_from_fname(image_fname) +
            f" [{crop[0]};{crop[1]};{crop[2]};{crop[3]};]"
            for crop in self.slicer.crops
        ]
Exemplo n.º 14
0
    def __init__(
        self,
        image_fname: str,
        mask_fname: str,
        image_loader: Callable,
        target_loader: Callable,
        tile_size,
        tile_step,
        image_margin=0,
        transform=None,
        target_shape=None,
        need_weight_mask=False,
        keep_in_mem=False,
        make_mask_target_fn: Callable = mask_to_bce_target,
    ):
        self.image_fname = image_fname
        self.mask_fname = mask_fname
        self.image_loader = image_loader
        self.mask_loader = target_loader
        self.image = None
        self.mask = None
        self.need_weight_mask = need_weight_mask

        if target_shape is None or keep_in_mem:
            image = image_loader(image_fname)
            mask = target_loader(mask_fname)
            if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[
                    1]:
                raise ValueError(
                    f"Image size {image.shape} and mask shape {image.shape} must have equal width and height"
                )

            target_shape = image.shape

        self.slicer = ImageSlicer(target_shape, tile_size, tile_step,
                                  image_margin)

        self.transform = transform
        self.image_ids = [fs.id_from_fname(image_fname)] * len(
            self.slicer.crops)
        self.crop_coords_str = [
            f"[{crop[0]};{crop[1]};{crop[2]};{crop[3]};]"
            for crop in self.slicer.crops
        ]
        self.make_mask_target_fn = make_mask_target_fn
def run_validation(data_df, model, data_folder, augmentation, tiles=False):
    total_dice_coeffs = []
    mean_dice_per_image = []
    for image_n in tqdm(range(data_df.shape[0])):
        image = cv2.imread(
            os.path.join(data_folder, data_df.index.values[image_n]))
        augmented = augmentation(image=image)
        image_processed = augmented['image']
        if tiles:
            tiler = ImageSlicer(image_processed.shape[:2],
                                tile_size=(224, 224),
                                tile_step=(56, 56),
                                weight='mean')
            merger = CudaTileMerger(tiler.target_shape, 4, tiler.weight)
            tiles = [
                tensor_from_rgb_image(tile)
                for tile in tiler.split(image_processed)
            ]
            for tiles_batch, coords_batch in DataLoader(list(
                    zip(tiles, tiler.crops)),
                                                        batch_size=16,
                                                        pin_memory=True):
                tiles_batch = tiles_batch.float().cuda()
                pred_batch = torch.nn.Sigmoid()(model(tiles_batch))
                merger.integrate_batch(pred_batch, coords_batch)
            predictions = np.moveaxis(to_numpy(merger.merge()), 0, -1)
            predictions = tiler.crop_to_orignal_size(predictions)
        else:
            image_processed = torch.from_numpy(
                np.expand_dims(image_processed.transpose((2, 0, 1)),
                               0)).float()
            predictions = torch.nn.Sigmoid()(model(
                image_processed.cuda())[0]).detach().cpu().numpy()
            predictions = np.moveaxis(predictions, 0, -1)
        predictions_bin = (predictions > 0.5).astype(int)
        fname, masks = make_mask(image_n, data_df)
        dices_image = []
        for defect_type in range(4):
            computed_dice = dice(masks[:, :, defect_type],
                                 predictions_bin[:, :, defect_type])
            total_dice_coeffs.append(computed_dice)
            dices_image.append(computed_dice)
        mean_dice_per_image.append(np.mean(dices_image))
    return np.mean(total_dice_coeffs), mean_dice_per_image
Exemplo n.º 16
0
def predict_on_zslice_tiles(model,
                            zimage,
                            tile_size=(512, 512),
                            tile_step=(256, 256)):

    image = zimage[0, 0, :, :]
    print(f'Stack shape:{zimage.shape}')
    print(f'Slice shape:{image.shape}')

    # Cut large image into overlapping tiles
    tiler = ImageSlicer(image.shape,
                        tile_size=(512, 512),
                        tile_step=(256, 256))

    print(tiler.crops)

    # HCW -> CHW. Optionally, do normalization here
    tiles = [tensor_from_mask_image(tile) for tile in tiler.split(image)]

    # Allocate a CUDA buffer for holding entire mask
    merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight)

    # Run predictions for tiles and accumulate them
    for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)),
                                                batch_size=1,
                                                pin_memory=True):
        #         for x, y, tile_width, tile_height in coords_batch:
        #             tile = image[y : y + tile_height, x : x + tile_width].copy()
        tiles_batch = tiles_batch.float().cuda()
        pred_batch = model(tiles_batch)

        merger.integrate_batch(pred_batch, coords_batch)

    # Normalize accumulated mask and convert back to numpy


#     merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8)
    merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1)
    merged_mask = tiler.crop_to_orignal_size(merged_mask)

    return merged_mask
Exemplo n.º 17
0
def test_tiles_split_merge_non_dividable_cuda():
    image = np.random.random((5632, 5120, 3)).astype(np.uint8)
    tiler = ImageSlicer(image.shape,
                        tile_size=(1280, 1280),
                        tile_step=(1280, 1280),
                        weight='mean')
    tiles = tiler.split(image)

    merger = CudaTileMerger(tiler.target_shape,
                            channels=image.shape[2],
                            weight=tiler.weight)
    for tile, coordinates in zip(tiles, tiler.crops):
        # Integrate as batch of size 1
        merger.integrate_batch(
            tensor_from_rgb_image(tile).unsqueeze(0).float().cuda(),
            [coordinates])

    merged = merger.merge()
    merged = rgb_image_from_tensor(merged, mean=0, std=1, max_pixel_value=1)
    merged = tiler.crop_to_orignal_size(merged)

    np.testing.assert_equal(merged, image)
Exemplo n.º 18
0
class TiledSingleImageDataset(Dataset):
    def __init__(self,
                 image_fname: str,
                 mask_fname: str,
                 image_loader: Callable,
                 target_loader: Callable,
                 tile_size,
                 tile_step,
                 image_margin=0,
                 transform=None,
                 target_shape=None,
                 keep_in_mem=False):
        self.image_fname = image_fname
        self.mask_fname = mask_fname
        self.image_loader = image_loader
        self.mask_loader = target_loader
        self.image = None
        self.mask = None

        if target_shape is None or keep_in_mem:
            image = image_loader(image_fname)
            mask = target_loader(mask_fname)
            if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[
                    1]:
                raise ValueError(
                    f"Image size {image.shape} and mask shape {image.shape} must have equal width and height"
                )

            target_shape = image.shape

        self.slicer = ImageSlicer(target_shape, tile_size, tile_step,
                                  image_margin)

        if keep_in_mem:
            self.images = self.slicer.split(image)
            self.masks = self.slicer.split(mask)
        else:
            self.images = None
            self.masks = None

        self.transform = transform
        self.image_ids = [
            id_from_fname(image_fname) +
            f' [{crop[0]};{crop[1]};{crop[2]};{crop[3]};]'
            for crop in self.slicer.crops
        ]

    def _get_image(self, index):
        if self.images is None:
            image = self.image_loader(self.image_fname)
            image = self.slicer.cut_patch(image, index)
        else:
            image = self.images[index]
        return image

    def _get_mask(self, index):
        if self.masks is None:
            mask = self.mask_loader(self.mask_fname)
            mask = self.slicer.cut_patch(mask, index)
        else:
            mask = self.masks[index]
        return mask

    def __len__(self):
        return len(self.slicer.crops)

    def __getitem__(self, index):
        image = self._get_image(index)
        mask = self._get_mask(index)
        data = self.transform(image=image, mask=mask)

        return {
            'features': tensor_from_rgb_image(data['image']),
            'targets': tensor_from_mask_image(data['mask']).float(),
            'image_id': self.image_ids[index]
        }
        input_x = args_experiment.crop_size[0]
        input_y = args_experiment.crop_size[1]
    except AttributeError:
        input_x = config['training']['crop_size'][0]
        input_y = config['training']['crop_size'][1]
    for file in tqdm(files, desc='Running inference'):

        img_full = cv2.imread(file)
        #img_full = np.flip(img_full, axis=0)

        x, y, ch = img_full.shape
        mask_full = np.zeros((x, y))

        # Cut large image into overlapping tiles
        tiler = ImageSlicer(img_full.shape,
                            tile_size=(input_x, input_y),
                            tile_step=(input_x // 2, input_y // 2),
                            weight=args.weight)

        # HCW -> CHW. Optionally, do normalization here
        tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(img_full)]

        # Allocate a CUDA buffer for holding entire mask
        merger = CudaTileMerger(tiler.target_shape,
                                channels=1,
                                weight=tiler.weight)

        # Loop evaluating inference on every fold
        masks = []
        for fold in range(len(models)):

            # Run predictions for tiles and accumulate them
Exemplo n.º 20
0
def predict_gradcam_mask(image,
                         model,
                         dims=3,
                         size=(150, 300),
                         step=(150, 300),
                         batch_size=8,
                         grad_thr=0.6,
                         weight_type='mean',
                         plot_image=False,
                         dstdir=None,
                         img_name='image1.png'):
    image = scale_img(image).astype(np.float32)

    if image.ndim == 2:
        image = np.expand_dims(image, 2)
    if image.shape[-1] != dims:
        if image.shape[-1] == 1:
            image = np.repeat(image, 3, axis=2)
        elif image.shape[-1] == 3:
            image = np.expand_dims(image[:, :, 0], 2)

    image = (image - np.min(image)) / (0.5 * np.ptp(image)) - 1

    # Cut large image into overlapping tiles
    tiler = ImageSlicer(image.shape,
                        tile_size=size,
                        tile_step=step,
                        weight=weight_type)

    # HCW -> CHW. Optionally, do normalization here
    tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image)]

    # Allocate a CUDA buffer for holding entire mask
    merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight)

    # Run predictions for tiles and accumulate them
    for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)),
                                                batch_size=batch_size,
                                                pin_memory=True):
        tiles_batch = tiles_batch.float().cuda()
        with torch.no_grad():
            pred_batch = torch.max(F.softmax(model(tiles_batch), dim=1),
                                   dim=1)[1].detach().cpu().numpy()
        image_needed_classes = pred_batch == 1
        masks = []
        for tile_idx, has_target in enumerate(image_needed_classes):
            if has_target:
                tile = tiles_batch[tile_idx].unsqueeze(0)
                heatmap, mask = show_gradcam(tile, model)
                mask = mask[:, :, 0]
                mask[mask < grad_thr] = 0
                masks.append(torch.Tensor(mask).unsqueeze(0).unsqueeze(0))
            else:
                masks.append(
                    torch.zeros_like(tiles_batch[tile_idx,
                                                 0]).unsqueeze(0).unsqueeze(0))
        masks = torch.cat([mask.cuda() for mask in masks], dim=0) * 1000

        merger.integrate_batch(masks, coords_batch)

    # Normalize accumulated mask and convert back to numpy
    merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8)
    merged_mask = tiler.crop_to_orignal_size(merged_mask) / 1000

    if plot_image:
        assert dstdir is not None, 'dstdir should be passed'
        fig, ax = plt.subplots(ncols=2, figsize=(20, 10))
        ax[0].imshow(image[:, :, 0], cmap='gray')
        ax[1].imshow(merged_mask[:, :, 0], alpha=0.3)
        fig.savefig(osp.join(dstdir, img_name),
                    bbox_inches='tight',
                    pad_inches=0)
        print(osp.join(dstdir, img_name))
    return merged_mask
Exemplo n.º 21
0
class _InrialTiledImageMaskDataset(Dataset):
    def __init__(
        self,
        image_fname: str,
        mask_fname: str,
        image_loader: Callable,
        target_loader: Callable,
        tile_size,
        tile_step,
        image_margin=0,
        transform=None,
        target_shape=None,
        need_weight_mask=False,
        keep_in_mem=False,
        make_mask_target_fn: Callable = mask_to_bce_target,
    ):
        self.image_fname = image_fname
        self.mask_fname = mask_fname
        self.image_loader = image_loader
        self.mask_loader = target_loader
        self.image = None
        self.mask = None
        self.need_weight_mask = need_weight_mask

        if target_shape is None or keep_in_mem:
            image = image_loader(image_fname)
            mask = target_loader(mask_fname)
            if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[
                    1]:
                raise ValueError(
                    f"Image size {image.shape} and mask shape {image.shape} must have equal width and height"
                )

            target_shape = image.shape

        self.slicer = ImageSlicer(target_shape, tile_size, tile_step,
                                  image_margin)

        self.transform = transform
        self.image_ids = [fs.id_from_fname(image_fname)] * len(
            self.slicer.crops)
        self.crop_coords_str = [
            f"[{crop[0]};{crop[1]};{crop[2]};{crop[3]};]"
            for crop in self.slicer.crops
        ]
        self.make_mask_target_fn = make_mask_target_fn

    def _get_image(self, index):
        image = self.image_loader(self.image_fname)
        image = self.slicer.cut_patch(image, index)
        return image

    def _get_mask(self, index):
        mask = self.mask_loader(self.mask_fname)
        mask = self.slicer.cut_patch(mask, index)
        return mask

    def __len__(self):
        return len(self.slicer.crops)

    def __getitem__(self, index):
        image = self._get_image(index)
        mask = self._get_mask(index)
        data = self.transform(image=image, mask=mask)

        image = data["image"]
        mask = data["mask"]

        data = {
            INPUT_IMAGE_KEY: image_to_tensor(image),
            INPUT_MASK_KEY: self.make_mask_target_fn(mask),
            INPUT_IMAGE_ID_KEY: self.image_ids[index],
            "crop_coords": self.crop_coords_str[index],
        }

        if self.need_weight_mask:
            data[INPUT_MASK_WEIGHT_KEY] = tensor_from_mask_image(
                compute_weight_mask(mask)).float()

        return data