def create_transform(input_size,
                     is_training=False,
                     use_prefetcher=False,
                     color_jitter=0.4,
                     auto_augment=None,
                     interpolation='bilinear',
                     mean=IMAGENET_DEFAULT_MEAN,
                     std=IMAGENET_DEFAULT_STD,
                     re_prob=0.,
                     re_mode='const',
                     re_count=1,
                     re_num_splits=0,
                     crop_pct=None,
                     tf_preprocessing=False,
                     separate=False,
                     squish=False,
                     min_crop_factor=0.08,
                     do_8_rotations=False):
    if isinstance(input_size, tuple):
        img_size = input_size[-2:]
    else:
        img_size = input_size

    if tf_preprocessing and use_prefetcher:
        assert not separate, "Separate transforms not supported for TF preprocessing"
        from timm.data.tf_preprocessing import TfPreprocessTransform
        transform = TfPreprocessTransform(is_training=is_training,
                                          size=img_size,
                                          interpolation=interpolation)
    else:
        if is_training:
            transform = transforms_imagenet_train(
                img_size,
                scale=(min_crop_factor, 1),
                color_jitter=color_jitter,
                auto_augment=auto_augment,
                interpolation=interpolation,
                use_prefetcher=use_prefetcher,
                mean=mean,
                std=std,
                re_prob=re_prob,
                re_mode=re_mode,
                re_count=re_count,
                re_num_splits=re_num_splits,
                separate=separate,
                squish=squish,
                do_8_rotations=do_8_rotations)
        else:
            assert not separate, "Separate transforms not supported for validation preprocessing"
            transform = transforms_imagenet_eval(img_size,
                                                 interpolation=interpolation,
                                                 use_prefetcher=use_prefetcher,
                                                 mean=mean,
                                                 std=std,
                                                 crop_pct=crop_pct,
                                                 squish=squish)

    return transform
Exemple #2
0
def create_transform(
        input_size,
        is_training=False,
        use_prefetcher=False,
        color_jitter=0.4,
        auto_augment=None,
        interpolation='bilinear',
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        crop_pct=None,
        tf_preprocessing=False,
        use_aug=False):

    if isinstance(input_size, tuple):
        img_size = input_size[-2:]
    else:
        img_size = input_size

    if tf_preprocessing and use_prefetcher:
        from timm.data.tf_preprocessing import TfPreprocessTransform
        transform = TfPreprocessTransform(
            is_training=is_training, size=img_size, interpolation=interpolation)
    else:
        if is_training:
            transform = transforms_imagenet_train(
                img_size,
                color_jitter=color_jitter,
                auto_augment=auto_augment,
                interpolation=interpolation,
                use_prefetcher=use_prefetcher,
                mean=mean,
                std=std,
                use_aug=use_aug)
        else:
            transform = transforms_imagenet_eval(
                img_size,
                interpolation=interpolation,
                use_prefetcher=use_prefetcher,
                mean=mean,
                std=std,
                crop_pct=crop_pct)

    return transform
def create_loader(
    dataset,
    input_size,
    batch_size,
    is_training=False,
    use_prefetcher=True,
    rand_erase_prob=0.,
    rand_erase_mode='const',
    interpolation='bilinear',
    mean=IMAGENET_DEFAULT_MEAN,
    std=IMAGENET_DEFAULT_STD,
    num_workers=1,
    distributed=False,
    crop_pct=None,
    collate_fn=None,
    tf_preprocessing=False,
):
    if isinstance(input_size, tuple):
        img_size = input_size[-2:]
    else:
        img_size = input_size

    if tf_preprocessing and use_prefetcher:
        from timm.data.tf_preprocessing import TfPreprocessTransform
        transform = TfPreprocessTransform(is_training=is_training,
                                          size=img_size)
    else:
        if is_training:
            transform = transforms_imagenet_train(
                img_size,
                interpolation=interpolation,
                use_prefetcher=use_prefetcher,
                mean=mean,
                std=std)
        else:
            transform = transforms_imagenet_eval(img_size,
                                                 interpolation=interpolation,
                                                 use_prefetcher=use_prefetcher,
                                                 mean=mean,
                                                 std=std,
                                                 crop_pct=crop_pct)

    dataset.transform = transform

    sampler = None
    if distributed:
        if is_training:
            sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        else:
            # This will add extra duplicate entries to result in equal num
            # of samples per-process, will slightly alter validation results
            sampler = OrderedDistributedSampler(dataset)

    if collate_fn is None:
        collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=sampler is None and is_training,
        num_workers=num_workers,
        sampler=sampler,
        collate_fn=collate_fn,
        drop_last=is_training,
    )
    if use_prefetcher:
        loader = PrefetchLoader(
            loader,
            rand_erase_prob=rand_erase_prob if is_training else 0.,
            rand_erase_mode=rand_erase_mode,
            mean=mean,
            std=std)

    return loader
def create_transform(input_size,
                     is_training=False,
                     use_prefetcher=False,
                     no_aug=False,
                     scale=None,
                     ratio=None,
                     hflip=0.5,
                     vflip=0.,
                     color_jitter=0.4,
                     auto_augment=None,
                     interpolation='bilinear',
                     mean=IMAGENET_DEFAULT_MEAN,
                     std=IMAGENET_DEFAULT_STD,
                     re_prob=0.,
                     re_mode='const',
                     re_count=1,
                     re_num_splits=0,
                     crop_pct=None,
                     tf_preprocessing=False,
                     separate=False):

    if isinstance(input_size, (tuple, list)):
        img_size = input_size[-2:]
    else:
        img_size = input_size

    if tf_preprocessing and use_prefetcher:
        assert not separate, "Separate transforms not supported for TF preprocessing"
        from timm.data.tf_preprocessing import TfPreprocessTransform
        transform = TfPreprocessTransform(is_training=is_training,
                                          size=img_size,
                                          interpolation=interpolation)
    else:
        if is_training and no_aug:
            assert not separate, "Cannot perform split augmentation with no_aug"
            transform = transforms_noaug_train(img_size,
                                               interpolation=interpolation,
                                               use_prefetcher=use_prefetcher,
                                               mean=mean,
                                               std=std)
        elif is_training:
            transform = transforms_imagenet_train(
                img_size,
                scale=scale,
                ratio=ratio,
                hflip=hflip,
                vflip=vflip,
                color_jitter=color_jitter,
                auto_augment=auto_augment,
                interpolation=interpolation,
                use_prefetcher=use_prefetcher,
                mean=mean,
                std=std,
                re_prob=re_prob,
                re_mode=re_mode,
                re_count=re_count,
                re_num_splits=re_num_splits,
                separate=separate)
        else:
            assert not separate, "Separate transforms not supported for validation preprocessing"
            transform = transforms_imagenet_eval(img_size,
                                                 interpolation=interpolation,
                                                 use_prefetcher=use_prefetcher,
                                                 mean=mean,
                                                 std=std,
                                                 crop_pct=crop_pct)

    return transform