def fbresnet_augmentor(isTrain):
    """
    Augmentor used in fb.resnet.torch, for BGR images in range [0,255].
    """
    if isTrain:
        augmentors = [
            GoogleNetResize(),
            imgaug.RandomOrderAug([
                JohnAug(),
                imgaug.BrightnessScale((0.6, 1.4), clip=False),
                imgaug.Contrast((0.6, 1.4), clip=False),
                imgaug.Saturation(0.4, rgb=False),
                # rgb-bgr conversion for the constants copied from fb.resnet.torch
                imgaug.Lighting(
                    0.1,
                    eigval=np.asarray([0.2175, 0.0188, 0.0045][::-1]) * 255.0,
                    eigvec=np.array([[-0.5675, 0.7192, 0.4009],
                                     [-0.5808, -0.0045, -0.8140],
                                     [-0.5836, -0.6948, 0.4203]],
                                    dtype='float32')[::-1, ::-1])
            ]),
            imgaug.Flip(horiz=True),
        ]
    else:
        round2pow2 = lambda x: 2**(x - 1).bit_length()

        augmentors = [
            imgaug.ResizeShortestEdge(round2pow2(IMAGE_SIZE), cv2.INTER_CUBIC),
            imgaug.CenterCrop((IMAGE_SIZE, IMAGE_SIZE)),
        ]
    return augmentors
def fbresnet_augmentor(isTrain):
    """
    Augmentor used in fb.resnet.torch, for BGR images in range [0,255]. # 残差网络增强图像
    """
    if isTrain:  # 如果训练数据的话
        augmentors = [
            GoogleNetResize(),  # 定义好了crop_area_fraction等参数
            imgaug.
            RandomOrderAug(  # GPU不行的话就把这部分删除Remove these augs if your CPU is not fast enough #imgaug是一个图像增强库
                [
                    imgaug.BrightnessScale((0.6, 1.4), clip=False),
                    imgaug.Contrast((0.6, 1.4), clip=False),
                    imgaug.Saturation(0.4, rgb=False),
                    # rgb-bgr conversion for the constants copied from fb.resnet.torch
                    imgaug.Lighting(
                        0.1,
                        eigval=np.asarray([0.2175, 0.0188, 0.0045][::-1]) *
                        255.0,
                        eigvec=np.array([[-0.5675, 0.7192, 0.4009],
                                         [-0.5808, -0.0045, -0.8140],
                                         [-0.5836, -0.6948, 0.4203]],
                                        dtype='float32')[::-1, ::-1])
                ]),
            imgaug.Flip(horiz=True),
        ]
    else:  # 如果不是训练数据的话
        augmentors = [
            imgaug.ResizeShortestEdge(
                256, cv2.INTER_CUBIC),  #  在保持纵横比的同时,将最短边的大小调整为某个数字。
            imgaug.CenterCrop((224, 224)),  # 在中间裁剪图像
        ]
    return augmentors
Example #3
0
def get_dataflow(path, is_train):
    ds = CocoPoseLMDB(path, is_train)  # read data from lmdb
    if is_train:
        ds = MapDataComponent(ds, pose_random_scale)
        ds = MapDataComponent(ds, pose_rotation)
        ds = MapDataComponent(ds, pose_flip)
        ds = MapDataComponent(ds, pose_resize_shortestedge_random)
        ds = MapDataComponent(ds, pose_crop_random)
        ds = MapData(ds, pose_to_img)
        augs = [
            imgaug.RandomApplyAug(
                imgaug.RandomChooseAug([
                    imgaug.BrightnessScale((0.6, 1.4), clip=False),
                    imgaug.Contrast((0.7, 1.4), clip=False),
                    imgaug.GaussianBlur(max_size=3)
                ]), 0.7),
        ]
        ds = AugmentImageComponent(ds, augs)
    else:
        ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
        ds = MapDataComponent(ds, pose_crop_center)
        ds = MapData(ds, pose_to_img)

    ds = PrefetchData(ds, 1000, multiprocessing.cpu_count())

    return ds
Example #4
0
def fbresnet_augmentor(isTrain, input_size=224):
    """
    Augmentor used in fb.resnet.torch, for BGR images in range [0,255].
    """
    if isTrain:
        augmentors = [
            imgaug.GoogleNetRandomCropAndResize(interp=cv2.INTER_CUBIC),
            # It's OK to remove the following augs if your CPU is not fast enough.
            # Removing brightness/contrast/saturation does not have a significant effect on accuracy.
            # Removing lighting leads to a tiny drop in accuracy.
            imgaug.RandomOrderAug(
                [imgaug.BrightnessScale((0.6, 1.4), clip=False),
                 imgaug.Contrast((0.6, 1.4), clip=False),
                 imgaug.Saturation(0.4, rgb=False),
                 # rgb-bgr conversion for the constants copied from fb.resnet.torch
                 imgaug.Lighting(0.1,
                                 eigval=np.asarray(
                                     [0.2175, 0.0188, 0.0045][::-1]) * 255.0,
                                 eigvec=np.array(
                                     [[-0.5675, 0.7192, 0.4009],
                                      [-0.5808, -0.0045, -0.8140],
                                      [-0.5836, -0.6948, 0.4203]],
                                     dtype='float32')[::-1, ::-1]
                                 )]),
            imgaug.Flip(horiz=True),
        ]
    else:
        augmentors = [
            imgaug.ResizeShortestEdge(input_size + 32, cv2.INTER_CUBIC),
            imgaug.CenterCrop((input_size, input_size)),
        ]
    return augmentors
Example #5
0
def sample_augmentations():
    ds = CocoPoseLMDB('/data/public/rw/coco-pose-estimation-lmdb/',
                      is_train=False,
                      only_idx=0)
    ds = MapDataComponent(ds, pose_random_scale)
    ds = MapDataComponent(ds, pose_rotation)
    ds = MapDataComponent(ds, pose_flip)
    ds = MapDataComponent(ds, pose_resize_shortestedge_random)
    ds = MapDataComponent(ds, pose_crop_random)
    ds = MapData(ds, pose_to_img)
    augs = [
        imgaug.RandomApplyAug(
            imgaug.RandomChooseAug([
                imgaug.GaussianBlur(3),
                imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01),
                imgaug.RandomOrderAug([
                    imgaug.BrightnessScale((0.8, 1.2), clip=False),
                    imgaug.Contrast((0.8, 1.2), clip=False),
                    # imgaug.Saturation(0.4, rgb=True),
                ]),
            ]),
            0.7),
    ]
    ds = AugmentImageComponent(ds, augs)

    ds.reset_state()
    for l1, l2, l3 in ds.get_data():
        CocoPoseLMDB.display_image(l1, l2, l3)
