コード例 #1
0
def imagenet_augment_transform(
    size: int = 224,
    scale: Optional[float] = None,
    ratio: Optional[float] = None,
    interpolation: str = "random",
    hflip: Union[float, bool] = 0.5,
    vflip: Union[float, bool] = False,
    color_jitter: Union[Sequence, float] = 0.4,
    auto_augment: Optional[str] = None,
    mean: Optional[Sequence[float]] = IMAGENET_DEFAULT_MEAN,
) -> T.Compose:
    """
    The default image transform with data augmentation.It is often useful for training models on Imagenet.

    Adapted from: https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/transforms_factory.py
    """

    scale = tuple(scale or (0.08, 1.0))  # default imagenet scale range
    ratio = tuple(ratio
                  or (3.0 / 4.0, 4.0 / 3.0))  # default imagenet ratio range

    transforms = [
        RandomResizedCropAndInterpolation(size, scale, ratio, interpolation),
    ]

    if hflip and hflip > 0:
        transforms.append(T.RandomHorizontalFlip(p=hflip))
    if vflip and vflip > 0.0:
        transforms.append(T.RandomVerticalFlip(p=vflip))

    if auto_augment:
        assert isinstance(auto_augment, str)
        if isinstance(size, (tuple, list)):
            size_min = min(size)
        else:
            size_min = size

        aa_params = dict(
            translate_const=int(size_min * 0.45),
            img_mean=tuple([min(255, round(255 * x)) for x in mean]),
        )

        if interpolation and interpolation != "random":
            aa_params["interpolation"] = _pil_interp(interpolation)
        if auto_augment.startswith("rand"):
            transforms += [rand_augment_transform(auto_augment, aa_params)]
        else:
            transforms += [auto_augment_transform(auto_augment, aa_params)]

    elif color_jitter is not None:
        # color jitter is enabled when not using AA
        if isinstance(color_jitter, (list, tuple)):
            # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
            # or 4 if also augmenting hue
            assert len(color_jitter) in (3, 4)
        else:
            # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
            color_jitter = (float(color_jitter), ) * 3
        transforms += [T.ColorJitter(*color_jitter)]
    return T.Compose(transforms)
def instantiate_transforms(cfg: DictConfig, global_config: DictConfig = None):
    "loades in individual transformations"
    if cfg._target_ == "aa":
        img_size_min = global_config.input.input_size
        aa_params = dict(
            translate_const=int(img_size_min * 0.45),
            img_mean=tuple(
                [min(255, round(255 * x)) for x in global_config.input.mean]),
        )

        if (global_config.input.interpolation
                and global_config.input.interpolation != "random"):
            aa_params["interpolation"] = _pil_interp(
                global_config.input.interpolation)

        # Load autoaugment transformations
        if cfg.policy.startswith("rand"):
            return rand_augment_transform(cfg.policy, aa_params)
        elif cfg.policy.startswith("augmix"):
            aa_params["translate_pct"] = 0.3
            return augment_and_mix_transform(cfg.policy, aa_params)
        else:
            return auto_augment_transform(cfg.policy, aa_params)

    else:
        return instantiate(cfg)
