Example #1
0
def build_transforms(mean=(0.485, 0.456, 0.406),
                     std=(0.229, 0.224, 0.225),
                     divide_by=255.0,
                     scale_limit=0.0,
                     shear_limit=0,
                     rotate_limit=0,
                     brightness_limit=0.0,
                     contrast_limit=0.0,
                     clahe_p=0.0,
                     blur_p=0.0,
                     autoaugment=False,
                     ):
    norm = Normalize(mean=mean, std=std, max_pixel_value=divide_by)

    if autoaugment:
        train_transform = Compose([
            torch_custom.AutoAugmentWrapper(p=1.0),
            torch_custom.RandomCropThenScaleToOriginalSize(limit=scale_limit, p=1.0),
            norm,
        ])
    else:
        train_transform = Compose([
            IAAAffine(rotate=(-rotate_limit, rotate_limit), shear=(-shear_limit, shear_limit), mode='constant'),
            RandomBrightnessContrast(brightness_limit=brightness_limit, contrast_limit=contrast_limit),
            MotionBlur(p=blur_p),
            CLAHE(p=clahe_p),
            torch_custom.RandomCropThenScaleToOriginalSize(limit=scale_limit, p=1.0),
            norm,
        ])
    eval_transform = Compose([norm])

    return train_transform, eval_transform
Example #2
0
def empty_aug2():
    return [
        HorizontalFlip(p=0.001),
        IAAPiecewiseAffine(p=1.0),
        OneOf([
            # OpticalDistortion(p=0.1),
            # GridDistortion(p=0.1),
            # IAAPerspective(p=1.0),
            IAAAffine(p=1.0),
            # IAAPiecewiseAffine(p=1.0),
        ], p=0.0)
    ]
def shiftscalerotate_aug():
    augs_list = [
        OneOf([
            ShiftScaleRotate(scale_limit=.15,
                             rotate_limit=15,
                             border_mode=cv2.BORDER_REPLICATE,
                             p=0.5),
            IAAAffine(shear=20, p=0.5),
            IAAPerspective(p=0.5),
        ],
              p=0.5),
        Normalize(),
        ToTensorV2()
    ]
    return Compose(augs_list, p=1)
Example #4
0
def aug():
    return Compose(
        [
            HorizontalFlip(p=0.5),  # applied
            VerticalFlip(p=0.5),  # applied
            ShiftScaleRotate(
                shift_limit=(
                    0.1,
                    0.1),  # width_shift_range=0.1,# height_shift_range=0.1,
                # zoom_range=[0.9,1.25]
                scale_limit=(0.9, 1.25),
                rotate_limit=20,
                p=0.5),  # rotation_range=20,
            RandomBrightnessContrast(brightness_limit=(0.4, 1.5),
                                     p=0.5),  # brightness_range=[0.4,1.5]
            # shear_range=0.01,fill_mode='reflect'
            IAAAffine(shear=0.01, mode='reflect', p=0.5)
        ],
        p=1)
Example #5
0
    def __init__(self,
                 basic=True,
                 elastic_transform=False,
                 shift_scale_rotate=False):
        transforms = []

        if basic:
            transforms.append(Flip(p=0.5))
            transforms.append(
                Rotate(p=0.5, border_mode=Config.BORDER_MODE, limit=45))

        if elastic_transform:
            transforms.append(ElasticTransform(p=0.2))

        if shift_scale_rotate:
            transforms.append(ShiftScaleRotate(p=0.2))

        transforms.append(IAAAffine(p=1, shear=0.2, mode="constant"))
        self.transforms = transforms
Example #6
0
def strong_aug(p=0.5):
    return Compose([
        RandomGridShuffle((2, 2), p=0.75),
        OneOf([
            ShiftScaleRotate(shift_limit=0.125),
            Transpose(),
            RandomRotate90(),
            VerticalFlip(),
            HorizontalFlip(),
            IAAAffine(shear=0.1)
        ]),
        OneOf([GaussNoise(),
               GaussianBlur(),
               MedianBlur(),
               MotionBlur()]),
        OneOf([RandomBrightnessContrast(),
               CLAHE(), IAASharpen()]),
        Cutout(10, 2, 2, 127),
    ],
                   p=p)
