class TiledSingleImageDataset(Dataset):
    def __init__(self,
                 image_fname: str,
                 mask_fname: str,
                 image_loader: Callable,
                 target_loader: Callable,
                 tile_size,
                 tile_step,
                 image_margin=0,
                 transform=None,
                 target_shape=None,
                 keep_in_mem=False):
        self.image_fname = image_fname
        self.mask_fname = mask_fname
        self.image_loader = image_loader
        self.mask_loader = target_loader
        self.image = None
        self.mask = None

        if target_shape is None or keep_in_mem:
            image = image_loader(image_fname)
            mask = target_loader(mask_fname)
            if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[
                    1]:
                raise ValueError(
                    f"Image size {image.shape} and mask shape {image.shape} must have equal width and height"
                )

            target_shape = image.shape

        self.slicer = ImageSlicer(target_shape, tile_size, tile_step,
                                  image_margin)

        if keep_in_mem:
            self.images = self.slicer.split(image)
            self.masks = self.slicer.split(mask)
        else:
            self.images = None
            self.masks = None

        self.transform = transform
        self.image_ids = [
            id_from_fname(image_fname) +
            f' [{crop[0]};{crop[1]};{crop[2]};{crop[3]};]'
            for crop in self.slicer.crops
        ]

    def _get_image(self, index):
        if self.images is None:
            image = self.image_loader(self.image_fname)
            image = self.slicer.cut_patch(image, index)
        else:
            image = self.images[index]
        return image

    def _get_mask(self, index):
        if self.masks is None:
            mask = self.mask_loader(self.mask_fname)
            mask = self.slicer.cut_patch(mask, index)
        else:
            mask = self.masks[index]
        return mask

    def __len__(self):
        return len(self.slicer.crops)

    def __getitem__(self, index):
        image = self._get_image(index)
        mask = self._get_mask(index)
        data = self.transform(image=image, mask=mask)

        return {
            'features': tensor_from_rgb_image(data['image']),
            'targets': tensor_from_mask_image(data['mask']).float(),
            'image_id': self.image_ids[index]
        }
Beispiel #2
0
class _InrialTiledImageMaskDataset(Dataset):
    def __init__(
        self,
        image_fname: str,
        mask_fname: str,
        image_loader: Callable,
        target_loader: Callable,
        tile_size,
        tile_step,
        image_margin=0,
        transform=None,
        target_shape=None,
        need_weight_mask=False,
        keep_in_mem=False,
        make_mask_target_fn: Callable = mask_to_bce_target,
    ):
        self.image_fname = image_fname
        self.mask_fname = mask_fname
        self.image_loader = image_loader
        self.mask_loader = target_loader
        self.image = None
        self.mask = None
        self.need_weight_mask = need_weight_mask

        if target_shape is None or keep_in_mem:
            image = image_loader(image_fname)
            mask = target_loader(mask_fname)
            if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[
                    1]:
                raise ValueError(
                    f"Image size {image.shape} and mask shape {image.shape} must have equal width and height"
                )

            target_shape = image.shape

        self.slicer = ImageSlicer(target_shape, tile_size, tile_step,
                                  image_margin)

        self.transform = transform
        self.image_ids = [fs.id_from_fname(image_fname)] * len(
            self.slicer.crops)
        self.crop_coords_str = [
            f"[{crop[0]};{crop[1]};{crop[2]};{crop[3]};]"
            for crop in self.slicer.crops
        ]
        self.make_mask_target_fn = make_mask_target_fn

    def _get_image(self, index):
        image = self.image_loader(self.image_fname)
        image = self.slicer.cut_patch(image, index)
        return image

    def _get_mask(self, index):
        mask = self.mask_loader(self.mask_fname)
        mask = self.slicer.cut_patch(mask, index)
        return mask

    def __len__(self):
        return len(self.slicer.crops)

    def __getitem__(self, index):
        image = self._get_image(index)
        mask = self._get_mask(index)
        data = self.transform(image=image, mask=mask)

        image = data["image"]
        mask = data["mask"]

        data = {
            INPUT_IMAGE_KEY: image_to_tensor(image),
            INPUT_MASK_KEY: self.make_mask_target_fn(mask),
            INPUT_IMAGE_ID_KEY: self.image_ids[index],
            "crop_coords": self.crop_coords_str[index],
        }

        if self.need_weight_mask:
            data[INPUT_MASK_WEIGHT_KEY] = tensor_from_mask_image(
                compute_weight_mask(mask)).float()

        return data