Example #6
0
def get_dataflow(path, is_train):
    ds = CocoPoseLMDB(path, is_train)  # read data from lmdb
    if is_train:
        ds = MapDataComponent(ds, pose_random_scale)
        ds = MapDataComponent(ds, pose_rotation)
        ds = MapDataComponent(ds, pose_flip)
        ds = MapDataComponent(ds, pose_resize_shortestedge_random)
        ds = MapDataComponent(ds, pose_crop_random)
        ds = MapData(ds, pose_to_img)
        augs = [
            imgaug.RandomApplyAug(
                imgaug.RandomChooseAug([
                    imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01),
                    imgaug.RandomOrderAug([
                        imgaug.BrightnessScale((0.8, 1.2), clip=False),
                        imgaug.Contrast((0.8, 1.2), clip=False),
                        # imgaug.Saturation(0.4, rgb=True),
                    ]),
                ]),
                0.7),
        ]
        ds = AugmentImageComponent(ds, augs)
    else:
        ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
        ds = MapDataComponent(ds, pose_crop_center)
        ds = MapData(ds, pose_to_img)

    ds = PrefetchData(ds, 1000, multiprocessing.cpu_count())

    return ds
Example #7
0
def get_augmentations(is_train):
    if is_train:
        augmentors = [
            GoogleNetResize(crop_area_fraction=0.76,
                            target_shape=224),  # TODO : 76% or 49%?
            imgaug.RandomOrderAug([
                imgaug.BrightnessScale((0.6, 1.4), clip=True),
                imgaug.Contrast((0.6, 1.4), clip=True),
                imgaug.Saturation(0.4, rgb=False),
                # rgb-bgr conversion for the constants copied from fb.resnet.torch
                imgaug.Lighting(
                    0.1,
                    eigval=np.asarray([0.2175, 0.0188, 0.0045][::-1]) * 255.0,
                    eigvec=np.array([[-0.5675, 0.7192, 0.4009],
                                     [-0.5808, -0.0045, -0.8140],
                                     [-0.5836, -0.6948, 0.4203]],
                                    dtype='float32')[::-1, ::-1])
            ]),
            imgaug.Flip(horiz=True),
        ]
    else:
        augmentors = [
            imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
            imgaug.CenterCrop((224, 224)),
        ]
    return augmentors
Example #8
0
def fbresnet_augmentor(isTrain):
    """
    Augmentor used in fb.resnet.torch, for BGR images in range [0,255].
    """
    if isTrain:
        """
        Sec 5.1:
        We use scale and aspect ratio data augmentation [35] as
        in [12]. The network input image is a 224×224 pixel random
        crop from an augmented image or its horizontal flip.
        """
        augmentors = [
            GoogleNetResize(),
            imgaug.RandomOrderAug([
                imgaug.BrightnessScale((0.6, 1.4), clip=False),
                imgaug.Contrast((0.6, 1.4), clip=False),
                imgaug.Saturation(0.4, rgb=False),
                # rgb-bgr conversion for the constants copied from fb.resnet.torch
                imgaug.Lighting(
                    0.1,
                    eigval=np.asarray([0.2175, 0.0188, 0.0045][::-1]) * 255.0,
                    eigvec=np.array([[-0.5675, 0.7192, 0.4009],
                                     [-0.5808, -0.0045, -0.8140],
                                     [-0.5836, -0.6948, 0.4203]],
                                    dtype='float32')[::-1, ::-1])
            ]),
            imgaug.Flip(horiz=True),
        ]
    else:
        augmentors = [
            imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
            imgaug.CenterCrop((224, 224)),
        ]
    return augmentors
Example #9
0
def get_dataflow(is_train):
    ds = CocoPoseLMDB('/data/public/rw/coco-pose-estimation-lmdb/', is_train)
    if is_train:
        ds = MapDataComponent(ds, pose_rotation)
        ds = MapDataComponent(ds, pose_flip)
        ds = MapDataComponent(ds, pose_resize_shortestedge_random)
        ds = MapDataComponent(ds, pose_crop_random)
        ds = MapData(ds, pose_to_img)
        augs = [
            imgaug.RandomApplyAug(imgaug.RandomChooseAug([
                imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01),
                imgaug.RandomOrderAug([
                    imgaug.BrightnessScale((0.8, 1.2), clip=False),
                    imgaug.Contrast((0.8, 1.2), clip=False),
                    # imgaug.Saturation(0.4, rgb=True),
                ]),
            ]), 0.7),
        ]
        ds = AugmentImageComponent(ds, augs)
    else:
        ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
        ds = MapDataComponent(ds, pose_crop_center)
        ds = MapData(ds, pose_to_img)

    return ds
def fbresnet_augmentor(isTrain):
    """
    Augmentor used in fb.resnet.torch, for BGR images.
    """
    if isTrain:
        augmentors = [
            GoogleNetResize(),
            imgaug.RandomOrderAug([
                imgaug.Brightness(30, clip=False),
                imgaug.Contrast((0.8, 1.2), clip=False),
                imgaug.Saturation(0.4, rgb=False),
                # rgb-bgr conversion
                imgaug.Lighting(0.1,
                                eigval=[0.2175, 0.0188, 0.0045][::-1],
                                eigvec=np.array([[-0.5675, 0.7192, 0.4009],
                                                 [-0.5808, -0.0045, -0.8140],
                                                 [-0.5836, -0.6948, 0.4203]],
                                                dtype='float32')[::-1, ::-1])
            ]),
            imgaug.Clip(),
            imgaug.Flip(horiz=True),
        ]
    else:
        augmentors = [
            imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
            imgaug.CenterCrop((224, 224)),
        ]
    return augmentors
Example #11
0
def fbresnet_augmentor(isTrain):
    """
    Augmentor used in fb.resnet.torch, for BGR images in range [0,255].
    """
    if isTrain:
        augmentors = [
            GoogleNetResize(),
            imgaug.RandomOrderAug(      # Remove these augs if your CPU is not fast enough
                [imgaug.BrightnessScale((0.6, 1.4), clip=False),
                 imgaug.Contrast((0.6, 1.4), clip=False),
                 imgaug.Saturation(0.4, rgb=False),
                 # rgb-bgr conversion for the constants copied from fb.resnet.torch
                 imgaug.Lighting(0.1,
                                 eigval=np.asarray(
                                     [0.2175, 0.0188, 0.0045][::-1]) * 255.0,
                                 eigvec=np.array(
                                     [[-0.5675, 0.7192, 0.4009],
                                      [-0.5808, -0.0045, -0.8140],
                                      [-0.5836, -0.6948, 0.4203]],
                                     dtype='float32')[::-1, ::-1]
                                 )]),
            imgaug.Flip(horiz=True),
        ]
    else:
        augmentors = [
            imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
            imgaug.CenterCrop((224, 224)),
        ]
    return augmentors
