Beispiel #1
0
def strong_aug(p=0.8):
    return Compose([
        # RandomRotate90(),
        # Flip(),
        # Transpose(),
        OneOf([
            IAAAdditiveGaussianNoise(),
            GaussNoise(),
        ], p=0.2),
        OneOf([ #模糊
            MotionBlur(p=0.5),
            MedianBlur(blur_limit=3, p=0.5),
            Blur(blur_limit=3, p=0.5),
            JpegCompression(p=1,quality_lower=7,quality_upper=40)
        ], p=1),
        # ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
        OneOf([ 
            IAAPiecewiseAffine(p=1,scale=(0.005, 0.01), nb_rows=4, nb_cols=4),
            IAAPerspective(p=1,scale=(random.uniform(0.01,0.03),random.uniform(0.01, 0.03))),
            ElasticTransform(p=1,alpha=random.randint(50,100), sigma=random.randint(8,13), alpha_affine=0,border_mode=3),
        ], p=0.2),
        OneOf([ 
            ElasticTransform(p=1,alpha=random.randint(50,100), sigma=random.randint(8,13), alpha_affine=0,border_mode=3),
        ], p=0.6),
        OneOf([ 
            OpticalDistortion(p=1,distort_limit=0.2,border_mode=3),
            # GridDistortion(p=1,distort_limit=0.1,border_mode=3),
        ], p=0.1),        
        OneOf([
            CLAHE(clip_limit=2,p=0.5),
            # IAASharpen(),
            IAAEmboss(p=0.5),
            RandomBrightnessContrast(p=1), #随机调整亮度饱和度,和下一个区别?
            HueSaturationValue(p=1), #随机调整hsv值
            RGBShift(p=0.5), #随机调整rgb值
            ChannelShuffle(p=0.5), #RGB通道调换
            InvertImg(p=0.1), #255-像素值,反转图像
        ], p=0.5),    

    ], p=p) 
Beispiel #2
0
    def load_image(self, image_index):
        """
        Load an image at the image_index.
        """
        image = cv2.imread(self.image_path(image_index))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.with_aug:
            aug = Compose([
                OneOf([
                    IAAAdditiveGaussianNoise(),
                    GaussNoise(),
                ], p=0.2),
                OneOf([
                    MotionBlur(p=.2),
                    MedianBlur(blur_limit=3, p=0.1),
                    Blur(blur_limit=3, p=0.1),
                ],
                      p=0.2),
                OneOf([
                    OpticalDistortion(p=0.3),
                    GridDistortion(p=.1),
                    IAAPiecewiseAffine(p=0.3),
                ],
                      p=0.2),
                OneOf([
                    CLAHE(clip_limit=2),
                    IAASharpen(),
                    IAAEmboss(),
                    RandomBrightnessContrast(),
                ],
                      p=0.2),
                HueSaturationValue(p=0.5),
            ],
                          p=0.7)
            image = aug(image=image)['image']

        if self.resized_h and self.resized_w:
            image = cv2.resize(image, (self.resized_w, self.resized_h),
                               interpolation=cv2.INTER_AREA)
        return image
def aug_train_heavy(resolution, p=1.0):
    return Compose([
        Resize(resolution, resolution),
        OneOf([
            RandomRotate90(),
            Flip(),
            Transpose(),
            HorizontalFlip(),
            VerticalFlip()
        ],
              p=1.0),
        OneOf([
            IAAAdditiveGaussianNoise(),
            GaussNoise(),
        ], p=0.5),
        OneOf([
            MotionBlur(p=.2),
            MedianBlur(blur_limit=3, p=0.1),
            Blur(blur_limit=3, p=0.1),
        ],
              p=0.2),
        ShiftScaleRotate(
            shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            IAAPiecewiseAffine(p=0.3),
        ],
              p=0.1),
        OneOf([
            CLAHE(clip_limit=2),
            IAASharpen(),
            IAAEmboss(),
            RandomBrightnessContrast(),
        ],
              p=0.3),
        HueSaturationValue(p=0.3),
        Normalize()
    ],
                   p=p)
Beispiel #4
0
def data_augmentation(image, roi):
    if to_apply(p=0.5):
        image = aorta_boot(image, roi, min_boot=0.25, max_boot=0.60)

    aug = Compose([
        OneOf([
            ElasticTransform(alpha=1.,
                             sigma=50,
                             alpha_affine=10,
                             interpolation=1,
                             border_mode=0,
                             p=0.25),
            GridDistortion(num_steps=50,
                           distort_limit=0.2,
                           interpolation=1,
                           border_mode=0,
                           p=0.25),
            IAAPiecewiseAffine(scale=(0.01, 0.03),
                               nb_rows=3,
                               nb_cols=3,
                               order=1,
                               cval=0,
                               mode='constant',
                               p=0.25),
            OpticalDistortion(distort_limit=0,
                              shift_limit=25,
                              interpolation=1,
                              border_mode=0,
                              p=0.25)
        ],
              p=0.3),
        IAAAdditiveGaussianNoise(
            loc=0, scale=(10, 25), per_channel=False, p=0.3)
    ],
                  p=0.7)

    augmented = aug(image=image, mask=roi)

    return augmented['image'], augmented['mask']
Beispiel #5
0
def _strong_aug(crop_size, p):
    return Compose(
        [
            # RandomCrop(crop_size, crop_size),
            PadIfNeeded(crop_size, crop_size),
            RandomRotate90(),
            Flip(),
            Transpose(),
            OneOf([IAAAdditiveGaussianNoise(), GaussNoise()], p=0.2),
            OneOf(
                [
                    MotionBlur(p=0.2),
                    MedianBlur(blur_limit=3, p=0.1),
                    Blur(blur_limit=3, p=0.1),
                ],
                p=0.2,
            ),
            ShiftScaleRotate(shift_limit=0.0625 * 2, scale_limit=0.3, rotate_limit=45, p=0.75),
            OneOf(
                [
                    OpticalDistortion(p=0.3),
                    GridDistortion(p=0.1),
                    IAAPiecewiseAffine(p=0.3),
                ],
                p=0.2,
            ),
            OneOf(
                [
                    CLAHE(clip_limit=2),
                    IAASharpen(),
                    IAAEmboss(),
                    RandomBrightnessContrast,
                ],
                p=0.3,
            ),
            HueSaturationValue(p=0.15),
        ],
        p=p,
    )
