Пример #1
0
def get_transforms(auto_augment, input_sizes, m, mean, n, std):
    if auto_augment:
        # AutoAugment + Cutout
        train_transforms = Compose([
            RandomCrop(size=input_sizes, padding=4, fill=128),
            RandomHorizontalFlip(p=0.5),
            CIFAR10Policy(),
            ToTensor(),
            Normalize(mean=mean, std=std),
            Cutout(n_holes=1, length=16)
        ])
    else:
        # RandAugment + Cutout
        train_transforms = Compose([
            RandomCrop(size=input_sizes, padding=4, fill=128),
            RandomHorizontalFlip(p=0.5),
            RandomRandAugment(n=n, m_max=m),  # This version includes cutout
            ToTensor(),
            Normalize(mean=mean, std=std)
        ])
    test_transforms = Compose([
        ToTensor(),
        Normalize(mean=mean, std=std)
    ])

    return test_transforms, train_transforms
Пример #2
0
    def __init__(self, cfg):
        super(DAGDataset, self).__init__()

        self.template_size = cfg.DAG.TRAIN.TEMPLATE_SIZE
        self.search_size = cfg.DAG.TRAIN.SEARCH_SIZE

        self.size = 25
        self.stride = cfg.DAG.TRAIN.STRIDE

        self.color = cfg.DAG.DATASET.COLOR
        self.flip = cfg.DAG.DATASET.FLIP
        self.rotation = cfg.DAG.DATASET.ROTATION
        self.blur = cfg.DAG.DATASET.BLUR
        self.shift = cfg.DAG.DATASET.SHIFT
        self.scale = cfg.DAG.DATASET.SCALE
        self.gray = cfg.DAG.DATASET.GRAY
        self.label_smooth = cfg.DAG.DATASET.LABELSMOOTH
        self.mixup = cfg.DAG.DATASET.MIXUP
        self.cutout = cfg.DAG.DATASET.CUTOUT

        self.shift_s = cfg.DAG.DATASET.SHIFTs
        self.scale_s = cfg.DAG.DATASET.SCALEs
        self.grids()
        self.neg_num = cfg.DAG.TRAIN.NEG_NUM
        self.pos_num = cfg.DAG.TRAIN.POS_NUM
        self.total_num = cfg.DAG.TRAIN.TOTAL_NUM
        self.neg = cfg.DAG.DATASET.NEG
        self.transform_extra = transforms.Compose([
            transforms.ToPILImage(),
        ] + ([
            transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
        ] if self.color > random.random() else []) + ([
            transforms.RandomHorizontalFlip(),
        ] if self.flip > random.random() else []) + ([
            transforms.RandomRotation(degrees=10),
        ] if self.rotation > random.random() else []) + ([
            transforms.Grayscale(num_output_channels=3),
        ] if self.gray > random.random() else []) + (
            [Cutout(n_holes=1, length=16
                    )] if self.cutout > random.random() else []))

        print('train datas: {}'.format(cfg.DAG.TRAIN.WHICH_USE))
        self.train_datas = []
        start = 0
        self.num = 0
        for data_name in cfg.DAG.TRAIN.WHICH_USE:
            dataset = subData(cfg, data_name, start)
            self.train_datas.append(dataset)
            start += dataset.num
            self.num += dataset.num_use

        self._shuffle()
        print(cfg)