def get_data(image_ids, batch_size=1, is_train=False, shape=510):

    ds = ImageAndMaskFromFile(image_ids, channel=3, shuffle=True)

    if is_train:

        number_of_prefetch = 8

        augs_with_label = [
            imgaug.RandomCrop(shape),
            imgaug.Flip(horiz=True, prob=0.5),
            imgaug.Flip(vert=True, prob=0.5)
        ]

        augs_no_label = [
            imgaug.RandomOrderAug(
                [imgaug.Brightness(delta=20),
                 imgaug.Contrast((0.6, 1.4))])
        ]

        # augmentors = [
        #     GoogleNetResize(),
        #     # It's OK to remove the following augs if your CPU is not fast enough.
        #     # Removing brightness/contrast/saturation does not have a significant effect on accuracy.
        #     # Removing lighting leads to a tiny drop in accuracy.
        #     imgaug.RandomOrderAug(
        #         [imgaug.BrightnessScale((0.6, 1.4), clip=False),
        #          imgaug.Contrast((0.6, 1.4), clip=False),
        #          imgaug.Saturation(0.4, rgb=False),
        #          # rgb-bgr conversion for the constants copied from fb.resnet.torch
        #          imgaug.Lighting(0.1,—
        #                          eigval=np.asarray(
        #                              [0.2175, 0.0188, 0.0045][::-1]) * 255.0,
        #                          eigvec=np.array(
        #                              [[-0.5675, 0.7192, 0.4009],
        #                               [-0.5808, -0.0045, -0.8140],
        #                               [-0.5836, -0.6948, 0.4203]],
        #                              dtype='float32')[::-1, ::-1]
        #                          )]),
        #     imgaug.Flip(horiz=True),
        # ]

    else:

        number_of_prefetch = 1

        augs_with_label = [imgaug.CenterCrop(shape)]
        augs_no_label = []

    ds = AugmentImageComponents(ds, augs_with_label, (0, 1))
    ds = AugmentImageComponents(ds, augs_no_label, [0])

    ds = BatchData(ds, batch_size)
    ds = PrefetchData(ds, 30, 1)  #number_of_prefetch)

    return ds
    def get_train_augmentors(self, input_shape, output_shape, view=False):
        print(input_shape, output_shape)
        shape_augs = [
            imgaug.Affine(
                shear=5,  # in degree
                scale=(0.8, 1.2),
                rotate_max_deg=179,
                translate_frac=(0.01, 0.01),
                interp=cv2.INTER_NEAREST,
                border=cv2.BORDER_CONSTANT),
            imgaug.Flip(vert=True),
            imgaug.Flip(horiz=True),
            imgaug.CenterCrop(input_shape),
        ]

        input_augs = [
            imgaug.RandomApplyAug(
                imgaug.RandomChooseAug([
                    GaussianBlur(),
                    MedianBlur(),
                    imgaug.GaussianNoise(),
                ]), 0.5),
            # standard color augmentation
            imgaug.RandomOrderAug([
                imgaug.Hue((-8, 8), rgb=True),
                imgaug.Saturation(0.2, rgb=True),
                imgaug.Brightness(26, clip=True),
                imgaug.Contrast((0.75, 1.25), clip=True),
            ]),
            imgaug.ToUint8(),
        ]

        label_augs = []
        if self.model_type == 'unet' or self.model_type == 'micronet':
            label_augs = [GenInstanceUnetMap(crop_shape=output_shape)]
        if self.model_type == 'dcan':
            label_augs = [GenInstanceContourMap(crop_shape=output_shape)]
        if self.model_type == 'dist':
            label_augs = [
                GenInstanceDistance(crop_shape=output_shape, inst_norm=False)
            ]
        if self.model_type == 'np_hv':
            label_augs = [GenInstanceHV(crop_shape=output_shape)]
        if self.model_type == 'np_dist':
            label_augs = [
                GenInstanceDistance(crop_shape=output_shape, inst_norm=True)
            ]

        if not self.type_classification:
            label_augs.append(BinarizeLabel())

        if not view:
            label_augs.append(imgaug.CenterCrop(output_shape))

        return shape_augs, input_augs, label_augs
Example #14
0
    def get_train_augmentors(self, view=False):
        shape_augs = [
            imgaug.Affine(
                shear=5,  # in degree
                scale=(0.8, 1.2),
                rotate_max_deg=179,
                translate_frac=(0.01, 0.01),
                interp=cv2.INTER_NEAREST,
                border=cv2.BORDER_CONSTANT),
            imgaug.Flip(vert=True),
            imgaug.Flip(horiz=True),
            imgaug.CenterCrop(self.train_input_shape),
        ]

        input_augs = [
            imgaug.RandomApplyAug(
                imgaug.RandomChooseAug([
                    GaussianBlur(),
                    MedianBlur(),
                    imgaug.GaussianNoise(),
                ]), 0.5),
            # standard color augmentation
            imgaug.RandomOrderAug([
                imgaug.Hue((-8, 8), rgb=True),
                imgaug.Saturation(0.2, rgb=True),
                imgaug.Brightness(26, clip=True),
                imgaug.Contrast((0.75, 1.25), clip=True),
            ]),
            imgaug.ToUint8(),
        ]

        # default to 'xy'
        if self.model_mode != 'np+dst':
            label_augs = [GenInstanceXY(self.train_mask_shape)]
        else:
            label_augs = [GenInstanceDistance(self.train_mask_shape)]
        label_augs.append(BinarizeLabel())

        if not view:
            label_augs.append(imgaug.CenterCrop(self.train_mask_shape))

        return shape_augs, input_augs, label_augs
Example #15
0
def get_ilsvrc_data_alexnet(is_train, image_size, batchsize, directory):
    if is_train:
        if not directory.startswith('/'):
            ds = ILSVRCTTenthTrain(directory)
        else:
            ds = ILSVRC12(directory, 'train')
        augs = [
            imgaug.RandomApplyAug(imgaug.RandomResize((0.9, 1.2), (0.9, 1.2)),
                                  0.7),
            imgaug.RandomApplyAug(imgaug.RotationAndCropValid(15), 0.7),
            imgaug.RandomApplyAug(
                imgaug.RandomChooseAug([
                    imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01),
                    imgaug.RandomOrderAug([
                        imgaug.BrightnessScale((0.8, 1.2), clip=False),
                        imgaug.Contrast((0.8, 1.2), clip=False),
                        # imgaug.Saturation(0.4, rgb=True),
                    ]),
                ]),
                0.7),
            imgaug.Flip(horiz=True),
            imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
            imgaug.RandomCrop((224, 224)),
        ]
        ds = AugmentImageComponent(ds, augs)
        ds = PrefetchData(ds, 1000, multiprocessing.cpu_count())
        ds = BatchData(ds, batchsize)
        ds = PrefetchData(ds, 10, 4)
    else:
        if not directory.startswith('/'):
            ds = ILSVRCTenthValid(directory)
        else:
            ds = ILSVRC12(directory, 'val')
        ds = AugmentImageComponent(ds, [
            imgaug.ResizeShortestEdge(224, cv2.INTER_CUBIC),
            imgaug.CenterCrop((224, 224)),
        ])
        ds = PrefetchData(ds, 100, multiprocessing.cpu_count())
        ds = BatchData(ds, batchsize)

    return ds
Example #16
0
def fbresnet_augmentor(isTrain):
    """
    Augmentor used in fb.resnet.torch, for BGR images in range [0,255].
    """
    interpolation = cv2.INTER_LINEAR
    if isTrain:
        """
        Sec 5.1:
        We use scale and aspect ratio data augmentation [35] as
        in [12]. The network input image is a 224×224 pixel random
        crop from an augmented image or its horizontal flip.
        """
        augmentors = [
            imgaug.GoogleNetRandomCropAndResize(interp=interpolation),
            # It's OK to remove the following augs if your CPU is not fast enough.
            # Removing brightness/contrast/saturation does not have a significant effect on accuracy.
            # Removing lighting leads to a tiny drop in accuracy.
            imgaug.RandomOrderAug(
                [imgaug.BrightnessScale((0.6, 1.4), clip=False),
                 imgaug.Contrast((0.6, 1.4), rgb=False, clip=False),
                 imgaug.Saturation(0.4, rgb=False),
                 # rgb-bgr conversion for the constants copied from fb.resnet.torch
                 imgaug.Lighting(0.1,
                                 eigval=np.asarray(
                                     [0.2175, 0.0188, 0.0045][::-1]) * 255.0,
                                 eigvec=np.array(
                                     [[-0.5675, 0.7192, 0.4009],
                                      [-0.5808, -0.0045, -0.8140],
                                      [-0.5836, -0.6948, 0.4203]],
                                     dtype='float32')[::-1, ::-1]
                                 )]),
            imgaug.Flip(horiz=True),
        ]
    else:
        augmentors = [
            imgaug.ResizeShortestEdge(256, interp=interpolation),
            imgaug.CenterCrop((224, 224)),
        ]
    return augmentors
