def infer_one(model, mask, tile_size=(512, 512), tile_step=(256, 256), weight='mean'): image = mask.cpu().numpy() image = np.moveaxis(image, 0, -1) with torch.no_grad(): tiler = ImageSlicer((900, 900), tile_size=tile_size, tile_step=tile_step, weight=weight) tiles = [np.moveaxis(tile, -1, 0) for tile in tiler.split(image)] merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight) for tiles_batch, coords_batch in DataLoader(list( zip(tiles, tiler.crops)), batch_size=10, pin_memory=False): tiles_batch = tiles_batch.float().cuda() pred_batch = model(tiles_batch) tiles_batch.cpu().detach() merger.integrate_batch(pred_batch, coords_batch) merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1) merged_mask = tiler.crop_to_orignal_size(merged_mask) m = merged_mask[..., 0].copy() return m
def predict(model: nn.Module, image: np.ndarray, image_size, normalize=A.Normalize(), batch_size=1) -> np.ndarray: tile_step = (image_size[0] // 2, image_size[1] // 2) tile_slicer = ImageSlicer(image.shape, image_size, tile_step) tile_merger = CudaTileMerger(tile_slicer.target_shape, 1, tile_slicer.weight) patches = tile_slicer.split(image) transform = A.Compose([normalize, A.Lambda(image=_tensor_from_rgb_image)]) data = list({ "image": patch, "coords": np.array(coords, dtype=np.int) } for (patch, coords) in zip(patches, tile_slicer.crops)) for batch in DataLoader(InMemoryDataset(data, transform), pin_memory=True, batch_size=batch_size): image = batch["image"].cuda(non_blocking=True) coords = batch["coords"] mask_batch = model(image) tile_merger.integrate_batch(mask_batch, coords) mask = tile_merger.merge() mask = np.moveaxis(to_numpy(mask), 0, -1) mask = tile_slicer.crop_to_orignal_size(mask) return mask
def test_tiles_split_merge_cuda(): if not torch.cuda.is_available(): return class MaxChannelIntensity(nn.Module): def __init__(self): super().__init__() def forward(self, input): max_channel, _ = torch.max(input, dim=1, keepdim=True) return max_channel image = np.random.random((5000, 5000, 3)).astype(np.uint8) tiler = ImageSlicer(image.shape, tile_size=(512, 512), tile_step=(256, 256), weight='pyramid') tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image)] model = MaxChannelIntensity().eval().cuda() merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight) for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), batch_size=8, pin_memory=True): tiles_batch = tiles_batch.float().cuda() pred_batch = model(tiles_batch) merger.integrate_batch(pred_batch, coords_batch) merged = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8) merged = tiler.crop_to_orignal_size(merged) np.testing.assert_equal(merged, image.max(axis=2, keepdims=True))
def inference_tiles(inference_model, img_full, device='cuda', shape=(32, 1, 768, 448), weight='mean', mean=88.904434, std=62.048634, plot=False): bs = shape[0] input_x = shape[2] input_y = shape[3] # Cut large image into overlapping tiles tiler = ImageSlicer(img_full.shape, tile_size=(input_x, input_y), tile_step=(input_x // 2, input_y // 2), weight=weight) # HCW -> CHW. Optionally, do normalization here tiles = [ tensor_from_rgb_image(tile) for tile in tiler.split(cv2.cvtColor(img_full, cv2.COLOR_GRAY2RGB)) ] # Allocate a CUDA buffer for holding entire mask merger = CudaTileMerger(tiler.target_shape, channels=1, weight=tiler.weight) # Run predictions for tiles and accumulate them for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), batch_size=bs, pin_memory=True): # Move tile to GPU tiles_batch = ((tiles_batch.float() - mean) / std).to(device) # Predict pred_batch = inference_model(tiles_batch) # Merge on GPU merger.integrate_batch(pred_batch, coords_batch) if plot: for i in range(pred_batch.to('cpu').numpy().shape[0]): plt.imshow(tiles_batch.to('cpu').numpy()[i, 0, :, :]) plt.show() plt.imshow(pred_batch.to('cpu').numpy()[i, 0, :, :]) plt.colorbar() plt.show() # Normalize accumulated mask and convert back to numpy merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype('float32') merged_mask = tiler.crop_to_orignal_size(merged_mask) torch.cuda.empty_cache() return merged_mask.squeeze()
def test_tiles_split_merge_non_dividable(): image = np.random.random((563, 512, 3)).astype(np.uint8) tiler = ImageSlicer( image.shape, tile_size=(128, 128), tile_step=(128, 128), weight="mean" ) tiles = tiler.split(image) merged = tiler.merge(tiles, dtype=np.uint8) np.testing.assert_equal(merged, image)
def predict_mask(image, model, dims=3, size=394, step=192, batch_size=8, plot_image=False, dstdir=None, img_name='image1.png'): if image.ndim == 2: image = np.expand_dims(image, 2) if image.shape[-1] != dims: if image.shape[-1] == 1: image = np.repeat(image, 3, axis=2) elif image.shape[-1] == 3: image = np.expand_dims(image[:, :, 0], 2) print(image.shape) # Cut large image into overlapping tiles tiler = ImageSlicer(image.shape, tile_size=(size, size), tile_step=(step, step), weight='pyramid') # HCW -> CHW. Optionally, do normalization here tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image)] # Allocate a CUDA buffer for holding entire mask merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight) # Run predictions for tiles and accumulate them with torch.no_grad(): for tiles_batch, coords_batch in DataLoader(list( zip(tiles, tiler.crops)), batch_size=batch_size, pin_memory=True): # print(tiles_batch.shape) tiles_batch = tiles_batch.float().cuda() pred_batch = model(tiles_batch) pred_mask = pred_batch.max(dim=1)[1].float() merger.integrate_batch(pred_mask, coords_batch) # Normalize accumulated mask and convert back to numpy merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8) merged_mask = tiler.crop_to_orignal_size(merged_mask) if plot_image: assert dstdir is not None, 'dstdir should be passed' fig, ax = plt.subplots(ncols=2, figsize=(20, 10)) ax[0].imshow(image[:, :, 0], cmap='gray') ax[1].imshow(merged_mask[:, :, 0], alpha=0.3) fig.savefig(osp.join(dstdir, img_name), bbox_inches='tight', pad_inches=0) print(osp.join(dstdir, img_name)) return merged_mask
def test_tiles_split_merge(): image = np.random.random((5000, 5000, 3)).astype(np.uint8) tiler = ImageSlicer(image.shape, tile_size=512, tile_step=256, weight='mean') tiles = tiler.split(image) merged = tiler.merge(tiles, dtype=np.uint8) np.testing.assert_equal(merged, image)
def test_tiles_split_merge_2(): image = np.random.random((5000, 5000, 3)).astype(np.uint8) tiler = ImageSlicer( image.shape, tile_size=(512, 512), tile_step=(256, 256), weight="pyramid" ) np.testing.assert_allclose(tiler.weight, tiler.weight.T) tiles = tiler.split(image) merged = tiler.merge(tiles, dtype=np.uint8) np.testing.assert_equal(merged, image)
def predict(model: nn.Module, image: np.ndarray, image_size, tta=None, normalize=A.Normalize(), batch_size=1, activation='sigmoid') -> np.ndarray: model.eval() tile_step = (image_size[0] // 2, image_size[1] // 2) tile_slicer = ImageSlicer(image.shape, image_size, tile_step, weight='pyramid') tile_merger = CudaTileMerger(tile_slicer.target_shape, 1, tile_slicer.weight) patches = tile_slicer.split(image) transform = A.Compose([normalize, A.Lambda(image=_tensor_from_rgb_image)]) if tta == 'fliplr': model = TTAWrapperFlipLR(model) print('Using FlipLR TTA') if tta == 'd4': model = TTAWrapperD4(model) print('Using D4 TTA') with torch.no_grad(): data = list({ 'image': patch, 'coords': np.array(coords, dtype=np.int) } for (patch, coords) in zip(patches, tile_slicer.crops)) for batch in DataLoader(InMemoryDataset(data, transform), pin_memory=True, batch_size=batch_size): image = batch['image'].cuda(non_blocking=True) coords = batch['coords'] mask_batch = model(image) tile_merger.integrate_batch(mask_batch, coords) mask = tile_merger.merge() if activation == 'sigmoid': mask = mask.sigmoid() if isinstance(activation, float): mask = F.relu(mask_batch - activation, inplace=True) mask = np.moveaxis(to_numpy(mask), 0, -1) mask = tile_slicer.crop_to_orignal_size(mask) return mask
def inference(inference_model, img_full, device='cuda'): x, y, ch = img_full.shape input_x = config['training']['crop_size'][0] input_y = config['training']['crop_size'][1] # Cut large image into overlapping tiles tiler = ImageSlicer(img_full.shape, tile_size=(input_x, input_y), tile_step=(input_x // 2, input_y // 2), weight=args.weight) # HCW -> CHW. Optionally, do normalization here tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(img_full)] # Allocate a CUDA buffer for holding entire mask merger = CudaTileMerger(tiler.target_shape, channels=1, weight=tiler.weight) # Run predictions for tiles and accumulate them for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), batch_size=args.bs, pin_memory=True): # Move tile to GPU tiles_batch = (tiles_batch.float() / 255.).to(device) # Predict and move back to CPU pred_batch = inference_model(tiles_batch) # Merge on GPU merger.integrate_batch(pred_batch, coords_batch) # Plot if args.plot: for i in range(args.bs): if args.bs != 1: plt.imshow(pred_batch.cpu().detach().numpy().astype('float32').squeeze()[i, :, :]) else: plt.imshow(pred_batch.cpu().detach().numpy().astype('float32').squeeze()) plt.show() # Normalize accumulated mask and convert back to numpy merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype('float32') merged_mask = tiler.crop_to_orignal_size(merged_mask) # Plot if args.plot: for i in range(args.bs): if args.bs != 1: plt.imshow(merged_mask) else: plt.imshow(merged_mask.squeeze()) plt.show() torch.cuda.empty_cache() gc.collect() return merged_mask.squeeze()
def get_slice_heatmap(slide_id, image_dir, model, crop_size=512, tile_step=512, bsize=16, num_workers=4, normalize=True, *args, **kwargs): image = openslide.OpenSlide(os.path.join(image_dir, f'{slide_id}.tiff')) x, y = image.level_dimensions[0][::-1] tiler = ImageSlicer((x, y), tile_size=(crop_size, crop_size), tile_step=(tile_step, tile_step), weight='mean') merger = CpuTileMerger(tiler.target_shape, 1, tiler.weight, normalize=normalize) dataset = InferenceSingleImage(tiler.crops, slide_id, image_dir, crop_size=crop_size, **kwargs) dataloader = DataLoader(dataset, batch_size=bsize, pin_memory=True, num_workers=num_workers) pseudo_mask = np.ones((1, crop_size, crop_size)) with torch.no_grad(): for data_b in tqdm(dataloader, total=len(dataloader)): need_processing = data_b['need_to_process'] if need_processing.sum() > 0: pred_batch = model(data_b['image'][need_processing].cuda()) pred_batch = torch.nn.Softmax(dim=1)(pred_batch).cpu().numpy() pred_batch = np.stack([ pseudo_mask * np.argmax(pred_batch[idx]) for idx in range(pred_batch.shape[0]) ]) merger.integrate_batch( pred_batch, data_b['coords'][need_processing].cpu().numpy()) merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8) merged_mask = tiler.crop_to_orignal_size(merged_mask) merged_mask = cv2.resize(merged_mask, image.level_dimensions[-1], interpolation=cv2.INTER_AREA) return (merged_mask)
def split_image(image_fname, output_dir, tile_size, tile_step, image_margin): os.makedirs(output_dir, exist_ok=True) image = read_image_as_is(image_fname) image_id = id_from_fname(image_fname) slicer = ImageSlicer(image.shape, tile_size, tile_step, image_margin) tiles = slicer.split(image) fnames = [] for i, tile in enumerate(tiles): output_fname = os.path.join(output_dir, f"{image_id}_tile_{i}.png") cv2.imwrite(output_fname, tile) fnames.append(output_fname) return fnames
def __init__( self, image_fname: str, mask_fname: str, image_loader: Callable, target_loader: Callable, tile_size, tile_step, image_margin=0, transform=None, target_shape=None, keep_in_mem=False, ): self.image_fname = image_fname self.mask_fname = mask_fname self.image_loader = image_loader self.mask_loader = target_loader self.image = None self.mask = None if target_shape is None or keep_in_mem: image = image_loader(image_fname) mask = target_loader(mask_fname) if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[ 1]: raise ValueError( f"Image size {image.shape} and mask shape {image.shape} must have equal width and height" ) target_shape = image.shape self.slicer = ImageSlicer(target_shape, tile_size, tile_step, image_margin) if keep_in_mem: self.images = self.slicer.split(image) self.masks = self.slicer.split(mask) else: self.images = None self.masks = None self.transform = transform self.image_ids = [ id_from_fname(image_fname) + f" [{crop[0]};{crop[1]};{crop[2]};{crop[3]};]" for crop in self.slicer.crops ]
def __init__( self, image_fname: str, mask_fname: str, image_loader: Callable, target_loader: Callable, tile_size, tile_step, image_margin=0, transform=None, target_shape=None, need_weight_mask=False, keep_in_mem=False, make_mask_target_fn: Callable = mask_to_bce_target, ): self.image_fname = image_fname self.mask_fname = mask_fname self.image_loader = image_loader self.mask_loader = target_loader self.image = None self.mask = None self.need_weight_mask = need_weight_mask if target_shape is None or keep_in_mem: image = image_loader(image_fname) mask = target_loader(mask_fname) if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[ 1]: raise ValueError( f"Image size {image.shape} and mask shape {image.shape} must have equal width and height" ) target_shape = image.shape self.slicer = ImageSlicer(target_shape, tile_size, tile_step, image_margin) self.transform = transform self.image_ids = [fs.id_from_fname(image_fname)] * len( self.slicer.crops) self.crop_coords_str = [ f"[{crop[0]};{crop[1]};{crop[2]};{crop[3]};]" for crop in self.slicer.crops ] self.make_mask_target_fn = make_mask_target_fn
def run_validation(data_df, model, data_folder, augmentation, tiles=False): total_dice_coeffs = [] mean_dice_per_image = [] for image_n in tqdm(range(data_df.shape[0])): image = cv2.imread( os.path.join(data_folder, data_df.index.values[image_n])) augmented = augmentation(image=image) image_processed = augmented['image'] if tiles: tiler = ImageSlicer(image_processed.shape[:2], tile_size=(224, 224), tile_step=(56, 56), weight='mean') merger = CudaTileMerger(tiler.target_shape, 4, tiler.weight) tiles = [ tensor_from_rgb_image(tile) for tile in tiler.split(image_processed) ] for tiles_batch, coords_batch in DataLoader(list( zip(tiles, tiler.crops)), batch_size=16, pin_memory=True): tiles_batch = tiles_batch.float().cuda() pred_batch = torch.nn.Sigmoid()(model(tiles_batch)) merger.integrate_batch(pred_batch, coords_batch) predictions = np.moveaxis(to_numpy(merger.merge()), 0, -1) predictions = tiler.crop_to_orignal_size(predictions) else: image_processed = torch.from_numpy( np.expand_dims(image_processed.transpose((2, 0, 1)), 0)).float() predictions = torch.nn.Sigmoid()(model( image_processed.cuda())[0]).detach().cpu().numpy() predictions = np.moveaxis(predictions, 0, -1) predictions_bin = (predictions > 0.5).astype(int) fname, masks = make_mask(image_n, data_df) dices_image = [] for defect_type in range(4): computed_dice = dice(masks[:, :, defect_type], predictions_bin[:, :, defect_type]) total_dice_coeffs.append(computed_dice) dices_image.append(computed_dice) mean_dice_per_image.append(np.mean(dices_image)) return np.mean(total_dice_coeffs), mean_dice_per_image
def predict_on_zslice_tiles(model, zimage, tile_size=(512, 512), tile_step=(256, 256)): image = zimage[0, 0, :, :] print(f'Stack shape:{zimage.shape}') print(f'Slice shape:{image.shape}') # Cut large image into overlapping tiles tiler = ImageSlicer(image.shape, tile_size=(512, 512), tile_step=(256, 256)) print(tiler.crops) # HCW -> CHW. Optionally, do normalization here tiles = [tensor_from_mask_image(tile) for tile in tiler.split(image)] # Allocate a CUDA buffer for holding entire mask merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight) # Run predictions for tiles and accumulate them for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), batch_size=1, pin_memory=True): # for x, y, tile_width, tile_height in coords_batch: # tile = image[y : y + tile_height, x : x + tile_width].copy() tiles_batch = tiles_batch.float().cuda() pred_batch = model(tiles_batch) merger.integrate_batch(pred_batch, coords_batch) # Normalize accumulated mask and convert back to numpy # merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8) merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1) merged_mask = tiler.crop_to_orignal_size(merged_mask) return merged_mask
def test_tiles_split_merge_non_dividable_cuda(): image = np.random.random((5632, 5120, 3)).astype(np.uint8) tiler = ImageSlicer(image.shape, tile_size=(1280, 1280), tile_step=(1280, 1280), weight='mean') tiles = tiler.split(image) merger = CudaTileMerger(tiler.target_shape, channels=image.shape[2], weight=tiler.weight) for tile, coordinates in zip(tiles, tiler.crops): # Integrate as batch of size 1 merger.integrate_batch( tensor_from_rgb_image(tile).unsqueeze(0).float().cuda(), [coordinates]) merged = merger.merge() merged = rgb_image_from_tensor(merged, mean=0, std=1, max_pixel_value=1) merged = tiler.crop_to_orignal_size(merged) np.testing.assert_equal(merged, image)
class TiledSingleImageDataset(Dataset): def __init__(self, image_fname: str, mask_fname: str, image_loader: Callable, target_loader: Callable, tile_size, tile_step, image_margin=0, transform=None, target_shape=None, keep_in_mem=False): self.image_fname = image_fname self.mask_fname = mask_fname self.image_loader = image_loader self.mask_loader = target_loader self.image = None self.mask = None if target_shape is None or keep_in_mem: image = image_loader(image_fname) mask = target_loader(mask_fname) if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[ 1]: raise ValueError( f"Image size {image.shape} and mask shape {image.shape} must have equal width and height" ) target_shape = image.shape self.slicer = ImageSlicer(target_shape, tile_size, tile_step, image_margin) if keep_in_mem: self.images = self.slicer.split(image) self.masks = self.slicer.split(mask) else: self.images = None self.masks = None self.transform = transform self.image_ids = [ id_from_fname(image_fname) + f' [{crop[0]};{crop[1]};{crop[2]};{crop[3]};]' for crop in self.slicer.crops ] def _get_image(self, index): if self.images is None: image = self.image_loader(self.image_fname) image = self.slicer.cut_patch(image, index) else: image = self.images[index] return image def _get_mask(self, index): if self.masks is None: mask = self.mask_loader(self.mask_fname) mask = self.slicer.cut_patch(mask, index) else: mask = self.masks[index] return mask def __len__(self): return len(self.slicer.crops) def __getitem__(self, index): image = self._get_image(index) mask = self._get_mask(index) data = self.transform(image=image, mask=mask) return { 'features': tensor_from_rgb_image(data['image']), 'targets': tensor_from_mask_image(data['mask']).float(), 'image_id': self.image_ids[index] }
input_x = args_experiment.crop_size[0] input_y = args_experiment.crop_size[1] except AttributeError: input_x = config['training']['crop_size'][0] input_y = config['training']['crop_size'][1] for file in tqdm(files, desc='Running inference'): img_full = cv2.imread(file) #img_full = np.flip(img_full, axis=0) x, y, ch = img_full.shape mask_full = np.zeros((x, y)) # Cut large image into overlapping tiles tiler = ImageSlicer(img_full.shape, tile_size=(input_x, input_y), tile_step=(input_x // 2, input_y // 2), weight=args.weight) # HCW -> CHW. Optionally, do normalization here tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(img_full)] # Allocate a CUDA buffer for holding entire mask merger = CudaTileMerger(tiler.target_shape, channels=1, weight=tiler.weight) # Loop evaluating inference on every fold masks = [] for fold in range(len(models)): # Run predictions for tiles and accumulate them
def predict_gradcam_mask(image, model, dims=3, size=(150, 300), step=(150, 300), batch_size=8, grad_thr=0.6, weight_type='mean', plot_image=False, dstdir=None, img_name='image1.png'): image = scale_img(image).astype(np.float32) if image.ndim == 2: image = np.expand_dims(image, 2) if image.shape[-1] != dims: if image.shape[-1] == 1: image = np.repeat(image, 3, axis=2) elif image.shape[-1] == 3: image = np.expand_dims(image[:, :, 0], 2) image = (image - np.min(image)) / (0.5 * np.ptp(image)) - 1 # Cut large image into overlapping tiles tiler = ImageSlicer(image.shape, tile_size=size, tile_step=step, weight=weight_type) # HCW -> CHW. Optionally, do normalization here tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image)] # Allocate a CUDA buffer for holding entire mask merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight) # Run predictions for tiles and accumulate them for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), batch_size=batch_size, pin_memory=True): tiles_batch = tiles_batch.float().cuda() with torch.no_grad(): pred_batch = torch.max(F.softmax(model(tiles_batch), dim=1), dim=1)[1].detach().cpu().numpy() image_needed_classes = pred_batch == 1 masks = [] for tile_idx, has_target in enumerate(image_needed_classes): if has_target: tile = tiles_batch[tile_idx].unsqueeze(0) heatmap, mask = show_gradcam(tile, model) mask = mask[:, :, 0] mask[mask < grad_thr] = 0 masks.append(torch.Tensor(mask).unsqueeze(0).unsqueeze(0)) else: masks.append( torch.zeros_like(tiles_batch[tile_idx, 0]).unsqueeze(0).unsqueeze(0)) masks = torch.cat([mask.cuda() for mask in masks], dim=0) * 1000 merger.integrate_batch(masks, coords_batch) # Normalize accumulated mask and convert back to numpy merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8) merged_mask = tiler.crop_to_orignal_size(merged_mask) / 1000 if plot_image: assert dstdir is not None, 'dstdir should be passed' fig, ax = plt.subplots(ncols=2, figsize=(20, 10)) ax[0].imshow(image[:, :, 0], cmap='gray') ax[1].imshow(merged_mask[:, :, 0], alpha=0.3) fig.savefig(osp.join(dstdir, img_name), bbox_inches='tight', pad_inches=0) print(osp.join(dstdir, img_name)) return merged_mask
class _InrialTiledImageMaskDataset(Dataset): def __init__( self, image_fname: str, mask_fname: str, image_loader: Callable, target_loader: Callable, tile_size, tile_step, image_margin=0, transform=None, target_shape=None, need_weight_mask=False, keep_in_mem=False, make_mask_target_fn: Callable = mask_to_bce_target, ): self.image_fname = image_fname self.mask_fname = mask_fname self.image_loader = image_loader self.mask_loader = target_loader self.image = None self.mask = None self.need_weight_mask = need_weight_mask if target_shape is None or keep_in_mem: image = image_loader(image_fname) mask = target_loader(mask_fname) if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[ 1]: raise ValueError( f"Image size {image.shape} and mask shape {image.shape} must have equal width and height" ) target_shape = image.shape self.slicer = ImageSlicer(target_shape, tile_size, tile_step, image_margin) self.transform = transform self.image_ids = [fs.id_from_fname(image_fname)] * len( self.slicer.crops) self.crop_coords_str = [ f"[{crop[0]};{crop[1]};{crop[2]};{crop[3]};]" for crop in self.slicer.crops ] self.make_mask_target_fn = make_mask_target_fn def _get_image(self, index): image = self.image_loader(self.image_fname) image = self.slicer.cut_patch(image, index) return image def _get_mask(self, index): mask = self.mask_loader(self.mask_fname) mask = self.slicer.cut_patch(mask, index) return mask def __len__(self): return len(self.slicer.crops) def __getitem__(self, index): image = self._get_image(index) mask = self._get_mask(index) data = self.transform(image=image, mask=mask) image = data["image"] mask = data["mask"] data = { INPUT_IMAGE_KEY: image_to_tensor(image), INPUT_MASK_KEY: self.make_mask_target_fn(mask), INPUT_IMAGE_ID_KEY: self.image_ids[index], "crop_coords": self.crop_coords_str[index], } if self.need_weight_mask: data[INPUT_MASK_WEIGHT_KEY] = tensor_from_mask_image( compute_weight_mask(mask)).float() return data