Пример #3
0
def main2(args):
    best_prec1 = 0.0

    torch.backends.cudnn.deterministic = not args.cudaNoise

    torch.manual_seed(time.time())

    if args.init != "None":
        args.name = "lrnet_%s" % args.init

    if args.tensorboard:
        configure(f"runs/{args.name}")

    dstype = nondigits(args.dataset)
    if dstype == "cifar":
        means = [125.3, 123.0, 113.9]
        stds = [63.0, 62.1, 66.7]
    elif dstype == "imgnet":
        means = [123.3, 118.1, 108.0]
        stds = [54.1, 52.6, 53.2]

    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in means],
        std=[x / 255.0 for x in stds],
    )

    writer = SummaryWriter(log_dir="runs/%s" % args.name, comment=str(args))
    args.classes = onlydigits(args.dataset)

    if args.augment:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                              mode="reflect").squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose(
            [transforms.ToTensor(), normalize])

    if args.cutout:
        transform_train.transforms.append(
            Cutout(n_holes=args.n_holes, length=args.length))

    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {"num_workers": 1, "pin_memory": True}

    assert dstype in ["cifar", "cinic", "imgnet"]

    if dstype == "cifar":
        train_loader = torch.utils.data.DataLoader(
            datasets.__dict__[args.dataset.upper()]("../data",
                                                    train=True,
                                                    download=True,
                                                    transform=transform_train),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
        val_loader = torch.utils.data.DataLoader(
            datasets.__dict__[args.dataset.upper()]("../data",
                                                    train=False,
                                                    transform=transform_test),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
    elif dstype == "cinic":
        cinic_directory = "%s/cinic10" % args.dir
        cinic_mean = [0.47889522, 0.47227842, 0.43047404]
        cinic_std = [0.24205776, 0.23828046, 0.25874835]
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(cinic_directory + '/train',
                                             transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                                 transforms.Normalize(
                                                     mean=cinic_mean,
                                                     std=cinic_std)
                                             ])),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
        print("Using CINIC10 dataset")
        val_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(cinic_directory + '/valid',
                                             transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                                 transforms.Normalize(
                                                     mean=cinic_mean,
                                                     std=cinic_std)
                                             ])),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
    elif dstype == "imgnet":
        print("Using converted imagenet")
        train_loader = torch.utils.data.DataLoader(
            IMGNET("%s" % args.dir,
                   train=True,
                   transform=transform_train,
                   target_transform=None,
                   classes=args.classes),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
        val_loader = torch.utils.data.DataLoader(
            IMGNET("%s" % args.dir,
                   train=False,
                   transform=transform_test,
                   target_transform=None,
                   classes=args.classes),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
    else:
        print("Error matching dataset %s" % dstype)

    ##print("main bn:")
    ##print(args.batchnorm)
    ##print("main fixup:")
    ##print(args.fixup)

    if args.prune:
        pruner_state = getPruneMask(args)
        if pruner_state is None:
            print("Failed to prune network, aborting")
            return None

    if args.arch.lower() == "constnet":
        model = WideResNet(
            args.layers,
            args.classes,
            args.widen_factor,
            droprate=args.droprate,
            use_bn=args.batchnorm,
            use_fixup=args.fixup,
            varnet=args.varnet,
            noise=args.noise,
            lrelu=args.lrelu,
            sigmaW=args.sigmaW,
            init=args.init,
            dropl1=args.dropl1,
        )
    elif args.arch.lower() == "leakynet":
        model = LRNet(
            args.layers,
            args.classes,
            args.widen_factor,
            droprate=args.droprate,
            use_bn=args.batchnorm,
            use_fixup=args.fixup,
            varnet=args.varnet,
            noise=args.noise,
            lrelu=args.lrelu,
            sigmaW=args.sigmaW,
            init=args.init,
        )
    else:
        print("arch %s is not supported" % args.arch)
        return None

    ##draw(args,model)  complex installation

    param_num = sum([p.data.nelement() for p in model.parameters()])

    print(f"Number of model parameters: {param_num}")

    if torch.cuda.device_count() > 1:

        start = int(args.device[0])
        end = int(args.device[2]) + 1
        torch.cuda.set_device(start)
        dev_list = []
        for i in range(start, end):
            dev_list.append("cuda:%d" % i)
        model = torch.nn.DataParallel(model, device_ids=dev_list)

    model = model.cuda()

    if args.freeze > 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['scale'], name.split('.')):
                cnt = cnt + 1
                if cnt == args.freeze:
                    break

            if cnt >= args.freeze_start:
                ##                if intersection(['conv','conv1'],name.split('.')):
                ##                    print("Freezing Block: %s" % name.split('.')[1:3]  )
                if not intersection(['conv_res', 'fc'], name.split('.')):
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)

    elif args.freeze < 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['scale'], name.split('.')):
                cnt = cnt + 1

            if cnt > args.layers - 3 + args.freeze - 1:
                ##                if intersection(['conv','conv1'],name.split('.')):
                ##                    print("Freezing Block: %s" % name  )

                if not intersection(['conv_res', 'fc'], name.split('.')):
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)

    if args.res_freeze > 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['conv_res'], name.split('.')):
                cnt = cnt + 1
                if cnt > args.res_freeze_start:
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)
                if cnt >= args.res_freeze:
                    break
    elif args.res_freeze < 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['conv_res'], name.split('.')):
                cnt = cnt + 1
                if cnt > 3 + args.res_freeze:
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)

    if args.prune:
        if args.prune_epoch >= 100:
            weightsFile = "runs/%s-net/checkpoint.pth.tar" % args.prune
        else:
            weightsFile = "runs/%s-net/model_epoch_%d.pth.tar" % (
                args.prune, args.prune_epoch)

        if os.path.isfile(weightsFile):
            print(f"=> loading checkpoint {weightsFile}")
            checkpoint = torch.load(weightsFile)
            model.load_state_dict(checkpoint["state_dict"])
            print(
                f"=> loaded checkpoint '{weightsFile}' (epoch {checkpoint['epoch']})"
            )
        else:
            if args.prune_epoch == 0:
                print(f"=> No source data, Restarting network from scratch")
            else:
                print(f"=> no checkpoint found at {weightsFile}, aborting...")
                return None

    else:
        if args.resume:
            tarfile = "runs/%s-net/checkpoint.pth.tar" % args.resume
            if os.path.isfile(tarfile):
                print(f"=> loading checkpoint {args.resume}")
                checkpoint = torch.load(tarfile)
                args.start_epoch = checkpoint["epoch"]
                best_prec1 = checkpoint["best_prec1"]
                model.load_state_dict(checkpoint["state_dict"])
                print(
                    f"=> loaded checkpoint '{tarfile}' (epoch {checkpoint['epoch']})"
                )
            else:
                print(f"=> no checkpoint found at {tarfile}, aborting...")
                return None

    cudnn.benchmark = True
    criterion = nn.CrossEntropyLoss().cuda()

    if args.optimizer.lower() == 'sgd':
        optimizer = torch.optim.SGD(
            model.parameters(),
            args.lr,
            momentum=args.momentum,
            nesterov=args.nesterov,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer.lower() == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay)

    if args.prune and pruner_state is not None:
        cutoff_retrain = prunhild.cutoff.LocalRatioCutoff(args.cutoff)
        params_retrain = get_params_for_pruning(args, model)
        pruner_retrain = prunhild.pruner.CutoffPruner(params_retrain,
                                                      cutoff_retrain)
        pruner_retrain.load_state_dict(pruner_state)
        pruner_retrain.prune(update_state=False)
        pruned_weights_count = count_pruned_weights(params_retrain,
                                                    args.cutoff)
        params_left = param_num - pruned_weights_count
        print("Pruned %d weights, New model size:  %d/%d (%d%%)" %
              (pruned_weights_count, params_left, param_num,
               int(100 * params_left / param_num)))

    else:
        pruner_retrain = None

    if args.eval:
        best_prec1 = validate(args, val_loader, model, criterion, 0, None)
    else:

        if args.varnet:
            save_checkpoint(
                args,
                {
                    "epoch": 0,
                    "state_dict": model.state_dict(),
                    "best_prec1": 0.0,
                },
                True,
            )
            best_prec1 = 0.0

        turns_above_50 = 0

        for epoch in range(args.start_epoch, args.epochs):
            adjust_learning_rate(args, optimizer, epoch + 1)
            train(args, train_loader, model, criterion, optimizer, epoch,
                  pruner_retrain, writer)

            prec1 = validate(args, val_loader, model, criterion, epoch, writer)
            correlation.measure_correlation(model, epoch, writer=writer)

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            if args.savenet:
                save_checkpoint(
                    args,
                    {
                        "epoch": epoch + 1,
                        "state_dict": model.state_dict(),
                        "best_prec1": best_prec1,
                    },
                    is_best,
                )
            if args.symmetry_break:
                if prec1 > 50.0:
                    turns_above_50 += 1
                    if turns_above_50 > 3:
                        return epoch

    writer.close()

    print("Best accuracy: ", best_prec1)
    return best_prec1
