transforms.RandomAffine(2, shear=2),
    transforms.RandomAffine(2, shear=2),
    transforms.RandomAffine(2, shear=2),
    #transforms.RandomCrop(224),
    transforms.RandomResizedCrop(150,
                                 scale=(0.25, 1.0),
                                 ratio=(0.9, 1.11),
                                 interpolation=2),
    transforms.ColorJitter(brightness=0.7,
                           contrast=0.7,
                           saturation=0.3,
                           hue=0.02),
    transforms.RandomGrayscale(p=0.75),
    QLoss(min_q=2, max_q=60),
    transforms.RandomChoice([
        transforms.RandomApply((Otsu(), ), p=0.1),
        transforms.RandomApply((Sauvola(2, 8), ), p=0.05)
    ]),
    transforms.ToTensor()
])

training = ImageFolder('lines/training', transform=trans)
training.target_transform = tgc.classMap.get_target_transform(
    training.class_to_idx)

validation = ImageFolder('lines/validation', transform=None)
validation.target_transform = tgc.classMap.get_target_transform(
    validation.class_to_idx)
best_validation = 0

data_loader = torch.utils.data.DataLoader(training,
                                          batch_size=64,
Esempio n. 2
0
import warnings
from scipy import ndimage
import cv2
from driver import std, mean
from os.path import join
from commons.utils import visualize

transforms.RandomChoice([
    transforms.ColorJitter(brightness=0.1),
    transforms.ColorJitter(contrast=0.2),
    transforms.ColorJitter(saturation=0.1),
    transforms.ColorJitter(hue=0.15),
    transforms.ColorJitter(brightness=0.1,
                           contrast=0.1,
                           saturation=0.1,
                           hue=0.1),
    transforms.ColorJitter(brightness=0.15,
                           contrast=0.15,
                           saturation=0.15,
                           hue=0.15),
    transforms.ColorJitter(brightness=0.2,
                           contrast=0.2,
                           saturation=0.2,
                           hue=0.2),
])


def rescale_crop(image, scale, num, mode, output_size=224):
    image_list = []
    #TODO  for cls, it should be h,w.
    #TODO but for seg, it could be w,h
    h, w = image.size
Esempio n. 3
0
               depth=depth,
               latent_size=128,
               gpu=0,
               seed=27)

dataloader = ProGANDataLoader(
    data_path='maua/datasets/flower_pix2pix/B',
    prescaled_data=True,
    prescaled_data_path='maua/datasets/flowerGAN_prescaled',
    transforms=tn.Compose([
        tn.Resize(2**depth),
        tn.RandomHorizontalFlip(),
        tn.RandomVerticalFlip(),
        tn.RandomChoice([
            tn.RandomRotation([0, 0]),
            tn.RandomRotation([90, 90]),
            tn.RandomRotation([270, 270])
        ]),
    ]))

model.train(dataloader=dataloader,
            fade_in=0.75,
            save_freq=25,
            log_freq=5,
            loss="r1-reg",
            num_epochs=75)

result = model(
    th.randn(1, 128) *
    3)  # messing with the latent vector has a big effect on output image
save_image(result, 'maua/output/progan_flower1.png')
Esempio n. 4
0
import torchvision
from torchvision import transforms