コード例 #3
0
    def build_train_transform(self,
                              image_size=None,
                              print_log=True,
                              auto_augment='rand-m9-mstd0.5'):
        if image_size is None:
            image_size = self.image_size
        # if print_log:
        #     print('Color jitter: %s, resize_scale: %s, img_size: %s' %
        #           (self.distort_color, self.resize_scale, image_size))

        # if self.distort_color == 'torch':
        #     color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
        # elif self.distort_color == 'tf':
        #     color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
        # else:
        #     color_transform = None

        if isinstance(image_size, list):
            resize_transform_class = MyRandomResizedCrop
            print(
                'Use MyRandomResizedCrop: %s, \t %s' %
                MyRandomResizedCrop.get_candidate_image_size(),
                'sync=%s, continuous=%s' %
                (MyRandomResizedCrop.SYNC_DISTRIBUTED,
                 MyRandomResizedCrop.CONTINUOUS))
            img_size_min = min(image_size)
        else:
            resize_transform_class = transforms.RandomResizedCrop
            img_size_min = image_size

        train_transforms = [
            resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
            transforms.RandomHorizontalFlip(),
        ]

        aa_params = dict(
            translate_const=int(img_size_min * 0.45),
            img_mean=tuple([
                min(255, round(255 * x)) for x in
                [0.48933587508932375, 0.5183537408957618, 0.5387914411673883]
            ]),
        )
        aa_params['interpolation'] = _pil_interp('bicubic')
        train_transforms += [rand_augment_transform(auto_augment, aa_params)]

        # if color_transform is not None:
        #     train_transforms.append(color_transform)
        train_transforms += [
            transforms.ToTensor(),
            self.normalize,
        ]

        train_transforms = transforms.Compose(train_transforms)
        return train_transforms
コード例 #4
0
def transforms_imagenet_train(
    img_size=224,
    scale=(0.08, 1.0),
    color_jitter=0.4,
    auto_augment=None,
    interpolation='random',
    use_prefetcher=False,
    mean=IMAGENET_DEFAULT_MEAN,
    std=IMAGENET_DEFAULT_STD,
    re_prob=0.,
    re_mode='const',
    re_count=1,
    re_num_splits=0,
    separate=False,
    squish=False,
    do_8_rotations=False,
):
    """
    If separate==True, the transforms are returned as a tuple of 3 separate transforms
    for use in a mixing dataset that passes
     * all data through the first (primary) transform, called the 'clean' data
     * a portion of the data through the secondary transform
     * normalizes and converts the branches above with the third, final transform
    """
    if squish:
        if not isinstance(img_size, tuple):
            img_size = (img_size, img_size)
        resize = transforms.Resize(img_size, _pil_interp('bilinear'))
    else:
        resize = RandomResizedCropAndInterpolation(img_size,
                                                   scale=scale,
                                                   interpolation=interpolation)

    if do_8_rotations:
        primary_tfl = [resize, RandomRotation()]
    else:
        primary_tfl = [resize, transforms.RandomHorizontalFlip()]

    secondary_tfl = []
    if auto_augment:
        assert isinstance(auto_augment, str)
        if isinstance(img_size, tuple):
            img_size_min = min(img_size)
        else:
            img_size_min = img_size
        aa_params = dict(
            translate_const=int(img_size_min * 0.45),
            img_mean=tuple([min(255, round(255 * x)) for x in mean]),
        )
        if interpolation and interpolation != 'random':
            aa_params['interpolation'] = _pil_interp(interpolation)
        if auto_augment.startswith('rand'):
            secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
        elif auto_augment.startswith('augmix'):
            aa_params['translate_pct'] = 0.3
            secondary_tfl += [
                augment_and_mix_transform(auto_augment, aa_params)
            ]
        else:
            secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
    elif color_jitter is not None:
        # color jitter is enabled when not using AA
        if isinstance(color_jitter, (list, tuple)):
            # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
            # or 4 if also augmenting hue
            assert len(color_jitter) in (3, 4)
        else:
            # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
            color_jitter = (float(color_jitter), ) * 3
        secondary_tfl += [transforms.ColorJitter(*color_jitter)]

    final_tfl = []
    if use_prefetcher:
        # prefetcher and collate will handle tensor conversion and norm
        final_tfl += [ToNumpy()]
    else:
        final_tfl += [
            transforms.ToTensor(),
            transforms.Normalize(mean=torch.tensor(mean),
                                 std=torch.tensor(std))
        ]
        if re_prob > 0.:
            final_tfl.append(
                RandomErasing(re_prob,
                              mode=re_mode,
                              max_count=re_count,
                              num_splits=re_num_splits,
                              device='cpu'))

    if separate:
        return transforms.Compose(primary_tfl), transforms.Compose(
            secondary_tfl), transforms.Compose(final_tfl)
    else:
        return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)