コード例 #1
0
ファイル: loader.py プロジェクト: chicm/salt
    def aug_image(self, img, mask=None):
        if mask is not None:
            Xi, Mi = from_pil(img, mask)
            #print('>>>', Xi.shape, Mi.shape)
            #print(Mi)
            Xi, Mi = self.augment_with_target(Xi, Mi)
            if self.image_augment is not None:
                Xi = self.image_augment(Xi)
            Xi, Mi = to_pil(Xi, Mi)

            if self.mask_transform is not None:
                Mi = self.mask_transform(Mi)

            if self.image_transform is not None:
                Xi = self.image_transform(Xi)

            return Xi, Mi  #torch.cat(Mi, dim=0)
        else:
            Xi = from_pil(img)
            Xi = self.image_augment(Xi)
            Xi = to_pil(Xi)

            if self.image_transform is not None:
                Xi = self.image_transform(Xi)
            return Xi
コード例 #2
0
ファイル: augmentation.py プロジェクト: chicm/salt
def test_augment():
    img = os.path.join(settings.TRAIN_IMG_DIR, '003c477d7c.png')
    mask = os.path.join(settings.TRAIN_MASK_DIR, '003c477d7c.png')
    img = Image.open(img)
    img = img.convert('RGB')
    mask = Image.open(mask)
    mask = mask.convert('L').point(lambda x: 0 if x < 128 else 255, '1')
    print(type(mask))
    mask = from_pil(mask)
    Mi = [to_pil(mask == class_nr) for class_nr in [0, 1]]
    img, *Mi = from_pil(img, *Mi)

    aug = ImgAug(
        crop_seq(crop_size=(settings.H, settings.W),
                 pad_size=(32, 32),
                 pad_method='reflect'))
    aug2 = ImgAug(brightness_seq)
    img, *Mi = aug(img, *Mi)
    img = aug2(img)

    img, *Mi = to_pil(img, Mi[0] * 255, Mi[1] * 255)
    ImageDraw.Draw(img)
    ImageDraw.Draw(Mi[0])
    ImageDraw.Draw(Mi[1])
    img.show()
    Mi[0].show()
    Mi[1].show()
コード例 #3
0
    def __getitem__(self, index):
        if self.image_source == 'memory':
            load_func = self.load_from_memory
        elif self.image_source == 'disk':
            load_func = self.load_from_disk
        else:
            raise NotImplementedError(
                "Possible loading options: 'memory' and 'disk'!")

        Xi = load_func(self.X, index, filetype='png', grayscale=False)

        if self.y is not None:
            Mi = self.load_target(self.y, index, load_func)
            Xi, *Mi = from_pil(Xi, *Mi)
            Xi, *Mi = self.image_augment_with_target(Xi, *Mi)
            Xi = self.image_augment(Xi)
            Xi, *Mi = to_pil(Xi, *Mi)

            if self.mask_transform is not None:
                Mi = [self.mask_transform(m) for m in Mi]

            if self.image_transform is not None:
                Xi = self.image_transform(Xi)

            Mi = torch.cat(Mi, dim=0)
            return Xi, Mi
        else:
            Xi = from_pil(Xi)
            Xi = self.image_augment(Xi)
            Xi = to_pil(Xi)

            if self.image_transform is not None:
                Xi = self.image_transform(Xi)
            return Xi
コード例 #4
0
    def __getitem__(self, index):
        Xi = self.X[0][index]

        if self.y is not None:
            Mi = self.y[0][index]
            CTi = self.y[1][index]
            CRi = self.y[2][index]

            if self.train_mode and self.image_augment_with_target is not None:
                Xi, Mi, CTi, CRi = from_pil(Xi, Mi, CTi, CRi)
                Xi, Mi, CTi, CRi = self.image_augment_with_target(Xi, Mi, CTi, CRi)
                Xi = self.image_augment(Xi)
                Xi, Mi, CTi, CRi = to_pil(Xi, Mi, CTi, CRi)

            if self.mask_transform is not None:
                Mi = self.mask_transform(Mi)
                CTi = self.mask_transform(CTi)
                CRi = self.mask_transform(CRi)

            if self.image_transform is not None:
                Xi = self.image_transform(Xi)

            return Xi, Mi, CTi, CRi
        else:
            if self.image_transform is not None:
                Xi = self.image_transform(Xi)
            return Xi