image_transforms = {
    'train':
    transforms.Compose([
        transforms.RandomPerspective(distortion_scale=0.2,
                                     p=0.1,
                                     interpolation=3),
        transforms.RandomChoice([
            transforms.CenterCrop(180),
            transforms.CenterCrop(160),
            transforms.CenterCrop(140),
            transforms.CenterCrop(120),
            transforms.Compose(
                [transforms.CenterCrop(280),
                 transforms.Grayscale(3)]),
            transforms.Compose([
                transforms.CenterCrop(200),
                transforms.Grayscale(3),
            ]),
        ]),
        transforms.Resize((224, 224)),
        transforms.ColorJitter(hue=(0.1, 0.2)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid':
    transforms.Compose([
        transforms.RandomPerspective(distortion_scale=0.2,
                                     p=0.1,
Esempio n. 5
0
                if (len(Frames) == self.n):
                    break
        if self.transform is not None:
            for i in range(len(Frames)):
                Frames[i] = self.transform(Image.fromarray(Frames[i]))
        return torch.cat(Frames)


if __name__ == '__main__':

    opt = TrainOptions().parse()
    tfms = transforms.Compose([
        transforms.RandomChoice([
            transforms.Resize(opt.loadSize, interpolation=1),
            transforms.Resize(opt.loadSize, interpolation=2),
            transforms.Resize(opt.loadSize, interpolation=3),
            transforms.Resize((opt.loadSize, opt.loadSize), interpolation=1),
            transforms.Resize((opt.loadSize, opt.loadSize), interpolation=2),
            transforms.Resize((opt.loadSize, opt.loadSize), interpolation=3)
        ]),
        transforms.RandomChoice([
            transforms.RandomResizedCrop(opt.fineSize, interpolation=1),
            transforms.RandomResizedCrop(opt.fineSize, interpolation=2),
            transforms.RandomResizedCrop(opt.fineSize, interpolation=3)
        ]),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    if ("video" not in opt.name and "synthetic" not in opt.name):
        opt.dataroot = './dataset/ilsvrc2012/%s/' % opt.phase
        dataset = torchvision.datasets.ImageFolder(opt.dataroot,
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1

    args.gpu = gpu
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    model = pretrainedmodels.__dict__[args.arch](pretrained=False)

    #     model.aux_logits=False

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # accelerate
    cudnn.benchmark = True

    # set dataset
    # set dataset
    train_dataset_transformed = TinyImageNetDataset(
        './TinyImageNet',
        './TinyImageNet/train.txt',
        transform=transforms.Compose([
            transforms.RandomApply(transforms.RandomRotation(20)),
            transforms.RandomChoice([
                transforms.RandomHorizontalFlip(0.5),
                transforms.RandomVerticalFlip(0.5),
                transforms.RandomRotation(20)
            ]),
            transforms.RandomChoice([
                transforms.ColorJitter(brightness=0.1),
                transforms.ColorJitter(contrast=0.1),
                transforms.ColorJitter(saturation=0.1)
            ]),
            transforms.RandomGrayscale(0.5),
            transforms.ToTensor(),
            transforms.RandomErasing(0.5),
            transforms.Normalize([0.4802, 0.4481, 0.3975],
                                 [0.2302, 0.2265, 0.2262])
        ]))

    train_dataset_original = TinyImageNetDataset(
        './TinyImageNet',
        './TinyImageNet/train.txt',
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4802, 0.4481, 0.3975],
                                 [0.2302, 0.2265, 0.2262])
        ]))
    val_dataset = TinyImageNetDataset(
        './TinyImageNet',
        './TinyImageNet/val.txt',
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4802, 0.4481, 0.3975],
                                 [0.2302, 0.2265, 0.2262])
        ]))
    test_dataset = TinyImageNetDataset(
        './TinyImageNet',
        './TinyImageNet/test.txt',
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.4802, 0.4481, 0.3975],
                                 [0.2302, 0.2265, 0.2262])
        ]))

    train_sampler = None

    train_loader_transformed = torch.utils.data.DataLoader(
        train_dataset_transformed,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler)

    train_loader_original = torch.utils.data.DataLoader(
        train_dataset_original,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader_transformed, model, criterion, optimizer, epoch,
              args)
        train(train_loader_original, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save
        best_acc1 = max(acc1, best_acc1)

        test(test_loader, model, args)
if using_label_smooth == 'True':
    criterion = loss.CrossEntropyLabelSmooth(10, epsilon=0.1)
else:
    criterion = nn.CrossEntropyLoss()

#训练中记录三个指标:验证集的最佳acc和对应的loss,验证集上的最低loss
metrics = {'best_acc': 10, 'best_acc_loss': 100, 'best_loss': 100}

train_transform = transforms.Compose([
    transforms.Resize((size + 32, size + 32)),
    transforms.RandomChoice([
        transforms.RandomCrop(size,
                              padding=1,
                              pad_if_needed=True,
                              padding_mode='edge'),
        transforms.RandomResizedCrop(size,
                                     scale=(resize_scale, 1.0),
                                     ratio=(0.8, 1.2))
    ]),
    transforms.RandomHorizontalFlip(),
    auto_augment.AutoAugment(dataset='CIFAR'),  #auto_augment
    transforms.ToTensor(),
    transforms.RandomErasing(p=erasing_prob),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((size, size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
Esempio n. 8
0
def return_data(args):
    # TODO: cnn_datasets return_data
    train_dset_dir = args.train_dset_dir
    test_dset_dir = args.test_dset_dir
    batch_size = args.batch_size
    num_workers = args.num_workers
    image_size = args.image_size
    time_window = args.time_window
    trivial_augmentation = bool(args.trivial_augmentation)
    sliding_augmentation = bool(args.sliding_augmentation)

    transform_list = [transforms.Resize((image_size, image_size))]

    if args.channel == 1:
        transform_list.append(transforms.Grayscale(num_output_channels=1))

    if trivial_augmentation:
        trivial_transform_list = [
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4,
                                   hue=0.1),
            transforms.RandomResizedCrop(image_size,
                                         scale=(0.8, 1.0),
                                         ratio=(1, 1)),
            RandomNoise(mean=0, std=10),
        ]
        transform_list.append(transforms.RandomChoice(trivial_transform_list))

    if sliding_augmentation:
        transform_list.append(RandomTimeWindow(time_window=time_window))
    else:
        transform_list.append(TimeWindow(time_window=time_window))

    transform_list.append(transforms.ToTensor())

    if args.channel == 1:
        transform_list.append(transforms.Normalize([0.5], [0.5]))
    else:
        transform_list.append(
            transforms.Normalize([0.5] * args.channel, [0.5] * args.channel))
    print(transform_list)
    transform = transforms.Compose(transform_list)

    # if args.channel == 1:
    #     transform = transforms.Compose([
    #         transforms.Resize((image_size, image_size)),
    #         transforms.Grayscale(num_output_channels=1),
    #         TimeWindow(time_window=time_window),
    #         transforms.ToTensor(),
    #         transforms.Normalize([0.5], [0.5]),
    #     ])
    # else:
    #     transform = transforms.Compose([
    #         transforms.Resize((image_size, image_size)),
    #         TimeWindow(time_window=time_window),
    #         transforms.ToTensor(),
    #         transforms.Normalize([0.5] * args.channel, [0.5] * args.channel),
    #     ])

    train_root = Path(train_dset_dir)
    test_root = Path(test_dset_dir)
    train_kwargs = {'root': train_root, 'transform': transform}
    test_kwargs = {'root': test_root, 'transform': transform}
    dset = ImageFolder

    train_data = dset(**train_kwargs)
    test_data = dset(**test_kwargs)
    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True)
    test_loader = DataLoader(test_data,
                             batch_size=252,
                             shuffle=True,
                             num_workers=num_workers,
                             pin_memory=True,
                             drop_last=True)

    data_loader = dict()
    data_loader['train'] = train_loader
    data_loader['test'] = test_loader

    return data_loader
Esempio n. 9
0
        else:
            img = Image.open(self.df[index]).convert('RGB')
            img = self.transform(img)
            return img, torch.from_numpy(np.array(0))

    def __len__(self):
        return len(self.df)


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
    transforms.Resize([299, 299]),
    transforms.RandomRotation(15),
    transforms.RandomChoice(
        [transforms.Resize([256, 256]),
         transforms.CenterCrop([256, 256])]),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
    normalize,
])

val_transform = transforms.Compose([
    transforms.Resize([config.img_size, config.img_size]),
    transforms.ToTensor(),
    normalize,
])

test_transform = transforms.Compose([
    transforms.Resize([config.img_size, config.img_size]),
Esempio n. 10
0
    def __getitem__(self, index):
        #prev_time=time.time();

        #index2=index%self.listnum;
        index2 = index // 5
        path = self.path_list[index2][0]
        part = self.path_list[index2][1]
        path2 = facail_part_path + '/' + part + '/images/' + path + ".jpg"
        if (self.choice != None):
            image_trans = transforms.Compose([
                ToPILImage(),
                transforms.RandomChoice(self.choice.augmentation[index % 5]),
                #transforms.RandomChoice(self.choice[index//self.listnum]),
                Resize(size=(self.choicesize, self.choicesize),
                       interpolation=Image.NEAREST),
                ToTensor(),
                LabelUpdata(self.choicesize)
            ])
        else:
            image_trans = None
        image = Image.open(path2)
        if (part == "eye2" or part == "eyebrow2"):
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
        label = []
        #label.append(Image.new('L', (image.size[0],image.size[1]),color=1))
        path2 = facail_part_path + '/' + part + '/labels/' + path + '/'
        '''
        for i in range(2,10):
            image2=Image.open(path2+path+"_lbl"+str(i//10)+str(i%10)+".png")   
            image2=image2.convert('L');            
            label.append(image2);        
        sample={'image':image,'label':label,'index':index,'part':self.face_part};
        if (image_trans!=None):
            sample=image_trans(sample);
        sample['label']=torch.cat(tuple(sample['label']),0);
        #print("sample['label'].size=",sample['label'].size());
        if (self.face_part=="eye"):                                                
            if (sample['label'][4].sum()>sample['label'][3].sum()):
                sample['label'][3]=sample['label'][4]                        
            sample['label']=sample['label'][[0,3]]
        if (self.face_part=="eyebrow"):            
            if (sample['label'][2].sum()>sample['label'][1].sum()):
                sample['label'][1]=sample['label'][2]   
            sample['label']=sample['label'][[0,1]]
        if (self.face_part=="nose"):                                          
            sample['label']=sample['label'][[0,5]]
        if (self.face_part=="mouth"):            
            sample['label']=sample['label'][[0,6,7,8]]
        '''
        if (part == "eye1"):
            for i in range(4, 5):
                image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                    str(i % 10) + ".png").convert('L')
                #image2=image2.convert('L');
                label.append(image2)
        if (part == "eye2"):
            for i in range(5, 6):
                image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                    str(i % 10) + ".png").convert('L')
                #image2=image2.convert('L');
                image2 = image2.transpose(Image.FLIP_LEFT_RIGHT)
                label.append(image2)
        if (part == "eyebrow1"):
            for i in range(2, 3):
                image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                    str(i % 10) + ".png").convert('L')
                #image2=image2.convert('L');
                label.append(image2)
        if (part == "eyebrow2"):
            for i in range(3, 4):
                image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                    str(i % 10) + ".png").convert('L')
                #image2=image2.convert('L');
                image2 = image2.transpose(Image.FLIP_LEFT_RIGHT)
                label.append(image2)
        if (part == "nose"):
            for i in range(6, 7):
                image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                    str(i % 10) + ".png").convert('L')
                #image2=image2.convert('L');
                label.append(image2)
        if (part == "mouth"):
            for i in range(7, 10):
                image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                    str(i % 10) + ".png").convert('L')
                #image2=image2.convert('L');
                label.append(image2)

        for i in range(len(label)):
            label[i] = np.array(label[i])
        bg = 255 - np.sum(label, axis=0, keepdims=True)
        labels = np.concatenate([bg, label], axis=0)  # [L + 1, 64, 64]
        labels = np.uint8(labels)
        #labels = [TF.to_pil_image(labels[i])for i in range(labels.shape[0])]
        #now_time=time.time();
        #print("getitem time part1:",now_time-prev_time);
        #pre_time=now_time;
        sample = {
            'image': image,
            'label': labels,
            'index': index,
            'part': self.face_part
        }
        if (image_trans != None):
            sample = image_trans(sample)
            sample['label'][0] = torch.sum(
                sample['label'][1:sample['label'].shape[0]],
                dim=0,
                keepdim=True)  # 1 x 64 x 64
            sample['label'][0] = 1 - sample['label'][0]
        '''
        sample['label']=torch.cat(tuple(sample['label']),0);#[L,64,64]
        lnum=len(sample['label']);
        sample['label']=torch.argmax(sample['label'],dim=0,keepdim=False);#[64,64]                
        sample['label']=sample['label'].unsqueeze(dim=0);#[1,64,64]        
        sample['label']=torch.zeros(lnum,self.choicesize,self.choicesize).scatter_(0, sample['label'], 255);#[L,64,64]   
        '''
        '''
        for l in range(lnum):
            for i in range(self.choicesize):
                for j in range(self.choicesize):
                    print(int(sample['label'][l][i][j]),end=' ');
                print();
            print();
        '''
        #now_time=time.time();
        #print("getitem time part2:",now_time-prev_time);
        #pre_time=now_time;
        return sample