Beispiel #6
0
    def six_channel_transform(self, arr):

        aug = Compose(
            [
                # Resize(height=self.size, width=self.size),
                RandomRotate90(),
                Flip(),
                Transpose(),
                OneOf([
                    IAAAdditiveGaussianNoise(),
                    GaussNoise(),
                ], p=0.2),
                OneOf([
                    MotionBlur(p=.2),
                    MedianBlur(blur_limit=3, p=0.1),
                    Blur(blur_limit=3, p=0.1),
                ],
                      p=0.2),
                ShiftScaleRotate(shift_limit=0.0625,
                                 scale_limit=0.2,
                                 rotate_limit=45,
                                 p=0.2),
                OneOf([
                    OpticalDistortion(p=0.3),
                    GridDistortion(p=.1),
                    IAAPiecewiseAffine(p=0.3),
                ],
                      p=0.2),
                OneOf([
                    IAASharpen(),
                    IAAEmboss(),
                    RandomBrightnessContrast(),
                ],
                      p=0.3),
            ],
            p=1)

        ret = aug(image=arr)['image']
        return ret
Beispiel #7
0
def alb_transform_train(imsize=256, num_channels=4, p=1):
    albumentations_transform = Compose(
        [
            Resize(imsize, imsize),
            RandomRotate90(),
            Flip(),
            Transpose(),
            OneOf([
                IAAAdditiveGaussianNoise(),
                GaussNoise(),
            ], p=0.2),
            OneOf([
                MotionBlur(p=.2),
                MedianBlur(blur_limit=3, p=.1),
                Blur(blur_limit=3, p=.1),
            ],
                  p=0.2),
            ShiftScaleRotate(
                shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=.2),
            OneOf([
                OpticalDistortion(p=0.3),
                GridDistortion(p=.1),
                IAAPiecewiseAffine(p=0.3),
            ],
                  p=0.2),
            OneOf(
                [
                    # CLAHE(clip_limit=2),
                    IAASharpen(),
                    IAAEmboss(),
                    RandomContrast(),
                    RandomBrightness(),
                ],
                p=0.3),
            Normalize(mean=[0.485, 0.456, 0.406, 0.456][:num_channels],
                      std=[0.229, 0.224, 0.225, 0.224][:num_channels])
        ],
        p=p)
    return albumentations_transform
def get_transforms(phase_config):
    list_transforms = []
    if phase_config.Noise:
        list_transforms.append(
            OneOf([
                GaussNoise(),
                IAAAdditiveGaussianNoise(),
            ], p=0.5), )
    if phase_config.Contrast:
        list_transforms.append(
            OneOf([
                RandomContrast(0.5),
                RandomGamma(),
                RandomBrightness(),
            ],
                  p=0.5), )
    if phase_config.Blur:
        list_transforms.append(
            OneOf([
                MotionBlur(p=.2),
                MedianBlur(blur_limit=3, p=0.1),
                Blur(blur_limit=3, p=0.1),
            ],
                  p=0.5))
    if phase_config.Distort:
        list_transforms.append(
            OneOf([
                OpticalDistortion(p=0.3),
                GridDistortion(p=.1),
                IAAPiecewiseAffine(p=0.3),
            ],
                  p=0.5))
    list_transforms.extend([
        Normalize(mean=phase_config.mean, std=phase_config.std, p=1),
        ToTensor(),
    ])

    return Compose(list_transforms)
Beispiel #9
0
def augment(img):
    img = np.array(img * 255, dtype=np.uint8)
    print(img.shape)

    generator = Compose([
        Resize(240, 240),
        HorizontalFlip(),
        OneOf([
            IAAAdditiveGaussianNoise(),
            GaussNoise(),
        ], p=0.2),
        OneOf([
            MotionBlur(p=0.2),
            MedianBlur(blur_limit=3, p=0.1),
            Blur(blur_limit=3, p=0.1),
        ],
              p=0.2),
        ShiftScaleRotate(
            shift_limit=0.0625, scale_limit=0.2, rotate_limit=75, p=0.5),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=0.1),
            IAAPiecewiseAffine(p=0.3),
        ],
              p=0.2),
        OneOf([
            CLAHE(clip_limit=2),
            IAASharpen(),
            IAAEmboss(),
            RandomBrightnessContrast(),
        ],
              p=0.3),
        RandomCrop(224, 224),
        HueSaturationValue(p=0.3),
    ],
                        p=1)
    img = generator(image=img)['image']
    return img
def strong_aug(p=1):
    return Compose(
        [
            Rotate(limit=20, p=0.8),
            #         RandomRotate90(),
            #         HorizontalFlip(p=0.4),
            #         Transpose(),
            OneOf([
                IAAAdditiveGaussianNoise(),
                GaussNoise(),
            ], p=0.7),
            OneOf([
                MotionBlur(p=.7),
                MedianBlur(blur_limit=7, p=0.7),
                Blur(blur_limit=7, p=0.7),
            ],
                  p=0.4),
            ShiftScaleRotate(
                shift_limit=0.0625, scale_limit=0.2, rotate_limit=35, p=0.2),
            OneOf([
                OpticalDistortion(p=0.3),
                GridDistortion(p=.1),
                IAAPiecewiseAffine(p=0.3),
            ],
                  p=0.4),
            OneOf([
                CLAHE(),
                IAASharpen(),
                IAAEmboss(),
                RandomBrightnessContrast(0.3, 0.7),
                JpegCompression(70),
                RandomBrightness(-0.6)
            ],
                  p=1),
            HueSaturationValue(p=0.3),
        ],
        p=p)
