示例#1
0
    def get_transforms(stage: str = None,
                       mode: str = None,
                       image_size=224,
                       one_hot_classes=None):
        pre_transform_fn = pre_transforms(image_size=image_size)

        if mode == "train":
            post_transform_fn = Compose([hard_transform(), post_transforms()])
        elif mode in ["valid", "infer"]:
            post_transform_fn = post_transforms()
        else:
            raise NotImplementedError()

        if mode in ["train", "valid"]:
            result = MixinAdapter(
                mixin=RotateMixin(input_key="image",
                                  output_key="rotation_factor",
                                  targets_key="targets",
                                  one_hot_classes=one_hot_classes),
                pre_transforms=Augmentor(
                    dict_key="image",
                    augment_fn=lambda x: pre_transform_fn(image=x)["image"]),
                post_transforms=Augmentor(
                    dict_key="image",
                    augment_fn=lambda x: post_transform_fn(image=x)["image"]))
        elif mode in ["infer"]:
            result_fn = Compose([pre_transform_fn, post_transform_fn])
            result = Augmentor(
                dict_key="image",
                augment_fn=lambda x: result_fn(image=x)["image"])
        else:
            raise NotImplementedError()

        return result
示例#2
0
    def get_transforms(stage: str = None,
                       mode: str = None,
                       input_size: int = 224):
        train_image_transforms = [
            OpticalDistortion(distort_limit=0.3, p=0.3),
            JpegCompression(quality_lower=50, p=0.8),
            HorizontalFlip(p=0.5),
            MotionBlur(p=0.5),
            ShiftScaleRotate(shift_limit=0.1,
                             scale_limit=0.2,
                             rotate_limit=20,
                             p=0.5),
            RandomBrightnessContrast(brightness_limit=0.3,
                                     contrast_limit=0.2,
                                     p=0.4),
            HueSaturationValue(hue_shift_limit=3,
                               sat_shift_limit=20,
                               val_shift_limit=30,
                               p=0.4),
            CLAHE(clip_limit=2, p=0.3)
        ]
        infer_image_transforms = [
            Resize(input_size, input_size),
            Normalize(),
            ToTorchTensor(p=1.0)
        ]
        stack = TorchStack()

        train_images_fn = GroupTransform(transforms=train_image_transforms +
                                         infer_image_transforms)
        valid_images_fn = GroupTransform(transforms=infer_image_transforms)

        def train_aug_fn(images):
            images = train_images_fn(images)
            images = stack(images)
            return images

        def valid_aug_fn(images):
            images = valid_images_fn(images)
            images = stack(images)
            return images

        train_transforms = Augmentor(dict_key="features",
                                     augment_fn=lambda x: train_aug_fn(x))

        valid_transforms = Augmentor(dict_key="features",
                                     augment_fn=lambda x: valid_aug_fn(x))

        if mode == "train":
            return train_transforms
        else:
            return valid_transforms
示例#3
0
    def _get_transform(**params) -> Callable:
        key_value_flag = params.pop("_key_value", False)

        if key_value_flag:
            transforms_composition = {
                transform_key: ConfigExperiment._get_transform(  # noqa: WPS437
                    **transform_params)
                for transform_key, transform_params in params.items()
            }

            transform = AugmentorCompose({
                key: Augmentor(
                    dict_key=key,
                    augment_fn=transform,
                    input_key=key,
                    output_key=key,
                )
                for key, transform in transforms_composition.items()
            })
        else:
            if "transforms" in params:
                transforms_composition = [
                    ConfigExperiment._get_transform(  # noqa: WPS437
                        **transform_params)
                    for transform_params in params["transforms"]
                ]
                params.update(transforms=transforms_composition)

            transform = TRANSFORMS.get_from_params(**params)

        return transform