Esempio n. 11
0
def get_data(dataset='MNIST', train_augment_angle='0'):
    if train_augment_angle == 'all':
        transform_augment = [
            transforms.RandomAffine(degrees=360, translate=None)
        ]
    elif train_augment_angle == '0':
        transform_augment = [
            transforms.RandomAffine(degrees=0, translate=None)
        ]
    else:
        augment_angle = int(train_augment_angle)
        transform_augment = [
            transforms.RandomChoice([
                transforms.RandomAffine(degrees=(augment_angle * i,
                                                 augment_angle * i),
                                        translate=None)
                for i in range(int(360 / augment_angle))
            ])
        ]

    if dataset == 'MNIST':
        transform_augment = [*transform_augment]
        transform_normalize = [transforms.Normalize((0.1307, ), (0.3081, ))]
        transform_train = transforms.Compose(
            [*transform_augment,
             transforms.ToTensor(), *transform_normalize])
        transform_val = transforms.Compose(
            [transforms.ToTensor(), *transform_normalize])

        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('./../data',
                           train=True,
                           download=True,
                           transform=transform_train),
            batch_size=BATCH_SIZE,
            shuffle=True)  # , num_workers=4)
        validation_loader = torch.utils.data.DataLoader(datasets.MNIST(
            './../data', train=False, transform=transform_val),
                                                        batch_size=BATCH_SIZE,
                                                        shuffle=False)

    elif dataset == 'CIFAR10':
        transform_augment = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(), *transform_augment
        ]
        transform_normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                   (0.2023, 0.1994, 0.2010))
        transform_train = transforms.Compose(
            [*transform_augment,
             transforms.ToTensor(), transform_normalize])
        transform_val = transforms.Compose(
            [transforms.ToTensor(), transform_normalize])

        train_loader = torch.utils.data.DataLoader(
            datasets.cifar.CIFAR10('./../data',
                                   train=True,
                                   download=True,
                                   transform=transform_train),
            batch_size=BATCH_SIZE,
            shuffle=True)  # , num_workers=4)
        validation_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            './../data', train=False, transform=transform_val),
                                                        batch_size=BATCH_SIZE,
                                                        shuffle=False)
    else:
        raise NotImplementedError
    return train_loader, validation_loader
Esempio n. 12
0
    def __getitem__(self, index):
        path = self.path_list[index][0]
        part = self.path_list[index][1]
        path2 = facail_part_path + '/' + part + '/images/' + path + ".jpg"
        image = Image.open(path2)
        label = []
        path2 = facail_part_path + '/' + part + '/labels/' + path + '/'
        if (self.choice != None):
            image_trans = transforms.Compose([
                ToPILImage(),
                transforms.RandomChoice(self.choice),
                #transforms.RandomChoice(self.choice[index//self.listnum]),
                Resize(size=(self.choicesize, self.choicesize),
                       interpolation=Image.NEAREST),
                ToTensor(),
                LabelUpdata(self.choicesize)
            ])

        if (part == "eye1"):
            for i in range(4, 5):
                image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                    str(i % 10) + ".png").convert('L')
                #image2=image2.convert('L');
                label.append(image2)
        if (part == "eye2"):
            for i in range(5, 6):
                image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                    str(i % 10) + ".png").convert('L')
                #image2=image2.convert('L');
                #image2=image2.transpose(Image.FLIP_LEFT_RIGHT)
                label.append(image2)
        if (part == "eyebrow1"):
            for i in range(2, 3):
                image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                    str(i % 10) + ".png").convert('L')
                #image2=image2.convert('L');
                label.append(image2)
        if (part == "eyebrow2"):
            for i in range(3, 4):
                image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                    str(i % 10) + ".png").convert('L')
                #image2=image2.convert('L');
                #image2=image2.transpose(Image.FLIP_LEFT_RIGHT)
                label.append(image2)
        if (part == "nose"):
            for i in range(6, 7):
                image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                    str(i % 10) + ".png").convert('L')
                #image2=image2.convert('L');
                label.append(image2)
        if (part == "mouth"):
            for i in range(7, 10):
                image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                    str(i % 10) + ".png").convert('L')
                #image2=image2.convert('L');
                label.append(image2)

        for i in range(len(label)):
            label[i] = np.array(label[i])
        bg = 255 - np.sum(label, axis=0, keepdims=True)
        labels = np.concatenate([bg, label], axis=0)  # [L + 1, 64, 64]
        labels = np.uint8(labels)
        #labels = [TF.to_pil_image(labels[i])for i in range(labels.shape[0])]

        sample = {
            'image': image,
            'label': labels,
            'index': index,
            'part': self.face_part
        }

        if image_trans != None:
            sample = image_trans(sample)
            sample['label'][0] = torch.sum(
                sample['label'][1:sample['label'].shape[0]],
                dim=0,
                keepdim=True)  # 1 x 64 x 64
            sample['label'][0] = 1 - sample['label'][0]
        #labels_list = torch.cat(sample['label'], dim=0)
        '''
        for i in range(9):
            x3=0;
            for x1 in range(64):                
                for x2 in range(64):       
                    if (labels_list[i][x1][x2]!=0):
                        x3+=1;
            print(print("label"+str(i)+"Num of pixel is "+str(x3)+" After resize"));        
        '''
        return sample['image'], sample['label']