def transform(image, mask, image_name, mask_name):

    x, y = image, mask

    rand = random.uniform(0, 1)
    if (rand > 0.5):

        images_name = [f"{image_name}"]
        masks_name = [f"{mask_name}"]
        images_aug = [x]
        masks_aug = [y]

        it = iter(images_name)
        it2 = iter(images_aug)
        imagedict = dict(zip(it, it2))

        it = iter(masks_name)
        it2 = iter(masks_aug)
        masksdict = dict(zip(it, it2))

        return imagedict, masksdict

    mask_density = np.count_nonzero(y)

    ## Augmenting only images with Gloms
    if (mask_density > 0):
        try:
            h, w, c = x.shape
        except Exception as e:
            image = image[:-1]
            x, y = image, mask
            h, w, c = x.shape

        aug = Blur(p=1, blur_limit=3)
        augmented = aug(image=x, mask=y)
        x0 = augmented['image']
        y0 = augmented['mask']

        #    aug = CenterCrop(p=1, height=32, width=32)
        #    augmented = aug(image=x, mask=y)
        #    x1 = augmented['image']
        #    y1 = augmented['mask']

        ## Horizontal Flip
        aug = HorizontalFlip(p=1)
        augmented = aug(image=x, mask=y)
        x2 = augmented['image']
        y2 = augmented['mask']

        aug = VerticalFlip(p=1)
        augmented = aug(image=x, mask=y)
        x3 = augmented['image']
        y3 = augmented['mask']

        #      aug = Normalize(p=1)
        #      augmented = aug(image=x, mask=y)
        #      x4 = augmented['image']
        #      y4 = augmented['mask']

        aug = Transpose(p=1)
        augmented = aug(image=x, mask=y)
        x5 = augmented['image']
        y5 = augmented['mask']

        aug = RandomGamma(p=1)
        augmented = aug(image=x, mask=y)
        x6 = augmented['image']
        y6 = augmented['mask']

        ## Optical Distortion
        aug = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
        augmented = aug(image=x, mask=y)
        x7 = augmented['image']
        y7 = augmented['mask']

        ## Grid Distortion
        aug = GridDistortion(p=1)
        augmented = aug(image=x, mask=y)
        x8 = augmented['image']
        y8 = augmented['mask']

        aug = RandomGridShuffle(p=1)
        augmented = aug(image=x, mask=y)
        x9 = augmented['image']
        y9 = augmented['mask']

        aug = HueSaturationValue(p=1)
        augmented = aug(image=x, mask=y)
        x10 = augmented['image']
        y10 = augmented['mask']

        #        aug = PadIfNeeded(p=1)
        #        augmented = aug(image=x, mask=y)
        #        x11 = augmented['image']
        #        y11 = augmented['mask']

        aug = RGBShift(p=1)
        augmented = aug(image=x, mask=y)
        x12 = augmented['image']
        y12 = augmented['mask']

        ## Random Brightness
        aug = RandomBrightness(p=1)
        augmented = aug(image=x, mask=y)
        x13 = augmented['image']
        y13 = augmented['mask']

        ## Random  Contrast
        aug = RandomContrast(p=1)
        augmented = aug(image=x, mask=y)
        x14 = augmented['image']
        y14 = augmented['mask']

        #aug = MotionBlur(p=1)
        #augmented = aug(image=x, mask=y)
        #   x15 = augmented['image']
        #  y15 = augmented['mask']

        aug = MedianBlur(p=1, blur_limit=5)
        augmented = aug(image=x, mask=y)
        x16 = augmented['image']
        y16 = augmented['mask']

        aug = GaussianBlur(p=1, blur_limit=3)
        augmented = aug(image=x, mask=y)
        x17 = augmented['image']
        y17 = augmented['mask']

        aug = GaussNoise(p=1)
        augmented = aug(image=x, mask=y)
        x18 = augmented['image']
        y18 = augmented['mask']

        aug = GlassBlur(p=1)
        augmented = aug(image=x, mask=y)
        x19 = augmented['image']
        y19 = augmented['mask']

        aug = CLAHE(clip_limit=1.0,
                    tile_grid_size=(8, 8),
                    always_apply=False,
                    p=1)
        augmented = aug(image=x, mask=y)
        x20 = augmented['image']
        y20 = augmented['mask']

        aug = ChannelShuffle(p=1)
        augmented = aug(image=x, mask=y)
        x21 = augmented['image']
        y21 = augmented['mask']

        aug = ToGray(p=1)
        augmented = aug(image=x, mask=y)
        x22 = augmented['image']
        y22 = augmented['mask']

        aug = ToSepia(p=1)
        augmented = aug(image=x, mask=y)
        x23 = augmented['image']
        y23 = augmented['mask']

        aug = JpegCompression(p=1)
        augmented = aug(image=x, mask=y)
        x24 = augmented['image']
        y24 = augmented['mask']

        aug = ImageCompression(p=1)
        augmented = aug(image=x, mask=y)
        x25 = augmented['image']
        y25 = augmented['mask']

        aug = Cutout(p=1)
        augmented = aug(image=x, mask=y)
        x26 = augmented['image']
        y26 = augmented['mask']

        #       aug = CoarseDropout(p=1, max_holes=8, max_height=32, max_width=32)
        #       augmented = aug(image=x, mask=y)
        #       x27 = augmented['image']
        #       y27 = augmented['mask']

        #       aug = ToFloat(p=1)
        #       augmented = aug(image=x, mask=y)
        #       x28 = augmented['image']
        #       y28 = augmented['mask']

        aug = FromFloat(p=1)
        augmented = aug(image=x, mask=y)
        x29 = augmented['image']
        y29 = augmented['mask']

        ## Random Brightness and Contrast
        aug = RandomBrightnessContrast(p=1)
        augmented = aug(image=x, mask=y)
        x30 = augmented['image']
        y30 = augmented['mask']

        aug = RandomSnow(p=1)
        augmented = aug(image=x, mask=y)
        x31 = augmented['image']
        y31 = augmented['mask']

        aug = RandomRain(p=1)
        augmented = aug(image=x, mask=y)
        x32 = augmented['image']
        y32 = augmented['mask']

        aug = RandomFog(p=1)
        augmented = aug(image=x, mask=y)
        x33 = augmented['image']
        y33 = augmented['mask']

        aug = RandomSunFlare(p=1)
        augmented = aug(image=x, mask=y)
        x34 = augmented['image']
        y34 = augmented['mask']

        aug = RandomShadow(p=1)
        augmented = aug(image=x, mask=y)
        x35 = augmented['image']
        y35 = augmented['mask']

        aug = Lambda(p=1)
        augmented = aug(image=x, mask=y)
        x36 = augmented['image']
        y36 = augmented['mask']

        aug = ChannelDropout(p=1)
        augmented = aug(image=x, mask=y)
        x37 = augmented['image']
        y37 = augmented['mask']

        aug = ISONoise(p=1)
        augmented = aug(image=x, mask=y)
        x38 = augmented['image']
        y38 = augmented['mask']

        aug = Solarize(p=1)
        augmented = aug(image=x, mask=y)
        x39 = augmented['image']
        y39 = augmented['mask']

        aug = Equalize(p=1)
        augmented = aug(image=x, mask=y)
        x40 = augmented['image']
        y40 = augmented['mask']

        aug = Posterize(p=1)
        augmented = aug(image=x, mask=y)
        x41 = augmented['image']
        y41 = augmented['mask']

        aug = Downscale(p=1)
        augmented = aug(image=x, mask=y)
        x42 = augmented['image']
        y42 = augmented['mask']

        aug = MultiplicativeNoise(p=1)
        augmented = aug(image=x, mask=y)
        x43 = augmented['image']
        y43 = augmented['mask']

        aug = FancyPCA(p=1)
        augmented = aug(image=x, mask=y)
        x44 = augmented['image']
        y44 = augmented['mask']

        #       aug = MaskDropout(p=1)
        #       augmented = aug(image=x, mask=y)
        #       x45 = augmented['image']
        #       y45 = augmented['mask']

        aug = GridDropout(p=1)
        augmented = aug(image=x, mask=y)
        x46 = augmented['image']
        y46 = augmented['mask']

        aug = ColorJitter(p=1)
        augmented = aug(image=x, mask=y)
        x47 = augmented['image']
        y47 = augmented['mask']

        ## ElasticTransform
        aug = ElasticTransform(p=1,
                               alpha=120,
                               sigma=512 * 0.05,
                               alpha_affine=512 * 0.03)
        augmented = aug(image=x, mask=y)
        x50 = augmented['image']
        y50 = augmented['mask']

        aug = CropNonEmptyMaskIfExists(p=1, height=22, width=32)
        augmented = aug(image=x, mask=y)
        x51 = augmented['image']
        y51 = augmented['mask']

        aug = IAAAffine(p=1)
        augmented = aug(image=x, mask=y)
        x52 = augmented['image']
        y52 = augmented['mask']

        #        aug = IAACropAndPad(p=1)
        #        augmented = aug(image=x, mask=y)
        #        x53 = augmented['image']
        #        y53 = augmented['mask']

        aug = IAAFliplr(p=1)
        augmented = aug(image=x, mask=y)
        x54 = augmented['image']
        y54 = augmented['mask']

        aug = IAAFlipud(p=1)
        augmented = aug(image=x, mask=y)
        x55 = augmented['image']
        y55 = augmented['mask']

        aug = IAAPerspective(p=1)
        augmented = aug(image=x, mask=y)
        x56 = augmented['image']
        y56 = augmented['mask']

        aug = IAAPiecewiseAffine(p=1)
        augmented = aug(image=x, mask=y)
        x57 = augmented['image']
        y57 = augmented['mask']

        aug = LongestMaxSize(p=1)
        augmented = aug(image=x, mask=y)
        x58 = augmented['image']
        y58 = augmented['mask']

        aug = NoOp(p=1)
        augmented = aug(image=x, mask=y)
        x59 = augmented['image']
        y59 = augmented['mask']

        #       aug = RandomCrop(p=1, height=22, width=22)
        #       augmented = aug(image=x, mask=y)
        #       x61 = augmented['image']
        #       y61 = augmented['mask']

        #      aug = RandomResizedCrop(p=1, height=22, width=20)
        #      augmented = aug(image=x, mask=y)
        #      x63 = augmented['image']
        #      y63 = augmented['mask']

        aug = RandomScale(p=1)
        augmented = aug(image=x, mask=y)
        x64 = augmented['image']
        y64 = augmented['mask']

        #      aug = RandomSizedCrop(p=1, height=22, width=20, min_max_height = [32,32])
        #      augmented = aug(image=x, mask=y)
        #      x66 = augmented['image']
        #      y66 = augmented['mask']

        #      aug = Resize(p=1, height=22, width=20)
        #      augmented = aug(image=x, mask=y)
        #      x67 = augmented['image']
        #      y67 = augmented['mask']

        aug = Rotate(p=1)
        augmented = aug(image=x, mask=y)
        x68 = augmented['image']
        y68 = augmented['mask']

        aug = ShiftScaleRotate(p=1)
        augmented = aug(image=x, mask=y)
        x69 = augmented['image']
        y69 = augmented['mask']

        aug = SmallestMaxSize(p=1)
        augmented = aug(image=x, mask=y)
        x70 = augmented['image']
        y70 = augmented['mask']

        images_aug.extend([
            x, x0, x2, x3, x5, x6, x7, x8, x9, x10, x12, x13, x14, x16, x17,
            x18, x19, x20, x21, x22, x23, x24, x25, x26, x29, x30, x31, x32,
            x33, x34, x35, x36, x37, x38, x39, x40, x41, x42, x43, x44, x46,
            x47, x50, x51, x52, x54, x55, x56, x57, x58, x59, x64, x68, x69,
            x70
        ])

        masks_aug.extend([
            y, y0, y2, y3, y5, y6, y7, y8, y9, y10, y12, y13, y14, y16, y17,
            y18, y19, y20, y21, y22, y23, y24, y25, y26, y29, y30, y31, y32,
            y33, y34, y35, y36, y37, y38, y39, y40, y41, y42, y43, y44, y46,
            y47, y50, y51, y52, y54, y55, y56, y57, y58, y59, y64, y68, y69,
            y70
        ])

        idx = -1
        images_name = []
        masks_name = []
        for i, m in zip(images_aug, masks_aug):
            if idx == -1:
                tmp_image_name = f"{image_name}"
                tmp_mask_name = f"{mask_name}"
            else:
                tmp_image_name = f"{image_name}_{smalllist[idx]}"
                tmp_mask_name = f"{mask_name}_{smalllist[idx]}"
            images_name.extend(tmp_image_name)
            masks_name.extend(tmp_mask_name)
            idx += 1

        it = iter(images_name)
        it2 = iter(images_aug)
        imagedict = dict(zip(it, it2))

        it = iter(masks_name)
        it2 = iter(masks_aug)
        masksdict = dict(zip(it, it2))

    return imagedict, masksdict
