Exemplo n.º 1
0
    def __init__(self,
                 image_fname,
                 mask_fname,
                 tile_size,
                 tile_step=0,
                 image_margin=0,
                 transform=None,
                 target_shape=None,
                 keep_in_mem=False):
        self.image_fname = image_fname
        self.mask_fname = mask_fname

        self.image = None
        self.mask = None

        if target_shape is None or keep_in_mem:
            image = read_rgb(image_fname)
            mask = read_mask(mask_fname)
            if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[
                    1]:
                raise ValueError()

            target_shape = image.shape

            if keep_in_mem:
                self.image = image
                self.mask = mask

        if tile_step <= 0:
            tile_step = tile_size // 2

        self.slicer = ImageSlicer(target_shape, tile_size, tile_step,
                                  image_margin)
        self.transform = transform
Exemplo n.º 2
0
def main():
    dataset_dir = 'e:\\datasets\\inria\\train\\'
    output_dir = 'e:\\datasets\\inria\\train_512'

    os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'gt'), exist_ok=True)
    images = find_in_dir(os.path.join(dataset_dir, 'images'))
    targets = find_in_dir(os.path.join(dataset_dir, 'gt'))

    for x, y in tqdm(zip(images, targets), total=len(images)):

        image_name = splitext(os.path.basename(x))[0]
        mask_name = splitext(os.path.basename(y))[0]

        x = read_rgb(x)
        y = read_gray(y)

        slicer = ImageSlicer(x.shape, 512, 256)
        xs = slicer.split(x)
        ys = slicer.split(y)

        for i, patch in enumerate(xs):
            cv2.imwrite(
                os.path.join(output_dir, 'images',
                             '%s_%d.tif' % (image_name, i)), patch)

        for i, patch in enumerate(ys):
            cv2.imwrite(
                os.path.join(output_dir, 'gt', '%s_%d.tif' % (mask_name, i)),
                patch)
