def transforms_imagenet_eval(img_size=224,
                             crop_pct=None,
                             interpolation='bilinear',
                             use_prefetcher=False,
                             mean=IMAGENET_DEFAULT_MEAN,
                             std=IMAGENET_DEFAULT_STD):
    crop_pct = 1

    if isinstance(img_size, tuple):
        assert len(img_size) == 2
        if img_size[-1] == img_size[-2]:
            # fall-back to older behaviour so Resize scales to shortest edge if target is square
            scale_size = int(math.floor(img_size[0] / crop_pct))
        else:
            scale_size = tuple([int(x / crop_pct) for x in img_size])
    else:
        scale_size = int(math.floor(img_size / crop_pct))

    tfl = [
        transforms.Resize(scale_size, _pil_interp(interpolation)),
        transforms.CenterCrop(img_size),
    ]
    if use_prefetcher:
        # prefetcher and collate will handle tensor conversion and norm
        tfl += [ToNumpy()]
    else:
        tfl += [
            transforms.ToTensor(),
            transforms.Normalize(mean=torch.tensor(mean),
                                 std=torch.tensor(std))
        ]

    return transforms.Compose(tfl)
def transforms_noaug_train(
    img_size=224,
    interpolation='bilinear',
    use_prefetcher=False,
    mean=IMAGENET_DEFAULT_MEAN,
    std=IMAGENET_DEFAULT_STD,
):
    if interpolation == 'random':
        # random interpolation not supported with no-aug
        interpolation = 'bilinear'
    tfl = [
        transforms.Resize(img_size,
                          interpolation=str_to_interp_mode(interpolation)),
        transforms.CenterCrop(img_size)
    ]
    if use_prefetcher:
        # prefetcher and collate will handle tensor conversion and norm
        tfl += [ToNumpy()]
    else:
        tfl += [
            transforms.ToTensor(),
            transforms.Normalize(mean=torch.tensor(mean),
                                 std=torch.tensor(std))
        ]
    return transforms.Compose(tfl)
def transforms_imagenet_eval(img_size=224,
                             crop_pct=None,
                             interpolation='bilinear',
                             use_prefetcher=False,
                             mean=IMAGENET_DEFAULT_MEAN,
                             std=IMAGENET_DEFAULT_STD):
    crop_pct = crop_pct or DEFAULT_CROP_PCT

    scale = tuple((0.08, 1.0))  # default imagenet scale range
    ratio = tuple((3. / 4., 4. / 3.))  # default imagenet ratio range

    if isinstance(img_size, (tuple, list)):
        assert len(img_size) == 2
        if img_size[-1] == img_size[-2]:
            # fall-back to older behaviour so Resize scales to shortest edge if target is square
            scale_size = int(math.floor(img_size[0] / crop_pct))
        else:
            scale_size = tuple([int(x / crop_pct) for x in img_size])
    else:
        scale_size = int(math.floor(img_size / crop_pct))

    tfl = [
        transforms.Resize(scale_size, _pil_interp(interpolation)),
        transforms.CenterCrop(img_size),
        RandomResizedCropAndInterpolation(img_size,
                                          scale=scale,
                                          ratio=ratio,
                                          interpolation=interpolation),
    ]
    color_jitter = (float(0.4), ) * 3
    tfl += [transforms.ColorJitter(*color_jitter)]
    tfl += [transforms.RandomHorizontalFlip(p=0.5)]
    if use_prefetcher:
        # prefetcher and collate will handle tensor conversion and norm
        tfl += [ToNumpy()]
    else:
        tfl += [
            transforms.ToTensor(),
            transforms.Normalize(mean=torch.tensor(mean),
                                 std=torch.tensor(std))
        ]

    return transforms.Compose(tfl)
def transforms_imagenet_train(
    img_size=224,
    scale=(0.08, 1.0),
    color_jitter=0.4,
    auto_augment=None,
    interpolation='random',
    use_prefetcher=False,
    mean=IMAGENET_DEFAULT_MEAN,
    std=IMAGENET_DEFAULT_STD,
    re_prob=0.,
    re_mode='const',
    re_count=1,
    re_num_splits=0,
    separate=False,
    squish=False,
    do_8_rotations=False,
):
    """
    If separate==True, the transforms are returned as a tuple of 3 separate transforms
    for use in a mixing dataset that passes
     * all data through the first (primary) transform, called the 'clean' data
     * a portion of the data through the secondary transform
     * normalizes and converts the branches above with the third, final transform
    """
    if squish:
        if not isinstance(img_size, tuple):
            img_size = (img_size, img_size)
        resize = transforms.Resize(img_size, _pil_interp('bilinear'))
    else:
        resize = RandomResizedCropAndInterpolation(img_size,
                                                   scale=scale,
                                                   interpolation=interpolation)

    if do_8_rotations:
        primary_tfl = [resize, RandomRotation()]
    else:
        primary_tfl = [resize, transforms.RandomHorizontalFlip()]

    secondary_tfl = []
    if auto_augment:
        assert isinstance(auto_augment, str)
        if isinstance(img_size, tuple):
            img_size_min = min(img_size)
        else:
            img_size_min = img_size
        aa_params = dict(
            translate_const=int(img_size_min * 0.45),
            img_mean=tuple([min(255, round(255 * x)) for x in mean]),
        )
        if interpolation and interpolation != 'random':
            aa_params['interpolation'] = _pil_interp(interpolation)
        if auto_augment.startswith('rand'):
            secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
        elif auto_augment.startswith('augmix'):
            aa_params['translate_pct'] = 0.3
            secondary_tfl += [
                augment_and_mix_transform(auto_augment, aa_params)
            ]
        else:
            secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
    elif color_jitter is not None:
        # color jitter is enabled when not using AA
        if isinstance(color_jitter, (list, tuple)):
            # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
            # or 4 if also augmenting hue
            assert len(color_jitter) in (3, 4)
        else:
            # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
            color_jitter = (float(color_jitter), ) * 3
        secondary_tfl += [transforms.ColorJitter(*color_jitter)]

    final_tfl = []
    if use_prefetcher:
        # prefetcher and collate will handle tensor conversion and norm
        final_tfl += [ToNumpy()]
    else:
        final_tfl += [
            transforms.ToTensor(),
            transforms.Normalize(mean=torch.tensor(mean),
                                 std=torch.tensor(std))
        ]
        if re_prob > 0.:
            final_tfl.append(
                RandomErasing(re_prob,
                              mode=re_mode,
                              max_count=re_count,
                              num_splits=re_num_splits,
                              device='cpu'))

    if separate:
        return transforms.Compose(primary_tfl), transforms.Compose(
            secondary_tfl), transforms.Compose(final_tfl)
    else:
        return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)