コード例 #5
0
    def __getitem__(self, index):
        img_filepath = self.X[index]

        Xi = self.load_image(img_filepath)
        if self.y is not None:
            mask_filepath = self.y[index, 0]
            contour_filepath = self.y[index, 1]
            center_filepath = self.y[index, 2]

            Mi = self.load_image(mask_filepath)
            CTi = self.load_image(contour_filepath)
            CRi = self.load_image(center_filepath)

            if self.train_mode and self.image_augment_with_target is not None:
                Xi, Mi, CTi, CRi = from_pil(Xi, Mi, CTi, CRi)
                Xi, Mi, CTi, CRi = self.image_augment_with_target(Xi, Mi, CTi, CRi)
                Xi = self.image_augment(Xi)
                Xi, Mi, CTi, CRi = to_pil(Xi, Mi, CTi, CRi)

            if self.image_transform is not None:
                Xi = self.image_transform(Xi)

            if self.mask_transform is not None:
                Mi = self.mask_transform(Mi)
                CTi = self.mask_transform(CTi)
                CRi = self.mask_transform(CRi)
            return Xi, Mi, CTi, CRi
        else:
            if self.image_transform is not None:
                Xi = self.image_transform(Xi)
            return Xi
コード例 #6
0
    def __getitem__(self, index):
        img_filepath = self.X[index]

        Xi = self.load_image(img_filepath)
        if self.y is not None:
            target_masks = []
            for i in range(y.shape[1]):
                filepath = self.y[index, i]
                mask = self.load_image(filepath)
                target_masks.append(mask)
            target_masks = [target[index] for target in self.y]
            data = [Xi] + target_masks

            if self.train_mode and self.image_augment_with_target is not None:
                data = from_pil(*data)
                data = self.image_augment_with_target(*data)
                data[0] = self.image_augment(data[0])
                data = to_pil(*data)

            if self.mask_transform is not None:
                data[1:] = [self.mask_transform(mask) for mask in data[1:]]

            if self.image_transform is not None:
                data[0] = self.image_transform(data[0])

            return data
        else:
            if self.image_transform is not None:
                Xi = self.image_transform(Xi)
            return Xi
コード例 #7
0
ファイル: loaders.py プロジェクト: vivian-wong/CS230UNet
    def __getitem__(self, index):
        img_filepath = self.X[index]
        Xi = self.load_image(img_filepath)

        if self.y is not None:
            mask_filepath = self.y[index]
            Mi = self.load_image(mask_filepath)

            if self.train_mode and self.image_augment_with_target is not None:
                Xi, Mi = from_pil(Xi, Mi)
                Xi, Mi = self.image_augment_with_target(Xi, Mi)
                if self.image_augment is not None:
                    Xi = self.image_augment(Xi)
                Xi, Mi = to_pil(Xi, Mi)

            if self.mask_transform is not None:
                Mi = self.mask_transform(Mi)

            if self.image_transform is not None:
                Xi = self.image_transform(Xi)
            return Xi, Mi
        else:
            if self.image_transform is not None:
                Xi = self.image_transform(Xi)
            return Xi
コード例 #8
0
ファイル: loaders.py プロジェクト: vivian-wong/CS230UNet
    def get_patches(self, X):
        patches, patch_ids, tta_angles, patch_y_coords, patch_x_coords, image_h, image_w = [], [], [], [], [], [], []
        for i, image in enumerate((X[0])):
            image = from_pil(image)
            h, w = image.shape[:2]
            for y_coord, x_coord, image_patch in generate_patches(
                    image, self.dataset_params.h,
                    self.dataset_params.patching_stride):
                for tta_rotation_angle, image_patch_tta in test_time_augmentation(
                        image_patch):
                    image_patch_tta = to_pil(image_patch_tta)
                    patches.append(image_patch_tta)
                    patch_ids.append(i)
                    tta_angles.append(tta_rotation_angle)
                    patch_y_coords.append(y_coord)
                    patch_x_coords.append(x_coord)
                    image_h.append(h)
                    image_w.append(w)

        patch_ids = pd.DataFrame({
            'patch_ids': patch_ids,
            'tta_angles': tta_angles,
            'y_coordinates': patch_y_coords,
            'x_coordinates': patch_x_coords,
            'image_h': image_h,
            'image_w': image_w
        })
        return [patches], patch_ids
コード例 #9
0
ファイル: loaders.py プロジェクト: vivian-wong/CS230UNet
    def __getitem__(self, index):
        Xi = self.X[0][index]

        if self.y is not None:
            target_masks = [target[index] for target in self.y]
            data = [Xi] + target_masks

            if self.train_mode and self.image_augment_with_target is not None:
                data = from_pil(*data)
                data = self.image_augment_with_target(*data)
                if self.image_augment is not None:
                    data[0] = self.image_augment(data[0])
                data = to_pil(*data)

            if self.mask_transform is not None:
                data[1:] = [self.mask_transform(mask) for mask in data[1:]]

            if self.image_transform is not None:
                data[0] = self.image_transform(data[0])

            return data
        else:
            if self.image_transform is not None:
                Xi = self.image_transform(Xi)
            return Xi
