예제 #1
0
    def test_dataloader(self):
        """Summary

        Returns:
            TYPE: Description
        """
        ds_test = CustomDataSet(folder=self.hparams.data,
                                train_or_valid='test',
                                size=np.inf,
                                hparams=self.hparams)

        ds_test.reset_state()
        ag_test = [
            imgaug.Resize(self.hparams.shape, interp=cv2.INTER_NEAREST),
            imgaug.ToFloat32(),
        ]
        # ds_test = AugmentImageComponent(ds_test, [imgaug.Albumentations(AB.CLAHE(tile_grid_size=(32, 32), always_apply=True, p=1)),], 0)
        ds_test = AugmentImageComponents(ds_test, ag_test, [0, 1])
        ds_test = BatchData(ds_test, self.hparams.batch, remainder=True)
        ds_test = MultiProcessRunner(ds_test, num_proc=4, num_prefetch=16)
        ds_test = PrintData(ds_test)
        ds_test = MapData(
            ds_test, lambda dp: [
                torch.tensor(dp[0][:, np.newaxis, :, :]).float(),
                torch.tensor(dp[1][:, np.newaxis, :, :]).float(),
            ])
        return ds_test
예제 #2
0
def read_and_augment_images(ds):
    def mapf(dp):
        fname = dp[0]
        im = cv2.imread(fname, cv2.IMREAD_COLOR).astype('float32')
        assert im is not None, dp[0]
        dp[0] = im

        # assume floatbox as input
        assert dp[1].dtype == np.float32
        dp[1] = box_to_point8(dp[1])

        dp.append(fname)
        return dp

    ds = MapData(ds, mapf)

    augs = [
        CustomResize(config.SHORT_EDGE_SIZE, config.MAX_SIZE),
        imgaug.Flip(horiz=True)
    ]
    ds = AugmentImageComponents(ds, augs, index=(0, ), coords_index=(1, ))

    def unmapf(points):
        boxes = point8_to_box(points)
        return boxes

    ds = MapDataComponent(ds, unmapf, 1)
    return ds