Beispiel #11
0
def strong_aug(p=0.5):
    return Compose(
        [
            RandomRotate90(),
            Flip(),
            Transpose(),
            ElasticTransform(p=1.0),
            OneOf([
                IAAAdditiveGaussianNoise(),
                GaussNoise(),
            ], p=0.2),
            OneOf([
                MotionBlur(p=0.2),
                MedianBlur(blur_limit=3, p=0.1),
                Blur(blur_limit=3, p=0.1),
            ],
                  p=0.2),
            ShiftScaleRotate(
                shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
            OneOf([
                OpticalDistortion(p=0.3),
                GridDistortion(p=0.1),
                IAAPiecewiseAffine(p=0.3),
            ],
                  p=0.2),
            OneOf(
                [
                    # CLAHE(clip_limit=2),
                    IAASharpen(),
                    IAAEmboss(),
                    RandomContrast(),
                    RandomBrightness(),
                ],
                p=0.3),
            # HueSaturationValue(p=0.3),
        ],
        p=p)
Beispiel #12
0
 def StrongAug(self, image, p=1):
     image = image.astype(np.uint8)
     aug = Compose([
         RandomRotate90(),
         Flip(),
         Transpose(),
         OneOf([
             IAAAdditiveGaussianNoise(),
             GaussNoise(),
         ], p=0.2),
         OneOf([
             MotionBlur(p=.2),
             MedianBlur(blur_limit=3, p=.1),
             Blur(blur_limit=3, p=.1),
         ],
               p=0.2),
         ShiftScaleRotate(
             shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=.2),
         OneOf([
             OpticalDistortion(p=0.3),
             GridDistortion(p=.1),
             IAAPiecewiseAffine(p=0.3),
         ],
               p=0.2),
         OneOf([
             CLAHE(clip_limit=2),
             IAASharpen(),
             IAAEmboss(),
             RandomContrast(),
             RandomBrightness(),
         ],
               p=0.3),
         HueSaturationValue(p=0.3),
     ],
                   p=p)
     output = aug(image=image)['image']
     return output.astype(np.float32)
Beispiel #13
0
def strong_aug(p=.5, config=None):
    return Compose([
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        RandomRotate90(p=0.5),
        Transpose(p=0.5),
        OneOf([
            IAAAdditiveGaussianNoise(),
            GaussNoise(),
        ], p=0.2),
        OneOf([
            MotionBlur(p=.2),
            MedianBlur(blur_limit=3, p=.1),
            Blur(blur_limit=3, p=.1),
        ],
              p=0.2),
        ShiftScaleRotate(shift_limit=0.001,
                         scale_limit=0.1,
                         rotate_limit=20,
                         p=.2),
        Compose([
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            IAAPiecewiseAffine(p=0.3),
        ],
                p=0.2),
        OneOf([
            CLAHE(clip_limit=2),
            IAASharpen(),
            IAAEmboss(),
            RandomContrast(),
            RandomBrightness(),
        ],
              p=0.3),
        HueSaturationValue(p=0.3),
        RandomCrop(height=224, width=224, p=1.0),
    ])
Beispiel #14
0
    def __call__(self, original_image):
        self.augmentation_pipeline = Compose(
            [
                HorizontalFlip(p=0.5),
                ShiftScaleRotate(rotate_limit=25.0, p=0.7),
                OneOf([
                       IAASharpen(p=1),
                       Blur(p=1)], p=0.5),
                IAAPiecewiseAffine(p=0.5),
                Resize(self.height, self.width, always_apply=True),
                Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                    always_apply=True
                ),
                ToTensor()
            ]
        )

        augmented = self.augmentation_pipeline(
            image=original_image
        )
        image = augmented["image"]
        return image
Beispiel #15
0
def train_compose(config):
    compose = Compose([
        Resize(int(1.14 * config.train.img_size), int(1.14 * config.train.img_size)),
        HorizontalFlip(),
        Transpose(),
        CoarseDropout(p=0.3),
        OneOf([
            RandomBrightnessContrast(brightness_limit=0.6),
            RandomGamma(),
        ], p=0.6),
        ShiftScaleRotate(rotate_limit=45),  # 75
        OneOf([
            CLAHE(p=0.5),
            GaussianBlur(3, p=0.3),
            IAASharpen(alpha=(0.2, 0.3), p=0.3),
        ], p=0.8),  # 1
        OneOf([
            # 畸变相关操作
            OpticalDistortion(p=0.3),
            GridDistortion(p=0.2),
            IAAPiecewiseAffine(p=0.3),
        ], p=0.2),
        # add
        OneOf([
            MotionBlur(p=0.3),
            MedianBlur(blur_limit=3, p=0.3),
            Blur(blur_limit=3, p=0.3),
        ], p=0.8),  # 0.8
        Normalize(
            mean=config.train.mean,
            std=config.train.std,
        ),
        RandomCrop(config.train.img_size, config.train.img_size),
        ToTensorV2(),
    ])
    return compose
def get_default_albumentations():
    # Аналог get_default_imgaug c использованием albumentations вместо imgaug.
    # https://albumentations.readthedocs.io/en/latest/examples.html
    return Compose(
        [
            #RandomRotate90(),
            #Flip(),
            #Transpose(),
            OneOf([
                IAAAdditiveGaussianNoise(),
                GaussNoise(),
            ], p=0.2),
            OneOf([
                MotionBlur(p=0.2),
                MedianBlur(blur_limit=3, p=0.1),
                Blur(blur_limit=3, p=0.1),
            ],
                  p=0.2),
            ShiftScaleRotate(
                shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
            OneOf([
                OpticalDistortion(p=0.3),
                GridDistortion(p=0.1),
                IAAPiecewiseAffine(p=0.3),
            ],
                  p=0.2),
            OneOf([
                CLAHE(clip_limit=2),
                IAASharpen(),
                IAAEmboss(),
                RandomBrightnessContrast(),
            ],
                  p=0.3),
            HueSaturationValue(p=0.3),
        ],
        p=0.9)
Beispiel #17
0
def get_train_transforms():
    augmentations = Compose([
        Resize(236, 236),
        HorizontalFlip(),
        OneOf([
            IAAAdditiveGaussianNoise(p=.5),
            GaussNoise(p=.4),
        ], p=0.6),
        OneOf([
            MotionBlur(p=0.6),
            Blur(blur_limit=3, p=0.2),
        ], p=0.6),
        ShiftScaleRotate(shift_limit=0.0725,
                         scale_limit=0.2,
                         rotate_limit=45,
                         p=0.6),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=0.4),
            IAAPiecewiseAffine(p=0.2),
        ],
              p=0.6),
        OneOf([
            CLAHE(clip_limit=2),
            IAASharpen(),
            IAAEmboss(),
            RandomBrightnessContrast(),
        ],
              p=0.45),
        HueSaturationValue(p=0.3),
        CenterCrop(224, 224),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensor()
    ])

    return lambda img: augmentations(image=np.array(img))
Beispiel #18
0
 def __getitem__(self, idx):
     label1 = int(self.df[idx][1])
     c = str(self.df[idx][0])
     image = cv2.imread(c)
     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
     image = RandomRotate90()(image=image)['image']
     image = Flip()(image=image)['image']
     image = JpegCompression(quality_lower=9,
                             quality_upper=10)(image=image)['image']
     image = Transpose()(image=image)['image']
     image = Downscale()(image=image)['image']
     image = IAAAdditiveGaussianNoise()(image=image)['image']
     image = Blur(blur_limit=7)(image=image)['image']
     image = ShiftScaleRotate(shift_limit=0.0625,
                              scale_limit=0.2,
                              rotate_limit=45)(image=image)['image']
     image = IAAPiecewiseAffine()(image=image)['image']
     image = RGBShift()(image=image)['image']
     image = RandomBrightnessContrast()(image=image)['image']
     image = HueSaturationValue()(image=image)['image']
     image = transforms.ToPILImage()(image)
     if self.transform:
         image = self.transform(image)
     return image, label1
