示例#1
0
    def __init__(self, root, args, transform=None, target_transform=None,
                 loader=default_loader, db_path='./data_split/labeled_images_0.10.pth', is_unlabeled=False):
        classes, class_to_idx = find_classes(root)
        #imgs = make_dataset(root, class_to_idx)
        imgs = load_db(db_path, class_to_idx)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.is_unlabeled = is_unlabeled
        self.autoaugment = ImageNetPolicy()

        self.indices = [i for i in range(len(imgs))]
        random.shuffle(self.indices)
        if self.is_unlabeled:
            self.total_train_count = args.batch_size_unlabeled * args.max_iter * args.unlabeled_iter
        else:
            self.total_train_count = args.batch_size * args.max_iter

        print ("sample count {}".format(len(self.indices)))
        print ("total sample count {}".format(self.total_train_count))
示例#2
0
def inception_autoaugment_preproccess(input_size, normalize=_IMAGENET_STATS):
    return transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        ImageNetPolicy(fillcolor=(128, 128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(**normalize)
    ])
示例#3
0
 def imagenet(self):
     self.transform = transforms.Compose([
         transforms.CenterCrop(self.size),
         transforms.RandomHorizontalFlip(p=0.5),
         ImageNetPolicy(),
         transforms.ToTensor(),
         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
     ])
示例#4
0
def load_data(image_size, category_filter, train_test_split):
    train_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(30, translate=(0.3, 0.3), scale=(1.0, 1.5)),
        ImageNetPolicy(),
        # transforms.Grayscale(3),
        transforms.ToTensor()
    ])
    test_transform = transforms.Compose([
        transforms.Resize(image_size),
        # transforms.Grayscale(3),
        transforms.ToTensor()
    ])

    root_dir = '/train-data/inaturalist-2019/'
    train_test_set = InaturalistDataset(root_dir + 'train2019.json',
                                        root_dir,
                                        train_transform,
                                        category_filter=category_filter)
    val_set = InaturalistDataset(root_dir + 'val2019.json',
                                 root_dir,
                                 test_transform,
                                 category_filter=category_filter)

    if isinstance(train_test_split, float):
        train_size = int(len(train_test_set) * train_test_split)
        test_size = len(train_test_set) - train_size

        train_set, test_set = torch.utils.data.random_split(
            train_test_set, [train_size, test_size])
    elif isinstance(train_test_split, dict):
        train_indices, test_indices = train_test_set.sample(train_test_split)

        train_set = torch.utils.data.Subset(train_test_set, train_indices)
        test_set = torch.utils.data.Subset(train_test_set, test_indices)

    test_set.__getattribute__('dataset').__setattr__('transform',
                                                     test_transform)

    print(
        f'train: {len(train_set)}, val: {len(val_set)}, test: {len(test_set)}')

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=64,
                                               shuffle=True,
                                               num_workers=32)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=32)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=32)

    return (train_loader, val_loader, test_loader)
示例#5
0
def load_train_validate_data_2(csv_file,
                               root_dir,
                               batch_size,
                               valid_size=100,
                               extra=True):
    """
    Loads data from image directory and a csv_file for labels into data loaders
    :param csv_file:        string
    :param root_dir:        string                  (directory with the directory of images)
    :param batch_size:      int
    :param valid_size:      int                     (amount of images in validation set)
    :return train_loader:   pytorch dataloader
            val_loader:     pytorch dataloader
    """

    train_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomRotation(30),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        ImageNetPolicy(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_data_set = DatasetTorch(csv_file=csv_file,
                                 root_dir=root_dir,
                                 transform=test_transforms)
    train_data_set = DatasetTorch(csv_file=csv_file,
                                  root_dir=root_dir,
                                  transform=train_transforms)

    split = int(np.floor(valid_size))
    indices_train = shuffle(list(range(len(train_data_set))), random_state=0)
    indices_test = shuffle(list(range(len(test_data_set))), random_state=0)

    train_idx, test_idx = indices_train[split:], indices_test[:split]

    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(test_idx)

    # train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               sampler=train_sampler,
                                               batch_size=batch_size)
    val_loader = torch.utils.data.DataLoader(test_data_set,
                                             sampler=test_sampler,
                                             batch_size=batch_size)
    return train_loader, val_loader
示例#6
0
def main():
    #定向使用sample_pairng
    img = Image.open("./imgs/time1.jpg")
    img2 = Image.open("./imgs/time2.jpg")
    #    img.show()
    #    samp= SubPolicy(1,"sample_paring", 9,0,"rotate",9)
    #    samp(img,img2).show()
    #随机调用一个policy
    impolicy = ImageNetPolicy()
    impolicy(img).show()
    def IMAGENET_policy(self):
        """generates IMAGENET Policy """

        imag_tag = self.im_file
        policy = ImageNetPolicy()
        image = Image.open(imag_tag)
        tag = re.findall('([-\w]+\.(?:jpg|gif|png|JPG|JPEG|jpeg))', imag_tag)
        for t in range(self.count):
            polic = policy(image)
            polic.save(self.data_save_dir + str(t) + tag[0], format="JPEG")
示例#8
0
def get_train_transform(aug=None):
    transforms = []
    transforms.append(RandomCrop(320, pad_if_needed=True))
    transforms.append(RandomHorizontalFlip())

    if 'autoaug' in aug:
        print('=> using auto augmentation.')
        transforms.append(ImageNetPolicy(fillcolor=(128, 128, 128)))

    return Compose(transforms)