Пример #4
0
def load_datasets():
    """Create data loaders for the CIFAR-10 dataset.

    Returns: Dict containing data loaders.
    """
    normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize])

    if args.cutout > 0:
        train_transform.transforms.append(Cutout(length=args.cutout))

    valid_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize])

    train_dataset = datasets.CIFAR10(root=args.data_path,
                                     train=True,
                                     transform=train_transform,
                                     download=True)

    valid_dataset = datasets.CIFAR10(root=args.data_path,
                                     train=True,
                                     transform=valid_transform,
                                     download=True)

    test_dataset = datasets.CIFAR10(root=args.data_path,
                                    train=False,
                                    transform=test_transform,
                                    download=True)

    train_indices = list(range(0, 45000))
    valid_indices = list(range(45000, 50000))
    train_subset = Subset(train_dataset, train_indices)
    valid_subset = Subset(valid_dataset, valid_indices)

    data_loaders = {}
    data_loaders['train_subset'] = torch.utils.data.DataLoader(dataset=train_subset,
                                                               batch_size=args.batch_size,
                                                               shuffle=True,
                                                               pin_memory=True,
                                                               num_workers=2)

    data_loaders['valid_subset'] = torch.utils.data.DataLoader(dataset=valid_subset,
                                                               batch_size=args.batch_size,
                                                               shuffle=True,
                                                               pin_memory=True,
                                                               num_workers=2,
                                                               drop_last=True)

    data_loaders['train_dataset'] = torch.utils.data.DataLoader(dataset=train_dataset,
                                                                batch_size=args.batch_size,
                                                                shuffle=True,
                                                                pin_memory=True,
                                                                num_workers=2)

    data_loaders['test_dataset'] = torch.utils.data.DataLoader(dataset=test_dataset,
                                                               batch_size=args.batch_size,
                                                               shuffle=False,
                                                               pin_memory=True,
                                                               num_workers=2)

    return data_loaders