Example #8
0
def compose_augmentations(img_height,
                          img_width,
                          flip_p=0.5,
                          translate_p=0.5,
                          distort_p=0.5,
                          color_p=0.5,
                          overlays_p=0.15,
                          blur_p=0.25,
                          noise_p=0.25):
    # Resize
    resize_p = 1 if img_height != 1024 else 0

    # Random sized crop
    if img_height == 1024:
        min_max_height = (896, 960)
    elif img_height == 512:
        min_max_height = (448, 480)
    elif img_height == 256:
        min_max_height = (224, 240)
    else:
        raise NotImplementedError

    return Compose([
        Resize(p=resize_p, height=img_height, width=img_width),
        OneOf([
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            Transpose(p=0.5),
            RandomRotate90(p=0.5),
        ],
              p=flip_p),
        OneOf([
            Rotate(p=0.25, limit=10),
            RandomSizedCrop(p=0.5,
                            min_max_height=min_max_height,
                            height=img_height,
                            width=img_width),
            OneOrOther(IAAAffine(p=0.1, translate_percent=0.05),
                       IAAPerspective(p=0.1)),
        ],
              p=translate_p),
        OneOf([
            ElasticTransform(p=0.5,
                             alpha=10,
                             sigma=img_height * 0.05,
                             alpha_affine=img_height * 0.03,
                             approximate=True),
            GridDistortion(p=0.5),
            OpticalDistortion(p=0.5),
            IAAPiecewiseAffine(p=0.25, scale=(0.01, 0.03)),
        ],
              p=distort_p),
        OneOrOther(
            OneOf([
                CLAHE(p=0.5),
                RandomGamma(p=0.5),
                RandomContrast(p=0.5),
                RandomBrightness(p=0.5),
                RandomBrightnessContrast(p=0.5),
            ],
                  p=color_p),
            OneOf([IAAEmboss(p=0.1),
                   IAASharpen(p=0.1),
                   IAASuperpixels(p=0)],
                  p=overlays_p)),
        OneOrOther(
            OneOf([
                Blur(p=0.2),
                MedianBlur(p=0.1),
                MotionBlur(p=0.1),
                GaussianBlur(p=0.1),
            ],
                  p=blur_p),
            OneOf([GaussNoise(p=0.2),
                   IAAAdditiveGaussianNoise(p=0.1)],
                  p=noise_p)),
        Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensor(sigmoid=False),
    ])