示例#9
0
    def __init__(self,
                 image_paths,
                 labels=None,
                 train=True,
                 test=False,
                 aug=None,
                 use_onehot=False):
        self.paths = image_paths
        self.test = test
        self.use_onehot = use_onehot

        if self.test == False:
            self.labels = labels
        self.train = train

        self.transform = []
        # self.transform.append(T.Resize((224,224),interpolation=Image.LANCZOS))
        self.transform.append(T.Resize((448, 448),
                                       interpolation=Image.LANCZOS))

        self.transform.append(T.RandomHorizontalFlip())
        # self.transform.append(T.RandomVerticalFlip())
        # self.transform.append(T.RandomCrop((192,192)))
        self.transform.append(T.RandomCrop((384, 384)))
        if aug is not None and 'autoaug' in aug:
            print('=> using auto augmentation.')
            self.transform.append(ImageNetPolicy(fillcolor=(128, 128, 128)))
        if aug is not None and 'albu' in aug:
            print('=> using albu augmentation.')
            self.transform.append(Rank1Aug(p=0.5))
        if aug is not None and 'cj' in aug:
            print('=> using color jittering.')
            self.transform.append(
                T.ColorJitter(brightness=0.02,
                              contrast=0.02,
                              saturation=0.02,
                              hue=0.01))

        self.transform.append(T.ToTensor())
        self.transform.append(
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
        self.transform.append(RandomErasing())
        self.transform = T.Compose(self.transform)

        self.default_transform = []
        # # self.default_transform.append(T.RandomCrop((832,832)))
        # self.default_transform.append(T.Resize((224,224),interpolation=Image.LANCZOS))
        # self.default_transform.append(T.CenterCrop((192,192)))
        self.default_transform.append(
            T.Resize((448, 448), interpolation=Image.LANCZOS))
        self.default_transform.append(T.CenterCrop((384, 384)))
        self.default_transform.append(T.ToTensor())
        self.default_transform.append(
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
        self.default_transform = T.Compose(self.default_transform)
def transform_training():

    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(cf.resize),
        transforms.RandomHorizontalFlip(),
        ImageNetPolicy(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])  # meanstd transformation

    return transform_train
示例#11
0
def read_train_data(data_dir, flag):
    """
    This function read in the training dataset.
    INPUT: data_dir -- Training dataset directory.
    OUTPUT: fpaths  -- a list of paths of the files in the dataset.
            datas   -- training data.
            labels  -- corresponding labes for the training data.
    """
    datas = []   ## a list of images
    labels = []  ## a list of labels
    fpaths = []  ## a list of image paths
    
    ## AutoAugment policies.
    policy = ImageNetPolicy()
    
    for fname in os.listdir(data_dir):
        fpath = os.path.join(data_dir, fname)
        fpaths.append(fpath)
        
        ## Extract label from file name.
        label = int(fname.split("_")[0])
        
        ## Read image and resize it to 256x256.
        image = Image.open(fpath)
        image = image.resize((256,256),Image.BILINEAR) 
        image = image.convert("RGB")
        
        if flag == 0:
            ## No autoaugmentation
            data = np.array(image) / 255.0 ## Normalization
            datas.append(data)
            labels.append(label)
        else:
            ## With AutoAugment Transformation
            img_list = []
            img_list.append(image) ## add the original image
            labels.append(label)
            ## Randomly select 7 policies.
            for _ in range(7):
                ## Apply autoaugment policies to images.
                img_list.append(policy(image))
                labels.append(label)
            for i in img_list:
                data = np.array(i) / 255.0 
                datas.append(data)
        
    datas = np.array(datas)
    labels = np.array(labels)

    print("Shape of training datas: {}\tshape of labels: {}".format(datas.shape, labels.shape))
    return fpaths, datas, labels
示例#12
0
def data_auto_augmentation(scale=2):
    """ Auto Augmentation on train dataset.

        Example:
        >>> data_auto_augmentation(scale=2)
        
    """
    print('#' * 20)
    print("Start Image augmentation...")
    print('#' * 20)
    path_y = DATA_PATH_TRAIN + 'y/'
    save_dir_y = DATA_PATH_TRAIN + 'y/'
    path_n = DATA_PATH_TRAIN + 'n/'
    save_dir_n = DATA_PATH_TRAIN + 'n/'
    # add augmentation on dataset
    j = 0
    for i in tqdm.tqdm(os.listdir(path_y)):
        try:
            for _ in range(scale):
                img = PIL.Image.open(path_y + i)
                policy = ImageNetPolicy()
                img1 = policy(img)
                img1.save(save_dir_y + '{}.jpg'.format(j))
                j += 1
        except:
            pass
    for i in tqdm.tqdm(os.listdir(path_n)):
        try:
            for _ in range(scale):
                img = PIL.Image.open(path_n + i)
                policy = ImageNetPolicy()
                img1 = policy(img)
                img1.save(save_dir_n + '{}.jpg'.format(j))
                j += 1
        except:
            pass
    print("Done.....")
示例#13
0
    def __init__(self, root, ann_file, is_train=True, size=299):

        # load annotations
        print('Loading annotations from: ' + os.path.basename(ann_file))
        with open(ann_file) as data_file:
            ann_data = json.load(data_file)

        # set up the filenames and annotations
        self.imgs = [aa['file_name'] for aa in ann_data['images']]
        self.ids = [aa['id'] for aa in ann_data['images']]

        # if we dont have class labels set them to '0'
        if 'annotations' in ann_data.keys():
            self.classes = [
                aa['category_id'] for aa in ann_data['annotations']
            ]
        else:
            self.classes = [0] * len(self.imgs)

        # print out some stats
        print('\t' + str(len(self.imgs)) + ' images')
        print('\t' + str(len(set(self.classes))) + ' classes')

        self.root = root
        self.is_train = is_train
        self.loader = default_loader

        # augmentation params
        self.im_size = [size, size]
        self.mu_data = [0.485, 0.456, 0.406]
        self.std_data = [0.229, 0.224, 0.225]
        self.brightness = 0.4
        self.contrast = 0.4
        self.saturation = 0.4
        self.hue = 0.25

        # augmentations
        self.center_crop = transforms.CenterCrop(
            (self.im_size[0], self.im_size[1]))
        self.scale_aug = transforms.RandomResizedCrop(size=self.im_size[0])
        self.flip_aug = transforms.RandomHorizontalFlip()
        self.color_aug = transforms.ColorJitter(self.brightness, self.contrast,
                                                self.saturation, self.hue)
        self.tensor_aug = transforms.ToTensor()
        self.norm_aug = transforms.Normalize(mean=self.mu_data,
                                             std=self.std_data)

        self.autoaugment = ImageNetPolicy()
示例#14
0
def get_loaders(traindir,
                valdir,
                sz,
                bs,
                val_bs=None,
                workers=8,
                use_ar=False,
                min_scale=0.08,
                distributed=False,
                autoaugment=False):
    val_bs = val_bs or bs
    train_tfms = [
        # AdaptiveRandomResizedCrop(sz, scale=(min_scale, 1.0)),
        transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
        transforms.RandomHorizontalFlip()
    ]
    if autoaugment: train_tfms.append(ImageNetPolicy())
    train_dataset = datasets.ImageFolder(traindir,
                                         transforms.Compose(train_tfms))
    train_sampler = (torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=get_world_size(), rank=get_rank())
                     if distributed else None)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=bs,
                                               shuffle=(train_sampler is None),
                                               num_workers=workers,
                                               pin_memory=True,
                                               collate_fn=fast_collate,
                                               sampler=train_sampler)

    val_dataset, val_sampler = create_validation_set(valdir,
                                                     val_bs,
                                                     sz,
                                                     use_ar=use_ar,
                                                     distributed=distributed)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             num_workers=workers,
                                             pin_memory=True,
                                             collate_fn=fast_collate,
                                             batch_sampler=val_sampler)

    return train_loader, val_loader, train_sampler, val_sampler