示例#4
0
    def prepare_transforms(*, mode, stage=None, **kwargs):

        if mode == "train":
            transform_fn = train_transform(image_size=IMG_SIZE)
        elif mode in ["valid", "infer"]:
            transform_fn = valid_transform(image_size=IMG_SIZE)
        else:
            raise NotImplementedError

        return Augmentor(dict_key="image",
                         augment_fn=lambda x: transform_fn(image=x)["image"])
def get_transforms():
    train_transforms = compose(
        [pre_transforms(),
         hard_transforms(),
         post_transforms()])
    valid_transforms = compose([pre_transforms(), post_transforms()])

    show_transforms = compose([pre_transforms(), hard_transforms()])

    # Takes an image from the input dictionary by the key `dict_key`
    # and performs `train_transforms` on it.
    train_data_transforms = Augmentor(
        dict_key="features",
        augment_fn=lambda x: train_transforms(image=x)["image"])

    # Similarly for the validation part of the dataset.
    # we only perform squaring, normalization and ToTensor
    valid_data_transforms = Augmentor(
        dict_key="features",
        augment_fn=lambda x: valid_transforms(image=x)["image"])
    return train_data_transforms, valid_data_transforms, show_transforms
示例#6
0
def get_transform(phase, img_size):
    assert phase in {'train', 'valid'}, f'invalid phase: {phase}'
    if phase == 'train':
        transforms = compose([
            pre_transforms(image_size=img_size),
            hard_transforms(),
            post_transforms()
        ])
    else:
        transforms = compose([pre_transforms(), post_transforms()])

    data_transforms = Augmentor(
        dict_key="features", augment_fn=lambda x: transforms(image=x)["image"])
    return data_transforms
示例#7
0
    def get_transforms(stage: str = None,
                       mode: str = None,
                       image_size: int = 224,
                       one_hot_classes: int = None):
        pre_transform_fn = pre_transforms(image_size=image_size)

        if mode == "train":
            post_transform_fn = Compose(
                [hard_transform(image_size=image_size),
                 post_transforms()])
        elif mode in ["valid", "infer"]:
            post_transform_fn = post_transforms()
        else:
            raise NotImplementedError()

        result_fn = Compose([pre_transform_fn, post_transform_fn])
        result = Augmentor(dict_key="image",
                           augment_fn=lambda x: result_fn(image=x)["image"])

        return result
        pre_transforms(config.size),
        hard_transforms(),
        post_transforms()
    ])
    valid_transforms = plant.compose([
        pre_transforms(config.size),
        post_transforms()
    ])

    show_transforms = plant.compose([
        pre_transforms(config.size),
        hard_transforms()
    ])

    train_data_transforms = Augmentor(
        dict_key="features",
        augment_fn=lambda x: train_transforms(image=x)["image"]
    )
    valid_data_transforms = Augmentor(
        dict_key="features",
        augment_fn=lambda x: valid_transforms(image=x)["image"]
    )

    loaders = plant.get_loaders(
        train_transforms_fn=train_data_transforms,
        valid_transforms_fn=valid_data_transforms,
        batch_size=config.batch_size,
        config=config
    )

    model = plant.get_model(config.model_name, config.num_classes)
示例#9
0
data = list(zip(images, masks))

train_data = data[:n_images // 2]
valid_data = data[n_images // 2:]

# Data loaders
augmentations = [
    # TODO specify augmentations (e.g. histogram normalization)
    SegAugmentWrapper(RandomFlip(0.5),
                      image_key='features',
                      mask_key='targets')
]

transformations = [
    Augmentor(
        dict_key="features",
        augment_fn=lambda x: \
            torch.from_numpy(x.copy().astype(np.float32) / 256.).unsqueeze_(0).float()),
    Augmentor(
        dict_key="targets",
        augment_fn=lambda x: \
            torch.from_numpy(x.copy()).unsqueeze_(0).float()),
]

train_data_transform = transforms.Compose(augmentations + transformations)
valid_data_transform = transforms.Compose(transformations)

open_fn = lambda x: {"features": x[0], "targets": x[1]}

train_loader = UtilsFactory.create_loader(train_data,
                                          open_fn=open_fn,
                                          dict_transform=train_data_transform,
示例#10
0
            RandomContrast(),
            RandomBrightness(),
        ], p=0.5),
        HueSaturationValue(p=0.5),
    ], p=p)