Example #9
0
def aug_test():
    def get_bb_points(msk):
        h, w = msk.shape
        x0 = 0
        x1 = msk.shape[1]
        y0 = 0
        y1 = msk.shape[0]
        for i in range(w):
            if msk[:, i].max() > 200:
                x0 = i
                break
        for i in range(w):
            if msk[:, msk.shape[1] - i - 1].max() > 200:
                x1 = msk.shape[1] - i - 1
                break
        for i in range(h):
            if msk[i, :].max() > 200:
                y0 = i
                break
        for i in range(h):
            if msk[msk.shape[0] - i - 1, :].max() > 200:
                y1 = msk.shape[0] - i - 1
                break
        return (x0, y0), (x1, y1)

    image_name = '7aea0b3e2.jpg'
    p1, p2 = (12, 84), (391, 248)
    img = imread(f'../DATA/aug_test/src/{image_name}')

    h = 300
    alpha, sigma, alpha_affine = h * 2, h * 0.08, h * 0.08

    augs = {
        '1_IAAAdditiveGaussianNoise':
        IAAAdditiveGaussianNoise(scale=(0.01 * 255, 0.05 * 255), p=1.0),
        '1_GaussNoise':
        GaussNoise(var_limit=(20, 120), p=1.0),
        '1_RandomGamma':
        RandomGamma(gamma_limit=(80, 120), p=1.0),
        '2_RandomBrightnessContrast':
        RandomBrightnessContrast(p=1.0),
        '2_MotionBlur':
        MotionBlur(p=1.0),
        '2_MedianBlur':
        MedianBlur(blur_limit=6, p=1.0),
        '2_Blur':
        Blur(blur_limit=9, p=1.0),
        '2_IAASharpen':
        IAASharpen(p=1.0),
        '2_IAAEmboss':
        IAAEmboss(p=1.0),
        '2_IAASuperpixels':
        IAASuperpixels(n_segments=50, p_replace=0.05, p=1.0),
        '3_CLAHE':
        CLAHE(clip_limit=8, p=1.0),
        '3_RGBShift':
        RGBShift(p=1.0),
        '3_ChannelShuffle':
        ChannelShuffle(p=1.0),
        '3_HueSaturationValue':
        HueSaturationValue(p=1.0),
        '3_ToGray':
        ToGray(p=1.0),
        '4_OpticalDistortion':
        OpticalDistortion(border_mode=cv2.BORDER_CONSTANT, p=1.0),
        '4_GridDistortion':
        GridDistortion(border_mode=cv2.BORDER_CONSTANT, p=1.0),
        '4_IAAPiecewiseAffine':
        IAAPiecewiseAffine(nb_rows=4, nb_cols=4, p=1.0),
        '4_IAAPerspective':
        IAAPerspective(p=1.0),
        '4_IAAAffine':
        IAAAffine(mode='constant', p=1.0),
        '4_ElasticTransform':
        ElasticTransform(alpha=alpha,
                         sigma=sigma,
                         alpha_affine=alpha_affine,
                         border_mode=cv2.BORDER_CONSTANT,
                         p=1.0)
    }

    # im_merge.shape[1] * 2, im_merge.shape[1] * 0.08, im_merge.shape[1] * 0.08

    for aug in augs:
        mask = np.zeros(img.shape[:2], dtype=np.uint8)
        cv2.rectangle(mask, p1, p2, 255, 2)
        data = {"image": img.copy(), 'mask': mask}
        augmented = augs[aug](**data)
        augimg = augmented['image']
        draw_shadow_text(augimg, f'{aug}', (5, 15), 0.5, (255, 255, 255), 1)
        ap1, ap2 = get_bb_points(augmented['mask'])
        cv2.rectangle(augimg, ap1, ap2, (0, 255, 0), 2)
        imsave(f'../DATA/aug_test/aug/{aug}-{image_name}', augimg)