Example #17
0
    def __init__(self, verbose=True):
        self.model_config = os.environ[
            "H_PROFILE"] if "H_PROFILE" in os.environ else ""
        data_config = defaultdict(
            lambda: None,
            yaml.load(open("config.yml"),
                      Loader=yaml.FullLoader)[self.model_config],
        )

        # Validation
        assert data_config["input_prefix"] is not None
        assert data_config["output_prefix"] is not None

        # Load config yml file
        self.log_path = data_config["output_prefix"]  # log root path
        self.data_dir_root = os.path.join(
            data_config["input_prefix"],
            data_config["data_dir"])  # without modes

        self.extract_type = data_config["extract_type"]
        self.data_modes = data_config["data_modes"]
        self.win_size = data_config["win_size"]
        self.step_size = data_config["step_size"]
        self.img_ext = (".png" if data_config["img_ext"] is None else
                        data_config["img_ext"])

        for step in [
                "preproc", "extract", "train", "infer", "export", "process"
        ]:
            exec(
                f"self.out_{step}_root = os.path.join(data_config['output_prefix'], '{step}')"
            )
        # self.out_preproc_root = os.path.join(data_config['output_prefix'], 'preprocess')
        # self.out_extract_root = os.path.join(data_config['output_prefix'], 'extract')

        self.img_dirs = {
            k: v
            for k, v in zip(
                self.data_modes,
                [
                    os.path.join(self.data_dir_root, mode, "Images")
                    for mode in self.data_modes
                ],
            )
        }
        self.labels_dirs = {
            k: v
            for k, v in zip(
                self.data_modes,
                [
                    os.path.join(self.data_dir_root, mode, "Labels")
                    for mode in self.data_modes
                ],
            )
        }

        # normalized images
        self.out_preproc = None
        if data_config["include_preproc"]:
            self.out_preproc = {
                k: v
                for k, v in zip(
                    self.data_modes,
                    [
                        os.path.join(self.out_preproc_root, self.model_config,
                                     mode, "Images")
                        for mode in self.data_modes
                    ],
                )
            }

        if data_config["stain_norm"] is not None:
            # self.target_norm = f"{self._data_dir}/{self.data_modes[0]}/'Images'/{data_config['stain_norm']['target']}{self.img_ext}"
            self.norm_target = os.path.join(
                self.data_dir_root,
                data_config["stain_norm"]["mode"],
                "Images",
                f"{data_config['stain_norm']['image']}{self.img_ext}",
            )
            self.norm_brightness = data_config["stain_norm"]["norm_brightness"]

        self.normalized = (data_config["include_preproc"]) and (
            data_config["stain_norm"] is not None)
        win_code = "{}_{}x{}_{}x{}{}".format(
            self.model_config,
            self.win_size[0],
            self.win_size[1],
            self.step_size[0],
            self.step_size[1],
            "_stain_norm" if self.normalized else "",
        )
        self.out_extract = {
            k: v
            for k, v in zip(
                self.data_modes,
                [
                    os.path.join(self.out_extract_root, win_code, mode,
                                 "Annotations") for mode in self.data_modes
                ],
            )
        }

        # init model params
        self.seed = data_config["seed"]
        mode = data_config["mode"]
        self.model_type = data_config["model_type"]
        self.type_classification = data_config["type_classification"]

        # Some semantic segmentation network like micronet, nr_types will replace nr_classes if type_classification=True
        self.nr_classes = 2  # Nuclei Pixels vs Background

        self.nuclei_type_dict = data_config["nuclei_types"]
        self.nr_types = len(
            self.nuclei_type_dict.values()) + 1  # plus background

        #### Dynamically setting the config file into variable
        if mode == "hover":
            config_file = importlib.import_module("opt.hover")
        config_dict = config_file.__getattribute__(self.model_type)

        for variable, value in config_dict.items():
            self.__setattr__(variable, value)

        # patches are stored as numpy arrays with N channels
        # ordering as [Image][Nuclei Pixels][Nuclei Type][Additional Map] - training data
        # Ex: with type_classification=True
        #     HoVer-Net: RGB - Nuclei Pixels - Type Map - Horizontal and Vertical Map
        # Ex: with type_classification=False
        #     Dist     : RGB - Nuclei Pixels - Distance Map

        self.color_palete = COLOR_PALETE

        # self.model_name = f"{self.model_config}-{self.model_type}-{data_config['input_augs']}-{data_config['exp_id']}"
        self.model_name = (
            f"{self.model_config}-{data_config['input_augs']}-{data_config['exp_id']}"
        )

        self.data_ext = (".npy" if data_config["data_ext"] is None else
                         data_config["data_ext"])
        # list of directories containing validation patches

        # self.train_dir = data_config['train_dir']
        # self.valid_dir = data_config['valid_dir']
        if data_config["include_extract"]:
            self.train_dir = [
                os.path.join(self.out_extract_root, win_code, x)
                for x in data_config["train_dir"]
            ]
            self.valid_dir = [
                os.path.join(self.out_extract_root, win_code, x)
                for x in data_config["valid_dir"]
            ]
        else:
            self.train_dir = [
                os.path.join(self.data_dir_root, x)
                for x in data_config["train_dir"]
            ]
            self.valid_dir = [
                os.path.join(self.data_dir_root, x)
                for x in data_config["valid_dir"]
            ]

        # nr of processes for parallel processing input
        self.nr_procs_train = (8 if data_config["nr_procs_train"] is None else
                               data_config["nr_procs_train"])
        self.nr_procs_valid = (4 if data_config["nr_procs_valid"] is None else
                               data_config["nr_procs_valid"])

        self.input_norm = data_config[
            "input_norm"]  # normalize RGB to 0-1 range

        # self.save_dir = os.path.join(data_config['output_prefix'], 'train', self.model_name)
        self.save_dir = os.path.join(self.out_train_root, self.model_name)

        #### Info for running inference
        self.inf_auto_find_chkpt = data_config["inf_auto_find_chkpt"]
        # path to checkpoints will be used for inference, replace accordingly

        if self.inf_auto_find_chkpt:
            self.inf_model_path = os.path.join(self.save_dir)
        else:
            self.inf_model_path = os.path.join(data_config["input_prefix"],
                                               "models",
                                               data_config["inf_model"])
        # self.save_dir + '/model-19640.index'

        # output will have channel ordering as [Nuclei Type][Nuclei Pixels][Additional]
        # where [Nuclei Type] will be used for getting the type of each instance
        # while [Nuclei Pixels][Additional] will be used for extracting instances

        # TODO: encode the file extension for each folder?
        # list of [[root_dir1, codeX, subdirA, subdirB], [root_dir2, codeY, subdirC, subdirD] etc.]
        # code is used together with 'inf_output_dir' to make output dir for each set
        self.inf_imgs_ext = (".png" if data_config["inf_imgs_ext"] is None else
                             data_config["inf_imgs_ext"])

        # rootdir, outputdirname, subdir1, subdir2(opt) ...
        self.inf_data_list = [
            os.path.join(data_config["input_prefix"], x)
            for x in data_config["inf_data_list"]
        ]

        model_used = (
            self.model_name if self.inf_auto_find_chkpt else
            os.path.basename(f"{data_config['inf_model'].split('.')[0]}"))

        self.inf_auto_metric = data_config["inf_auto_metric"]
        self.inf_output_dir = os.path.join(
            self.out_infer_root,
            f"{model_used}.{''.join(data_config['inf_data_list']).replace('/', '_').rstrip('_')}.{self.inf_auto_metric}",
        )
        self.model_export_dir = os.path.join(self.out_export_root,
                                             self.model_name)
        self.remap_labels = data_config["remap_labels"]
        self.outline = data_config["outline"]
        self.skip_types = ([
            self.nuclei_type_dict[x.strip()] for x in data_config["skip_types"]
        ] if data_config["skip_types"] is not None else None)

        self.inf_auto_comparator = data_config["inf_auto_comparator"]
        self.inf_batch_size = data_config["inf_batch_size"]

        # For inference during evalutaion mode i.e run by inferer.py
        self.eval_inf_input_tensor_names = ["images"]
        self.eval_inf_output_tensor_names = ["predmap-coded"]
        # For inference during training mode i.e run by trainer.py
        self.train_inf_output_tensor_names = ["predmap-coded", "truemap-coded"]

        assert data_config["input_augs"] != "" or data_config[
            "input_augs"] is not None

        #### Policies
        policies = {
            "p_standard": [
                imgaug.RandomApplyAug(
                    imgaug.RandomChooseAug([
                        GaussianBlur(),
                        MedianBlur(),
                        imgaug.GaussianNoise(),
                    ]),
                    0.5,
                ),
                imgaug.RandomOrderAug([
                    imgaug.Hue((-8, 8), rgb=True),
                    imgaug.Saturation(0.2, rgb=True),
                    imgaug.Brightness(26, clip=True),
                    imgaug.Contrast((0.75, 1.25), clip=True),
                ]),
                imgaug.ToUint8(),
            ],
            "p_hed_random": [
                imgaug.RandomApplyAug(
                    imgaug.RandomChooseAug([
                        GaussianBlur(),
                        MedianBlur(),
                        imgaug.GaussianNoise(),
                        #
                        imgaug.ColorSpace(cv2.COLOR_RGB2HSV),
                        imgaug.ColorSpace(cv2.COLOR_HSV2RGB),
                        #
                        eqRGB2HED(),
                    ]),
                    0.5,
                ),
                # standard color augmentation
                imgaug.RandomOrderAug([
                    imgaug.Hue((-8, 8), rgb=True),
                    imgaug.Saturation(0.2, rgb=True),
                    imgaug.Brightness(26, clip=True),
                    imgaug.Contrast((0.75, 1.25), clip=True),
                ]),
                imgaug.ToUint8(),
            ],
            "p_linear_1": [
                imgaug.RandomApplyAug(
                    imgaug.RandomChooseAug([
                        GaussianBlur(),
                        MedianBlur(),
                        imgaug.GaussianNoise(),
                    ]),
                    0.5,
                ),
                linearAugmentation(),
                imgaug.ToUint8(),
            ],
            "p_linear_2": [
                imgaug.RandomApplyAug(
                    imgaug.RandomChooseAug([
                        GaussianBlur(),
                        MedianBlur(),
                        imgaug.GaussianNoise(),
                    ]),
                    0.5,
                ),
                linearAugmentation(),
                imgaug.RandomOrderAug([
                    imgaug.Hue((-8, 8), rgb=True),
                    imgaug.Saturation(0.2, rgb=True),
                    imgaug.Brightness(26, clip=True),
                    imgaug.Contrast((0.8, 1.20), clip=True),  # 0.75, 1.25
                ]),
                imgaug.ToUint8(),
            ],
            "p_linear_3": [
                imgaug.RandomApplyAug(
                    imgaug.RandomChooseAug([
                        GaussianBlur(),
                        MedianBlur(),
                        imgaug.GaussianNoise(),
                    ]),
                    0.5,
                ),
                imgaug.RandomChooseAug([
                    linearAugmentation(),
                    imgaug.RandomOrderAug([
                        imgaug.Hue((-2, 2), rgb=True),
                        imgaug.Saturation(0.2, rgb=True),
                        imgaug.Brightness(26, clip=True),
                        imgaug.Contrast((0.9, 1.1), clip=True),  # 0.75, 1.25
                    ]),
                ]),
                imgaug.ToUint8(),
            ],
        }

        self.input_augs = policies[(data_config["input_augs"])]

        # Checks
        if verbose:
            print("--------")
            print("Config info:")
            print("--------")
            print(f"Log path: <{self.log_path}>")
            print(f"Extraction out dirs: <{self.out_extract}>")
            print("--------")
            print("Training")
            print(f"Model name: <{self.model_name}>")
            print(f"Input img dirs: <{self.img_dirs}>")
            print(f"Input labels dirs: <{self.labels_dirs}>")
            print(f"Train out dir: <{self.save_dir}>")
            print("--------")
            print("Inference")
            print(f"Auto-find trained model: <{self.inf_auto_find_chkpt}>")
            print(f"Inference model path dir: <{self.inf_model_path}>")
            print(f"Input inference path: <{self.inf_data_list}>")
            print(f"Output inference path: <{self.inf_output_dir}>")
            print(f"Model export out: <{self.model_export_dir}>")
            print("--------")
            print()