AUG_TRAIN = strong_aug(p=0.75)
AUG_INFER = Compose([
    Resize(IMG_SIZE, IMG_SIZE),
    Normalize(),
])

TRAIN_TRANSFORM_FN = [
    Augmentor(
        dict_key="image",
        augment_fn=lambda x: AUG_TRAIN(image=x)["image"]),
]

INFER_TRANSFORM_FN = [
    Augmentor(
        dict_key="image",
        augment_fn=lambda x: AUG_INFER(image=x)["image"]),
    Augmentor(
        dict_key="image",
        augment_fn=lambda x: torch.tensor(x).permute(2, 0, 1)),
]


# ---- Data ----
示例#11
0
# In[ ]:

import collections
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from catalyst.data.augmentor import Augmentor
from catalyst.dl.utils import UtilsFactory

bs = 1
num_workers = 0

data_transform = transforms.Compose([
    Augmentor(
        dict_key="features",
        augment_fn=lambda x: \
            torch.from_numpy(x.copy().astype(np.float32) / 255.).unsqueeze_(0)),
    Augmentor(
        dict_key="features",
        augment_fn=transforms.Normalize(
            (0.5, ),
            (0.5, ))),
    Augmentor(
        dict_key="targets",
        augment_fn=lambda x: \
            torch.from_numpy(x.copy().astype(np.float32) / 255.).unsqueeze_(0))
])

open_fn = lambda x: {"features": x[0], "targets": x[1]}

loaders = collections.OrderedDict()
示例#12
0
def create_dataloders(
    train_file: str,
    valid_file: str,
    root_folder: str,
    meta_info_file: str,
    num_classes: int,
    one_hot_encoding: bool,
    bs: int,
    num_workers: int,
    augmenters: Dict = None,
):

    train_data = _prepare(train_file, root_folder)
    valid_data = _prepare(valid_file, root_folder)

    train_augmenter = augmenters['train']
    valid_augmenter = augmenters['valid']

    train_transforms_fn = transforms.Compose([
        Augmentor(
            dict_key="features",
            augment_fn=lambda x: train_augmenter(samples=x, sample_rate=16000))
    ])

    # Similarly for the validation part of the dataset.
    # we only perform squaring, normalization and ToTensor
    valid_transforms_fn = transforms.Compose([
        Augmentor(
            dict_key="features",
            augment_fn=lambda x: valid_augmenter(samples=x, sample_rate=16000))
    ])

    compose = [
        AudioReader(
            input_key="filepath",
            output_key="features",
        ),
        ScalarReader(input_key="label",
                     output_key="targets",
                     default_value=-1,
                     dtype=np.int64),
    ]

    if one_hot_encoding:
        compose.append(
            ScalarReader(
                input_key="label",
                output_key="targets_one_hot",
                default_value=-1,
                dtype=np.int64,
                one_hot_classes=num_classes,
            ))

    open_fn = ReaderCompose(compose)

    train_loader = catalyst_utils.get_loader(
        train_data,
        open_fn=open_fn,
        dict_transform=train_transforms_fn,
        batch_size=bs,
        num_workers=num_workers,
        shuffle=
        True,  # shuffle data only if Sampler is not specified (PyTorch requirement)
    )

    valid_loader = catalyst_utils.get_loader(
        valid_data,
        open_fn=open_fn,
        dict_transform=valid_transforms_fn,
        batch_size=bs,
        num_workers=1,
        shuffle=False,
    )

    loaders = OrderedDict()
    loaders["train"] = train_loader
    loaders["valid"] = valid_loader

    return loaders