def main():

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    #data_folder = os.path.join(args.data_folder, 'train')
    data_folder = '/home/C2L/CXR/'

    image_size = 224
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.aug == 'NULL':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'CJ':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomRotation(10),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        raise NotImplemented('augmentation not supported: {}'.format(args.aug))

    train_transform.transforms.append(Cutout(n_holes=3, length=32))
    train_dataset = ImageFolderInstance(data_folder,
                                        transform=train_transform,
                                        two_crop=args.c2l)
    print(len(train_dataset))
    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # create model and optimizer
    n_data = len(train_dataset)

    if args.model == 'resnet50':
        model = InsResNet50()
        if args.c2l:
            model_ema = InsResNet50()
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        if args.c2l:
            model_ema = InsResNet50(width=2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        if args.c2l:
            model_ema = InsResNet50(width=4)
    elif args.model == 'resnet18':
        model = InsResNet18(width=1)
        if args.c2l:
            model_ema = InsResNet18(width=1)
    elif args.model == 'densenet121':
        model = DenseNet121(isTrained=False)
        if args.c2l:
            model_ema = DenseNet121(isTrained=False)
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    # copy weights from `model' to `model_ema'
    if args.c2l:
        moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    if args.c2l:
        contrast = MemoryC2L(128, n_data, args.nce_k, args.nce_t,
                             args.softmax).cuda(args.gpu)
    else:
        contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t,
                                args.nce_m, args.softmax).cuda(args.gpu)

    criterion = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data)
    criterion = criterion.cuda(args.gpu)

    model = model.cuda()
    if args.c2l:
        model_ema = model_ema.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    if args.amp:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        if args.c2l:
            optimizer_ema = torch.optim.SGD(model_ema.parameters(),
                                            lr=0,
                                            momentum=0,
                                            weight_decay=0)
            model_ema, optimizer_ema = amp.initialize(model_ema,
                                                      optimizer_ema,
                                                      opt_level=args.opt_level)

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            contrast.load_state_dict(checkpoint['contrast'])
            if args.c2l:
                model_ema.load_state_dict(checkpoint['model_ema'])

            if args.amp and checkpoint['opt'].amp:
                print('==> resuming amp state_dict')
                amp.load_state_dict(checkpoint['amp'])

            print("=> loaded successfully '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    #logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        if args.c2l:
            loss, prob = train_C2L(epoch, train_loader, model, model_ema,
                                   contrast, criterion, optimizer, args)
        else:
            loss, prob = train_ins(epoch, train_loader, model, contrast,
                                   criterion, optimizer, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        #logger.log_value('ins_loss', loss, epoch)
        #logger.log_value('ins_prob', prob, epoch)
        #logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            if args.c2l:
                state['model_ema'] = model_ema.state_dict()
            if args.amp:
                state['amp'] = amp.state_dict()
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        # saving the model
        print('==> Saving...')
        state = {
            'opt': args,
            'model': model.state_dict(),
            'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        if args.c2l:
            state['model_ema'] = model_ema.state_dict()
        if args.amp:
            state['amp'] = amp.state_dict()
        save_file = os.path.join(args.model_folder, 'current.pth')
        torch.save(state, save_file)
        if epoch % args.save_freq == 0:
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
        # help release GPU memory
        del state
        torch.cuda.empty_cache()
Пример #6
0
def init(batch_size,
         state,
         mean,
         std,
         input_sizes,
         base,
         num_workers,
         train_set,
         val_set,
         rand_augment=True,
         n=1,
         m=1,
         dataset='cifar10'):
    # # Original transforms
    # train_transforms = Compose([
    #     Pad(padding=4, padding_mode='reflect'),
    #     RandomHorizontalFlip(p=0.5),
    #     RandomCrop(size=input_sizes),
    #     ToTensor(),
    #     Normalize(mean=mean, std=std)
    # ])

    if rand_augment:
        # RandAugment + Cutout
        train_transforms = Compose([
            RandomCrop(size=input_sizes, padding=4, fill=128),
            RandomHorizontalFlip(p=0.5),
            RandomRandAugment(n=n, m_max=m),  # This version includes cutout
            ToTensor(),
            Normalize(mean=mean, std=std)
        ])
        test_transforms = Compose([ToTensor(), Normalize(mean=mean, std=std)])
    else:
        # AutoAugment + Cutout
        train_transforms = Compose([
            RandomCrop(size=input_sizes, padding=4, fill=128),
            RandomHorizontalFlip(p=0.5),
            CIFAR10Policy(),
            ToTensor(),
            Normalize(mean=mean, std=std),
            Cutout(n_holes=1, length=16)
        ])
        test_transforms = Compose([ToTensor(), Normalize(mean=mean, std=std)])

    # Data sets
    if dataset == 'cifar10':
        if state == 1:
            train_set = CIFAR10(root=base,
                                set_name=train_set,
                                transform=train_transforms,
                                label=True)
        test_set = CIFAR10(root=base,
                           set_name=val_set,
                           transform=test_transforms,
                           label=True)
    else:
        raise NotImplementedError

    # Data loaders
    if state == 1:
        train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                   batch_size=batch_size,
                                                   num_workers=num_workers,
                                                   shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                              batch_size=batch_size,
                                              num_workers=num_workers * 2,
                                              shuffle=False)
    if state == 1:
        return train_loader, test_loader
    else:
        return test_loader
Пример #7
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.tensorboard:
        configure(f"runs/{args.name}")

    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
    )

    if args.augment:
        transform_train = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Lambda(
                    lambda x: F.pad(
                        x.unsqueeze(0), (4, 4, 4, 4), mode="reflect"
                    ).squeeze()
                ),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        )
    else:
        transform_train = transforms.Compose([transforms.ToTensor(), normalize])

    if args.cutout:
        transform_train.transforms.append(
            Cutout(n_holes=args.n_holes, length=args.length)
        )

    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {"num_workers": 1, "pin_memory": True}
    assert args.dataset == "cifar10" or args.dataset == "cifar100"

    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()](
            "../data", train=True, download=True, transform=transform_train
        ),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()](
            "../data", train=False, transform=transform_test
        ),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )

    model = WideResNet(
        args.layers,
        args.dataset == "cifar10" and 10 or 100,
        args.widen_factor,
        droprate=args.droprate,
        use_bn=args.batchnorm,
        use_fixup=args.fixup,
    )

    param_num = sum([p.data.nelement() for p in model.parameters()])
    print(f"Number of model parameters: {param_num}")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model = model.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(f"=> loading checkpoint {args.resume}")
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            best_prec1 = checkpoint["best_prec1"]
            model.load_state_dict(checkpoint["state_dict"])
            print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})")
        else:
            print(f"=> no checkpoint found at {args.resume}")

    cudnn.benchmark = True
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr,
        momentum=args.momentum,
        nesterov=args.nesterov,
        weight_decay=args.weight_decay,
    )

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch + 1)
        train(train_loader, model, criterion, optimizer, epoch)

        prec1 = validate(val_loader, model, criterion, epoch)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "best_prec1": best_prec1,
            },
            is_best,
        )

    print("Best accuracy: ", best_prec1)