Example #18
0
def fbresnet_augmentor(isTrain, crop_method, color_augmentation):
    """
    Augmentor used in fb.resnet.torch, for BGR images in range [0,255].
    """
    execution_lst = []

    if isTrain:
        augmentors = [
            # 1. crop_method
            # a) GoogleNetResize
            GoogleNetResize(),
            # b) ShortestEdgeResize
            imgaug.ResizeShortestEdge(256),
            # c) GlobalWarp
            imgaug.Resize(226),  # NOTE: for CAM generation
            imgaug.RandomCrop((224, 224)),
            # d) CAMCrop
            # (when CAMCrop is set, the output from the original DataFlow has already been cropped)
            # 2. color_augmentation
            imgaug.RandomOrderAug([
                imgaug.BrightnessScale((0.6, 1.4), clip=False),
                imgaug.Contrast((0.6, 1.4), clip=False),
                imgaug.Saturation(0.4, rgb=False),
                # rgb-bgr conversion for the constants copied from fb.resnet.torch
                imgaug.Lighting(
                    0.1,
                    eigval=np.asarray([0.2175, 0.0188, 0.0045][::-1]) * 255.0,
                    eigvec=np.array([[-0.5675, 0.7192, 0.4009],
                                     [-0.5808, -0.0045, -0.8140],
                                     [-0.5836, -0.6948, 0.4203]],
                                    dtype='float32')[::-1, ::-1])
            ]),
            imgaug.Flip(horiz=True),
        ]

        #
        if crop_method == 'GoogleNetResize':
            print(
                '--> perform GoogleNetResize cropping method during the training pipeline'
            )
            execution_lst.extend([0])
        elif crop_method == 'ShortestEdgeResize':
            print(
                '--> perform ShortestEdgeResize cropping method during the training pipeline'
            )
            execution_lst.extend([1, 3])
        elif crop_method == 'GlobalWarp':
            print(
                '--> perform GlobalWarp cropping method during the training pipeline'
            )
            execution_lst.extend([2, 3])
        elif crop_method == 'CAMCrop':
            # enable CAMCrop @ 20171124
            print(
                '*** Perform CAMCrop to better the training dynamics and the results ***'
            )

        if color_augmentation:
            print(
                '--> perform color augmentation during the training pipeline')
            execution_lst.extend([4])
        else:
            print(
                '--> discard the color jittering process during the training pipeline'
            )

        # perform mirror reflection augmentation anyway
        execution_lst.extend([5])

    else:
        augmentors = [
            imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
            imgaug.CenterCrop((224, 224)),
            imgaug.RandomCrop((224, 224)),
        ]

        if crop_method == 'RandomCrop':
            execution_lst.extend([0, 2])

        elif crop_method == 'CenterCrop':
            execution_lst.extend([0, 1])

    return [
        item_ for id_, item_ in enumerate(augmentors) if id_ in execution_lst
    ]
