Esempio n. 1
0
    def train_dataloader(self):
        ds_train = MultiLabelDataset(folder=self.hparams.data,
                                     is_train='train',
                                     fname='covid_train_v5.csv',
                                     types=self.hparams.types,
                                     pathology=self.hparams.pathology,
                                     resize=int(self.hparams.shape),
                                     balancing=None)

        ds_train.reset_state()
        ag_train = [
            # imgaug.Albumentations(
            #     AB.SmallestMaxSize(self.hparams.shape, p=1.0)),
            imgaug.ColorSpace(mode=cv2.COLOR_GRAY2RGB),
            # imgaug.Affine(shear=10),
            imgaug.RandomChooseAug([
                imgaug.Albumentations(AB.Blur(blur_limit=4, p=0.25)),
                imgaug.Albumentations(AB.MotionBlur(blur_limit=4, p=0.25)),
                imgaug.Albumentations(AB.MedianBlur(blur_limit=4, p=0.25)),
            ]),
            imgaug.Albumentations(AB.CLAHE(tile_grid_size=(32, 32), p=0.5)),
            imgaug.RandomOrderAug([
                imgaug.Affine(shear=10,
                              border=cv2.BORDER_CONSTANT,
                              interp=cv2.INTER_AREA),
                imgaug.Affine(translate_frac=(0.01, 0.02),
                              border=cv2.BORDER_CONSTANT,
                              interp=cv2.INTER_AREA),
                imgaug.Affine(scale=(0.5, 1.0),
                              border=cv2.BORDER_CONSTANT,
                              interp=cv2.INTER_AREA),
            ]),
            imgaug.RotationAndCropValid(max_deg=10, interp=cv2.INTER_AREA),
            imgaug.GoogleNetRandomCropAndResize(
                crop_area_fraction=(0.8, 1.0),
                aspect_ratio_range=(0.8, 1.2),
                interp=cv2.INTER_AREA,
                target_shape=self.hparams.shape),
            imgaug.ColorSpace(mode=cv2.COLOR_RGB2GRAY),
            imgaug.ToFloat32(),
        ]
        ds_train = AugmentImageComponent(ds_train, ag_train, 0)
        # Label smoothing
        ag_label = [
            imgaug.BrightnessScale((0.8, 1.2), clip=False),
        ]
        # ds_train = AugmentImageComponent(ds_train, ag_label, 1)
        ds_train = BatchData(ds_train, self.hparams.batch, remainder=True)
        if self.hparams.debug:
            ds_train = FixedSizeData(ds_train, 2)
        ds_train = MultiProcessRunner(ds_train, num_proc=4, num_prefetch=16)
        ds_train = PrintData(ds_train)
        ds_train = MapData(
            ds_train, lambda dp: [
                torch.tensor(np.transpose(dp[0], (0, 3, 1, 2))),
                torch.tensor(dp[1]).float()
            ])
        return ds_train
Esempio n. 2
0
    def train_dataloader(self):
        """Summary

        Returns:
            TYPE: Description
        """
        ds_train = CustomDataSet(folder=self.hparams.data,
                                 train_or_valid='train',
                                 size=np.inf,
                                 hparams=self.hparams)
        ds_train.reset_state()
        ag_train = [
            imgaug.Affine(shear=10, interp=cv2.INTER_NEAREST),
            imgaug.Affine(translate_frac=(0.01, 0.02),
                          interp=cv2.INTER_NEAREST),
            imgaug.Affine(scale=(0.25, 1.0), interp=cv2.INTER_NEAREST),
            imgaug.RotationAndCropValid(max_deg=10, interp=cv2.INTER_NEAREST),
            imgaug.GoogleNetRandomCropAndResize(
                crop_area_fraction=(0.8, 1.0),
                aspect_ratio_range=(0.8, 1.2),
                interp=cv2.INTER_NEAREST,
                target_shape=self.hparams.shape),
            imgaug.Resize(self.hparams.shape, interp=cv2.INTER_NEAREST),
            imgaug.Flip(horiz=True, vert=False, prob=0.5),
            imgaug.Flip(horiz=False, vert=True, prob=0.5),
            imgaug.Transpose(prob=0.5),
            imgaug.Albumentations(AB.RandomRotate90(p=1)),
            imgaug.ToFloat32(),
        ]
        ds_train = AugmentImageComponent(
            ds_train,
            [
                # imgaug.Float32(),
                # imgaug.RandomChooseAug([
                #     imgaug.Albumentations(AB.IAAAdditiveGaussianNoise(p=0.25)),
                #     imgaug.Albumentations(AB.GaussNoise(p=0.25)),
                #     ]),
                # imgaug.ToUint8(),
                imgaug.RandomChooseAug([
                    imgaug.Albumentations(AB.Blur(blur_limit=4, p=0.25)),
                    imgaug.Albumentations(AB.MotionBlur(blur_limit=4, p=0.25)),
                    imgaug.Albumentations(AB.MedianBlur(blur_limit=4, p=0.25)),
                ]),
                imgaug.RandomChooseAug([
                    # imgaug.Albumentations(AB.IAASharpen(p=0.5)),
                    # imgaug.Albumentations(AB.IAAEmboss(p=0.5)),
                    imgaug.Albumentations(AB.RandomBrightnessContrast(p=0.5)),
                ]),
                imgaug.ToUint8(),
                imgaug.Albumentations(AB.CLAHE(tile_grid_size=(32, 32),
                                               p=0.5)),
            ],
            0)
        ds_train = AugmentImageComponents(ds_train, ag_train, [0, 1])

        ds_train = BatchData(ds_train, self.hparams.batch, remainder=True)
        if self.hparams.debug:
            ds_train = FixedSizeData(ds_train, 2)
        ds_train = MultiProcessRunner(ds_train, num_proc=4, num_prefetch=16)
        ds_train = PrintData(ds_train)

        ds_train = MapData(
            ds_train, lambda dp: [
                torch.tensor(dp[0][:, np.newaxis, :, :]).float(),
                torch.tensor(dp[1][:, np.newaxis, :, :]).float(),
            ])
        return ds_train