Esempio n. 13
0
    def __getitem__(self, index):
        if (self.choice != None):
            image_trans = transforms.Compose([
                #transforms.RandomChoice(self.choice[index//self.listnum]),
                transforms.RandomChoice(self.choice.augmentation[index % 5]),
                Resize(size=(self.choicesize, self.choicesize),
                       interpolation=Image.NEAREST),
                ToTensor(),
                LabelUpdata(self.choicesize)
            ])
        else:
            image_trans = None
        #index2=index%self.listnum;
        index2 = index // 5
        path = self.path_list[index2]
        path2 = image_path + '/' + path + ".jpg"
        image = Image.open(path2)
        label = []
        #label.append(Image.new('L', (image.size[0],image.size[1]),color=1))
        path2 = label_path + '/' + path + '/'
        for i in range(2, 10):
            image2 = Image.open(path2 + path + "_lbl" + str(i // 10) +
                                str(i % 10) + ".png")
            image2 = image2.convert('L')
            label.append(image2)

        for i in range(len(label)):
            label[i] = np.array(label[i])
        bg = 255 - np.sum(label, axis=0, keepdims=True)
        labels = np.concatenate([bg, label], axis=0)  # [L + 1, 64, 64]
        labels = np.uint8(labels)
        labels = [TF.to_pil_image(labels[i]) for i in range(labels.shape[0])]
        sample = {"image": image, "label": labels, "index": index}
        if (image_trans != None):
            sample = image_trans(sample)
            sample['label'][0] = torch.sum(
                sample['label'][1:sample['label'].shape[0]],
                dim=0,
                keepdim=True)  # 1 x 64 x 64
            sample['label'][0] = 1 - sample['label'][0]
        '''
        f = open("out1.txt", "w")   
        for i in range(64):            
            for j in range(64):
                print(float(sample['label'][1][i][j]),end=' ',file=f);
            print("",file=f);        
        #input('pause');
        '''
        lnum = len(sample['label'])
        sample['label'] = torch.argmax(sample['label'], dim=0, keepdim=False)
        #[64,64]
        sample['label'] = sample['label'].unsqueeze(dim=0)
        #[1,64,64]
        sample['label'] = torch.zeros(lnum, self.choicesize,
                                      self.choicesize).scatter_(
                                          0, sample['label'], 255)
        #[L,64,64]
        '''
        f = open("out2.txt", "w")   
        for i in range(64):            
            for j in range(64):
                print(float(sample['label'][1][i][j]),end=' ',file=f);
            print("",file=f);
        input('pause');
        '''
        return sample
Esempio n. 14
0
    def train(self, params):
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'

        mom_range = params['mom_range']
        n_res = params['n_res']
        niter = params['niter']
        scheduler = params['scheduler']
        optimizer_type = params['optimizer']
        momentum = params['momentum']
        learning_rate = params['learning_rate'].__format__('e')
        weight_decay = params['weight_decay'].__format__('e')

        weight_decay = float(str(weight_decay)[:1] + str(weight_decay)[-4:])
        learning_rate = float(str(learning_rate)[:1] + str(learning_rate)[-4:])
        if self.verbose > 1:
            print(
                "Parameters: \n\t",
                'zdim: ' + str(self.n_classes) + "\n\t",
                'mom_range: ' + str(mom_range) + "\n\t",
                'niter: ' + str(niter) + "\n\t",
                'nres: ' + str(n_res) + "\n\t",
                'learning_rate: ' + learning_rate.__format__('e') + "\n\t",
                'momentum: ' + str(momentum) + "\n\t",
                'weight_decay: ' + weight_decay.__format__('e') + "\n\t",
                'optimizer_type: ' + optimizer_type + "\n\t",
            )

        self.modelname = "classif_3dcnn_" \
                         + '_bn' + str(self.batchnorm) \
                         + '_niter' + str(niter) \
                         + '_nres' + str(n_res) \
                         + '_momrange' + str(mom_range) \
                         + '_momentum' + str(momentum) \
                         + '_' + str(optimizer_type) \
                         + "_nclasses" + str(self.n_classes) \
                         + '_initlr' + learning_rate.__format__('e') \
                         + '_wd' + weight_decay.__format__('e') \
                         + '_size' + str(self.size)
        model = MLP(n_neurons=n_neurons,
                    n_classes=self.n_classes,
                    activation=torch.nn.ReLU).to(device)
        criterion = nn.CrossEntropyLoss()
        if optimizer_type == 'adamw':
            optimizer = torch.optim.AdamW(params=model.parameters(),
                                          lr=learning_rate,
                                          weight_decay=weight_decay,
                                          amsgrad=True)
        elif optimizer_type == 'sgd':
            optimizer = torch.optim.SGD(params=model.parameters(),
                                        lr=learning_rate,
                                        weight_decay=weight_decay,
                                        momentum=momentum)
        elif optimizer_type == 'rmsprop':
            optimizer = torch.optim.RMSprop(params=model.parameters(),
                                            lr=learning_rate,
                                            weight_decay=weight_decay,
                                            momentum=momentum)
        else:
            exit('error: no such optimizer type available')

        epoch = 0
        model = model.to(device)

        train_transform = transforms.Compose([
            transforms.RandomChoice([XFlip(), YFlip(),
                                     ZFlip()]),
            transforms.RandomChoice([Flip90(), Flip180(),
                                     Flip270()]),
            ColorJitter3D(.2, .2, .2, .2),
            transforms.RandomChoice([
                RandomRotation3D(25, 0),
                RandomRotation3D(25, 1),
                RandomRotation3D(25, 2)
            ]),
            torchvision.transforms.Normalize(mean=(self.mean), std=(self.std)),
            Normalize()
        ])
        all_set = MRIDatasetClassifier(self.path,
                                       transform=train_transform,
                                       size=self.size)
        spliter = validation_spliter(all_set, cv=self.cross_validation)

        print("Training Started on device:", device)
        best_losses = []
        for cv in range(self.cross_validation):
            model.random_init(self.init_func)
            best_loss = -1
            valid_set, train_set = spliter.__next__()
            valid_set.transform = False
            train_loader = DataLoader(train_set,
                                      num_workers=0,
                                      shuffle=True,
                                      batch_size=self.batch_size,
                                      pin_memory=False,
                                      drop_last=True)
            valid_loader = DataLoader(valid_set,
                                      num_workers=0,
                                      shuffle=True,
                                      batch_size=2,
                                      pin_memory=False,
                                      drop_last=True)

            # Get shared output_directory ready
            logger = SummaryWriter('logs')

            if scheduler == 'ReduceLROnPlateau':
                lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer,
                    factor=0.1,
                    cooldown=50,
                    patience=50,
                    verbose=True,
                    min_lr=1e-15)
            elif scheduler == 'CycleScheduler':
                lr_schedule = CycleScheduler(optimizer,
                                             learning_rate,
                                             n_iter=niter * len(train_loader),
                                             momentum=[
                                                 max(0.0,
                                                     momentum - mom_range),
                                                 min(1.0,
                                                     momentum + mom_range),
                                             ])

            losses = {
                "train": [],
                "valid": [],
            }
            accuracies = {
                "train": [],
                "valid": [],
            }
            shapes = {
                "train": len(train_set),
                "valid": len(valid_set),
            }
            early_stop_counter = 0
            print("\n\n\nCV:", cv, "/", self.cross_validation,
                  "\nTrain samples:", len(train_set), "\nValid samples:",
                  len(valid_set), "\n\n\n")
            train_losses = []
            train_accuracy = []
            valid_losses = []
            valid_accuracy = []
            for epoch in range(self.epochs):
                if early_stop_counter == 200:
                    if self.verbose > 0:
                        print('EARLY STOPPING.')
                    break
                best_epoch = False
                model.train()

                for i, batch in enumerate(train_loader):
                    #    pbar.update(1)
                    model.zero_grad()
                    images, targets = batch

                    images = images.to(device)
                    targets = targets.to(device)

                    preds = model(images)

                    loss = criterion(preds, targets)

                    loss.backward()

                    accuracy = sum([
                        1 if torch.argmax(pred) == target else 0
                        for (pred, target) in zip(preds, targets)
                    ]) / len(targets)
                    train_accuracy += [accuracy]

                    train_losses += [loss.item()]

                    optimizer.step()
                    if scheduler == "CycleScheduler":
                        lr_schedule.step()
                    logger.add_scalar('training_loss', loss.item(),
                                      i + len(train_loader) * epoch)
                    del loss

                if epoch % self.epochs_per_print == 0:
                    losses["train"] += [np.mean(train_losses)]
                    accuracies["train"] += [np.mean(train_accuracy)]
                    if self.verbose > 1:
                        print("Epoch: {}:\t"
                              "Train Loss: {:.5f} , "
                              "Accuracy: {:.3f} , ".format(
                                  epoch, losses["train"][-1],
                                  accuracies["train"][-1]))
                    train_losses = []
                    train_accuracy = []

                model.eval()
                for i, batch in enumerate(valid_loader):
                    images, targets = batch
                    images = images.to(device)
                    targets = targets.to(device)
                    preds = model(images)

                    loss = criterion(preds, targets)
                    valid_losses += [loss.item()]
                    accuracy = sum([
                        1 if torch.argmax(pred) == target else 0
                        for (pred, target) in zip(preds, targets)
                    ]) / len(targets)
                    valid_accuracy += [accuracy]
                    logger.add_scalar('training loss', np.log2(loss.item()),
                                      i + len(train_loader) * epoch)
                if scheduler == "ReduceLROnPlateau":
                    if epoch > 25:
                        lr_schedule.step(losses["valid"][-1])
                mode = 'valid'
                if epoch > 10 and epoch % self.epochs_per_print == 0:
                    if (losses[mode][-1] < best_loss or best_loss == -1) \
                            and not np.isnan(losses[mode][-1]):
                        if self.verbose > 1:
                            print('BEST EPOCH!', losses[mode][-1],
                                  accuracies[mode][-1])
                        early_stop_counter = 0
                        best_loss = losses[mode][-1]
                        best_epoch = True
                    else:
                        early_stop_counter += 1

                if epoch % self.epochs_per_print == 0:
                    losses["valid"] += [np.mean(valid_losses)]
                    accuracies["valid"] += [np.mean(valid_accuracy)]
                    if self.verbose > 0:
                        print("Epoch: {}:\t"
                              "Valid Loss: {:.5f} , "
                              "Accuracy: {:.3f} ".format(
                                  epoch,
                                  losses["valid"][-1],
                                  accuracies["valid"][-1],
                              ))
                    if self.verbose > 1:
                        print("Current LR:", optimizer.param_groups[0]['lr'])
                    if 'momentum' in optimizer.param_groups[0].keys():
                        print("Current Momentum:",
                              optimizer.param_groups[0]['momentum'])
                    valid_losses = []
                    valid_accuracy = []

                if self.plot_perform:
                    plot_performance(loss_total=losses,
                                     losses_recon=None,
                                     accuracies=accuracies,
                                     kl_divs=None,
                                     shapes=shapes,
                                     results_path="../figures",
                                     filename="training_loss_trace_" +
                                     self.modelname + '.jpg')
            if self.verbose > 0:
                print('BEST LOSS :', best_loss)
            best_losses += [best_loss]
        return best_losses
Esempio n. 15
0
])

# color changes
brightness = 10
contrast = 10
saturation = 10
hue = 0.25
color_transform = transforms.Compose([
    transforms.ColorJitter(brightness=brightness,
                           contrast=contrast,
                           saturation=saturation,
                           hue=hue),
    transforms.ToTensor(),
])
# allows to chose randomly from the different transformations
transform_list = transforms.RandomChoice(
    [rotation_transform, hoz_transform, vert_transform, color_transform])

# ## Loading the data

# loading the custom dataset
dataset = CofgaDataset(csv_file='dataset/train_preprocessed.csv',
                       root_dir='dataset/root/train/resized/',
                       transform=transform_list)

print("Total number of images: ", len(dataset))

COFGA_headers = pd.read_csv('dataset/train_preprocessed.csv')

COFGA_labels = COFGA_headers.columns.tolist()
COFGA_labels.pop(0)
# Data Uplaod
print('\n[Phase 1] : Data Preparation')
torch.manual_seed(2809)
gaussian_transforms = [
    transforms.Lambda(lambda x: ndimage.gaussian_filter(x, sigma=0)),
    transforms.Lambda(lambda x: ndimage.gaussian_filter(x, sigma=1)),
    transforms.Lambda(lambda x: ndimage.gaussian_filter(x, sigma=2)),
    transforms.Lambda(lambda x: ndimage.gaussian_filter(x, sigma=5)),
    transforms.Lambda(lambda x: ndimage.gaussian_filter(x, sigma=10))
]
transform_train_noise = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cf.mean['cifar100'], cf.std['cifar100']),
    transforms.RandomChoice(gaussian_transforms),
    #transforms.ToTensor()
])

transform_train_clean = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cf.mean['cifar100'], cf.std['cifar100']),
])  # meanstd transformation