Example #19
0
def get_train_augmentors(self, input_shape, output_shape, view=False):
    print(input_shape, output_shape)
    if self.model_mode == "class_rotmnist":
        shape_augs = [
            imgaug.Affine(rotate_max_deg=359,
                          interp=cv2.INTER_NEAREST,
                          border=cv2.BORDER_CONSTANT),
        ]

        input_augs = []

    else:
        shape_augs = [
            imgaug.Affine(
                rotate_max_deg=359,
                translate_frac=(0.01, 0.01),
                interp=cv2.INTER_NEAREST,
                border=cv2.BORDER_REFLECT,
            ),
            imgaug.Flip(vert=True),
            imgaug.Flip(horiz=True),
            imgaug.CenterCrop(input_shape),
        ]

        input_augs = [
            imgaug.RandomApplyAug(
                imgaug.RandomChooseAug([
                    GaussianBlur(),
                    MedianBlur(),
                    imgaug.GaussianNoise(),
                ]),
                0.5,
            ),
            # Standard colour augmentation
            imgaug.RandomOrderAug([
                imgaug.Hue((-8, 8), rgb=True),
                imgaug.Saturation(0.2, rgb=True),
                imgaug.Brightness(26, clip=True),
                imgaug.Contrast((0.75, 1.25), clip=True),
            ]),
            imgaug.ToUint8(),
        ]

    if self.model_mode == "seg_gland":
        label_augs = []
        label_augs = [GenInstanceContourMap(mode=self.model_mode)]
        label_augs.append(BinarizeLabel())

        if not view:
            label_augs.append(imgaug.CenterCrop(output_shape))
        else:
            label_augs.append(imgaug.CenterCrop(input_shape))

        return shape_augs, input_augs, label_augs
    elif self.model_mode == "seg_nuc":
        label_augs = []
        label_augs = [GenInstanceMarkerMap()]
        label_augs.append(BinarizeLabel())
        if not view:
            label_augs.append(imgaug.CenterCrop(output_shape))
        else:
            label_augs.append(imgaug.CenterCrop(input_shape))

        return shape_augs, input_augs, label_augs

    else:
        return shape_augs, input_augs