def get_aug(p=1.0):
    return Compose([
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        ShiftScaleRotate(shift_limit=0.0625,
                         scale_limit=0.2,
                         rotate_limit=15,
                         p=0.9,
                         border_mode=cv2.BORDER_REFLECT),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            IAAPiecewiseAffine(p=0.3),
        ],
              p=0.3),
        OneOf([
            HueSaturationValue(10, 15, 10),
            CLAHE(clip_limit=2),
            RandomBrightnessContrast(),
        ],
              p=0.3),
    ],
                   p=p)
Beispiel #20
0
def aug_mega_hardcore(p=.95):
    return Compose([
        OneOf([CLAHE(clip_limit=2),
               IAASharpen(p=.25),
               IAAEmboss(p=.25)],
              p=.35),
        OneOf([
            IAAAdditiveGaussianNoise(p=.3),
            GaussNoise(p=.7),
        ], p=.5),
        RandomRotate90(),
        Flip(),
        Transpose(),
        OneOf([
            MotionBlur(p=.2),
            MedianBlur(blur_limit=3, p=.3),
            Blur(blur_limit=3, p=.5),
        ],
              p=.4),
        OneOf([
            RandomContrast(p=.5),
            RandomBrightness(p=.5),
        ], p=.4),
        ShiftScaleRotate(
            shift_limit=.0, scale_limit=.45, rotate_limit=45, p=.7),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=0.2),
            ElasticTransform(p=.2),
            IAAPerspective(p=.2),
            IAAPiecewiseAffine(p=.3),
        ],
              p=.6),
        HueSaturationValue(p=.5)
    ],
                   p=p)