コード例 #10
0
    def __getitem__(self, index):
        img_filepath = self.X[index]
        Xi = self.load_image(img_filepath)

        if self.y is not None:
            mask_filepath = self.y[index]
            Mi = self.load_image(mask_filepath)
            distance_filepath = mask_filepath.replace("/masks/", "/distances/")
            distance_filepath = os.path.splitext(distance_filepath)[0]
            size_filepath = distance_filepath.replace("/distances/", "/sizes/")
            Di = self.load_joblib(distance_filepath)
            Di = Di.astype(np.uint16)
            Si = self.load_joblib(size_filepath).astype(np.uint16)
            Si = np.sqrt(Si).astype(np.uint16)
            Xi, Mi = from_pil(Xi, Mi)
            if self.image_augment_with_target is not None:
                Xi, Mi, Di, Si = self.image_augment_with_target(Xi, Mi, Di, Si)
            if self.image_augment is not None:
                Xi = self.image_augment(Xi)
            Xi, Mi, Di, Si = to_pil(Xi, Mi, Di, Si)

            if self.mask_transform is not None:
                Mi = self.mask_transform(Mi)
                Di = self.mask_transform(Di)
                Si = self.mask_transform(Si)
                Mi = torch.cat((Mi, Di, Si), dim=0)

            if self.image_transform is not None:
                Xi = self.image_transform(Xi)
            return Xi, Mi
        else:
            if self.image_transform is not None:
                Xi = self.image_transform(Xi)
            return Xi
コード例 #11
0
    def visualise_dataset(self):
        """
        Sanity check.
        """
        if self.config.mode == 'train':
            dataset_names = ["train", "valid"]
            loaders = [self.train_loader, self.valid_loader]
        else:
            dataset_names = ["test"]
            loaders = [self.test_loader]

        for name, loader in zip(dataset_names, loaders):
            print(name)
            x, y, has_target, target = loader.dataset[
                loader.sampler.indices[0]]
            print("y:", y.item())
            print("attention target:",
                  target.numpy() if has_target else "none")
            to_pil(x).save(f"{name}_0.png")
コード例 #12
0
ファイル: loaders.py プロジェクト: vivian-wong/CS230UNet
    def __getitem__(self, index):
        img_filepath = self.X[index]
        Xi = self.load_image(img_filepath)

        if self.y is not None:
            mask_filepath = self.y[index]
            Mi = self.load_image(mask_filepath)
            distance_filepath = mask_filepath.replace("/masks/",
                                                      "/distances/")[:-4]
            Di = self.load_distances(distance_filepath)
            Di = np.sum(Di, axis=2).astype(
                np.uint8
            )  # TODO: remove it when Di will be sum of distances to 2 closest objects

            if self.train_mode and self.image_augment_with_target is not None:
                Xi, Mi = from_pil(Xi, Mi)
                Xi, Mi, Di = self.image_augment_with_target(Xi, Mi, Di)
                if self.image_augment is not None:
                    Xi = self.image_augment(Xi)
                Xi, Mi, Di = to_pil(Xi, Mi, Di)

            if not self.train_mode:
                Di = to_pil(Di)

            if self.mask_transform is not None:
                Mi = self.mask_transform(Mi)
                Di = self.mask_transform(Di)
                Mi = torch.cat((Mi, Di), dim=0)

            if self.image_transform is not None:
                Xi = self.image_transform(Xi)
            return Xi, Mi
        else:
            if self.image_transform is not None:
                Xi = self.image_transform(Xi)
            return Xi
コード例 #13
0
    def __getitem__(self, index):
        img_filepath = self.X[index]
        Xi = self.load_image(img_filepath)

        Xi = from_pil(Xi)

        if self.tta_params is not None:
            tta_transform_specs = self.tta_params[index]
            Xi = test_time_augmentation_transform(Xi, tta_transform_specs)

        if self.image_augment is not None:
            Xi = self.image_augment(Xi)
        Xi = to_pil(Xi)

        if self.image_transform is not None:
            Xi = self.image_transform(Xi)

        return Xi
コード例 #14
0
    def __getitem__(self, index):
        if self.image_source == 'memory':
            load_func = self.load_from_memory
        elif self.image_source == 'disk':
            load_func = self.load_from_disk
        else:
            raise NotImplementedError(
                "Possible loading options: 'memory' and 'disk'!")

        Xi = load_func(self.X, index, filetype='png', grayscale=False)
        Xi = from_pil(Xi)

        if self.image_augment is not None:
            Xi = self.image_augment(Xi)

        if self.tta_params is not None:
            tta_transform_specs = self.tta_params[index]
            Xi = self.tta_transform(Xi, tta_transform_specs)
        Xi = to_pil(Xi)

        if self.image_transform is not None:
            Xi = self.image_transform(Xi)

        return Xi
コード例 #15
0
 def read_json(self, path):
     with open(path, 'r') as file:
         data = json.load(file)
     masks = [to_pil(binary_from_rle(rle)) for rle in data]
     return masks
コード例 #16
0
 def load_target(self, data_source, index, load_func):
     Mi = load_func(data_source, index, filetype='png', grayscale=True)
     Mi = from_pil(Mi)
     target = [to_pil(Mi == class_nr) for class_nr in [0, 1]]
     return target