Example #20
0
    def __init__(self, verbose=True):
        self.model_config = os.environ['H_PROFILE'] if 'H_PROFILE' in os.environ else ''
        data_config = defaultdict(lambda: None, yaml.load(open('config.yml'), Loader=yaml.FullLoader)[self.model_config])

        # Validation
        assert (data_config['input_prefix'] is not None)
        assert (data_config['output_prefix'] is not None)

        # Load config yml file
        self.log_path = data_config['output_prefix'] # log root path
        self.data_dir_root = os.path.join(data_config['input_prefix'], data_config['data_dir']) # without modes
        
        self.extract_type = data_config['extract_type']
        self.data_modes = data_config['data_modes']
        self.win_size = data_config['win_size']
        self.step_size = data_config['step_size']
        self.img_ext = '.png' if data_config['img_ext'] is None else data_config['img_ext']

        for step in ['preproc', 'extract', 'train', 'infer', 'export', 'process']:
            exec(f"self.out_{step}_root = os.path.join(data_config['output_prefix'], '{step}')")
        #self.out_preproc_root = os.path.join(data_config['output_prefix'], 'preprocess')
        #self.out_extract_root = os.path.join(data_config['output_prefix'], 'extract')

        self.img_dirs = {k: v for k, v in zip(self.data_modes, [os.path.join(self.data_dir_root, mode, 'Images') 
                for mode in self.data_modes])}
        self.labels_dirs = {k: v for k, v in zip(self.data_modes, [os.path.join(self.data_dir_root, mode, 'Labels') 
                for mode in self.data_modes])}

        # normalized images
        self.out_preproc = None
        if data_config['include_preproc']:
            self.out_preproc = {k: v for k, v in zip(self.data_modes, [os.path.join(self.out_preproc_root, self.model_config, mode, 'Images') 
                    for mode in self.data_modes])}
        
        if data_config['stain_norm'] is not None:
            # self.target_norm = f"{self._data_dir}/{self.data_modes[0]}/'Images'/{data_config['stain_norm']['target']}{self.img_ext}"
            self.norm_target = os.path.join(self.data_dir_root, data_config['stain_norm']['mode'], 'Images', f"{data_config['stain_norm']['image']}{self.img_ext}")
            self.norm_brightness = data_config['stain_norm']['norm_brightness']
        
        self.normalized = (data_config['include_preproc']) and (data_config['stain_norm'] is not None)
        win_code = '{}_{}x{}_{}x{}{}'.format(self.model_config, self.win_size[0], self.win_size[1], self.step_size[0], self.step_size[1], '_stain_norm' if self.normalized else '')
        self.out_extract = {k: v for k, v in zip(self.data_modes, [os.path.join(self.out_extract_root, win_code, mode, 'Annotations') 
            for mode in self.data_modes])}

        # init model params
        self.seed = data_config['seed']
        mode = data_config['mode']
        self.model_type = data_config['model_type']
        self.type_classification = data_config['type_classification']

        # Some semantic segmentation network like micronet, nr_types will replace nr_classes if type_classification=True
        self.nr_classes = 2 # Nuclei Pixels vs Background
        
        self.nuclei_type_dict = data_config['nuclei_types']
        self.nr_types = len(self.nuclei_type_dict.values()) + 1 # plus background

        #### Dynamically setting the config file into variable
        if mode == 'hover':
            config_file = importlib.import_module('opt.hover') # np_hv, np_dist
        else:
            config_file = importlib.import_module('opt.other') # fcn8, dcan, etc.
        config_dict = config_file.__getattribute__(self.model_type)

        for variable, value in config_dict.items():
            self.__setattr__(variable, value)

        # patches are stored as numpy arrays with N channels 
        # ordering as [Image][Nuclei Pixels][Nuclei Type][Additional Map] - training data
        # Ex: with type_classification=True
        #     HoVer-Net: RGB - Nuclei Pixels - Type Map - Horizontal and Vertical Map
        # Ex: with type_classification=False
        #     Dist     : RGB - Nuclei Pixels - Distance Map
        data_code_dict = {
            'unet'     : '536x536_84x84',
            'dist'     : '536x536_84x84',
            'fcn8'     : '512x512_256x256',
            'dcan'     : '512x512_256x256',
            'segnet'   : '512x512_256x256',
            'micronet' : '504x504_252x252',
            'np_hv'    : '540x540_80x80',
            'np_dist'  : '540x540_80x80',
        }

        self.color_palete = {
        'Inflammatory': [0.0, 255.0, 0.0],  # bright green
        'Dead cells': [32.0, 32.0, 32.0],    # black
        'Neoplastic cells': [0.0, 0.0, 255.0], # dark blue      # aka Epithelial malignant
        'Epithelial': [255.0, 255.0, 0.0],  # bright yellow     # aka Epithelial healthy
        'Misc': [0.0, 0.0, 0.0],            # pure black
        'Spindle': [0.0, 255.0, 255.0],     # cyan              # Fibroblast, Muscle and Endothelial cells
        'Connective': [0.0, 220.0, 220.0],  # darker cyan       # plus Soft tissue cells
        'Background': [255.0, 0.0, 170.0],  # pink
        ###
        'light green': [170.0, 255.0, 0.0], # light green
        'purple': [170.0, 0.0, 255.0],      # purple
        'orange': [255.0, 170.0, 0.0],      # orange
        'red': [255.0, 0.0, 0.0]           # red
        }

        # self.model_name = f"{self.model_config}-{self.model_type}-{data_config['input_augs']}-{data_config['exp_id']}"
        self.model_name = f"{self.model_config}-{data_config['input_augs']}-{data_config['exp_id']}"

        self.data_ext = '.npy' if data_config['data_ext'] is None else data_config['data_ext']
        # list of directories containing validation patches

        # self.train_dir = data_config['train_dir']
        # self.valid_dir = data_config['valid_dir']
        if data_config['include_extract']:
            self.train_dir = [os.path.join(self.out_extract_root, win_code, x) for x in data_config['train_dir']]
            self.valid_dir = [os.path.join(self.out_extract_root, win_code, x) for x in data_config['valid_dir']]
        else:
            self.train_dir = [os.path.join(self.data_dir_root, x) for x in data_config['train_dir']]
            self.valid_dir = [os.path.join(self.data_dir_root, x) for x in data_config['valid_dir']]


        # nr of processes for parallel processing input
        self.nr_procs_train = 8 if data_config['nr_procs_train'] is None else data_config['nr_procs_train']
        self.nr_procs_valid = 4 if data_config['nr_procs_valid'] is None else data_config['nr_procs_valid']

        self.input_norm = data_config['input_norm'] # normalize RGB to 0-1 range

        #self.save_dir = os.path.join(data_config['output_prefix'], 'train', self.model_name)
        self.save_dir = os.path.join(self.out_train_root, self.model_name)

        #### Info for running inference
        self.inf_auto_find_chkpt = data_config['inf_auto_find_chkpt']
        # path to checkpoints will be used for inference, replace accordingly
        
        if self.inf_auto_find_chkpt:
            self.inf_model_path = os.path.join(self.save_dir)
        else:
            self.inf_model_path = os.path.join(data_config['input_prefix'], 'models', data_config['inf_model'])
        #self.save_dir + '/model-19640.index'

        # output will have channel ordering as [Nuclei Type][Nuclei Pixels][Additional]
        # where [Nuclei Type] will be used for getting the type of each instance
        # while [Nuclei Pixels][Additional] will be used for extracting instances

        # TODO: encode the file extension for each folder?
        # list of [[root_dir1, codeX, subdirA, subdirB], [root_dir2, codeY, subdirC, subdirD] etc.]
        # code is used together with 'inf_output_dir' to make output dir for each set
        self.inf_imgs_ext = '.png' if data_config['inf_imgs_ext'] is None else data_config['inf_imgs_ext']

        # rootdir, outputdirname, subdir1, subdir2(opt) ...
        self.inf_data_list = [os.path.join(data_config['input_prefix'], x) for x in data_config['inf_data_list']]
        
        model_used = self.model_name if self.inf_auto_find_chkpt else f"{data_config['inf_model'].split('.')[0]}"

        self.inf_auto_metric = data_config['inf_auto_metric']
        self.inf_output_dir = os.path.join(self.out_infer_root, f"{model_used}.{''.join(data_config['inf_data_list']).replace('/', '_').rstrip('_')}.{self.inf_auto_metric}")
        self.model_export_dir = os.path.join(self.out_export_root, self.model_name)
        self.remap_labels = data_config['remap_labels']
        self.outline = data_config['outline']
        self.skip_types = [self.nuclei_type_dict[x.strip()] for x in data_config['skip_types']] if data_config['skip_types'] != [''] else []

        self.inf_auto_comparator = data_config['inf_auto_comparator']

        # For inference during evalutaion mode i.e run by inferer.py
        self.eval_inf_input_tensor_names = ['images']
        self.eval_inf_output_tensor_names = ['predmap-coded']
        # For inference during training mode i.e run by trainer.py
        self.train_inf_output_tensor_names = ['predmap-coded', 'truemap-coded']

        assert data_config['input_augs'] != '' or data_config['input_augs'] is not None

        #### Policies
        policies = {
            'p_standard': [
                imgaug.RandomApplyAug(
                imgaug.RandomChooseAug([
                    GaussianBlur(),
                    MedianBlur(),
                    imgaug.GaussianNoise(),
                ]), 0.5
            ),
            imgaug.RandomOrderAug([
                imgaug.Hue((-8, 8), rgb=True), 
                imgaug.Saturation(0.2, rgb=True),
                imgaug.Brightness(26, clip=True),  
                imgaug.Contrast((0.75, 1.25), clip=True),
                ]),
            imgaug.ToUint8(),
            ], 
            'p_hed_random': [
                imgaug.RandomApplyAug(
                imgaug.RandomChooseAug([
                    GaussianBlur(),
                    MedianBlur(),
                    imgaug.GaussianNoise(),
                    #
                    imgaug.ColorSpace(cv2.COLOR_RGB2HSV),
                    imgaug.ColorSpace(cv2.COLOR_HSV2RGB),
                    #
                    eqRGB2HED(),
                ]), 0.5
            ),
            # standard color augmentation
            imgaug.RandomOrderAug([
                imgaug.Hue((-8, 8), rgb=True), 
                imgaug.Saturation(0.2, rgb=True),
                imgaug.Brightness(26, clip=True),  
                imgaug.Contrast((0.75, 1.25), clip=True),
                ]),
            imgaug.ToUint8(),
            ], 
            'p_linear_1': [
                imgaug.RandomApplyAug(
                    imgaug.RandomChooseAug([
                        GaussianBlur(),
                        MedianBlur(),
                        imgaug.GaussianNoise(),
                    ]), 0.5
                ),
                linearAugmentation(),
                imgaug.ToUint8(),
            ], 
            'p_linear_2': [
                imgaug.RandomApplyAug(
                    imgaug.RandomChooseAug([
                        GaussianBlur(),
                        MedianBlur(),
                        imgaug.GaussianNoise(),
                    ]), 0.5
                ),
                linearAugmentation(),
                imgaug.RandomOrderAug([
                    imgaug.Hue((-8, 8), rgb=True), 
                    imgaug.Saturation(0.2, rgb=True),
                    imgaug.Brightness(26, clip=True),  
                    imgaug.Contrast((0.8, 1.20), clip=True), # 0.75, 1.25
                ]),
            imgaug.ToUint8(),
            ],
            'p_linear_3': [
                imgaug.RandomApplyAug(
                    imgaug.RandomChooseAug([
                        GaussianBlur(),
                        MedianBlur(),
                        imgaug.GaussianNoise(),
                    ]), 0.5
                ),
                imgaug.RandomChooseAug([
                    linearAugmentation(),
                    imgaug.RandomOrderAug([
                        imgaug.Hue((-5, 5), rgb=True), 
                        imgaug.Saturation(0.2, rgb=True),
                        imgaug.Brightness(26, clip=True),  
                        imgaug.Contrast((0.5, 1.5), clip=True), # 0.75, 1.25
                    ])
                ]),
            imgaug.ToUint8(),
            ]
        }

        self.input_augs = policies[(data_config['input_augs'])]

        # Checks
        if verbose:
            print("--------")
            print("Config info:")
            print("--------")
            print(f"Log path: <{self.log_path}>")
            print(f"Extraction out dirs: <{self.out_extract}>")
            print("--------")
            print("Training")
            print(f"Model name: <{self.model_name}>")
            print(f"Input img dirs: <{self.img_dirs}>")
            print(f"Input labels dirs: <{self.labels_dirs}>")
            print(f"Train out dir: <{self.save_dir}>")
            print("--------")
            print("Inference")
            print(f"Auto-find trained model: <{self.inf_auto_find_chkpt}>")
            print(f"Inference model path dir: <{self.inf_model_path}>")
            print(f"Input inference path: <{self.inf_data_list}>")
            print(f"Output inference path: <{self.inf_output_dir}>")
            print(f"Model export out: <{self.model_export_dir}>")
            print("--------")
            print()
