Example #1
0
def fbresnet_augmentor(isTrain):
    """
    Augmentor used in fb.resnet.torch, for BGR images in range [0,255].
    """
    if isTrain:
        augmentors = [
            imgaug.GoogleNetRandomCropAndResize(interp=cv2.INTER_CUBIC),
            # It's OK to remove the following augs if your CPU is not fast enough.
            # Removing brightness/contrast/saturation does not have a significant effect on accuracy.
            # Removing lighting leads to a tiny drop in accuracy.
            imgaug.RandomOrderAug([
                imgaug.BrightnessScale((0.6, 1.4), clip=False),
                imgaug.Contrast((0.6, 1.4), clip=False),
                imgaug.Saturation(0.4, rgb=False),
                # rgb-bgr conversion for the constants copied from fb.resnet.torch
                imgaug.Lighting(
                    0.1,
                    eigval=np.asarray([0.2175, 0.0188, 0.0045][::-1]) * 255.0,
                    eigvec=np.array([[-0.5675, 0.7192, 0.4009],
                                     [-0.5808, -0.0045, -0.8140],
                                     [-0.5836, -0.6948, 0.4203]],
                                    dtype='float32')[::-1, ::-1])
            ]),
            imgaug.Flip(horiz=True),
        ]
    else:
        augmentors = [
            imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
            imgaug.CenterCrop((224, 224)),
        ]
    return augmentors
def get_tp_augmentor(isTrain):
    """
    Augmentor used in fb.resnet.torch, for BGR images in range [0,255].
    """
    interpolation = cv2.INTER_CUBIC
    # linear seems to have more stable performance.
    # but we keep cubic for compatibility with old models
    if isTrain:
        augmentors = [
            imgaug.GoogleNetRandomCropAndResize(interp=interpolation),
            # It's OK to remove the following augs if your CPU is not fast enough.
            # Removing brightness/contrast/saturation does not have a significant effect on accuracy.
            # Removing lighting leads to a tiny drop in accuracy.
            # imgaug.RandomOrderAug(
            #     [imgaug.BrightnessScale((0.6, 1.4), clip=False),
            #      imgaug.Contrast((0.6, 1.4), rgb=False, clip=False),
            #      imgaug.Saturation(0.4, rgb=False),
            #      # rgb-bgr conversion for the constants copied from fb.resnet.torch
            #      imgaug.Lighting(0.1,
            #                      eigval=np.asarray(
            #                          [0.2175, 0.0188, 0.0045][::-1]) * 255.0,
            #                      eigvec=np.array(
            #                          [[-0.5675, 0.7192, 0.4009],
            #                           [-0.5808, -0.0045, -0.8140],
            #                           [-0.5836, -0.6948, 0.4203]],
            #                          dtype='float32')[::-1, ::-1]
            #                      )]),
            imgaug.Flip(horiz=True),
        ]
    else:
        augmentors = [
            imgaug.ResizeShortestEdge(256, interp=interpolation),
            imgaug.CenterCrop((224, 224)),
        ]
    return augmentors
Example #3
0
    def train_dataloader(self):
        # REQUIRED
        ds_train = PretrainedSNEMI(folder=self.hparams.data_path,
            train_or_valid='train',
            size=40000,
            # resize=int(self.hparams.shape),
            debug=DEBUG
            )
        

        ag_train = [
            imgaug.RotationAndCropValid(max_deg=180, interp=cv2.INTER_NEAREST),
            imgaug.Flip(horiz=True, vert=False),
            imgaug.Flip(horiz=False, vert=True),
            imgaug.Transpose(),
            imgaug.GoogleNetRandomCropAndResize(crop_area_fraction=(0.2, 0.5), 
                    aspect_ratio_range=(0.5, 2.0),
                    interp=cv2.INTER_NEAREST, 
                    target_shape=self.hparams.shape),
            imgaug.ToFloat32(),
        ]

        ds_train.reset_state()
        # ds_train = AugmentImageComponent(ds_train, ag_train, 0)
        ds_train = AugmentImageComponent(ds_train, [imgaug.Albumentations(AB.RandomBrightnessContrast())], (0))
        ds_train = AugmentImageComponents(ds_train, ag_train, (0, 1))
        ds_train = MapData(ds_train, lambda dp: [dp[0], 255.0*(dp[1]>0)*(1-skimage.segmentation.find_boundaries(dp[1], mode='inner'))])
        ds_train = BatchData(ds_train, self.hparams.batch)
        ds_train = PrintData(ds_train)
        ds_train = MapData(ds_train, lambda dp: [torch.tensor(dp[0][:,np.newaxis,:,:]), torch.tensor(dp[1][:,np.newaxis,:,:]).float()])
        ds_train = MultiProcessRunner(ds_train, num_proc=32, num_prefetch=8)
        return ds_train