示例#1
0
def get_data(train_or_test, cifar_classnum):
    isTrain = train_or_test == 'train'
    if cifar_classnum == 10:
        ds = dataset.Cifar10(train_or_test)
    else:
        ds = dataset.Cifar100(train_or_test)
    if isTrain:
        augmentors = [
            imgaug.RandomCrop((30, 30)),
            imgaug.Flip(horiz=True),
            imgaug.Brightness(63),
            imgaug.Contrast((0.2, 1.8)),
            imgaug.GaussianDeform([(0.2, 0.2), (0.2, 0.8), (0.8, 0.8),
                                   (0.8, 0.2)], (30, 30), 0.2, 3),
            imgaug.MeanVarianceNormalize(all_channel=True)
        ]
    else:
        augmentors = [
            imgaug.CenterCrop((30, 30)),
            imgaug.MeanVarianceNormalize(all_channel=True)
        ]
    ds = AugmentImageComponent(ds, augmentors)
    ds = BatchData(ds, 128, remainder=not isTrain)
    if isTrain:
        ds = PrefetchDataZMQ(ds, 5)
    return ds
示例#2
0
def get_data(train_or_test):
    isTrain = train_or_test == 'train'
    ds = dataset.Cifar100(train_or_test)
    pp_mean = ds.get_per_pixel_mean(('train', ))
    if isTrain:
        augmentors = [
            imgaug.CenterPaste((40, 40)),
            imgaug.RandomCrop((32, 32)),
            imgaug.Flip(horiz=True),
            imgaug.MapImage(lambda x: x - pp_mean),
        ]
    else:
        augmentors = [imgaug.MapImage(lambda x: x - pp_mean)]
    ds = AugmentImageComponent(ds, augmentors)
    ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
    return ds
def get_cifar_augmented_data(subset,
                             options,
                             do_multiprocess=True,
                             do_validation=False,
                             shuffle=None):
    isTrain = subset == 'train' and do_multiprocess
    shuffle = shuffle if shuffle is not None else isTrain
    if options.num_classes == 10 and options.ds_name == 'cifar10':
        ds = dataset.Cifar10(subset,
                             shuffle=shuffle,
                             do_validation=do_validation)
        cutout_length = 16
        n_holes = 1
    elif options.num_classes == 100 and options.ds_name == 'cifar100':
        ds = dataset.Cifar100(subset,
                              shuffle=shuffle,
                              do_validation=do_validation)
        cutout_length = 8
        n_holes = 1
    else:
        raise ValueError(
            'Number of classes must be set to 10(default) or 100 for CIFAR')
    logger.info('{} set has n_samples: {}'.format(subset, len(ds.data)))
    pp_mean = ds.get_per_pixel_mean()
    if isTrain:
        logger.info('Will do cut-out with length={} n_holes={}'.format(
            cutout_length, n_holes))
        augmentors = [
            imgaug.CenterPaste((40, 40)),
            imgaug.RandomCrop((32, 32)),
            imgaug.Flip(horiz=True),
            imgaug.MapImage(lambda x: (x - pp_mean) / 128.0),
            Cutout(length=cutout_length, n_holes=n_holes),
        ]
    else:
        augmentors = [imgaug.MapImage(lambda x: (x - pp_mean) / 128.0)]
    ds = AugmentImageComponent(ds, augmentors)
    ds = BatchData(ds,
                   options.batch_size // options.nr_gpu,
                   remainder=not isTrain)
    if do_multiprocess:
        ds = PrefetchData(ds, 3, 2)
    return ds
