def simClrAugmentations(dataset_config, given_shape): """ Applies augmentations for the training of contrastive learning. Augmentations: max_div_images --> random_crop_1 --> random_flip_2 --> jitter_images_2 --> clip_2 |--> random_crop_2--> random_flip_1 --> jitter_images_1--> clip_1 :param given_shape: (Array) Array containing the shape of the data from the input pipeline. E.g. [None,None,3] :param dataset_config: (Dictionary) The config of the dataset to train and val on. :return: dataset_augmentor_train: (DatasetAugmentationBuilder) Handles the augmentation of the training dataset. :return: dataset_augmentor_val: (DatasetAugmentationBuilder) Handles the augmentation of the validation datasets. """ data_shape = dataset_config["dataShape"] crop_shape = [1, data_shape[0], data_shape[1], data_shape[2]] max_div_images = MaxDivNormalizer(255.0) # Order matters here! First image to image2! random_crop_1 = RandomCropWithResize(input_name="image", output_name="image2", crop_shape=crop_shape) random_crop_2 = RandomCropWithResize(input_name="image", output_name="image", crop_shape=crop_shape) random_flip_1 = RandomHorizontalFlip(input_name="image", output_name="image") random_flip_2 = RandomHorizontalFlip(input_name="image2", output_name="image2") jitter_images_1 = RandomColorJitter(input_name="image", output_name="image") jitter_images_2 = RandomColorJitter(input_name="image2", output_name="image2") clip_1 = ClipByValue(input_name="image", output_name="image") clip_2 = ClipByValue(input_name="image2", output_name="image2") dataset_augmentor_train = DatasetAugmentationBuilder( preprocessors=[max_div_images], generators=[ random_crop_1, random_crop_2, random_flip_1, random_flip_2, jitter_images_1, jitter_images_2 ], output_preprocessors=[clip_1, clip_2]) dataset_augmentor_val = dataset_augmentor_train return dataset_augmentor_train, dataset_augmentor_val
def rotationCaeAugmentations(dataset_config, given_shape): """ Applies augmentations for the training of color restoration autoencoders. Augmentations: max_div_images --> RandomCropWithResize --> RandomHorizontalFlip --> RandomColorJitter --> Random90xRotation --> clip_by_value :param given_shape: (Array) Array containing the shape of the data from the input pipeline. E.g. [None,None,3] :param dataset_config: (Dictionary) The config of the dataset to train and val on. :return: dataset_augmentor_train: (DatasetAugmentationBuilder) Handles the augmentation of the training dataset. :return: dataset_augmentor_val: (DatasetAugmentationBuilder) Handles the augmentation of the validation datasets. """ data_shape = dataset_config["dataShape"] crop_shape = [1, data_shape[0], data_shape[1], data_shape[2]] max_div_images = MaxDivNormalizer(255.0) dataset_augmentor_train = DatasetAugmentationBuilder( preprocessors=[ max_div_images, RandomCropWithResize(crop_shape=crop_shape), RandomHorizontalFlip(), RandomColorJitter() ], generators=[Random90xRotation()], output_preprocessors=[ClipByValue()]) dataset_augmentor_val = dataset_augmentor_train return dataset_augmentor_train, dataset_augmentor_val
def targetTaskAugmentations(dataset_config, given_shape): """ Applies base augmentations for the training of target models. Train Augmentations: max_div_images --> RandomCropWithResize --> RandomHorizontalFlip --> RandomColorJitter --> ClipByValue one_hot Val Augmentations: max_div_images --> (CenterCropWithResize) --> (ClipByValue) one_hot :param given_shape: (Array) Array containing the shape of the data from the input pipeline. E.g. [None,None,3] :param dataset_config: (Dictionary) The config of the dataset to train and val on. :return: dataset_augmentor_train: (DatasetAugmentationBuilder) Handles the augmentation of the training dataset. :return: dataset_augmentor_val: (DatasetAugmentationBuilder) Handles the augmentation of the validation datasets. """ data_shape = dataset_config["dataShape"] crop_shape = [1, data_shape[0], data_shape[1], data_shape[2]] if "labelName" in dataset_config.keys(): label_name = dataset_config["labelName"] else: label_name = "label" max_div_images = MaxDivNormalizer(255.0) one_hot = OneHot(dataset_config["numClasses"], input_name=label_name, output_name="label") dataset_augmentor_train = DatasetAugmentationBuilder(preprocessors=[ one_hot, max_div_images, RandomCropWithResize(crop_shape=crop_shape), RandomHorizontalFlip(), RandomColorJitter(), ClipByValue() ]) val_preprocessors = [one_hot, max_div_images] if list(given_shape) != list(data_shape): val_preprocessors.append(CenterCropWithResize(crop_shape=crop_shape)) val_preprocessors.append(ClipByValue()) dataset_augmentor_val = DatasetAugmentationBuilder( preprocessors=val_preprocessors) return dataset_augmentor_train, dataset_augmentor_val