示例#15
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for k, v in config['common'].items():
        setattr(args, k, v)

    gpu_num = torch.cuda.device_count()

    if args.distributed:
        args.rank, args.size = init_processes(args.dist_addr, args.dist_port,
                                              gpu_num, args.dist_backend)
        print("=> using {} GPUS for distributed training".format(args.size))
    else:
        args.rank = 0
        print("=> using {} GPUS for training".format(gpu_num))

    # create model
    print("=> creating model '{}'".format(args.arch))
    model = model_zoo[args.arch](num_classes=args.num_classes)

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        if args.finetune:
            model = torch.nn.parallel.DistributedDataParallel(
                model, [args.rank], find_unused_parameters=True)
        else:
            model = torch.nn.parallel.DistributedDataParallel(
                model, [args.rank])
        print('create DistributedDataParallel model successfully', args.rank)

    if args.rank == 0:
        mkdir_if_no_exist(args.save_path,
                          subdirs=['events/', 'logs/', 'checkpoints/'])
        tb_logger = SummaryWriter('{}/events'.format(args.save_path))
        logger = create_logger('global_logger',
                               '{}/logs/log.txt'.format(args.save_path))
        logger.debug(args)  # log args only to file
    else:
        tb_logger = None
        logger = None
    checkpoint_path = os.path.join(args.save_path, 'checkpoints/ckpt')

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.pretrained:
        load_ckpt(args.pretrained,
                  model,
                  ignores=['module.fc.weight', 'module.fc.bias'],
                  strict=False)
    if args.evaluate:
        load_ckpt(args.pretrained, model)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            args.start_epoch, best_prec1 = load_ckpt(args.resume,
                                                     model,
                                                     optimizer=optimizer)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.finetune:
        for param in model.parameters():
            param.requires_grad = False
        model.module.fc.weight.requires_grad = True
        model.module.fc.bias.requires_grad = True

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = FileListDataset(
        args.train_list,
        args.train_root,
        transforms.Compose([
            transforms.RandomResizedCrop(args.input_size),
            transforms.RandomHorizontalFlip(),
            #CIFAR10Policy(),
            ImageNetPolicy(),
            transforms.ToTensor(),
            ColorAugmentation(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(FileListDataset(
        args.val_list, args.val_root,
        transforms.Compose([
            transforms.Resize(args.image_size),
            transforms.CenterCrop(args.input_size),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, logger, args.print_freq,
                 args.rank)
        return

    assert max(args.lr_steps) < args.epochs
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, args.lr_steps, args.gamma)

    for _ in range(args.start_epoch):
        lr_scheduler.step()

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, logger,
              tb_logger, args.print_freq, args.rank)
        lr_scheduler.step()

        # evaluate on validation set
        prec1, loss = validate(val_loader,
                               model,
                               criterion,
                               logger,
                               args.print_freq,
                               args.rank,
                               epoch=epoch + 1,
                               save_path=args.save_path)

        # remember best prec@1 and save checkpoint
        if args.rank == 0:
            tb_logger.add_scalar('test_acc', prec1, epoch)
            tb_logger.add_scalar('test_loss', loss, epoch)
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_ckpt(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, checkpoint_path, epoch + 1, is_best)
示例#16
0
from autoaugment import ImageNetPolicy
import os
import PIL
import matplotlib.pyplot as plt

dataset_path = '/media/mxq/data/competition/HuaWei/train_data'
images = [f for f in os.listdir(dataset_path) if f.endswith('.jpg')]
for image in images:
    path = os.path.join(dataset_path, image)
    image = PIL.Image.open(path)
    policy = ImageNetPolicy()
    transformed = policy(image)
    plt.imshow(transformed)
    plt.show()
    transform_train_list = transform_train_list + [
        RandomErasing(probability=opt.erasing_p, mean=[0.0, 0.0, 0.0])
    ]

if opt.color_jitter:
    transform_train_list = [
        transforms.ColorJitter(
            brightness=0.1, contrast=0.1, saturation=0.1, hue=0)
    ] + transform_train_list
    transform_satellite_list = [
        transforms.ColorJitter(
            brightness=0.1, contrast=0.1, saturation=0.1, hue=0)
    ] + transform_satellite_list

if opt.DA:
    transform_train_list = [ImageNetPolicy()] + transform_train_list

print(transform_train_list)
data_transforms = {
    'train': transforms.Compose(transform_train_list),
    'val': transforms.Compose(transform_val_list),
    'satellite': transforms.Compose(transform_satellite_list)
}

train_all = ''
if opt.train_all:
    train_all = '_all'

image_datasets = {}
image_datasets['satellite'] = datasets.ImageFolder(
    os.path.join(data_dir, 'satellite'), data_transforms['satellite'])
示例#18
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

# create model
#####################################################################################

    if args.pretrained:
        if args.arch.startswith('efficientnet-b'):
            print('=> using pre-trained {}'.format(args.arch))
            model = EfficientNet.from_pretrained(args.arch,
                                                 advprop=args.advprop)

        else:
            print("=> using pre-trained model '{}'".format(args.arch))
            model = models.__dict__[args.arch](pretrained=True)
    else:
        if args.arch.startswith('efficientnet-b'):
            print("=> creating model {}".format(args.arch))
            model = EfficientNet.from_name(args.arch)
        elif args.arch.startswith('Dense'):
            print("=> creating model {}".format(args.arch))
            model = DenseNet40()
        else:
            print("=> creating model '{}'".format(args.arch))
            model = models.__dict__[args.arch]()

    # create teacher model
    if args.kd:
        print('=> loading teacher model')
        if args.teacher_arch.startswith('efficientnet-b'):
            teacher = EfficientNet.from_pretrained(args.teacher_arch)
            teacher.eval()
            print('=> {} loaded'.format(args.teacher_arch))

        elif args.teacher_arch.startswith('resnext101_32'):
            teacher = torch.hub.load('facebookresearch/WSL-Images',
                                     '{}_wsl'.format(args.teacher_arch))
            teacher.eval()
            print('=> {} loaded'.format(args.teacher_arch))
        elif args.overhaul:
            teacher = resnet.resnet152(pretrained=True)
        else:
            teacher = models.__dict__[args.teacher_arch](pretrained=True)
            teacher.eval()
            print('=> {} loaded'.format(args.teacher_arch))

        if args.overhaul:
            print('=> using overhaul distillation')
            d_net = Distiller(teacher, model)

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
            if args.kd:
                teacher = torch.nn.DataParallel(teacher).cuda()
                if args.overhaul:
                    d_net = torch.nn.DataParallel(d_net).cuda()

    if args.pretrained:
        if args.arch.startswith('efficientnet-b'):
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(args.pth_path, map_location=loc)
            model.load_state_dict(checkpoint['state_dict'])
#####################################################################################

# define loss function (criterion) and optimizer, scheduler
#####################################################################################
    if args.kd:
        criterion = kd_criterion
        if args.overhaul:
            criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    else:
        criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    if args.overhaul:
        optimizer = torch.optim.SGD(list(model.parameters()) +
                                    list(d_net.module.Connectors.parameters()),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)  # nesterov
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=args.lr,
                                      betas=(0.9, 0.999),
                                      eps=1e-08,
                                      weight_decay=args.weight_decay,
                                      amsgrad=False)
        scheduler = CosineAnnealingLR(optimizer,
                                      T_max=args.epochs *
                                      int(1281167 / args.batch_size),
                                      eta_min=0,
                                      last_epoch=-1)
        args.lr = 0.048
        args.batch_size = 384
        parameters = add_weight_decay(model, 1e-5)
        optimizer = RMSpropTF(parameters,
                              lr=0.048,
                              alpha=0.9,
                              eps=0.001,
                              momentum=args.momentum,
                              weight_decay=1e-5)

        scheduler = StepLRScheduler(
            optimizer,
            decay_t=2.4,
            decay_rate=0.97,
            warmup_lr_init=1e-6,
            warmup_t=3,
            noise_range_t=None,
            noise_pct=getattr(args, 'lr_noise_pct', 0.67),
            noise_std=getattr(args, 'lr_noise_std', 1.),
            noise_seed=getattr(args, 'seed', 42),
        )
    # scheduler = MultiStepLR(optimizer, milestones=args.schedule, gamma=args.gamma)
    # milestone = np.ceil(np.arange(0,300,2.4))

    # scheduler = MultiStepLR(optimizer, milestones=[30,60,90,120,150,180,210,240,270], gamma=0.1)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
    #####################################################################################

    # Data loading code
    #####################################################################################
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if args.advprop:
        normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

    train_dataset = ImageFolder_iid(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(p=0.5),
            ImageNetPolicy(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(ImageFolder_iid(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    #####################################################################################

    if args.evaluate:
        validate(val_loader, model, criterion, args)

# Start training
#####################################################################################
    best_acc1 = 0
    teacher_name = ''
    student_name = ''

    ema = EMA(model, 0.9999)
    ema.register()
    for epoch in range(args.start_epoch, args.epochs):

        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        if args.kd:
            if args.overhaul:
                train_with_overhaul(train_loader, d_net, optimizer, criterion,
                                    epoch, args)
                acc1 = validate_overhaul(val_loader, model, criterion, epoch,
                                         args)
            else:
                train_kd(train_loader, teacher, model, criterion, optimizer,
                         epoch, args)
                acc1 = validate_kd(val_loader, teacher, model, criterion, args)

                teacher_name = teacher.module.__class__.__name__

        else:
            student_name = model.module.__class__.__name__
            train(train_loader, model, criterion, optimizer, epoch, args, ema)
            acc1 = validate(val_loader, model, criterion, args, ema)

        # remember best acc@1 and save checkpoint
        #writer.add_scalars('acc1', acc1, epoch)
        is_best = acc1 > best_acc1
        if acc1 < 65:
            print(colored('not saving... accuracy smaller than 65', 'green'))
            is_best = False
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                teacher_name=teacher_name,
                student_name=student_name,
                save_path=args.save_path,
                acc=acc1)

        scheduler.step(epoch)
示例#19
0
def prepare_dataset(
    data_path,
    image_names_reg=None,
    image_classes_rule=None,
    batch_size=128,
    random_status=2,
    cache=True,
    shuffle_buffer_size=None,
    is_train=True,
):
    image_names, image_classes, classes = pre_process_folder(
        data_path, image_names_reg, image_classes_rule)
    if len(image_names) == 0:
        return None, 0
    print(len(image_names), len(image_classes), classes)

    data_df = pd.DataFrame({
        "image_names": image_names,
        "image_classes": image_classes
    })
    data_df.image_classes = data_df.image_classes.map(str)

    if is_train:
        if random_status != -1:
            image_gen = ImageDataGenerator(
                rescale=1.0 / 255,
                horizontal_flip=True,
                # samplewise_std_normalization=True,
                rotation_range=random_status * 5,
                # width_shift_range=random_status * 0.05,
                # height_shift_range=random_status * 0.05,
                brightness_range=(1.0 - random_status * 0.1,
                                  1.0 + random_status * 0.1),
                shear_range=random_status * 5,
                zoom_range=random_status * 0.15,
                fill_mode="constant",
                cval=0,
                preprocessing_function=lambda img: image_aug_random(
                    img, random_status),
            )
        else:
            from autoaugment import ImageNetPolicy

            policy = ImageNetPolicy()
            policy_func = lambda img: np.array(policy(
                tf.keras.preprocessing.image.array_to_img(img)),
                                               dtype=np.float32)
            image_gen = ImageDataGenerator(rescale=1.0 / 255,
                                           horizontal_flip=True,
                                           preprocessing_function=policy_func)
    else:
        image_gen = ImageDataGenerator(rescale=1.0 / 255)

    train_data_gen = image_gen.flow_from_dataframe(
        data_df,
        directory=None,
        x_col="image_names",
        y_col="image_classes",
        class_mode="categorical",
        target_size=(112, 112),
        batch_size=batch_size,
        validate_filenames=False,
    )
    classes = data_df.image_classes.unique().shape[0]
    steps_per_epoch = np.ceil(data_df.shape[0] / batch_size)
    """ Convert to tf.data.Dataset """
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    train_ds = tf.data.Dataset.from_generator(
        lambda: train_data_gen,
        output_types=(tf.float32, tf.int32),
        output_shapes=([None, 112, 112, 3], [None, classes]))
    # train_ds = train_ds.cache()
    # if shuffle_buffer_size == None:
    #     shuffle_buffer_size = batch_size * 100

    # train_ds = train_ds.shuffle(buffer_size=shuffle_buffer_size)
    # if cache:
    #     train_ds = train_ds.cache(cache) if isinstance(cache, str) else train_ds.cache()
    if is_train:
        train_ds = train_ds.repeat()
    train_ds = train_ds.map(lambda xx, yy:
                            ((tf.clip_by_value(xx, 0.0, 1.0) - 0.5) * 2, yy),
                            num_parallel_calls=AUTOTUNE)
    train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)

    return train_ds, steps_per_epoch, classes
示例#20
0
        transforms.ToTensor(),
        transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
    ]
)


def transform_strong_c1m_c10(x):
    return transform_strong_c1m_c10_compose(x)


transform_strong_c1m_in_compose = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        ImageNetPolicy(),
        transforms.ToTensor(),
        transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
    ]
)


def transform_strong_c1m_in(x):
    return transform_strong_c1m_in_compose(x)


class clothing_dataset(Dataset):
    def __init__(
        self,
        root,
        transform,
示例#21
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

# create model
#####################################################################################

    if args.pretrained:
        if args.arch.startswith('efficientnet-b'):
            print('=> using pre-trained {}'.format(args.arch))
            model = EfficientNet.from_pretrained(args.arch, advprop=args.advprop)

        else:
            print("=> using pre-trained model '{}'".format(args.arch))
            model = models.__dict__[args.arch](pretrained=True)
    else:
        if args.arch.startswith('efficientnet-b'):
            print("=> creating model {}".format(args.arch))
            model = EfficientNet.from_name(args.arch)
        elif args.arch.startswith('Dense'):
            print("=> creating model {}".format(args.arch))
            model = DenseNet40()
        else:
            print("=> creating model '{}'".format(args.arch))
            model = models.__dict__[args.arch]()

    # create teacher model
    if args.kd:
        print('=> loading teacher model')
        if args.teacher_arch.startswith('efficientnet-b'):
            teacher = EfficientNet.from_pretrained(args.teacher_arch)
            teacher.eval()
            print('=> {} loaded'.format(args.teacher_arch))

        elif args.teacher_arch.startswith('resnext101_32'):
            teacher = torch.hub.load('facebookresearch/WSL-Images', '{}_wsl'.format(args.teacher_arch))
            teacher.eval()
            print('=> {} loaded'.format(args.teacher_arch))
        elif args.overhaul:
            teacher = resnet.resnet152(pretrained=True)
        else:
            teacher = models.__dict__[args.teacher_arch](pretrained=True)
            teacher.eval()
            print('=> {} loaded'.format(args.teacher_arch))

        if args.overhaul:
            print('=> using overhaul distillation')
            d_net = Distiller(teacher, model)

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
            if args.kd:
                teacher = torch.nn.DataParallel(teacher).cuda()
                if args.overhaul:
                    d_net = torch.nn.DataParallel(d_net).cuda()

    if args.pretrained:
        if args.arch.startswith('efficientnet-b'):
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(args.pth_path, map_location=loc)
            model.load_state_dict(checkpoint['state_dict'])
#####################################################################################


# define loss function (criterion) and optimizer, scheduler
#####################################################################################
    if args.kd:
        criterion = kd_criterion
        if args.overhaul:
            criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    else:
        criterion = nn.CrossEntropyLoss().cuda(args.gpu)


    if args.overhaul:
        optimizer = torch.optim.SGD(list(model.parameters()) + list(d_net.module.Connectors.parameters()), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)  # nesterov
    else:
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay, amsgrad=False)
        scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs * int(1281167 / args.batch_size), eta_min=0,
                                      last_epoch=-1)
        args.lr = 0.048
        args.bs = 384
        optimizer = torch.optim.RMSprop(
            model.parameters(), lr=args.lr, alpha=0.9, eps=.001,
            momentum=0.9, weight_decay=args.weight_decay)

        from typing import Dict, Any
        class Scheduler:
            """ Parameter Scheduler Base Class
            A scheduler base class that can be used to schedule any optimizer parameter groups.
            Unlike the builtin PyTorch schedulers, this is intended to be consistently called
            * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
            * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
            The schedulers built on this should try to remain as stateless as possible (for simplicity).
            This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
            and -1 values for special behaviour. All epoch and update counts must be tracked in the training
            code and explicitly passed in to the schedulers on the corresponding step or step_update call.
            Based on ideas from:
             * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
             * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
            """

            def __init__(self,
                         optimizer: torch.optim.Optimizer,
                         param_group_field: str,
                         noise_range_t=None,
                         noise_type='normal',
                         noise_pct=0.67,
                         noise_std=1.0,
                         noise_seed=None,
                         initialize: bool = True) -> None:
                self.optimizer = optimizer
                self.param_group_field = param_group_field
                self._initial_param_group_field = f"initial_{param_group_field}"
                if initialize:
                    for i, group in enumerate(self.optimizer.param_groups):
                        if param_group_field not in group:
                            raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
                        group.setdefault(self._initial_param_group_field, group[param_group_field])
                else:
                    for i, group in enumerate(self.optimizer.param_groups):
                        if self._initial_param_group_field not in group:
                            raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
                self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
                self.metric = None  # any point to having this for all?
                self.noise_range_t = noise_range_t
                self.noise_pct = noise_pct
                self.noise_type = noise_type
                self.noise_std = noise_std
                self.noise_seed = noise_seed if noise_seed is not None else 42
                self.update_groups(self.base_values)

            def state_dict(self) -> Dict[str, Any]:
                return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

            def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
                self.__dict__.update(state_dict)

            def get_epoch_values(self, epoch: int):
                return None

            def get_update_values(self, num_updates: int):
                return None

            def step(self, epoch: int, metric: float = None) -> None:
                self.metric = metric
                values = self.get_epoch_values(epoch)
                if values is not None:
                    values = self._add_noise(values, epoch)
                    self.update_groups(values)

            def step_update(self, num_updates: int, metric: float = None):
                self.metric = metric
                values = self.get_update_values(num_updates)
                if values is not None:
                    values = self._add_noise(values, num_updates)
                    self.update_groups(values)

            def update_groups(self, values):
                if not isinstance(values, (list, tuple)):
                    values = [values] * len(self.optimizer.param_groups)
                for param_group, value in zip(self.optimizer.param_groups, values):
                    param_group[self.param_group_field] = value

            def _add_noise(self, lrs, t):
                if self.noise_range_t is not None:
                    if isinstance(self.noise_range_t, (list, tuple)):
                        apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
                    else:
                        apply_noise = t >= self.noise_range_t
                    if apply_noise:
                        g = torch.Generator()
                        g.manual_seed(self.noise_seed + t)
                        if self.noise_type == 'normal':
                            while True:
                                # resample if noise out of percent limit, brute force but shouldn't spin much
                                noise = torch.randn(1, generator=g).item()
                                if abs(noise) < self.noise_pct:
                                    break
                        else:
                            noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
                        lrs = [v + v * noise for v in lrs]
                return lrs
        class StepLRScheduler(Scheduler):
            """
            """

            def __init__(self,
                         optimizer: torch.optim.Optimizer,
                         decay_t: float,
                         decay_rate: float = 1.,
                         warmup_t=0,
                         warmup_lr_init=0,
                         t_in_epochs=True,
                         noise_range_t=None,
                         noise_pct=0.67,
                         noise_std=1.0,
                         noise_seed=42,
                         initialize=True,
                         ) -> None:
                super().__init__(
                    optimizer, param_group_field="lr",
                    noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
                    initialize=initialize)

                self.decay_t = decay_t
                self.decay_rate = decay_rate
                self.warmup_t = warmup_t
                self.warmup_lr_init = warmup_lr_init
                self.t_in_epochs = t_in_epochs
                if self.warmup_t:
                    self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
                    super().update_groups(self.warmup_lr_init)
                else:
                    self.warmup_steps = [1 for _ in self.base_values]

            def _get_lr(self, t):
                if t < self.warmup_t:
                    lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
                else:
                    lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
                return lrs

            def get_epoch_values(self, epoch: int):
                if self.t_in_epochs:
                    return self._get_lr(epoch)
                else:
                    return None

            def get_update_values(self, num_updates: int):
                if not self.t_in_epochs:
                    return self._get_lr(num_updates)
                else:
                    return None

        scheduler = StepLRScheduler(
            optimizer,
            decay_t=2.4,
            decay_rate=0.97,
            warmup_lr_init=1e-6,
            warmup_t=3,
            noise_range_t=None,
            noise_pct=getattr(args, 'lr_noise_pct', 0.67),
            noise_std=getattr(args, 'lr_noise_std', 1.),
            noise_seed=getattr(args, 'seed', 42),
        )
    # scheduler = MultiStepLR(optimizer, milestones=args.schedule, gamma=args.gamma)
    # milestone = np.ceil(np.arange(0,300,2.4))

    # scheduler = MultiStepLR(optimizer, milestones=[30,60,90,120,150,180,210,240,270], gamma=0.1)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
#####################################################################################


# Data loading code
#####################################################################################
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if args.advprop:
        normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])


    train_dataset = ImageFolder_iid(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(p=0.5),
            ImageNetPolicy(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        ImageFolder_iid(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
#####################################################################################

    if args.evaluate:
        validate(val_loader, model, criterion, args)

# Start training
#####################################################################################
    best_acc1 = 0
    teacher_name = ''
    student_name = ''
    for epoch in range(args.start_epoch, args.epochs):

        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        if args.kd:
            if args.overhaul:
                train_with_overhaul(train_loader, d_net, optimizer, criterion, epoch, args)
                acc1 = validate_overhaul(val_loader, model, criterion, epoch, args)
            else:
                train_kd(train_loader, teacher, model, criterion, optimizer, epoch, args)
                acc1 = validate_kd(val_loader, teacher, model, criterion, args)

                teacher_name = teacher.module.__class__.__name__

        else:
            student_name = model.module.__class__.__name__
            train(train_loader, model, criterion, optimizer, epoch, args)
            acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        #writer.add_scalars('acc1', acc1, epoch)
        is_best = acc1 > best_acc1
        if acc1 < 65:
            print(colored('not saving... accuracy smaller than 65','green'))
            is_best = False
        best_acc1 = max(acc1, best_acc1)


        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, is_best, teacher_name=teacher_name, student_name=student_name, save_path=args.save_path, acc=acc1)

        scheduler.step(epoch)
示例#22
0
    def __init__(self,
                 images_path,
                 labels_path,
                 mode=False,
                 setname='train',
                 way=5,
                 shot=1,
                 query=15,
                 augmentation=False,
                 augment_rate=0.5):
        assert os.path.exists(images_path), "threre is no directory {}".format(
            images_path)
        assert os.path.exists(labels_path), "there is no directory {}".format(
            labels_path)

        self.mode = mode
        self.way = way
        self.shot = shot
        self.query = query
        self.augmentation = augmentation
        self.augment_rate = augment_rate

        # static settings
        self.channel = 3
        self.size = 84

        self.datas = []
        self.labels = np.array([])
        self.num_classes = -1
        with open(os.path.join(labels_path, setname + ".csv")) as f:
            # remove first head
            lines = f.readlines()[1:]
            temp = []
            for line in lines:
                filename, label = line.split(',')
                self.datas.append(os.path.join(images_path, filename))
                if label not in temp:
                    temp.append(label)
                    self.num_classes += 1
                self.labels = np.append(self.labels, self.num_classes)
            self.labels = torch.from_numpy(self.labels)
        self.num_classes += 1

        # default transform
        self.transform = transforms.Compose([
            transforms.Resize(self.size),
            transforms.CenterCrop(self.size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        # autoaugmentation transform
        self.transform_autoaugment = transforms.Compose([
            transforms.Resize(self.size),
            transforms.CenterCrop(self.size),
            ImageNetPolicy(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
示例#23
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](low_dim=args.low_dim)

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    if not args.auto_aug:
        train_dataset = datasets.ImageFolderInstance(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
    else:
        sys.path.append('/home/chengxuz/AutoAugment')
        from autoaugment import ImageNetPolicy
        train_dataset = datasets.ImageFolderInstance(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
                transforms.RandomHorizontalFlip(),
                ImageNetPolicy(),
                transforms.ToTensor(),
                normalize,
            ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolderInstance(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define lemniscate and loss function (criterion)
    ndata = train_dataset.__len__()
    if args.nce_k > 0:
        lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t,
                                args.nce_m).cuda()
        criterion = NCECriterion(ndata).cuda()
    else:
        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                                   args.nce_m).cuda()
        criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            lemniscate = checkpoint['lemniscate']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.evaluate:
        kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, lemniscate, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = NN(epoch, model, lemniscate, train_loader, val_loader)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'lemniscate': lemniscate,
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
    # evaluate KNN after last epoch
    kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
示例#24
0
def train():
    data_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        ImageNetPolicy(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    train_set = IMAGE_Dataset(Path(DATASET_ROOT_train), data_transform)
    data_loader = DataLoader(dataset=train_set,
                             batch_size=16,
                             shuffle=True,
                             num_workers=1)
    model = torch.hub.load('pytorch/vision:v0.6.0',
                           'googlenet',
                           pretrained=True)
    #	model = models.vgg19(pretrained=True)
    #	final_in_features = model.classifier[6].in_features
    #	model.classifier[6].out_features=train_set.num_classes
    model = model.cuda(CUDA_DEVICES)
    model.train()

    best_model_params = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    num_epochs = 200
    criterion = nn.CrossEntropyLoss()

    stepsize = 20
    base_lr = 0.001
    max_lr = 0.01
    base_mm = 0.8
    max_mm = 0.99

    for epoch in range(num_epochs):
        #	newlr = get_triangular_lr(epoch,stepsize,base_lr,max_lr)
        #mm=get_dynamic_momentum(epoch,stepsize,base_mm,max_mm)
        optimizer = torch.optim.SGD(params=model.parameters(),
                                    lr=0.001,
                                    momentum=0.9)
        print(f'Epoch: {epoch + 1}/{num_epochs}')
        print('-' * len(f'Epoch: {epoch + 1}/{num_epochs}'))

        training_loss = 0.0
        training_corrects = 0

        for i, (inputs, labels) in enumerate(data_loader):
            inputs = Variable(inputs.cuda(CUDA_DEVICES))
            labels = Variable(labels.cuda(CUDA_DEVICES))

            optimizer.zero_grad()

            outputs = model(inputs)

            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            training_loss += loss.item() * inputs.size(0)
            #	print(training_loss)
            #revise loss.data[0]-->loss.item()
            training_corrects += torch.sum(preds == labels.data)
            #print(f'training_corrects: {training_corrects}')

        training_loss = training_loss / len(train_set)
        training_acc = training_corrects.double() / len(train_set)
        print(
            f'Training loss: {training_loss:.4f}\taccuracy: {training_acc:.4f}\n'
        )

        test_acc = test(model)

        if test_acc > best_acc:
            best_acc = test_acc
            best_model_params = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_model_params)
    torch.save(model, f'model-{best_acc:.02f}-best_train_acc.pth')
示例#25
0
def main():
    options, args = parse_args()
    logging.basicConfig(
        format=
        '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    # Initialize model
    num_classes = 196
    num_attentions = 64
    start_epoch = 0
    if options.model == 'resnetcbam':
        feature_net = resnet152_cbam(pretrained=True)
    elif options.model == 'efficientnetb3':
        feature_net = EfficientNet.from_pretrained('efficientnet-b3')
    elif options.model == 'inception':
        feature_net = inception_v3(pretrained=True)
    else:
        raise NotImplementedError(
            f'Invalid model name {options.model}, acceptable values are \
                                    inception/resnetcbam/efficientnetb3/efficientnetb4'
        )
    net = WSDAN(num_classes=num_classes, M=num_attentions, net=feature_net)

    # feature_center: size of (#classes, #attention_maps, #channel_features)
    feature_center = torch.zeros(num_classes, num_attentions,
                                 net.num_features * net.expansion).to(
                                     torch.device(device))

    if options.ckpt:
        ckpt = options.ckpt
        start_epoch = int((ckpt.split('/')[-1]).split('.')[0])

        # Load ckpt and get state_dict
        checkpoint = torch.load(ckpt)
        state_dict = checkpoint['state_dict']

        # Load weights
        net.load_state_dict(state_dict)
        logging.info('Network loaded from {}'.format(options.ckpt))

        # load feature center
        if 'feature_center' in checkpoint:
            feature_center = checkpoint['feature_center'].to(
                torch.device(device))
            logging.info('feature_center loaded from {}'.format(options.ckpt))

    # Initialize saving directory
    save_dir = options.save_dir
    os.makedirs(save_dir, exist_ok=True)

    # Use cuda
    cudnn.benchmark = True
    net.to(torch.device(device))
    net = nn.DataParallel(net)

    # Load dataset
    cwd = Path.cwd()
    if not options.data_dir:
        data_dir = cwd.parent / 'data' / 'stanford-car-dataset-by-classes-folder' / 'car_data_new_data_in_train_v2'
    else:
        data_dir = options.data_dir

    preprocess_with_augment = transforms.Compose([
        transforms.Resize(size=(image_size[0], image_size[1]),
                          interpolation=Image.LANCZOS),
        ImageNetPolicy(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    preprocess = transforms.Compose([
        transforms.Resize(size=(image_size[0], image_size[1]),
                          interpolation=Image.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    train_dataset = ImageFolder(str(data_dir / 'train'),
                                transform=preprocess_with_augment)
    validate_dataset = ImageFolder(str(data_dir / 'test'),
                                   transform=preprocess)

    train_loader = DataLoader(train_dataset,
                              batch_size=options.batch_size,
                              shuffle=True,
                              num_workers=options.workers,
                              pin_memory=True)
    validate_loader = DataLoader(validate_dataset,
                                 batch_size=options.batch_size * 4,
                                 shuffle=False,
                                 num_workers=options.workers,
                                 pin_memory=True)

    # Optimizer and loss
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=options.lr,
                                momentum=0.9,
                                weight_decay=0.00001)
    loss = nn.CrossEntropyLoss()

    # Learning rate scheduling
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.1,
                                                           patience=10,
                                                           verbose=True,
                                                           threshold=0.00001)

    # Training
    logging.info('')
    logging.info((
        f'Start training: Total epochs: {options.epochs}, Batch size: {options.batch_size}, '
        f'Training size: {len(train_dataset)}, Validation size: {len(validate_dataset)}'
    ))
    best_val_acc = 0
    best_val_epoch = 0
    for epoch in range(start_epoch, options.epochs):
        train(epoch=epoch,
              data_loader=train_loader,
              net=net,
              feature_center=feature_center,
              loss=loss,
              optimizer=optimizer,
              save_freq=options.save_freq,
              save_dir=options.save_dir,
              verbose=options.verbose)
        val_loss, val_acc = validate(data_loader=validate_loader,
                                     net=net,
                                     loss=loss,
                                     verbose=options.verbose)
        scheduler.step(val_loss)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_epoch = epoch
        logging.info(
            f'Best Validation Accuracy: {best_val_acc}, Epoch: {best_val_epoch + 1}'
        )
示例#26
0
def get_input_visualwakewords():
    # VISUALWAKEWORDS_CONFIG = {
    #     "train_instances": "/home/el/Datasets/COCO14/visualwakewords/instances_visualwakewords_train2014.json",
    #     "train_images": "/home/el/Datasets/COCO14/train2014",
    #     "val_instances": "/home/el/Datasets/COCO14/visualwakewords/instances_visualwakewords_val2014.json",
    #     "val_images": "/home/el/Datasets/COCO14/val2014",
    #     "filter_list": "/home/el/Datasets/COCO14/visualwakewords/mscoco_minival_ids.txt"
    # }

    target_side_size = 128
    train = VisualWakeWords(VISUALWAKEWORDS_CONFIG["train_instances"],
                            VISUALWAKEWORDS_CONFIG["train_images"],
                            shuffle=True)
    test = VisualWakeWords(VISUALWAKEWORDS_CONFIG["val_instances"],
                           VISUALWAKEWORDS_CONFIG["val_images"],
                           filter_list=VISUALWAKEWORDS_CONFIG["filter_list"],
                           shuffle=False)

    train_augmentors = imgaug.AugmentorList([
        imgaug.GoogleNetRandomCropAndResize(crop_area_fraction=(0.25, 1.),
                                            target_shape=target_side_size,
                                            interp=cv2.INTER_CUBIC),
        imgaug.Flip(horiz=True),
    ])

    autoaugment_policy = ImageNetPolicy()

    test_augmentors = imgaug.AugmentorList([
        imgaug.ResizeShortestEdge(int(target_side_size * 1.2),
                                  interp=cv2.INTER_CUBIC),
        imgaug.CenterCrop((target_side_size, target_side_size)),
    ])

    def preprocess(train):
        def apply(x):
            image, label = x
            onehot = np.zeros(2)
            onehot[label] = 1.0
            augmentors = train_augmentors if train else test_augmentors
            image = augmentors.augment(image)
            if train:
                image = np.array(autoaugment_policy(Image.fromarray(image)))
            return image, onehot
            # mean = [0.4767, 0.4488, 0.4074]
            # std = [0.2363, 0.2313, 0.2330]
            # return (image / 255.0 - mean) / std, onehot

        return apply

    parallel = min(18,
                   multiprocessing.cpu_count() // 2)  # assuming hyperthreading
    train = MapData(train, preprocess(train=True))
    train = PrefetchDataZMQ(train, parallel)

    test = MultiThreadMapData(test,
                              parallel,
                              preprocess(train=False),
                              strict=True)
    test = PrefetchDataZMQ(test, 1)

    return train, test, ((target_side_size, target_side_size, 3), (2, ))
示例#27
0
    def __init__(self,
                 model,
                 frames_path,
                 labels_path,
                 frame_size,
                 sequence_length,
                 setname='train',
                 random_pad_sample=True,
                 pad_option='default',
                 uniform_frame_sample=True,
                 random_start_position=True,
                 max_interval=7,
                 random_interval=True):

        self.sequence_length = sequence_length

        # pad option => using for _add_pads function
        self.random_pad_sample = random_pad_sample
        assert pad_option in [
            'default', 'autoaugment'
        ], "'{}' is not valid pad option.".format(pad_option)
        self.pad_option = pad_option

        # frame sampler option => using for _frame_sampler function
        self.uniform_frame_sample = uniform_frame_sample
        self.random_start_position = random_start_position
        self.max_interval = max_interval
        self.random_interval = random_interval

        # read a csv file that already separated by splitter.py
        assert setname in ['train', 'test'
                           ], "'{}' is not valid setname.".format(setname)
        if setname == 'train':
            csv = open(os.path.join(labels_path, 'train.csv'))
        if setname == 'test':
            csv = open(os.path.join(labels_path, 'test.csv'))
        self.data_paths = []

        # this value will using for CategoriesSampler class
        self.classes = []  # ex. [1, 1, 1, ..., 61, 61, 61]

        self.labels = {
        }  # ex. {HulaHoop: 1, JumpingJack: 2, ..., Hammering: 61}
        lines = csv.readlines()
        for line in lines:
            label, folder_name = line.rstrip().split(',')
            action = folder_name.split('_')[1]
            self.data_paths.append(os.path.join(frames_path, folder_name))
            self.classes.append(int(label))
            self.labels[action] = int(label)
        csv.close()

        self.num_classes = len(self.labels)

        # select normalize value
        if model == "resnet":
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
        if model == "r2plus1d":
            normalize = transforms.Normalize(mean=[0.43216, 0.394666, 0.37645],
                                             std=[0.22803, 0.22145, 0.216989])

        # transformer in training phase
        if setname == 'train':
            self.transform = transforms.Compose([
                transforms.Resize((frame_size + 16, frame_size + 48)),
                transforms.CenterCrop((frame_size, frame_size)),
                transforms.ToTensor(),
                transforms.ColorJitter(brightness=0.4,
                                       contrast=0.4,
                                       saturation=0.4),
                Lighting(alphastd=0.1,
                         eigval=[0.2175, 0.0188, 0.0045],
                         eigvec=[[-0.5675, 0.7192, 0.4009],
                                 [-0.5808, -0.0045, -0.8140],
                                 [-0.5836, -0.6948, 0.4203]]),
                normalize,
            ])
        else:
            # transformer in validation or testing phase
            self.transform = transforms.Compose([
                transforms.Resize((frame_size, frame_size)),
                transforms.ToTensor(),
                normalize,
            ])

        # autoaugment transformer for insufficient frames in training phase
        self.transform_autoaugment = transforms.Compose([
            transforms.Resize((frame_size + 16, frame_size + 48)),
            transforms.CenterCrop((frame_size, frame_size)),
            ImageNetPolicy(),
            transforms.ToTensor(),
            normalize,
        ])
    def __init__(self, model,
        frames_path, labels_path, list_number, frame_size, sequence_length, train=True,
        random_pad_sample=True, pad_option='default', 
        uniform_frame_sample=True, random_start_position=True, max_interval=7, random_interval=True):

        self.sequence_length = sequence_length

        # pad option => using for _add_pads function
        self.random_pad_sample = random_pad_sample
        assert pad_option in ['default', 'autoaugment'], "The pad option '{}' is not valid, you can try 'default' or 'autoaugment' pad option"
        self.pad_option = pad_option

        # frame sampler option => using for _frame_sampler function
        self.uniform_frame_sample = uniform_frame_sample
        self.random_start_position = random_start_position
        self.max_interval = max_interval
        self.random_interval = random_interval

        # make path list of data
        self.data_paths = []
        assert list_number in [1, 2, 3], "list_number need to be one of 1, 2, 3"
        listfilename = "trainlist0" + str(list_number) + ".txt" if train else "testlist0" + str(list_number) + ".txt"
        with open(os.path.join(labels_path, listfilename)) as f:
            lines = f.readlines()
        for line in lines:
            frame_name = line.split("/")[1].split(".avi")[0]
            self.data_paths.append(os.path.join(frames_path, frame_name))

        # make labels dictionary
        self.labels = {}
        with open(os.path.join(labels_path, "classInd.txt")) as f:
            lines = f.readlines()
        for line in lines:
            number, action = line.split()
            self.labels[action] = int(number) - 1
        
        self.num_classes = len(self.labels)

        # select normalize value
        if model == "resnet":
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        if model == "r2plus1d":
            normalize = transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989])

        # transformer in training phase
        if train:
            self.transform = transforms.Compose([
                transforms.Resize((frame_size + 16, frame_size + 48)),
                transforms.CenterCrop((frame_size, frame_size)),
                transforms.ToTensor(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4
                ),
                Lighting(alphastd=0.1, eigval=[0.2175, 0.0188, 0.0045],
                                        eigvec=[[-0.5675, 0.7192, 0.4009],
                                                [-0.5808, -0.0045, -0.8140],
                                                [-0.5836, -0.6948, 0.4203]]
                ),
                normalize,
            ])
        else:
        # transformer in validation or testing phase
            self.transform = transforms.Compose([
                transforms.Resize((frame_size, frame_size)),
                transforms.ToTensor(),
                normalize,
            ])
        
        # autoaugment transformer for insufficient frames in training phase
        self.transform_autoaugment = transforms.Compose([
            transforms.Resize((frame_size + 16, frame_size + 48)),
            transforms.CenterCrop((frame_size, frame_size)),
            ImageNetPolicy(),
            transforms.ToTensor(),
            normalize,
        ])