Example #10
0
def build_databunch(data_dir, img_sz, batch_sz):
    # TODO This is to avoid freezing in the middle of the first epoch. Would be nice
    # to fix this.
    num_workers = 0

    train_dir = join(data_dir, 'train')
    train_anns = glob.glob(join(train_dir, '*.json'))
    valid_dir = join(data_dir, 'valid')
    valid_anns = glob.glob(join(valid_dir, '*.json'))

    label_names = get_label_names(train_anns[0])

    # Augmentations
    policy_v3 = [
        [
            Posterize(p=0.8),
            IAAAffine(translate_px=(10, 20), p=1.0),
        ],
        [
            RandomCropNearBBox(p=0.2),
            IAASharpen(p=0.5),
        ],
        [
            Rotate(p=0.6),
            Rotate(p=0.8),
        ],
        [
            Equalize(p=0.8),
            RandomContrast(p=0.2),
        ],
        [
            Solarize(p=0.2),
            IAAAffine(translate_px=(10, 20), p=0.2),
        ],
        [
            IAASharpen(p=0.0),
            ToGray(p=0.4),
        ],
        [
            Equalize(p=1.0),
            IAAAffine(translate_px=(10, 20), p=1.0),
        ],
        [
            Posterize(p=0.8),
            Rotate(p=0.0),
        ],
        [
            RandomContrast(p=0.6),
            Rotate(p=1.0),
        ],
        [
            Equalize(p=0.0),
            Cutout(p=0.8),
        ],
        [
            RandomBrightness(p=1.0),
            IAAAffine(translate_px=(10, 20), p=1.0),
        ],
        [
            RandomContrast(p=0.0),
            IAAAffine(shear=60.0, p=0.8),
        ],
        [
            RandomContrast(p=0.8),
            RandomContrast(p=0.2),
        ],
        [
            Rotate(p=1.0),
            Cutout(p=1.0),
        ],
        [
            Solarize(p=0.8),
            Equalize(p=0.8),
        ],
    ]
    selected_subpolicy = int(np.random.randint(0, 15, 1))
    aug_transforms = policy_v3[selected_subpolicy]
    standard_transforms = [Resize(img_sz, img_sz), ToTensor()]
    aug_transforms.extend(standard_transforms)

    bbox_params = BboxParams(format='coco',
                             min_area=0.,
                             min_visibility=0.2,
                             label_fields=['labels'])
    aug_transforms = Compose(aug_transforms, bbox_params=bbox_params)
    standard_transforms = Compose(standard_transforms, bbox_params=bbox_params)

    train_ds = CocoDataset(train_dir, train_anns, transforms=aug_transforms)
    valid_ds = CocoDataset(valid_dir,
                           valid_anns,
                           transforms=standard_transforms)
    train_ds.label_names = label_names
    valid_ds.label_names = label_names

    train_dl = DataLoader(train_ds,
                          collate_fn=collate_fn,
                          shuffle=True,
                          batch_size=batch_sz,
                          num_workers=num_workers,
                          pin_memory=True)
    valid_dl = DataLoader(valid_ds,
                          collate_fn=collate_fn,
                          batch_size=batch_sz,
                          num_workers=num_workers,
                          pin_memory=True)
    return DataBunch(train_ds, train_dl, valid_ds, valid_dl, label_names)