transform_test_noise = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cf.mean['cifar100'], cf.std['cifar100']),
    transforms.RandomChoice(gaussian_transforms),
    #transforms.ToTensor()
Esempio n. 17
0
def train(args):
    # Get hardware device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Check if weights and biases integration is enabled.
    if args.wandb == 1:
        import wandb
        wandb.init(entity='surajpai',
                   project='FacialEmotionRecognition',
                   config=vars(args))

    # Get the dataset with "Training" usage.
    dataset = FER2013Dataset(args.data_path, "Training")

    # Randomly split the dataset into train and validation based on the specified train_split argument
    train_dataset, validation_dataset = torch.utils.data.random_split(
        dataset, [
            int(len(dataset) * args.train_split),
            len(dataset) - int(len(dataset) * args.train_split)
        ])

    logging.info(
        'Samples in the training set: {}\n Samples in the validation set: {} \n\n'
        .format(len(train_dataset), len(validation_dataset)))

    # Get class weights as inverse of frequencies from class occurences in the dataset.
    dataset_summary = dataset.get_summary_statistics()
    class_weights = (1 / dataset_summary["class_occurences"])
    class_weights = torch.Tensor(class_weights /
                                 np.sum(class_weights)).to(device)

    # Train loader and validation loader initialized with batch_size as specified and randomly shuffled
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              pin_memory=True)
    val_loader = DataLoader(validation_dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            pin_memory=True)

    # Model initialization
    model = torch.nn.DataParallel(Model(args.model_config)).to(device)

    # Set torch optimizer
    optimizer = torch.optim.Adam(model.parameters(), )

    # Get loss for training the network from the utils get_loss function
    criterion = get_loss(args, class_weights)
    bestLoss = -1000

    # Create metric logger object
    metrics = Metrics(upload=args.wandb)

    # Define augmentation transforms, if --augment is enabled
    if args.augment == 1:
        transform = transforms.RandomChoice([
            transforms.RandomHorizontalFlip(p=0.75),
            transforms.RandomAffine(15,
                                    translate=(0.1, 0.1),
                                    scale=(1.2, 1.2),
                                    shear=15),
            transforms.ColorJitter()
        ])

    # Start iterating over the total number of epochs set by epochs argument
    for n_epoch in range(args.epochs):

        # Reset running metrics at the beginning of each epoch.
        metrics.reset()

        # Utils logger
        logging.info(' Starting Epoch: {}/{} \n'.format(n_epoch, args.epochs))
        '''

        TRAINING

        '''

        # Model in train mode for batch-norm and dropout related ops.
        model.train()

        # Iterate over each batch in the train loader
        for idx, batch in enumerate(tqdm(train_loader)):

            # Reset gradients
            optimizer.zero_grad()

            # Apply augmentation transforms, if --augment is enabled
            if args.augment == 1 and n_epoch % 2 == 0:
                batch = apply_transforms(batch, transform)

            # Move the batch to the device, needed explicitly if GPU is present
            image, target = batch["image"].to(device), batch["emotion"].to(
                device)

            # Run a forward pass over images from the batch
            out = model(image)

            # Calculate loss based on the criterion set
            loss = criterion(out, target)

            # Backward pass from the final loss
            loss.backward()

            # Update the optimizer
            optimizer.step()

            # Update metrics for this batch
            metrics.update_train({
                "loss": loss.item(),
                "predicted": out,
                "ground_truth": target
            })
        '''

        VALIDATION

        '''

        logging.info(' Validating on the validation split ... \n \n')

        # Model in eval mode.
        model.eval()

        # Set no grad to disable gradient saving.
        with torch.no_grad():

            # Iterate over each batch in the val loader
            for idx, batch in enumerate(val_loader):

                # Move the batch to the device, needed explicitly if GPU is present
                image, target = batch["image"].to(device), batch["emotion"].to(
                    device)

                # Forward pass
                out = model(image)

                # Calculate loss based on the criterion set
                loss = criterion(out, target)

                # Metrics and sample predictions updated for validation batch
                metrics.update_val({
                    "loss": loss.item(),
                    "predicted": out,
                    "ground_truth": target,
                    "image": image,
                    "class_mapping": dataset.get_class_mapping()
                })

        # Display metrics at the end of each epoch
        metrics.display()

        # Weight Checkpointing to save the best model on validation loss
        save_path = "./saved_models/{}.pth.tar".format(
            args.model_config.split('/')[-1].split('.')[0])
        bestLoss = min(bestLoss, metrics.metric_dict["loss@val"])
        is_best = (bestLoss == metrics.metric_dict["loss@val"])
        save_checkpoint(
            {
                'epoch': n_epoch,
                'state_dict': model.state_dict(),
                'bestLoss': bestLoss,
                'optimizer': optimizer.state_dict(),
            }, is_best, save_path)

    # After training is completed, if weights and biases is enabled, visualize filters and upload final model.
    if args.wandb == 1:
        visualize_filters(model.modules())
        wandb.save(save_path)

    # Get report from the metrics logger
    train_report, val_report = metrics.get_report()

    # Save the report to csv files
    train_report.to_csv("{}_trainreport.csv".format(
        save_path.rstrip(".pth.tar")))
    val_report.to_csv("{}_valreport.csv".format(save_path.rstrip(".pth.tar")))
Esempio n. 18
0
# Just normalization for validation
transform_options = [
    #transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)
    #transforms.RandomRotation(degrees=[-30, 30])
    #transforms.GaussianBlur(kernel_size=5)
    transforms.RandomAffine(0, shear=20)
]