예제 #3
0
    def val_dataloader(self):
        """Summary

        Returns:
            TYPE: Description
        """
        ds_valid = CustomDataSet(folder=self.hparams.data,
                                 train_or_valid='valid',
                                 size=np.inf,
                                 hparams=self.hparams)

        ds_valid.reset_state()
        ag_valid = [
            imgaug.Resize(self.hparams.shape, interp=cv2.INTER_NEAREST),
            imgaug.ToFloat32(),
        ]
        ds_valid = AugmentImageComponent(ds_valid, [
            imgaug.Albumentations(AB.CLAHE(p=1)),
        ], 0)
        if self.hparams.types == 6:
            ds_valid = AugmentImageComponents(ds_valid, ag_valid,
                                              [0, 1, 2, 3, 4, 5, 6])
        elif self.hparams.types == 1:
            ds_valid = AugmentImageComponents(ds_valid, ag_valid, [0, 1])
        ds_valid = BatchData(ds_valid, self.hparams.batch, remainder=True)
        ds_valid = MultiProcessRunner(ds_valid, num_proc=4, num_prefetch=16)
        ds_valid = PrintData(ds_valid)
        if self.hparams.types == 6:
            ds_valid = MapData(
                ds_valid, lambda dp: [
                    torch.tensor(dp[0][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[1][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[2][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[3][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[4][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[5][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[6][:, np.newaxis, :, :]).float(),
                ])
        elif self.hparams.types == 1:
            ds_valid = MapData(
                ds_valid, lambda dp: [
                    torch.tensor(dp[0][:, np.newaxis, :, :]).float(),
                    torch.tensor(dp[1][:, np.newaxis, :, :]).float(),
                ])
        return ds_valid
예제 #4
0
파일: loader.py 프로젝트: haelkemary/xy_net
def train_generator(ds, shape_aug=None, input_aug=None, label_aug=None, batch_size=16, nr_procs=8):
    ### augment both the input and label
    ds = ds if shape_aug is None else AugmentImageComponents(ds, shape_aug, (0, 1), copy=True)
    ### augment just the input i.e index 0 within each yield of DatasetSerial
    ds = ds if input_aug is None else AugmentImageComponent(ds, input_aug, index=0, copy=False)
    ### augment just the output i.e index 1 within each yield of DatasetSerial
    ds = ds if label_aug is None else AugmentImageComponent(ds, label_aug, index=1, copy=True)
    #
    ds = BatchDataByShape(ds, batch_size, idx=0)
    ds = PrefetchDataZMQ(ds, nr_procs)
    return ds
예제 #5
0
파일: loader.py 프로젝트: haelkemary/xy_net
def valid_generator(ds, shape_aug=None, input_aug=None, label_aug=None, batch_size=16, nr_procs=1):
    ### augment both the input and label
    ds = ds if shape_aug is None else AugmentImageComponents(ds, shape_aug, (0, 1), copy=True)
    ### augment just the input
    ds = ds if input_aug is None else AugmentImageComponent(ds, input_aug, index=0, copy=False)
    ### augment just the output
    ds = ds if label_aug is None else AugmentImageComponent(ds, label_aug, index=1, copy=True)
    #
    ds = BatchData(ds, batch_size, remainder=True)
    ds = CacheData(ds) # cache all inference images 
    return ds
예제 #6
0
    def train_dataloader(self):
        """Summary

        Returns:
            TYPE: Description
        """
        ds_train = CustomDataSet(folder=self.hparams.data,
                                 train_or_valid='train',
                                 size=np.inf,
                                 hparams=self.hparams)
        ds_train.reset_state()
        ag_train = [
            imgaug.Affine(shear=10, interp=cv2.INTER_NEAREST),
            imgaug.Affine(translate_frac=(0.01, 0.02),
                          interp=cv2.INTER_NEAREST),
            imgaug.Affine(scale=(0.25, 1.0), interp=cv2.INTER_NEAREST),
            imgaug.RotationAndCropValid(max_deg=10, interp=cv2.INTER_NEAREST),
            imgaug.GoogleNetRandomCropAndResize(
                crop_area_fraction=(0.8, 1.0),
                aspect_ratio_range=(0.8, 1.2),
                interp=cv2.INTER_NEAREST,
                target_shape=self.hparams.shape),
            imgaug.Resize(self.hparams.shape, interp=cv2.INTER_NEAREST),
            imgaug.Flip(horiz=True, vert=False, prob=0.5),
            imgaug.Flip(horiz=False, vert=True, prob=0.5),
            imgaug.Transpose(prob=0.5),
            imgaug.Albumentations(AB.RandomRotate90(p=1)),
            imgaug.ToFloat32(),
        ]
        ds_train = AugmentImageComponent(
            ds_train,
            [
                # imgaug.Float32(),
                # imgaug.RandomChooseAug([
                #     imgaug.Albumentations(AB.IAAAdditiveGaussianNoise(p=0.25)),
                #     imgaug.Albumentations(AB.GaussNoise(p=0.25)),
                #     ]),
                # imgaug.ToUint8(),
                imgaug.RandomChooseAug([
                    imgaug.Albumentations(AB.Blur(blur_limit=4, p=0.25)),
                    imgaug.Albumentations(AB.MotionBlur(blur_limit=4, p=0.25)),
                    imgaug.Albumentations(AB.MedianBlur(blur_limit=4, p=0.25)),
                ]),
                imgaug.RandomChooseAug([
                    # imgaug.Albumentations(AB.IAASharpen(p=0.5)),
                    # imgaug.Albumentations(AB.IAAEmboss(p=0.5)),
                    imgaug.Albumentations(AB.RandomBrightnessContrast(p=0.5)),
                ]),
                imgaug.ToUint8(),
                imgaug.Albumentations(AB.CLAHE(tile_grid_size=(32, 32),
                                               p=0.5)),
            ],
            0)
        ds_train = AugmentImageComponents(ds_train, ag_train, [0, 1])

        ds_train = BatchData(ds_train, self.hparams.batch, remainder=True)
        if self.hparams.debug:
            ds_train = FixedSizeData(ds_train, 2)
        ds_train = MultiProcessRunner(ds_train, num_proc=4, num_prefetch=16)
        ds_train = PrintData(ds_train)

        ds_train = MapData(
            ds_train, lambda dp: [
                torch.tensor(dp[0][:, np.newaxis, :, :]).float(),
                torch.tensor(dp[1][:, np.newaxis, :, :]).float(),
            ])
        return ds_train