示例#4
0
文件: test.py 项目: lilujunai/role-kd
def get_data_cifar(train_or_test, cifar_classnum, image_shape, batch_size):
    from tensorpack.dataflow import dataset
    isTrain = train_or_test == 'train'
    if cifar_classnum == 10:
        ds = dataset.Cifar10(train_or_test)
    else:
        ds = dataset.Cifar100(train_or_test)

    if isTrain:
        import numpy as np
        import cv2
        augmentors = [
            GoogleNetResize(target_shape=image_shape),
            imgaug.RandomOrderAug([
                imgaug.BrightnessScale((0.6, 1.4), clip=False),
                imgaug.Contrast((0.6, 1.4), clip=False),
                imgaug.Saturation(0.4, rgb=False),
                # rgb-bgr conversion for the constants copied from fb.resnet.torch
                imgaug.Lighting(
                    0.1,
                    eigval=np.asarray([0.2175, 0.0188, 0.0045][::-1]) * 255.0,
                    eigvec=np.array([[-0.5675, 0.7192, 0.4009],
                                     [-0.5808, -0.0045, -0.8140],
                                     [-0.5836, -0.6948, 0.4203]],
                                    dtype='float32')[::-1, ::-1])
            ]),
            imgaug.Flip(horiz=True),
        ]
    else:
        import cv2
        re_size = 256
        if image_shape == 32:
            re_size = 40
        augmentors = [
            imgaug.ResizeShortestEdge(re_size, cv2.INTER_CUBIC),
            imgaug.CenterCrop((image_shape, image_shape)),
        ]
    ds = AugmentImageComponent(ds, augmentors)
    ds = BatchData(ds, batch_size, remainder=not isTrain)
    if isTrain:
        ds = PrefetchDataZMQ(ds, 5)
    return ds
示例#5
0
def get_data(train_or_test, classnum):
    isTrain = train_or_test == 'train'
    if classnum == 10:
        ds = dataset.Cifar10(train_or_test)
    else:
        ds = dataset.Cifar100(train_or_test)
    #  if isTrain:
    #      augmentors = [
    #          imgaug.Flip(horiz=True),
    #          imgaug.Brightness(63),
    #          imgaug.Contrast((0.2, 1.8)),
    #          imgaug.MeanVarianceNormalize(all_channel=True)
    #      ]
    #  else:
    #      augmentors = [
    #          imgaug.MeanVarianceNormalize(all_channel=True)
    #      ]
    #  ds = AugmentImageComponent(ds, augmentors)
    ds = BatchData(ds, 256, remainder=not isTrain)
    return ds
示例#6
0
 def get_data(self, train_or_test):
     isTrain = train_or_test == 'train'
     ds = dataset.Cifar100(train_or_test, dir='.')
     pp_mean = ds.get_per_pixel_mean()
     if isTrain:
         augmentors = [
             imgaug.CenterPaste((40, 40)),
             imgaug.RandomCrop((32, 32)),
             imgaug.Flip(horiz=True),
             # imgaug.Brightness(20),
             # imgaug.Contrast((0.6,1.4)),
             imgaug.MapImage(lambda x: x - pp_mean),
         ]
     else:
         augmentors = [imgaug.MapImage(lambda x: x - pp_mean)]
     ds = AugmentImageComponent(ds, augmentors)
     ds = BatchData(ds, self.batch_size, remainder=not isTrain)
     if isTrain:
         ds = PrefetchData(ds, 3, 2)
     return ds
示例#7
0
def get_data(train_or_test, args):
    isTrain = train_or_test == 'train'
    if args.classnum == 10:
        ds = dataset.Cifar10(train_or_test)
    else:
        ds = dataset.Cifar100(train_or_test)
    data_size = ds.size()
    pp_mean = ds.get_per_pixel_mean()
    if args.model == 'resnet':
        if isTrain:
            augmentors = [
                imgaug.CenterPaste((40, 40)),
                imgaug.RandomCrop((32, 32)),
                imgaug.Flip(horiz=True),
                imgaug.MapImage(lambda x: x - pp_mean),
            ]
        else:
            augmentors = [imgaug.MapImage(lambda x: x - pp_mean)]
    else:
        if isTrain:
            augmentors = [
                imgaug.RandomCrop((30, 30)),
                imgaug.Flip(horiz=True),
                imgaug.Brightness(63),
                imgaug.Contrast((0.2, 1.8)),
                imgaug.MeanVarianceNormalize(all_channel=True)
            ]
        else:
            augmentors = [
                imgaug.CenterCrop((30, 30)),
                imgaug.MeanVarianceNormalize(all_channel=True)
            ]
    ds = AugmentImageComponent(ds, augmentors)
    ds = BatchData(ds, args.batch_size, remainder=not isTrain)
    # ds = BatchData(ds, 16, remainder=not isTrain)
    if isTrain:
        ds = PrefetchDataZMQ(ds, 5)
    return ds, data_size