def transform(image, mask, image_name, mask_name):

    x, y = image, mask

    rand = random.uniform(0, 1)
    if (rand > 0.5):

        images_name = [f"{image_name}"]
        masks_name = [f"{mask_name}"]
        images_aug = [x]
        masks_aug = [y]

        it = iter(images_name)
        it2 = iter(images_aug)
        imagedict = dict(zip(it, it2))

        it = iter(masks_name)
        it2 = iter(masks_aug)
        masksdict = dict(zip(it, it2))

        return imagedict, masksdict

    mask_density = np.count_nonzero(y)

    ## Augmenting only images with Gloms
    if (mask_density > 0):
        try:
            h, w, c = x.shape
        except Exception as e:
            image = image[:-1]
            x, y = image, mask
            h, w, c = x.shape

        aug = Blur(p=1, blur_limit=3)
        augmented = aug(image=x, mask=y)
        x0 = augmented['image']
        y0 = augmented['mask']

        #    aug = CenterCrop(p=1, height=32, width=32)
        #    augmented = aug(image=x, mask=y)
        #    x1 = augmented['image']
        #    y1 = augmented['mask']

        ## Horizontal Flip
        aug = HorizontalFlip(p=1)
        augmented = aug(image=x, mask=y)
        x2 = augmented['image']
        y2 = augmented['mask']

        aug = VerticalFlip(p=1)
        augmented = aug(image=x, mask=y)
        x3 = augmented['image']
        y3 = augmented['mask']

        #      aug = Normalize(p=1)
        #      augmented = aug(image=x, mask=y)
        #      x4 = augmented['image']
        #      y4 = augmented['mask']

        aug = Transpose(p=1)
        augmented = aug(image=x, mask=y)
        x5 = augmented['image']
        y5 = augmented['mask']

        aug = RandomGamma(p=1)
        augmented = aug(image=x, mask=y)
        x6 = augmented['image']
        y6 = augmented['mask']

        ## Optical Distortion
        aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
        augmented = aug(image=x, mask=y)
        x7 = augmented['image']
        y7 = augmented['mask']

        ## Grid Distortion
        aug = GridDistortion(p=1)
        augmented = aug(image=x, mask=y)
        x8 = augmented['image']
        y8 = augmented['mask']

        aug = RandomGridShuffle(p=1)
        augmented = aug(image=x, mask=y)
        x9 = augmented['image']
        y9 = augmented['mask']

        aug = HueSaturationValue(p=1)
        augmented = aug(image=x, mask=y)
        x10 = augmented['image']
        y10 = augmented['mask']

        #        aug = PadIfNeeded(p=1)
        #        augmented = aug(image=x, mask=y)
        #        x11 = augmented['image']
        #        y11 = augmented['mask']

        aug = RGBShift(p=1)
        augmented = aug(image=x, mask=y)
        x12 = augmented['image']
        y12 = augmented['mask']

        ## Random Brightness
        aug = RandomBrightness(p=1)
        augmented = aug(image=x, mask=y)
        x13 = augmented['image']
        y13 = augmented['mask']

        ## Random  Contrast
        aug = RandomContrast(p=1)
        augmented = aug(image=x, mask=y)
        x14 = augmented['image']
        y14 = augmented['mask']

        #aug = MotionBlur(p=1)
        #augmented = aug(image=x, mask=y)
        #   x15 = augmented['image']
        #  y15 = augmented['mask']

        aug = MedianBlur(p=1, blur_limit=5)
        augmented = aug(image=x, mask=y)
        x16 = augmented['image']
        y16 = augmented['mask']

        aug = GaussianBlur(p=1, blur_limit=3)
        augmented = aug(image=x, mask=y)
        x17 = augmented['image']
        y17 = augmented['mask']

        aug = GaussNoise(p=1)
        augmented = aug(image=x, mask=y)
        x18 = augmented['image']
        y18 = augmented['mask']

        aug = GlassBlur(p=1)
        augmented = aug(image=x, mask=y)
        x19 = augmented['image']
        y19 = augmented['mask']

        aug = CLAHE(clip_limit=1.0,
                    tile_grid_size=(8, 8),
                    always_apply=False,
                    p=1)
        augmented = aug(image=x, mask=y)
        x20 = augmented['image']
        y20 = augmented['mask']

        aug = ChannelShuffle(p=1)
        augmented = aug(image=x, mask=y)
        x21 = augmented['image']
        y21 = augmented['mask']

        aug = ToGray(p=1)
        augmented = aug(image=x, mask=y)
        x22 = augmented['image']
        y22 = augmented['mask']

        aug = ToSepia(p=1)
        augmented = aug(image=x, mask=y)
        x23 = augmented['image']
        y23 = augmented['mask']

        aug = JpegCompression(p=1)
        augmented = aug(image=x, mask=y)
        x24 = augmented['image']
        y24 = augmented['mask']

        aug = ImageCompression(p=1)
        augmented = aug(image=x, mask=y)
        x25 = augmented['image']
        y25 = augmented['mask']

        aug = Cutout(p=1)
        augmented = aug(image=x, mask=y)
        x26 = augmented['image']
        y26 = augmented['mask']

        #       aug = CoarseDropout(p=1, max_holes=8, max_height=32, max_width=32)
        #       augmented = aug(image=x, mask=y)
        #       x27 = augmented['image']
        #       y27 = augmented['mask']

        #       aug = ToFloat(p=1)
        #       augmented = aug(image=x, mask=y)
        #       x28 = augmented['image']
        #       y28 = augmented['mask']

        aug = FromFloat(p=1)
        augmented = aug(image=x, mask=y)
        x29 = augmented['image']
        y29 = augmented['mask']

        ## Random Brightness and Contrast
        aug = RandomBrightnessContrast(p=1)
        augmented = aug(image=x, mask=y)
        x30 = augmented['image']
        y30 = augmented['mask']

        aug = RandomSnow(p=1)
        augmented = aug(image=x, mask=y)
        x31 = augmented['image']
        y31 = augmented['mask']

        aug = RandomRain(p=1)
        augmented = aug(image=x, mask=y)
        x32 = augmented['image']
        y32 = augmented['mask']

        aug = RandomFog(p=1)
        augmented = aug(image=x, mask=y)
        x33 = augmented['image']
        y33 = augmented['mask']

        aug = RandomSunFlare(p=1)
        augmented = aug(image=x, mask=y)
        x34 = augmented['image']
        y34 = augmented['mask']

        aug = RandomShadow(p=1)
        augmented = aug(image=x, mask=y)
        x35 = augmented['image']
        y35 = augmented['mask']

        aug = Lambda(p=1)
        augmented = aug(image=x, mask=y)
        x36 = augmented['image']
        y36 = augmented['mask']

        aug = ChannelDropout(p=1)
        augmented = aug(image=x, mask=y)
        x37 = augmented['image']
        y37 = augmented['mask']

        aug = ISONoise(p=1)
        augmented = aug(image=x, mask=y)
        x38 = augmented['image']
        y38 = augmented['mask']

        aug = Solarize(p=1)
        augmented = aug(image=x, mask=y)
        x39 = augmented['image']
        y39 = augmented['mask']

        aug = Equalize(p=1)
        augmented = aug(image=x, mask=y)
        x40 = augmented['image']
        y40 = augmented['mask']

        aug = Posterize(p=1)
        augmented = aug(image=x, mask=y)
        x41 = augmented['image']
        y41 = augmented['mask']

        aug = Downscale(p=1)
        augmented = aug(image=x, mask=y)
        x42 = augmented['image']
        y42 = augmented['mask']

        aug = MultiplicativeNoise(p=1)
        augmented = aug(image=x, mask=y)
        x43 = augmented['image']
        y43 = augmented['mask']

        aug = FancyPCA(p=1)
        augmented = aug(image=x, mask=y)
        x44 = augmented['image']
        y44 = augmented['mask']

        #       aug = MaskDropout(p=1)
        #       augmented = aug(image=x, mask=y)
        #       x45 = augmented['image']
        #       y45 = augmented['mask']

        aug = GridDropout(p=1)
        augmented = aug(image=x, mask=y)
        x46 = augmented['image']
        y46 = augmented['mask']

        aug = ColorJitter(p=1)
        augmented = aug(image=x, mask=y)
        x47 = augmented['image']
        y47 = augmented['mask']

        ## ElasticTransform
        aug = ElasticTransform(p=1,
                               alpha=120,
                               sigma=512 * 0.05,
                               alpha_affine=512 * 0.03)
        augmented = aug(image=x, mask=y)
        x50 = augmented['image']
        y50 = augmented['mask']

        aug = CropNonEmptyMaskIfExists(p=1, height=22, width=32)
        augmented = aug(image=x, mask=y)
        x51 = augmented['image']
        y51 = augmented['mask']

        aug = IAAAffine(p=1)
        augmented = aug(image=x, mask=y)
        x52 = augmented['image']
        y52 = augmented['mask']

        #        aug = IAACropAndPad(p=1)
        #        augmented = aug(image=x, mask=y)
        #        x53 = augmented['image']
        #        y53 = augmented['mask']

        aug = IAAFliplr(p=1)
        augmented = aug(image=x, mask=y)
        x54 = augmented['image']
        y54 = augmented['mask']

        aug = IAAFlipud(p=1)
        augmented = aug(image=x, mask=y)
        x55 = augmented['image']
        y55 = augmented['mask']

        aug = IAAPerspective(p=1)
        augmented = aug(image=x, mask=y)
        x56 = augmented['image']
        y56 = augmented['mask']

        aug = IAAPiecewiseAffine(p=1)
        augmented = aug(image=x, mask=y)
        x57 = augmented['image']
        y57 = augmented['mask']

        aug = LongestMaxSize(p=1)
        augmented = aug(image=x, mask=y)
        x58 = augmented['image']
        y58 = augmented['mask']

        aug = NoOp(p=1)
        augmented = aug(image=x, mask=y)
        x59 = augmented['image']
        y59 = augmented['mask']

        #       aug = RandomCrop(p=1, height=22, width=22)
        #       augmented = aug(image=x, mask=y)
        #       x61 = augmented['image']
        #       y61 = augmented['mask']

        #      aug = RandomResizedCrop(p=1, height=22, width=20)
        #      augmented = aug(image=x, mask=y)
        #      x63 = augmented['image']
        #      y63 = augmented['mask']

        aug = RandomScale(p=1)
        augmented = aug(image=x, mask=y)
        x64 = augmented['image']
        y64 = augmented['mask']

        #      aug = RandomSizedCrop(p=1, height=22, width=20, min_max_height = [32,32])
        #      augmented = aug(image=x, mask=y)
        #      x66 = augmented['image']
        #      y66 = augmented['mask']

        #      aug = Resize(p=1, height=22, width=20)
        #      augmented = aug(image=x, mask=y)
        #      x67 = augmented['image']
        #      y67 = augmented['mask']

        aug = Rotate(p=1)
        augmented = aug(image=x, mask=y)
        x68 = augmented['image']
        y68 = augmented['mask']

        aug = ShiftScaleRotate(p=1)
        augmented = aug(image=x, mask=y)
        x69 = augmented['image']
        y69 = augmented['mask']

        aug = SmallestMaxSize(p=1)
        augmented = aug(image=x, mask=y)
        x70 = augmented['image']
        y70 = augmented['mask']

        images_aug.extend([
            x, x0, x2, x3, x5, x6, x7, x8, x9, x10, x12, x13, x14, x16, x17,
            x18, x19, x20, x21, x22, x23, x24, x25, x26, x29, x30, x31, x32,
            x33, x34, x35, x36, x37, x38, x39, x40, x41, x42, x43, x44, x46,
            x47, x50, x51, x52, x54, x55, x56, x57, x58, x59, x64, x68, x69,
            x70
        ])

        masks_aug.extend([
            y, y0, y2, y3, y5, y6, y7, y8, y9, y10, y12, y13, y14, y16, y17,
            y18, y19, y20, y21, y22, y23, y24, y25, y26, y29, y30, y31, y32,
            y33, y34, y35, y36, y37, y38, y39, y40, y41, y42, y43, y44, y46,
            y47, y50, y51, y52, y54, y55, y56, y57, y58, y59, y64, y68, y69,
            y70
        ])

        idx = -1
        images_name = []
        masks_name = []
        for i, m in zip(images_aug, masks_aug):
            if idx == -1:
                tmp_image_name = f"{image_name}"
                tmp_mask_name = f"{mask_name}"
            else:
                tmp_image_name = f"{image_name}_{smalllist[idx]}"
                tmp_mask_name = f"{mask_name}_{smalllist[idx]}"
            images_name.extend(tmp_image_name)
            masks_name.extend(tmp_mask_name)
            idx += 1

        it = iter(images_name)
        it2 = iter(images_aug)
        imagedict = dict(zip(it, it2))

        it = iter(masks_name)
        it2 = iter(masks_aug)
        masksdict = dict(zip(it, it2))

    return imagedict, masksdict