Пример #8
0
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean_vec, std_vec)
    ]),
    'val':
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean_vec, std_vec)
    ]),
}

if args.cutout:
    data_transforms['train'].transforms.append(
        Cutout(n_holes=1, length=args.cutout_size))

train_dir = os.path.join(DATA_DIR, 'train')
val_dir = os.path.join(DATA_DIR, 'val')

train_dataset = datasets.ImageFolder(train_dir, data_transforms['train'])
val_dataset = datasets.ImageFolder(val_dir, data_transforms['val'])

print('train_dataset.size', len(train_dataset.samples))
print('val_dataset.size', len(val_dataset.samples))
image_datasets = {'train': train_dataset, 'val': val_dataset}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print('class_names:', class_names)

train_loader = torch.utils.data.DataLoader(train_dataset,
Пример #9
0
def main(args):

    global best_prec1

    # CIFAR-10 Training & Test Transformation
    print(
        '. . . . . . . . . . . . . . . .PREPROCESSING DATA . . . . . . . . . . . . . . . .'
    )
    TRAIN_transform = transforms.Compose([
        transforms.Pad(4),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.cutout:
        TRAIN_transform.transforms.append(
            Cutout(n_masks=args.n_masks, length=args.length))

    VAL_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    # CIFAR-10 dataset
    train_dataset = torchvision.datasets.CIFAR10(root='../data/',
                                                 train=True,
                                                 transform=TRAIN_transform,
                                                 download=True)
    val_dataset = torchvision.datasets.CIFAR10(root='../data/',
                                               train=False,
                                               transform=VAL_transform,
                                               download=True)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               pin_memory=True,
                                               drop_last=True,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             pin_memory=True,
                                             batch_size=args.batch_size,
                                             shuffle=False)

    # Device Config
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if args.normalize == 'groupnorm':
        model = SEresnet_gn()

    elif args.normalize == 'groupnorm+ws':
        model = SEresnet_gn_ws()
    else:
        model = SEresnet()

    model = model.to(device)

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          momentum=args.momentum)
    lr_schedule = lr_scheduler.MultiStepLR(optimizer,
                                           milestones=[250, 375],
                                           gamma=0.1)

    if args.evaluate:
        model.load_state_dict(torch.load('./save_model/model.th'))
        model.to(device)
        validation(args, val_loader, model, criterion)

    #  Epoch = args.Epoch
    for epoch_ in range(0, args.Epoch):
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train_one_epoch(args, train_loader, model, criterion, optimizer,
                        epoch_)
        lr_schedule.step()

        prec1 = validation(args, val_loader, model, criterion)

        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch_ > 0 and epoch_ % args.save_every == 0:
            save_checkpoint(
                {
                    'epoch': epoch_ + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=os.path.join(args.save_dir, 'checkpoint.pt'))

        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=os.path.join(args.save_dir, 'model.pt'))

        print('THE BEST MODEL prec@1 : {best_prec1:.3f} saved. '.format(
            best_prec1=best_prec1))