def prepare_video(args):
    data_root, video_path, video_width, video_height, video_length, video_downsample_ratio, video_index, batch_size, shared_mem_idx, is_training, is_ucf101, is_imagenet, is_zipped = args

    augs = [
        imgaug.BrightnessScale((0.6, 1.4), clip=False),
        imgaug.Contrast((0.6, 1.4), clip=False),
        imgaug.Saturation(0.4, rgb=True),
        imgaug.Lighting(0.1,
                        eigval=np.asarray([0.2175, 0.0188, 0.0045]) * 255.0,
                        eigvec=np.array([[-0.5675, 0.7192, 0.4009],
                                         [-0.5808, -0.0045, -0.8140],
                                         [-0.5836, -0.6948, 0.4203]],
                                        dtype='float32')),
    ]
    random.shuffle(augs)

    video_mem = np.frombuffer(shared_mem[shared_mem_idx],
                              np.ctypeslib.ctypes.c_float)
    video_mem = video_mem.reshape(
        (batch_size, video_length, video_height, video_width, 3))

    pathgen = lambda x: os.path.join(data_root, str(video_path), x)

    frames = None
    if os.path.isdir(pathgen('')):
        frames = sorted(os.listdir(pathgen('')))
    else:
        frames = [os.path.join(data_root, str(video_path))]

    crop_frames = is_training
    flip_frames = bool(
        random.getrandbits(1)) and is_training and (is_ucf101 or is_imagenet)
    add_noise = bool(random.getrandbits(1)) and is_training and not is_training

    # choose a random time to start the video
    num_frames = len(frames) // video_downsample_ratio
    t_offset = 0
    stride_offset = 0
    if is_training and num_frames > video_length:
        t_offset = random.choice(range(num_frames - video_length))
        stride_offset = random.choice(range(video_downsample_ratio))

    num_frames = min(len(frames) // video_downsample_ratio, video_length)
    assert num_frames != 0, 'num frames in video cannot be 0: {}'.format(
        video_path)

    round2pow2 = lambda x: 2**(x - 1).bit_length()
    pow2_width = round2pow2(video_width)
    pow2_height = round2pow2(video_height)
    crop_margin_x = pow2_width - video_width
    crop_margin_y = pow2_height - video_height

    x1 = random.choice(list(range(crop_margin_x)))
    y1 = random.choice(list(range(crop_margin_y)))
    x2 = pow2_width - (crop_margin_x - x1)
    y2 = pow2_height - (crop_margin_y - y1)

    rotation_angle = random.choice(list(range(-10, 10,
                                              1))) if is_training else 0

    video_mem[video_index, :, :, :] = 0
    for i in range(num_frames):
        image_idx = video_downsample_ratio * (i + t_offset)
        image_idx = min(image_idx + stride_offset, len(frames))

        image = None
        if is_imagenet and is_zipped:
            fname = pathgen(frames[image_idx])

            jpeg_filename = os.path.basename(fname)
            jpeg_dirname = os.path.basename(os.path.dirname(fname))
            zip_filepath = os.path.dirname(fname) + '.zip'
            f = zipfile.ZipFile(zip_filepath, 'r')
            compress_jpeg = io.BytesIO(
                f.read(os.path.join(jpeg_dirname, jpeg_filename)))
            image = Image.open(compress_jpeg)

        else:
            image = Image.open(pathgen(
                frames[image_idx]))  # in RGB order by default

        image = image.convert('RGB')
        #image = image.convert('L') # convert to YUV and grab Y-component

        if crop_frames:
            image = image.resize((pow2_width, pow2_height), PIL.Image.BICUBIC)
            image = image.crop(box=(x1, y1, x2, y2))
        else:
            image = image.resize((video_width, video_height),
                                 PIL.Image.BICUBIC)

        if flip_frames:
            image = image.transpose(PIL.Image.FLIP_LEFT_RIGHT)

        if rotation_angle != 0:
            image = image.rotate(rotation_angle)

        image = np.asarray(image, dtype=np.uint8)
        assert image.shape == (
            video_width, video_height,
            3), 'cropped image must be {} but was {}'.format(
                (video_width, video_height, 3), image.shape)

        if is_imagenet:
            if is_training:
                for a in augs:
                    a.reset_state()
                    image = a.augment(image)

            image = np.asarray(image, dtype=np.float32)
            image = image * (1.0 / 255)
            mean = np.asarray([0.485, 0.456, 0.406])
            std = np.asarray([0.229, 0.224, 0.225])
            image = (image - mean) / std

        else:
            image = np.asarray(image, dtype=np.uint32)
            image = image - 116  # center on mean value of 116 (as computed in preprocessing step)
            image = np.clip(image, -128, 128)

        image = np.asarray(image, dtype=np.float32)

        if add_noise:
            noise = np.random.normal(loc=0,
                                     scale=5,
                                     size=(video_height, video_width, 3))
            image = image + noise
            image = np.clip(image, -128, 128)

        video_mem[video_index, i, :, :, :] = image

    return {'num_frames': num_frames, 'video_path': video_path}