Beispiel #22
0
def do_augmentation(dataset_dir, output_dir, file_ext, strAugs):
    print("hi")

    for indDataset in dataset_dir:
        files_list = os.listdir(indDataset)
        imagesList = filterImages(files_list, file_ext)

        for augstr in strAugs:
            for image_name in imagesList:
                try:
                    base_name = os.path.splitext(image_name)[0]
                    image = cv2.imread(indDataset + image_name)
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

                    tree = ET.parse(indDataset + base_name + ".xml")
                    root = tree.getroot()

                    bbox = []
                    label = []
                    idx = 0
                    xminxml = []
                    yminxml = []
                    xmaxxml = []
                    ymaxxml = [
                    ]  # save to a new array to prevent out of order access

                    for xmin in root.iter('xmin'):
                        bbox.append([int(float(xmin.text))])
                        xminxml.append(xmin)

                    for ymin in root.iter('ymin'):
                        bbox[idx].append(int(float(ymin.text)))
                        idx += 1
                        yminxml.append(ymin)

                    idx = 0
                    for xmax in root.iter('xmax'):
                        bbox[idx].append(int(float(xmax.text)))
                        idx += 1
                        xmaxxml.append(xmax)

                    idx = 0
                    for ymax in root.iter('ymax'):
                        bbox[idx].append(int(float(ymax.text)))
                        idx += 1
                        ymaxxml.append(ymax)

                    idx = 0
                    for name in root.iter('name'):
                        label.append(name.text)
                        idx += 1

                    height, width, channels = image.shape
                    annotations = {
                        'image': image.copy(),
                        'bboxes': bbox,
                        'category_id': label
                    }
                    # category_id_to_name = {'car': 'car', 'tree': 'tree', 'house': 'house', 'pool':'pool'}
                    # category_id_to_name = {'car': 'car', 'truck': 'truck', 'bike': 'bike'}
                    category_id_to_name = {'cigarette': 'cigarette'}
                    # classNames = {1: 'parking sign',
                    #               2: 'stop sign',
                    #               3: 'tunnel sign',}
                    # visualize(annotations, category_id_to_name)
                    # plt.show()
                    ###augment the image
                    ###random crop
                    aug = get_aug([
                        HorizontalFlip(p=0.5),
                        # RandomSizedBBoxSafeCrop(height=300, width=300, p=0.3),
                        RandomBrightnessContrast(brightness_limit=0.1,
                                                 contrast_limit=0.1,
                                                 p=0.3),
                        # RandomRain(blur_value=2, p=0.5),
                        RandomSunFlare(p=0.3, src_radius=50),
                        # Cutout(max_h_size=20, max_w_size=20, p=0.4),
                        IAAPerspective(scale=(0.1, 0.1), p=0.4),
                        # ShiftScaleRotate(scale_limit=0.2, border_mode=cv2.BORDER_CONSTANT, p=1.0),
                        # RandomScale(p=0.3),
                        # Rotate(p=0.3, border_mode=cv2.BORDER_CONSTANT, limit=30),
                        # RandomGamma(p=1.0),
                        IAAPiecewiseAffine(scale=(0.01, 0.01), p=0.4)
                    ])
                    #
                    # aug = get_aug([HorizontalFlip(p=1)])
                    augmented = aug(**annotations)
                    # print(augmented)
                    # visualize(augmented, category_id_to_name)
                    # plt.show()
                    # print(augmented)
                    if augmented["bboxes"] == []:
                        continue  # if it could not generate labels
                    print(len(augmented["bboxes"]))
                    idx = 0
                    for xmin in xminxml:
                        xmin.text = str(int(augmented['bboxes'][idx][0]))
                        idx += 1

                    idx = 0
                    for ymin in yminxml:
                        ymin.text = str(int(augmented['bboxes'][idx][1]))
                        idx += 1

                    idx = 0
                    for xmax in xmaxxml:
                        xmax.text = str(int(augmented['bboxes'][idx][2]))
                        idx += 1

                    idx = 0
                    for ymax in ymaxxml:
                        ymax.text = str(int(augmented['bboxes'][idx][3]))
                        idx += 1

                    for fileName in root.iter('filename'):
                        fileName.text = base_name + augstr + file_ext
                    # plt.show()

                    # write to file
                    tree.write(output_dir + base_name + augstr + ".xml")
                    image = cv2.cvtColor(augmented['image'], cv2.COLOR_RGB2BGR)
                    cv2.imwrite(output_dir + base_name + augstr + file_ext,
                                image)
                    print("saved image ", base_name + augstr + file_ext)

                except Exception as e:
                    print(e)
                    print("Exception: ", augmented)
                    # cv2.imshow("augmented", augmented["image"])
                    # cv2.waitKey(0)
    #
    ##Create validation file###
    # create_validation_file(output_dir)
    # print("Created Validation Files")
    #
    print("copying original")
    ###Copy original into source directory
    for folder in dataset_dir:
        filenames = os.listdir(folder)
        print(filenames)
        for name in filenames:
            print("in: ", folder + name)
            print("out: ", output_dir + name)
            print("name:", name)
            shutil.copy(folder + name, output_dir + name)
            print("moved: " + name)
    print("Done")