Exemplo n.º 3
0
    def predict(self, image):
        import albumentations as A
        self.eval()

        normalize = A.Normalize()
        image = normalize(image=image)['image']

        slicer = ImageSlicer(image.shape, 512, 512 // 2)
        patches = [
            tensor_from_rgb_image(patch)
            for patch in slicer.split(image, borderType=cv2.BORDER_CONSTANT)
        ]
        offsets = torch.tensor([[crop[0], crop[1], crop[0], crop[1]]
                                for crop in slicer.bbox_crops],
                               dtype=torch.float32)

        all_bboxes = []
        all_labels = []

        with torch.set_grad_enabled(False):
            for patch, patch_loc in DataLoader(list(zip(patches, offsets)),
                                               batch_size=8,
                                               pin_memory=True):
                patch = patch.to(self.fpn.conv1.weight.device)
                bboxes, labels = self(patch)

                all_bboxes.extend(bboxes.cpu())
                all_labels.extend(labels.cpu())

        boxes, labels, scores = self.box_coder.decode_multi(
            all_bboxes, all_labels, offsets)
        return to_numpy(boxes), to_numpy(labels), to_numpy(scores)
def get_dataset(dataset_name, dataset_dir, grayscale, patch_size):
    dataset_name = dataset_name.lower()

    if dataset_name == 'dsb2018':
        images = find_in_dir(os.path.join(dataset_dir, dataset_name, 'images'))
        images = [
            cv2.imread(fname,
                       cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR)
            for fname in images
        ]
        images = [normalize_image(i) for i in images]
        if grayscale:
            images = [np.expand_dims(m, axis=-1) for m in images]

        masks = find_in_dir(os.path.join(dataset_dir, dataset_name, 'masks'))
        masks = [cv2.imread(fname, cv2.IMREAD_GRAYSCALE) for fname in masks]
        masks = [np.expand_dims(m, axis=-1) for m in masks]
        masks = [np.float32(m > 0) for m in masks]

        patch_images = []
        patch_masks = []
        for image, mask in zip(images, masks):
            slicer = ImageSlicer(image.shape, patch_size, patch_size // 2)

            patch_images.extend(slicer.split(image))
            patch_masks.extend(slicer.split(mask))

        return np.array(patch_images), np.array(patch_masks)

    raise ValueError(dataset_name)
Exemplo n.º 5
0
def predict_tiled(image, model, test_transform, patch_size, batch_size):
    image, _ = test_transform(image)

    slicer = ImageSlicer(image.shape,
                         patch_size,
                         patch_size // 2,
                         weight='pyramid')
    patches = slicer.split(image)

    patches = aug.tta_d4_aug(patches)
    testset = InMemoryDataset(patches, None)
    trainloader = DataLoader(testset,
                             batch_size=batch_size,
                             shuffle=False,
                             pin_memory=True,
                             drop_last=False)

    patches_pred = []
    for batch_index, x in enumerate(trainloader):
        x = x.cuda(non_blocking=True)
        y = model(x)
        y = torch.sigmoid(y).cpu().numpy()
        y = np.moveaxis(y, 1, -1)
        patches_pred.extend(y)

    patches_pred = aug.tta_d4_deaug(patches_pred)
    mask = slicer.merge(patches_pred, dtype=np.float32)
    return mask
Exemplo n.º 6
0
def cut_dataset_in_patches(data_dir, output_dir, patch_size):
    x = sorted(find_in_dir(os.path.join(data_dir, 'images')))
    y = sorted(find_in_dir(os.path.join(data_dir, 'gt')))

    out_img = os.path.join(output_dir, 'images')
    out_msk = os.path.join(output_dir, 'gt')
    os.makedirs(out_img, exist_ok=True)
    os.makedirs(out_msk, exist_ok=True)

    slicer = ImageSlicer((5000, 5000), patch_size, patch_size // 2)

    for image_fname, mask_fname in tqdm(zip(x, y), total=len(x)):
        image = read_rgb(image_fname)
        mask = read_mask(mask_fname)

        basename = os.path.basename(image_fname)
        basename = os.path.splitext(basename)[0]

        for index, patch in enumerate(slicer.split(image)):
            cv2.imwrite(os.path.join(out_img, '%s_%d.tif' % (basename, index)),
                        patch)

        for index, patch in enumerate(slicer.split(mask)):
            cv2.imwrite(os.path.join(out_msk, '%s_%d.tif' % (basename, index)),
                        patch)
Exemplo n.º 7
0
class TiledImageDataset(Dataset):
    def __init__(self,
                 image_fname,
                 mask_fname,
                 tile_size,
                 tile_step=0,
                 image_margin=0,
                 transform=None,
                 target_shape=None,
                 keep_in_mem=False):
        self.image_fname = image_fname
        self.mask_fname = mask_fname

        self.image = None
        self.mask = None

        if target_shape is None or keep_in_mem:
            image = read_rgb(image_fname)
            mask = read_mask(mask_fname)
            if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[
                    1]:
                raise ValueError()

            target_shape = image.shape

            if keep_in_mem:
                self.image = image
                self.mask = mask

        if tile_step <= 0:
            tile_step = tile_size // 2

        self.slicer = ImageSlicer(target_shape, tile_size, tile_step,
                                  image_margin)
        self.transform = transform

    def __len__(self):
        return len(self.slicer.crops)

    def __getitem__(self, index):
        image = self.image if self.image is not None else read_rgb(
            self.image_fname)
        mask = self.mask if self.mask is not None else read_mask(
            self.mask_fname)

        image = self.slicer.cut_patch(image, index).copy()
        mask = self.slicer.cut_patch(mask, index).copy()

        if self.transform is not None:
            image, mask = self.transform(image, mask)

        image = torch.from_numpy(np.moveaxis(image, -1, 0).copy()).float()
        mask = torch.from_numpy(np.expand_dims(mask, 0)).long()
        return image, mask
Exemplo n.º 8
0
def DSB2018Sliced(dataset_dir, grayscale, patch_size):
    """
    Returns train & test dataset or DSB2018
    :param dataset_dir:
    :param grayscale:
    :param patch_size:
    :return:
    """

    images = [read_rgb(x) for x in find_in_dir(os.path.join(dataset_dir, 'images'))]
    masks = [read_mask(x) for x in find_in_dir(os.path.join(dataset_dir, 'masks'))]

    image_ids = []
    patch_images = []
    patch_masks = []

    for image_id, (image, mask) in enumerate(zip(images, masks)):
        slicer = ImageSlicer(image.shape, patch_size, patch_size // 2)

        patch_images.extend(slicer.split(image))
        patch_masks.extend(slicer.split(mask))
        image_ids.extend([image_id] * len(slicer.crops))

    x_train, x_test, y_train, y_test = train_test_split(patch_images, patch_masks, random_state=1234, test_size=0.1, stratify=image_ids)

    train_transform = aug.Sequential([
        # aug.ImageOnly(aug.RandomGrayscale()),
        # aug.ImageOnly(aug.RandomInvert()),
        aug.ImageOnly(aug.NormalizeImage()),
        # aug.ImageOnly(aug.RandomBrightness()),
        # aug.ImageOnly(aug.RandomContrast()),
        aug.RandomRotate90(),
        aug.VerticalFlip(),
        aug.HorizontalFlip(),
        aug.ShiftScaleRotate(rotate_limit=15),
        aug.MaskOnly(aug.MakeBinary())
    ])

    test_transform = aug.Sequential([
        aug.ImageOnly(aug.NormalizeImage()),
        aug.MaskOnly(aug.MakeBinary())
    ])

    train = InMemoryDataset(x_train, y_train, transform=train_transform)
    test = InMemoryDataset(x_test, y_test, transform=test_transform)
    num_classes = 1
    return train, test, num_classes