def get_transforms(stage: str = None, mode: str = None, image_size=224, one_hot_classes=None): pre_transform_fn = pre_transforms(image_size=image_size) if mode == "train": post_transform_fn = Compose([hard_transform(), post_transforms()]) elif mode in ["valid", "infer"]: post_transform_fn = post_transforms() else: raise NotImplementedError() if mode in ["train", "valid"]: result = MixinAdapter( mixin=RotateMixin(input_key="image", output_key="rotation_factor", targets_key="targets", one_hot_classes=one_hot_classes), pre_transforms=Augmentor( dict_key="image", augment_fn=lambda x: pre_transform_fn(image=x)["image"]), post_transforms=Augmentor( dict_key="image", augment_fn=lambda x: post_transform_fn(image=x)["image"])) elif mode in ["infer"]: result_fn = Compose([pre_transform_fn, post_transform_fn]) result = Augmentor( dict_key="image", augment_fn=lambda x: result_fn(image=x)["image"]) else: raise NotImplementedError() return result
def get_transforms(stage: str = None, mode: str = None, input_size: int = 224): train_image_transforms = [ OpticalDistortion(distort_limit=0.3, p=0.3), JpegCompression(quality_lower=50, p=0.8), HorizontalFlip(p=0.5), MotionBlur(p=0.5), ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=20, p=0.5), RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.2, p=0.4), HueSaturationValue(hue_shift_limit=3, sat_shift_limit=20, val_shift_limit=30, p=0.4), CLAHE(clip_limit=2, p=0.3) ] infer_image_transforms = [ Resize(input_size, input_size), Normalize(), ToTorchTensor(p=1.0) ] stack = TorchStack() train_images_fn = GroupTransform(transforms=train_image_transforms + infer_image_transforms) valid_images_fn = GroupTransform(transforms=infer_image_transforms) def train_aug_fn(images): images = train_images_fn(images) images = stack(images) return images def valid_aug_fn(images): images = valid_images_fn(images) images = stack(images) return images train_transforms = Augmentor(dict_key="features", augment_fn=lambda x: train_aug_fn(x)) valid_transforms = Augmentor(dict_key="features", augment_fn=lambda x: valid_aug_fn(x)) if mode == "train": return train_transforms else: return valid_transforms
def _get_transform(**params) -> Callable: key_value_flag = params.pop("_key_value", False) if key_value_flag: transforms_composition = { transform_key: ConfigExperiment._get_transform( # noqa: WPS437 **transform_params) for transform_key, transform_params in params.items() } transform = AugmentorCompose({ key: Augmentor( dict_key=key, augment_fn=transform, input_key=key, output_key=key, ) for key, transform in transforms_composition.items() }) else: if "transforms" in params: transforms_composition = [ ConfigExperiment._get_transform( # noqa: WPS437 **transform_params) for transform_params in params["transforms"] ] params.update(transforms=transforms_composition) transform = TRANSFORMS.get_from_params(**params) return transform
def prepare_transforms(*, mode, stage=None, **kwargs): if mode == "train": transform_fn = train_transform(image_size=IMG_SIZE) elif mode in ["valid", "infer"]: transform_fn = valid_transform(image_size=IMG_SIZE) else: raise NotImplementedError return Augmentor(dict_key="image", augment_fn=lambda x: transform_fn(image=x)["image"])
def get_transforms(): train_transforms = compose( [pre_transforms(), hard_transforms(), post_transforms()]) valid_transforms = compose([pre_transforms(), post_transforms()]) show_transforms = compose([pre_transforms(), hard_transforms()]) # Takes an image from the input dictionary by the key `dict_key` # and performs `train_transforms` on it. train_data_transforms = Augmentor( dict_key="features", augment_fn=lambda x: train_transforms(image=x)["image"]) # Similarly for the validation part of the dataset. # we only perform squaring, normalization and ToTensor valid_data_transforms = Augmentor( dict_key="features", augment_fn=lambda x: valid_transforms(image=x)["image"]) return train_data_transforms, valid_data_transforms, show_transforms
def get_transform(phase, img_size): assert phase in {'train', 'valid'}, f'invalid phase: {phase}' if phase == 'train': transforms = compose([ pre_transforms(image_size=img_size), hard_transforms(), post_transforms() ]) else: transforms = compose([pre_transforms(), post_transforms()]) data_transforms = Augmentor( dict_key="features", augment_fn=lambda x: transforms(image=x)["image"]) return data_transforms
def get_transforms(stage: str = None, mode: str = None, image_size: int = 224, one_hot_classes: int = None): pre_transform_fn = pre_transforms(image_size=image_size) if mode == "train": post_transform_fn = Compose( [hard_transform(image_size=image_size), post_transforms()]) elif mode in ["valid", "infer"]: post_transform_fn = post_transforms() else: raise NotImplementedError() result_fn = Compose([pre_transform_fn, post_transform_fn]) result = Augmentor(dict_key="image", augment_fn=lambda x: result_fn(image=x)["image"]) return result
pre_transforms(config.size), hard_transforms(), post_transforms() ]) valid_transforms = plant.compose([ pre_transforms(config.size), post_transforms() ]) show_transforms = plant.compose([ pre_transforms(config.size), hard_transforms() ]) train_data_transforms = Augmentor( dict_key="features", augment_fn=lambda x: train_transforms(image=x)["image"] ) valid_data_transforms = Augmentor( dict_key="features", augment_fn=lambda x: valid_transforms(image=x)["image"] ) loaders = plant.get_loaders( train_transforms_fn=train_data_transforms, valid_transforms_fn=valid_data_transforms, batch_size=config.batch_size, config=config ) model = plant.get_model(config.model_name, config.num_classes)
data = list(zip(images, masks)) train_data = data[:n_images // 2] valid_data = data[n_images // 2:] # Data loaders augmentations = [ # TODO specify augmentations (e.g. histogram normalization) SegAugmentWrapper(RandomFlip(0.5), image_key='features', mask_key='targets') ] transformations = [ Augmentor( dict_key="features", augment_fn=lambda x: \ torch.from_numpy(x.copy().astype(np.float32) / 256.).unsqueeze_(0).float()), Augmentor( dict_key="targets", augment_fn=lambda x: \ torch.from_numpy(x.copy()).unsqueeze_(0).float()), ] train_data_transform = transforms.Compose(augmentations + transformations) valid_data_transform = transforms.Compose(transformations) open_fn = lambda x: {"features": x[0], "targets": x[1]} train_loader = UtilsFactory.create_loader(train_data, open_fn=open_fn, dict_transform=train_data_transform,
RandomContrast(), RandomBrightness(), ], p=0.5), HueSaturationValue(p=0.5), ], p=p) AUG_TRAIN = strong_aug(p=0.75) AUG_INFER = Compose([ Resize(IMG_SIZE, IMG_SIZE), Normalize(), ]) TRAIN_TRANSFORM_FN = [ Augmentor( dict_key="image", augment_fn=lambda x: AUG_TRAIN(image=x)["image"]), ] INFER_TRANSFORM_FN = [ Augmentor( dict_key="image", augment_fn=lambda x: AUG_INFER(image=x)["image"]), Augmentor( dict_key="image", augment_fn=lambda x: torch.tensor(x).permute(2, 0, 1)), ] # ---- Data ----
# In[ ]: import collections import numpy as np import torch import torchvision import torchvision.transforms as transforms from catalyst.data.augmentor import Augmentor from catalyst.dl.utils import UtilsFactory bs = 1 num_workers = 0 data_transform = transforms.Compose([ Augmentor( dict_key="features", augment_fn=lambda x: \ torch.from_numpy(x.copy().astype(np.float32) / 255.).unsqueeze_(0)), Augmentor( dict_key="features", augment_fn=transforms.Normalize( (0.5, ), (0.5, ))), Augmentor( dict_key="targets", augment_fn=lambda x: \ torch.from_numpy(x.copy().astype(np.float32) / 255.).unsqueeze_(0)) ]) open_fn = lambda x: {"features": x[0], "targets": x[1]} loaders = collections.OrderedDict()
def create_dataloders( train_file: str, valid_file: str, root_folder: str, meta_info_file: str, num_classes: int, one_hot_encoding: bool, bs: int, num_workers: int, augmenters: Dict = None, ): train_data = _prepare(train_file, root_folder) valid_data = _prepare(valid_file, root_folder) train_augmenter = augmenters['train'] valid_augmenter = augmenters['valid'] train_transforms_fn = transforms.Compose([ Augmentor( dict_key="features", augment_fn=lambda x: train_augmenter(samples=x, sample_rate=16000)) ]) # Similarly for the validation part of the dataset. # we only perform squaring, normalization and ToTensor valid_transforms_fn = transforms.Compose([ Augmentor( dict_key="features", augment_fn=lambda x: valid_augmenter(samples=x, sample_rate=16000)) ]) compose = [ AudioReader( input_key="filepath", output_key="features", ), ScalarReader(input_key="label", output_key="targets", default_value=-1, dtype=np.int64), ] if one_hot_encoding: compose.append( ScalarReader( input_key="label", output_key="targets_one_hot", default_value=-1, dtype=np.int64, one_hot_classes=num_classes, )) open_fn = ReaderCompose(compose) train_loader = catalyst_utils.get_loader( train_data, open_fn=open_fn, dict_transform=train_transforms_fn, batch_size=bs, num_workers=num_workers, shuffle= True, # shuffle data only if Sampler is not specified (PyTorch requirement) ) valid_loader = catalyst_utils.get_loader( valid_data, open_fn=open_fn, dict_transform=valid_transforms_fn, batch_size=bs, num_workers=1, shuffle=False, ) loaders = OrderedDict() loaders["train"] = train_loader loaders["valid"] = valid_loader return loaders