Beispiel #23
0
             "IAAPiecewiseAffine": dict(p=.3),
             "IAASharpen": dict(),
             "IAAEmboss": dict(),
             "RandomContrast": dict(),
             "RandomBrightness": dict(),
             "HueSaturationValue": dict(p=0.3),
             "ShiftScaleRotate": dict(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=.2),
             "CLAHE": dict(clip_limit=2),
             "Noise_of": dict(transforms=[IAAAdditiveGaussianNoise(),
                                          GaussNoise()], p=0.2),
             "Blur_of": dict(transforms=[MotionBlur(p=.2),
                                         MedianBlur(blur_limit=3, p=.1),
                                         Blur(blur_limit=3, p=.1)], p=0.2),
             "Optical_of": dict(transforms=[OpticalDistortion(p=0.3),
                                            GridDistortion(p=.1),
                                            IAAPiecewiseAffine(p=0.3)], p=0.2),
             "Color_of": dict(transforms=[IAASharpen(),
                                          # CLAHE(clip_limit=2),
                                          IAAEmboss(),
                                          RandomContrast(),
                                          RandomBrightness()], p=0.3)}


def str2class(classname):
    return getattr(sys.modules[__name__], classname)


class Auger(object):
    def __init__(self, config, local_config):
        self.config = config
        self.local_config = local_config
def get_transforms(phase_config):
    list_transforms = []
    if phase_config.Resize.p > 0:
        list_transforms.append(
            Resize(phase_config.Resize.height, phase_config.Resize.width, p=1))
    if phase_config.HorizontalFlip:
        list_transforms.append(HorizontalFlip())
    if phase_config.VerticalFlip:
        list_transforms.append(VerticalFlip())
    if phase_config.RandomCropScale:
        if phase_config.Resize.p > 0:
            height = phase_config.Resize.height
            width = phase_config.Resize.width
        else:
            height = HEIGHT
            width = WIDTH
        list_transforms.append(
            RandomSizedCrop(min_max_height=(int(height * 0.90), height),
                            height=height,
                            width=width,
                            w2h_ratio=width / height))
    if phase_config.ShiftScaleRotate:
        list_transforms.append(ShiftScaleRotate(p=1))

    if phase_config.RandomCrop.p > 0:
        list_transforms.append(
            RandomCrop(phase_config.RandomCrop.height,
                       phase_config.RandomCrop.width,
                       p=1))
    if phase_config.Noise:
        list_transforms.append(
            OneOf([
                GaussNoise(),
                IAAAdditiveGaussianNoise(),
            ], p=0.5), )
    if phase_config.Contrast:
        list_transforms.append(
            OneOf([
                RandomContrast(0.5),
                RandomGamma(),
                RandomBrightness(),
            ],
                  p=0.5), )
    if phase_config.Blur:
        list_transforms.append(
            OneOf([
                MotionBlur(p=.2),
                MedianBlur(blur_limit=3, p=0.1),
                Blur(blur_limit=3, p=0.1),
            ],
                  p=0.5))
    if phase_config.Distort:
        list_transforms.append(
            OneOf([
                OpticalDistortion(p=0.3),
                GridDistortion(p=.1),
                IAAPiecewiseAffine(p=0.3),
            ],
                  p=0.5))

    if phase_config.Cutout.num_holes > 0:
        num_holes = phase_config.Cutout.num_holes
        hole_size = phase_config.Cutout.hole_size
        list_transforms.append(Cutout(num_holes, hole_size))

    list_transforms.extend([
        Normalize(mean=phase_config.mean, std=phase_config.std, p=1),
        ToTensor(),
    ])

    return Compose(list_transforms)
    def __init__(self, config):
        super(AugmentedPair2, self).__init__(config)
        self.use_appearance_augmentation = config.get("data_augment_appearance", False)
        self.use_shape_augmentation = config.get("data_augment_shape", False)
        additional_targets = {
            "image{}".format(i): "image" for i in range(1, self.n_images)
        }
        p = 0.9
        appearance_augmentation = Compose(
            [
                OneOf(
                    [
                        MedianBlur(blur_limit=3, p=0.1),
                        Blur(blur_limit=3, p=0.1),
                    ],
                    p=0.5,
                ),
                OneOf(
                    [
                        RandomBrightnessContrast(p=0.3),
                        RGBShift(p=0.3),
                        HueSaturationValue(p=0.3),
                    ],
                    p=0.8,
                ),    
                OneOf(
                    [
                        RandomBrightnessContrast(p=0.3),
                        RGBShift(p=0.3),
                        HueSaturationValue(p=0.3),
                    ],
                    p=0.8,
                ),     
                OneOf(
                    [
                        RandomBrightnessContrast(p=0.3),
                        RGBShift(p=0.3),
                        HueSaturationValue(p=0.3),
                    ],
                    p=0.8,
                ),
                ToGray(p=0.1),  
                ChannelShuffle(p=0.3),
            ],
            p=p,
            additional_targets=additional_targets,
        )
        self.appearance_augmentation = appearance_augmentation  

        p = 0.9
        shape_augmentation = Compose(
            [
                HorizontalFlip(p=0.3),
                ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.25, rotate_limit=25, p=0.3, border_mode=cv2.BORDER_REPLICATE),
                OneOf([
                    # OpticalDistortion(p=0.3),
                    # GridDistortion(p=0.1),
                    IAAPiecewiseAffine(p=0.5),
                    ElasticTransform(p=0.5, border_mode=cv2.BORDER_REPLICATE)
                ], p=0.3),
            ],
            p=p,
            additional_targets=additional_targets,
        )
        self.shape_augmentation = shape_augmentation
Beispiel #26
0
augmentation = Compose([
    HorizontalFlip(),
    OneOf([
        IAAAdditiveGaussianNoise(),
        GaussNoise(),
    ], p=0.2),
    OneOf([
        MotionBlur(p=0.2),
        MedianBlur(blur_limit=3, p=0.1),
        Blur(blur_limit=3, p=0.1),
    ], p=0.2),
    ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=1),
    OneOf([
        OpticalDistortion(p=0.3),
        GridDistortion(p=0.1),
        IAAPiecewiseAffine(p=0.3),
    ], p=0.2),
    OneOf([
        CLAHE(clip_limit=2),
        IAASharpen(),
        IAAEmboss(),
        RandomBrightnessContrast(),
    ], p=0.3),
    HueSaturationValue(p=0.3),
], p=1)

#пример
plt.figure(figsize = (12,8))
for i in range(9):
    img = augmentation(image = images_train[0])['image']
    plt.subplot(3, 3, i + 1)
df_train = pd.read_csv('./input/train.csv')
df_train['fold'] = pd.read_csv('./input/df_folds.csv')['fold']

trn_fold = [i for i in range(6) if i not in [5]]
vid_fold = [5]

trn_idx = df_train.loc[df_train['fold'].isin(trn_fold)].index
vid_idx = df_train.loc[df_train['fold'].isin(vid_fold)].index