data_transforms = {
    'train':
    transforms.Compose([
        # transforms.RandomResizedCrop(IMG_SIZE)
        # ,
        #transforms.RandomHorizontalFlip()
        # ,
        transforms.RandomApply([transforms.RandomChoice(transform_options)],
                               p=1.0)
    ]),
    'val':
    transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

img_path = "data/training_data/" + "000001.jpg"
img = Image.open(img_path).convert("RGB")
img = data_transforms["train"](img)
img.show()
img.save("RandomAffine.png")
Esempio n. 19
0
def load_transforms(name):
    """Load data transformations.
    
    Note:
        - Gaussian Blur is defined at the bottom of this file.
    
    """
    _name = name.lower()
    if _name == "default":
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=8),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
    elif _name == "transfer":
        transform = transforms.Compose(
            [transforms.CenterCrop(32),
             transforms.ToTensor()])
    elif _name == "cifar":
        transform = transforms.Compose([
            transforms.RandomResizedCrop(32),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor()
        ])
    elif _name == "mnist":
        transform = transforms.Compose([
            transforms.RandomChoice([
                transforms.RandomAffine((-90, 90)),
                transforms.RandomAffine(0, translate=(0.2, 0.4)),
                transforms.RandomAffine(0, scale=(0.8, 1.1)),
                transforms.RandomAffine(0, shear=(-20, 20))
            ]),
            GaussianBlur(kernel_size=3),
            transforms.ToTensor()
        ])
    elif _name == "stl10":
        transform = transforms.Compose([
            transforms.RandomResizedCrop(96),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply(
                [transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(kernel_size=9),
            transforms.ToTensor()
        ])
    elif _name == "fashionmnist" or _name == "fmnist":
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation((-90, 90)),
            transforms.RandomChoice([
                transforms.RandomAffine((-90, 90)),
                transforms.RandomAffine(0, translate=(0.2, 0.4)),
                transforms.RandomAffine(0, scale=(0.8, 1.1)),
                transforms.RandomAffine(0, shear=(-20, 20))
            ]),
            GaussianBlur(kernel_size=3),
            transforms.ToTensor()
        ])
    elif _name == 'deca_train':
        transform = transforms.Compose([
            transforms.Resize(72),
            transforms.RandomCrop(64),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
    elif _name == 'deca_test':
        transform = transforms.Compose([
            transforms.Resize(72),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
        ])
    elif _name == "test":
        transform = transforms.ToTensor()
    else:
        raise NameError("{} not found in transform loader".format(name))
    return transform
########## 001 Data Transforms #####################

image_trans = {
    'train':
    transforms.Compose([
        # transfer the input image into gray scale
        #transforms.Grayscale(num_output_channels=1),
        # resize the input image into the predefined scale
        #transforms.Resize((img_h, img_w), interpolation=PIL.Image.BICUBIC), # (h, w)
        # random choose one of the predefined transforms (in the list) when performing the training process
        #transforms.Lambda(lambda img : head_center(img)),
        transforms.Lambda(lambda img: pad_img(img, img_w)),
        transforms.RandomChoice([
            transforms.RandomHorizontalFlip(),
            #transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
            #transforms.Lambda(lambda img : centralize(img,0.4,0.4,0.5,0.5)),
            #transforms.RandomRotation(30, resample=False, expand=False, center=None)
        ]),

        #transforms.Lambda(lambda img : verticalize(img)),
        # transfer the type of input image into tensor style
        transforms.ToTensor(),
    ]),
    'valid':
    transforms.Compose([
        #transforms.Grayscale(num_output_channels=1),
        #transforms.Resize((img_h, img_w), interpolation=PIL.Image.BICUBIC),
        #transforms.Lambda(lambda img : centralize(img,0.4,0.4,0.4,0.3)),
        #transforms.Lambda(lambda img : head_center(img)),
        transforms.Lambda(lambda img: pad_img(img, img_w)),
        transforms.RandomChoice([
Esempio n. 21
0
"""**Define Data Preprocessing**"""

# Define transforms for training phase
train_transform = tr.Compose([tr.Resize(256),  # Resizes short size of the PIL image to 256
                              tr.CenterCrop(224),  # Crops a central square patch of the image 224
                              # because torchvision's AlexNet needs a 224x224 input! Remember this when
                              # applying different transformations, otherwise you get an error
                              # /======================================================================================\
                              # 4.A: Data Augmentation
                              # ----------------------------------------------------------------------------------------
                              # tr.RandomHorizontalFlip(),
                              # tr.RandomPerspective(distortion_scale=0.2),
                              # tr.RandomRotation(degrees=10),
                              # ----------------------------------------------------------------------------------------
                              tr.RandomChoice([tr.RandomHorizontalFlip(),
                                               tr.RandomPerspective(distortion_scale=0.2),
                                               tr.RandomRotation(degrees=10)]),
                              # \======================================================================================/
                              tr.ToTensor(),  # Turn PIL Image to torch.Tensor
                              # /======================================================================================\
                              # Normalizes tensor with mean and standard deviation
                              # ----------------------------------------------------------------------------------------
                              # Till 3.A:
                              # tr.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                              # ----------------------------------------------------------------------------------------
                              # From 3.B on:
                              tr.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                              # \======================================================================================/
                              ])
# Define transforms for the evaluation phase
eval_transform = tr.Compose([tr.Resize(256),
Esempio n. 22
0
 def __init__(self, mdlParams, indSet):
     """
     Args:
         mdlParams (dict): Configuration for loading
         indSet (string): Indicates train, val, test
     """
     # Mdlparams
     self.mdlParams = mdlParams
     # Number of classes
     self.numClasses = mdlParams['numClasses']
     # Model input size
     self.input_size = (np.int32(mdlParams['input_size'][0]),np.int32(mdlParams['input_size'][1]))      
     # Whether or not to use ordered cropping 
     self.orderedCrop = mdlParams['orderedCrop']   
     # Number of crops for multi crop eval
     self.multiCropEval = mdlParams['multiCropEval']   
     # Whether during training same-sized crops should be used
     self.same_sized_crop = mdlParams['same_sized_crops']    
     # Only downsample
     self.only_downsmaple = mdlParams.get('only_downsmaple',False)   
     # Potential class balancing option 
     self.balancing = mdlParams['balance_classes']
     # Whether data should be preloaded
     self.preload = mdlParams['preload']
     # Potentially subtract a mean
     self.subtract_set_mean = mdlParams['subtract_set_mean']
     # Potential switch for evaluation on the training set
     self.train_eval_state = mdlParams['trainSetState']   
     # Potential setMean to deduce from channels
     self.setMean = mdlParams['setMean'].astype(np.float32)
     # Current indSet = 'trainInd'/'valInd'/'testInd'
     self.indices = mdlParams[indSet]  
     self.indSet = indSet
     # feature scaling for meta
     if mdlParams.get('meta_features',None) is not None and mdlParams['scale_features']:
         self.feature_scaler = mdlParams['feature_scaler_meta']
     if self.balancing == 3 and indSet == 'trainInd':
         # Sample classes equally for each batch
         # First, split set by classes
         not_one_hot = np.argmax(mdlParams['labels_array'],1)
         self.class_indices = []
         for i in range(mdlParams['numClasses']):
             self.class_indices.append(np.where(not_one_hot==i)[0])
             # Kick out non-trainind indices
             self.class_indices[i] = np.setdiff1d(self.class_indices[i],mdlParams['valInd'])
             # And test indices
             if 'testInd' in mdlParams:
                 self.class_indices[i] = np.setdiff1d(self.class_indices[i],mdlParams['testInd'])
         # Now sample indices equally for each batch by repeating all of them to have the same amount as the max number
         indices = []
         max_num = np.max([len(x) for x in self.class_indices])
         # Go thourgh all classes
         for i in range(mdlParams['numClasses']):
             count = 0
             class_count = 0
             max_num_curr_class = len(self.class_indices[i])
             # Add examples until we reach the maximum
             while(count < max_num):
                 # Start at the beginning, if we are through all available examples
                 if class_count == max_num_curr_class:
                     class_count = 0
                 indices.append(self.class_indices[i][class_count])
                 count += 1
                 class_count += 1
         print("Largest class",max_num,"Indices len",len(indices))
         print("Intersect val",np.intersect1d(indices,mdlParams['valInd']),"Intersect Testind",np.intersect1d(indices,mdlParams['testInd']))
         # Set labels/inputs
         self.labels = mdlParams['labels_array'][indices,:]
         self.im_paths = np.array(mdlParams['im_paths'])[indices].tolist()     
         # Normal train proc
         if self.same_sized_crop:
             cropping = transforms.RandomCrop(self.input_size)
         elif self.only_downsmaple:
             cropping = transforms.Resize(self.input_size)
         else:
             cropping = transforms.RandomResizedCrop(self.input_size[0])
         # All transforms
         self.composed = transforms.Compose([
                 cropping,
                 transforms.RandomHorizontalFlip(),
                 transforms.RandomVerticalFlip(),
                 transforms.ColorJitter(brightness=32. / 255.,saturation=0.5),
                 transforms.ToTensor(),
                 transforms.Normalize(torch.from_numpy(self.setMean).float(),torch.from_numpy(np.array([1.,1.,1.])).float())
                 ])                                
     elif self.orderedCrop and (indSet == 'valInd' or self.train_eval_state  == 'eval' or indSet == 'testInd'):
         # Also flip on top            
         if mdlParams.get('eval_flipping',0) > 1:
             # Complete labels array, only for current indSet, repeat for multiordercrop
             inds_rep = np.repeat(mdlParams[indSet], mdlParams['multiCropEval']*mdlParams['eval_flipping'])
             self.labels = mdlParams['labels_array'][inds_rep,:]
             # meta
             if mdlParams.get('meta_features',None) is not None:
                 self.meta_data = mdlParams['meta_array'][inds_rep,:]    
             # Path to images for loading, only for current indSet, repeat for multiordercrop
             self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist()
             print("len im path",len(self.im_paths))                
             if self.mdlParams.get('var_im_size',False):
                 self.cropPositions = np.tile(mdlParams['cropPositions'][mdlParams[indSet],:,:],(1,mdlParams['eval_flipping'],1))
                 self.cropPositions = np.reshape(self.cropPositions,[mdlParams['multiCropEval']*mdlParams['eval_flipping']*mdlParams[indSet].shape[0],2])
                 #self.cropPositions = np.repeat(self.cropPositions, (mdlParams['eval_flipping'],1))
                 #print("CP examples",self.cropPositions[:50,:])
             else:
                 self.cropPositions = np.tile(mdlParams['cropPositions'], (mdlParams['eval_flipping']*mdlParams[indSet].shape[0],1))
             # Flip states
             if mdlParams['eval_flipping'] == 2:
                 self.flipPositions = np.array([0,1])
             elif mdlParams['eval_flipping'] == 3:
                 self.flipPositions = np.array([0,1,2])
             elif mdlParams['eval_flipping'] == 4:
                 self.flipPositions = np.array([0,1,2,3])                    
             self.flipPositions = np.repeat(self.flipPositions, mdlParams['multiCropEval'])
             self.flipPositions = np.tile(self.flipPositions, mdlParams[indSet].shape[0])
             print("Crop positions shape",self.cropPositions.shape,"flip pos shape",self.flipPositions.shape)
             print("Flip example",self.flipPositions[:30])
         else:
             # Complete labels array, only for current indSet, repeat for multiordercrop
             inds_rep = np.repeat(mdlParams[indSet], mdlParams['multiCropEval'])
             self.labels = mdlParams['labels_array'][inds_rep,:]
             # meta
             if mdlParams.get('meta_features',None) is not None:
                 self.meta_data = mdlParams['meta_array'][inds_rep,:]                 
             # Path to images for loading, only for current indSet, repeat for multiordercrop
             self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist()
             print("len im path",len(self.im_paths))
             # Set up crop positions for every sample                
             if self.mdlParams.get('var_im_size',False):
                 self.cropPositions = np.reshape(mdlParams['cropPositions'][mdlParams[indSet],:,:],[mdlParams['multiCropEval']*mdlParams[indSet].shape[0],2])
                 #print("CP examples",self.cropPositions[:50,:])
             else:
                 self.cropPositions = np.tile(mdlParams['cropPositions'], (mdlParams[indSet].shape[0],1))
             print("CP",self.cropPositions.shape)
         #print("CP Example",self.cropPositions[0:len(mdlParams['cropPositions']),:])          
         # Set up transforms
         self.norm = transforms.Normalize(np.float32(self.mdlParams['setMean']),np.float32(self.mdlParams['setStd']))
         self.trans = transforms.ToTensor()
     elif indSet == 'valInd' or indSet == 'testInd':
         if self.multiCropEval == 0:
             if self.only_downsmaple:
                 self.cropping = transforms.Resize(self.input_size)
             else:
                 self.cropping = transforms.Compose([transforms.CenterCrop(np.int32(self.input_size[0]*1.5)),transforms.Resize(self.input_size)])
             # Complete labels array, only for current indSet
             self.labels = mdlParams['labels_array'][mdlParams[indSet],:]
             # meta
             if mdlParams.get('meta_features',None) is not None:
                 self.meta_data = mdlParams['meta_array'][mdlParams[indSet],:]                 
             # Path to images for loading, only for current indSet
             self.im_paths = np.array(mdlParams['im_paths'])[mdlParams[indSet]].tolist()                   
         else:
             # Deterministic processing
             if self.mdlParams.get('deterministic_eval',False):
                 total_len_per_im = mdlParams['numCropPositions']*len(mdlParams['cropScales'])*mdlParams['cropFlipping']                    
                 # Actual transforms are functionally applied at forward pass
                 self.cropPositions = np.zeros([total_len_per_im,3])
                 ind = 0
                 for i in range(mdlParams['numCropPositions']):
                     for j in range(len(mdlParams['cropScales'])):
                         for k in range(mdlParams['cropFlipping']):
                             self.cropPositions[ind,0] = i
                             self.cropPositions[ind,1] = mdlParams['cropScales'][j]
                             self.cropPositions[ind,2] = k
                             ind += 1
                 # Complete labels array, only for current indSet, repeat for multiordercrop
                 print("crops per image",total_len_per_im)
                 self.cropPositions = np.tile(self.cropPositions, (mdlParams[indSet].shape[0],1))
                 inds_rep = np.repeat(mdlParams[indSet], total_len_per_im)
                 self.labels = mdlParams['labels_array'][inds_rep,:]
                 # meta
                 if mdlParams.get('meta_features',None) is not None:
                     self.meta_data = mdlParams['meta_array'][inds_rep,:]                     
                 # Path to images for loading, only for current indSet, repeat for multiordercrop
                 self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist()
             else:
                 self.cropping = transforms.RandomResizedCrop(self.input_size[0],scale=(mdlParams.get('scale_min',0.08),1.0))
                 # Complete labels array, only for current indSet, repeat for multiordercrop
                 inds_rep = np.repeat(mdlParams[indSet], mdlParams['multiCropEval'])
                 self.labels = mdlParams['labels_array'][inds_rep,:]
                 # meta
                 if mdlParams.get('meta_features',None) is not None:
                     self.meta_data = mdlParams['meta_array'][inds_rep,:]                    
                 # Path to images for loading, only for current indSet, repeat for multiordercrop
                 self.im_paths = np.array(mdlParams['im_paths'])[inds_rep].tolist()
         print(len(self.im_paths))  
         # Set up transforms
         self.norm = transforms.Normalize(np.float32(self.mdlParams['setMean']),np.float32(self.mdlParams['setStd']))
         self.trans = transforms.ToTensor()                   
     else:
         all_transforms = []
         # Normal train proc
         if self.same_sized_crop:
             all_transforms.append(transforms.RandomCrop(self.input_size))
         elif self.only_downsmaple:
             all_transforms.append(transforms.Resize(self.input_size))
         else:
             all_transforms.append(transforms.RandomResizedCrop(self.input_size[0],scale=(mdlParams.get('scale_min',0.08),1.0)))
         if mdlParams.get('flip_lr_ud',False):
             all_transforms.append(transforms.RandomHorizontalFlip())
             all_transforms.append(transforms.RandomVerticalFlip())
         # Full rot
         if mdlParams.get('full_rot',0) > 0:
             if mdlParams.get('scale',False):
                 all_transforms.append(transforms.RandomChoice([transforms.RandomAffine(mdlParams['full_rot'], scale=mdlParams['scale'], shear=mdlParams.get('shear',0), resample=Image.NEAREST),
                                                             transforms.RandomAffine(mdlParams['full_rot'],scale=mdlParams['scale'],shear=mdlParams.get('shear',0), resample=Image.BICUBIC),
                                                             transforms.RandomAffine(mdlParams['full_rot'],scale=mdlParams['scale'],shear=mdlParams.get('shear',0), resample=Image.BILINEAR)])) 
             else:
                 all_transforms.append(transforms.RandomChoice([transforms.RandomRotation(mdlParams['full_rot'], resample=Image.NEAREST),
                                                             transforms.RandomRotation(mdlParams['full_rot'], resample=Image.BICUBIC),
                                                             transforms.RandomRotation(mdlParams['full_rot'], resample=Image.BILINEAR)]))    
         # Color distortion
         if mdlParams.get('full_color_distort') is not None:
             all_transforms.append(transforms.ColorJitter(brightness=mdlParams.get('brightness_aug',32. / 255.),saturation=mdlParams.get('saturation_aug',0.5), contrast = mdlParams.get('contrast_aug',0.5), hue = mdlParams.get('hue_aug',0.2)))
         else:
             all_transforms.append(transforms.ColorJitter(brightness=32. / 255.,saturation=0.5))   
         # Autoaugment
         if self.mdlParams.get('autoaugment',False):
             all_transforms.append(AutoAugment())             
         # Cutout
         if self.mdlParams.get('cutout',0) > 0:
             all_transforms.append(Cutout_v0(n_holes=1,length=self.mdlParams['cutout']))                             
         # Normalize
         all_transforms.append(transforms.ToTensor())
         all_transforms.append(transforms.Normalize(np.float32(self.mdlParams['setMean']),np.float32(self.mdlParams['setStd'])))            
         # All transforms
         self.composed = transforms.Compose(all_transforms)                  
         # Complete labels array, only for current indSet
         self.labels = mdlParams['labels_array'][mdlParams[indSet],:]
         # meta
         if mdlParams.get('meta_features',None) is not None:
             self.meta_data = mdlParams['meta_array'][mdlParams[indSet],:]            
         # Path to images for loading, only for current indSet
         self.im_paths = np.array(mdlParams['im_paths'])[mdlParams[indSet]].tolist()
     # Potentially preload
     if self.preload:
         self.im_list = []
         for i in range(len(self.im_paths)):
             self.im_list.append(Image.open(self.im_paths[i]))
Esempio n. 23
0

random_transforms = [
    transforms.RandomApply([], p=0),
    transforms.CenterCrop(48),
    transforms.ColorJitter(brightness=5),
    transforms.ColorJitter(saturation=5),
    transforms.ColorJitter(contrast=5),
    transforms.ColorJitter(hue=0.4),
    transforms.RandomRotation(15),
    transforms.RandomAffine(0, translate=(0.2, 0.2), resample=Image.BICUBIC),
    transforms.RandomAffine(0, shear=20, resample=Image.BICUBIC),
    transforms.RandomAffine(0, scale=(0.8, 1.2), resample=Image.BICUBIC),
]

train_data_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomChoice(random_transforms),
    transforms.Resize((48, 48)),
    CLAHE(),
    transforms.ToTensor(),
    transforms.Normalize((0.3337, 0.3064, 0.3171), (0.2672, 0.2564, 0.2629))
])

data_transforms = transforms.Compose([
    transforms.Resize((48, 48)),
    # CLAHE(),
    transforms.ToTensor(),
    transforms.Normalize((0.3337, 0.3064, 0.3171), (0.2672, 0.2564, 0.2629))
])
Esempio n. 24
0
        if self.use_gpu: targets = targets.cuda()
        targets = (1 -
                   self.epsilon) * targets + self.epsilon / self.num_classes
        loss = (-targets * log_probs).mean(0).sum()
        return loss


xent = CrossEntropyLabelSmooth(num_classes=2)
medianBlur = MedianBlur(3)
data_transforms = {
    'train':
    transforms.Compose([
        transforms.Scale(299),
        transforms.RandomRotation(15),
        transforms.RandomChoice(
            [transforms.RandomResizedCrop(224),
             transforms.CenterCrop(224)]),
        #transforms.RandomResizedCrop(224),
        #transforms.Resize([ 299, 299]),
        #transforms.CenterCrop(224),
        transforms.RandomChoice(
            [transforms.RandomHorizontalFlip(), medianBlur]),
        transforms.ColorJitter(brightness=0.2,
                               contrast=0.2,
                               saturation=0.2,
                               hue=0.2),
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'validation':
    transforms.Compose([
Esempio n. 25
0
import torchvision.transforms as transforms

from util import util

if __name__ == '__main__':
    opt = TrainOptions().parse()

    dataset = torchvision.datasets.ImageFolder(
        opt.dataroot,
        transform=transforms.Compose([
            transforms.RandomChoice([
                transforms.Resize(opt.loadSize, interpolation=1),
                transforms.Resize(opt.loadSize, interpolation=2),
                transforms.Resize(opt.loadSize, interpolation=3),
                transforms.Resize((opt.loadSize, opt.loadSize),
                                  interpolation=1),
                transforms.Resize((opt.loadSize, opt.loadSize),
                                  interpolation=2),
                transforms.Resize((opt.loadSize, opt.loadSize),
                                  interpolation=3)
            ]),
            transforms.RandomChoice([
                transforms.RandomResizedCrop(opt.fineSize, interpolation=1),
                transforms.RandomResizedCrop(opt.fineSize, interpolation=2),
                transforms.RandomResizedCrop(opt.fineSize, interpolation=3)
            ]),
            transforms.RandomChoice([
                transforms.ColorJitter(brightness=.05,
                                       contrast=.05,
                                       saturation=.05,
                                       hue=.05),
Esempio n. 26
0
def main(args):
    # Model settings
    model = ResNet()
    if args.cuda:
        model = model.cuda()
    optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.wd)
    if args.ckpt > 0:
        ckpt_name = 'resnet152'
        if args.poison:
            ckpt_name += '-poison'
        ckpt_name += '-' + str(args.ckpt) + '.pkl'
        ckpt_path = os.path.join('./ckpt', ckpt_name)
        print('Loading checkpoint from {}'.format(ckpt_path))
        dct = torch.load(ckpt_path)
        model.load_state_dict(dct['model'])
        optimizer.load_state_dict(dct['optim'])

    # Data loader settings
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((64, 64)),
        transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
    ])
    aug_transform = transforms.Compose([
        transforms.RandomChoice([
            # do nothing
            transforms.Compose([]),
            # horizontal flip
            transforms.RandomHorizontalFlip(1.),
            # random crop
            transforms.RandomResizedCrop(64),
            # rotate
            transforms.RandomRotation(30)
        ]),
        transforms.ToTensor(),
        transforms.Resize((64, 64)),
        transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
    ])
    task_dir = '/data/csnova1/benchmarks/%s' % args.task
    poison_dir = '/data/csnova1/poison'
    poison_config = get_poison_config()
    if args.task == "cifar10":
        Loader = CIFAR10Loader
        PoisonedILoader = PoisonedCIFAR10Loader
    train_loader = Loader(root=task_dir,
                          batch_size=args.batch_size,
                          split='train',
                          transform=aug_transform)
    test_loader = PoisonedILoader(root=task_dir,
                                  poison_root=poison_dir,
                                  poison_config=poison_config,
                                  poison_num=6,
                                  batch_size=args.batch_size,
                                  split="val",
                                  transform=transform)

    # Start
    if args.run == "train":
        train(args, train_loader, model, optimizer)
    elif args.run == "test":
        evaluate(args, test_loader, model)
Esempio n. 27
0
                                                                T_0=5,
                                                                T_mult=2)
 normMean = [0.5964188, 0.4566936, 0.3908954]
 normStd = [0.2590655, 0.2314241, 0.2269535]
 train_transformer = transforms.Compose([
     transforms.RandomChoice([
         transforms.ColorJitter(brightness=[0.8, 1.2],
                                contrast=[0.8, 1.2],
                                saturation=[0.8, 1.2],
                                hue=[-0.2, 0.2]),
         transforms.Compose(
             [transforms.Resize((150, 150)),
              transforms.Resize((200, 200))]),
         transforms.Compose([
             transforms.RandomResizedCrop(150,
                                          scale=(0.78, 1),
                                          ratio=(0.90, 1.10),
                                          interpolation=2),
             transforms.Resize((200, 200))
         ]),
         transforms.RandomHorizontalFlip(p=0.5),
         transforms.RandomRotation((-20, 20),
                                   resample=False,
                                   expand=False,
                                   center=None),
     ]),
     # transforms.RandomHorizontalFlip(p=0.5),
     # transforms.RandomRotation(degrees=(5, 10)),
     transforms.ToTensor(),
     transforms.Normalize(normMean, normStd)
 ])
Esempio n. 28
0
    batch_size = args.batch_size
    num_epoch = args.epochs
    lr = args.learning_rate
    wd = args.weight_decay
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    summary = SummaryWriter('./runs/')

    train_transforms = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomChoice([
            transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
            transforms.RandomResizedCrop(224),
            transforms.RandomAffine(degrees=15,
                                    translate=(0.2, 0.2),
                                    scale=(0.8, 1.2),
                                    shear=15,
                                    resample=Image.BILINEAR)
        ]),
        transforms.ToTensor(),
        transforms.Normalize((0.4452, 0.4457, 0.4464),
                             (0.2592, 0.2596, 0.2600)),
    ])

    test_transforms = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.4452, 0.4457, 0.4464),
                             (0.2592, 0.2596, 0.2600)),
    ])
Esempio n. 29
0
from yolo.config import IOU_THRESHOLD, TENSORBOARD_PATH
from tensorboardX import SummaryWriter
from datetime import datetime
import time
from torch.optim import SGD, RMSprop, Adam
from torch.optim.lr_scheduler import StepLR
from yolo.utils.evaluate.metter import AverageMeter

general_transform = Compose(
    [Resize((448, 448)),
     RandomHorizontalFlip(0.3),
     XyToCenter()])

transform = transforms.Compose([
    transforms.RandomChoice([
        transforms.ColorJitter(hue=.3, saturation=.2),
        transforms.RandomGrayscale(p=0.3),
    ]),
    transforms.ToTensor()
])

val_general_transform = Compose([Resize((448, 448)), XyToCenter()])

val_transform = transforms.Compose([transforms.ToTensor()])

# +
batch_size = 32

ds = OpenImage('/data/data/OpenImage/',
               'train',
               general_transform=general_transform,
               transform=transform)
def ContrastivePredictiveCodingAugmentations(img):
    # We use transformations as traceable
    # https://arxiv.org/pdf/1805.09501.pdf
    pool = [
        transforms.RandomRotation(  # Rotation
            30,
            resample=False,
            expand=False,
            center=None,
            fill=None),
        transforms.RandomAffine(  # Shearing
            0,
            translate=None,
            scale=None,
            shear=30,
            resample=False,
            fillcolor=0),
        transforms.RandomAffine(  # Translate
            0,
            translate=(0.3, 0.3),
            scale=None,
            shear=None,
            resample=False,
            fillcolor=0),
        transforms.Lambda(lambda x: imo.autocontrast(x)),  # Autocontrast
        transforms.Lambda(lambda x: imo.invert(x)),  # Invert
        transforms.Lambda(lambda x: imo.equalize(x)),  # Equalize
        transforms.Lambda(lambda x: imo.solarize(x)),  # Solarize
        transforms.Lambda(lambda x: imo.posterize(
            x, bits=int(np.random.randint(4, 8) + 1))),  # Posterize
        transforms.Lambda(
            lambda x: ime.Color(x).enhance(np.random.uniform())),  # Color
        transforms.Lambda(lambda x: ime.Brightness(x).enhance(
            np.random.uniform())),  # Brightness
        transforms.Lambda(lambda x: ime.Contrast(x).enhance(np.random.uniform(
        ))),  # Contrast
        transforms.Lambda(lambda x: ime.Sharpness(x).enhance(np.random.uniform(
        ))),  # Sharpness
        transforms.Compose(  # Set black
            [
                transforms.ToTensor(),
                transforms.RandomErasing(1.0),
                transforms.ToPILImage()
            ])
    ]

    # 1.
    t1 = transforms.RandomChoice(pool)
    t2 = transforms.RandomChoice(pool)
    t3 = transforms.RandomChoice(pool)

    img = t3(t2(t1(img)))

    # https://www.nature.com/articles/s41591-018-0107-6
    # 2. Only elastic def, no shearing as this is part of pool as well as hist changes
    if np.random.uniform() < 0.2:
        img = elastic_transform(img, sigma=10)

    # 3. In pool
    # 4.
    if np.random.uniform() < 0.25:
        img = transforms.functional.to_grayscale(img, num_output_channels=3)

    return img