augs = [
    HorizontalFlip(),
    GaussNoise(),
    ShiftScaleRotate(),
    RandomBrightnessContrast(),
    CoarseDropout(),
    IAAPiecewiseAffine()
]

transforms_train = albumentations.Compose(
    [AugMix(width=3, depth=2, alpha=.2, p=.5, augmentations=augs)])

trn_dataset = BengaliDataset(csv=df_train.iloc[trn_idx],
                             img_height=HEIGHT,
                             img_width=WIDTH,
                             transform=transforms_train)
vid_dataset = BengaliDataset(csv=df_train.iloc[vid_idx],
                             img_height=HEIGHT,
                             img_width=WIDTH)

trn_loader = torch.utils.data.DataLoader(dataset=trn_dataset,
                                         batch_size=BATCH_SIZE,
Beispiel #28
0
def compose_augmentations(img_height,
                          img_width,
                          flip_p=0.5,
                          translate_p=0.5,
                          distort_p=0.5,
                          color_p=0.5,
                          overlays_p=0.15,
                          blur_p=0.25,
                          noise_p=0.25):
    # Resize
    resize_p = 1 if img_height != 1024 else 0

    # Random sized crop
    if img_height == 1024:
        min_max_height = (896, 960)
    elif img_height == 512:
        min_max_height = (448, 480)
    elif img_height == 256:
        min_max_height = (224, 240)
    else:
        raise NotImplementedError

    return Compose([
        Resize(p=resize_p, height=img_height, width=img_width),
        OneOf([
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            Transpose(p=0.5),
            RandomRotate90(p=0.5),
        ],
              p=flip_p),
        OneOf([
            Rotate(p=0.25, limit=10),
            RandomSizedCrop(p=0.5,
                            min_max_height=min_max_height,
                            height=img_height,
                            width=img_width),
            OneOrOther(IAAAffine(p=0.1, translate_percent=0.05),
                       IAAPerspective(p=0.1)),
        ],
              p=translate_p),
        OneOf([
            ElasticTransform(p=0.5,
                             alpha=10,
                             sigma=img_height * 0.05,
                             alpha_affine=img_height * 0.03,
                             approximate=True),
            GridDistortion(p=0.5),
            OpticalDistortion(p=0.5),
            IAAPiecewiseAffine(p=0.25, scale=(0.01, 0.03)),
        ],
              p=distort_p),
        OneOrOther(
            OneOf([
                CLAHE(p=0.5),
                RandomGamma(p=0.5),
                RandomContrast(p=0.5),
                RandomBrightness(p=0.5),
                RandomBrightnessContrast(p=0.5),
            ],
                  p=color_p),
            OneOf([IAAEmboss(p=0.1),
                   IAASharpen(p=0.1),
                   IAASuperpixels(p=0)],
                  p=overlays_p)),
        OneOrOther(
            OneOf([
                Blur(p=0.2),
                MedianBlur(p=0.1),
                MotionBlur(p=0.1),
                GaussianBlur(p=0.1),
            ],
                  p=blur_p),
            OneOf([GaussNoise(p=0.2),
                   IAAAdditiveGaussianNoise(p=0.1)],
                  p=noise_p)),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensor(sigmoid=False),
    ])
def _get_data_loader(imgs, trn_df, vld_df):

    import albumentations as A
    from albumentations import (
        Rotate, HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE,
        RandomRotate90, Transpose, ShiftScaleRotate, Blur, OpticalDistortion,
        GridDistortion, HueSaturationValue, IAAAdditiveGaussianNoise,
        GaussNoise, MotionBlur, MedianBlur, RandomBrightnessContrast,
        IAAPiecewiseAffine, IAASharpen, IAAEmboss, Flip, OneOf, Compose)
    from albumentations.pytorch import ToTensor, ToTensorV2

    train_transforms = A.Compose([
        Rotate(20),
        OneOf([
            IAAAdditiveGaussianNoise(),
            GaussNoise(),
        ], p=0.2),
        OneOf([
            MotionBlur(p=.2),
            MedianBlur(blur_limit=3, p=0.1),
            Blur(blur_limit=3, p=0.1),
        ],
              p=0.2),
        ShiftScaleRotate(
            shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            IAAPiecewiseAffine(p=0.3),
        ],
              p=0.2),
        OneOf([
            CLAHE(clip_limit=2),
            IAASharpen(),
            IAAEmboss(),
            RandomBrightnessContrast(),
        ],
              p=0.3),
        HueSaturationValue(p=0.3),
        ToTensor()
    ],
                                 p=1.0)

    valid_transforms = A.Compose([ToTensor()])

    from torch.utils.data import Dataset, DataLoader
    trn_dataset = BangaliDataset(imgs=imgs,
                                 label_df=trn_df,
                                 transform=train_transforms)
    vld_dataset = BangaliDataset(imgs=imgs,
                                 label_df=vld_df,
                                 transform=valid_transforms)

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    trn_sampler = torch.utils.data.distributed.DistributedSampler(
        trn_dataset,
        num_replicas=world_size,  # worldsize만큼 분할
        rank=rank)

    trn_loader = DataLoader(trn_dataset,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            batch_size=BATCH_SIZE,
                            sampler=trn_sampler)

    vld_loader = DataLoader(vld_dataset,
                            shuffle=False,
                            num_workers=NUM_WORKERS,
                            batch_size=BATCH_SIZE)
    return trn_loader, vld_loader
df_train['fold'] = pd.read_csv('./input/df_folds.csv')['fold']

trn_fold = [i for i in range(6) if i not in [5]]
vid_fold = [5]

trn_idx = df_train.loc[df_train['fold'].isin(trn_fold)].index
vid_idx = df_train.loc[df_train['fold'].isin(vid_fold)].index

augs = [
    HorizontalFlip(always_apply=True),
    GaussNoise(always_apply=True),
    ShiftScaleRotate(rotate_limit=20, always_apply=True),
    RandomContrast(always_apply=True),
    RandomBrightness(always_apply=True),
    CoarseDropout(always_apply=True),
    IAAPiecewiseAffine(always_apply=True)
]

transforms_train = albumentations.Compose(
    [AugMix(width=3, depth=2, alpha=.2, p=.5, augmentations=augs)])

trn_dataset = BengaliDataset(csv=df_train.iloc[trn_idx],
                             img_height=HEIGHT,
                             img_width=WIDTH,
                             transform=transforms_train)
vid_dataset = BengaliDataset(csv=df_train.iloc[vid_idx],
                             img_height=HEIGHT,
                             img_width=WIDTH)

trn_loader = torch.utils.data.DataLoader(dataset=trn_dataset,
                                         batch_size=BATCH_SIZE,