コード例 #1
0
ファイル: train_moex.py プロジェクト: zhang405744522/MoEx
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()

    if args.dataset.startswith('cifar'):
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])

        if args.dataset == 'cifar100':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data', train=True, download=True, transform=transform_train),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data', train=False, transform=transform_test),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data', train=True, download=True, transform=transform_train),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data', train=False, transform=transform_test),
                batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
            numberofclass = 10
        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))

    elif args.dataset == 'imagenet':
        traindir = os.path.join('/scratch/imagenet/train')
        valdir = os.path.join('/scratch/imagenet/val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        jittering = utils.ColorJitter(brightness=0.4, contrast=0.4,
                                      saturation=0.4)
        lighting = utils.Lighting(alphastd=0.1,
                                  eigval=[0.2175, 0.0188, 0.0045],
                                  eigvec=[[-0.5675, 0.7192, 0.4009],
                                          [-0.5808, -0.0045, -0.8140],
                                          [-0.5836, -0.6948, 0.4203]])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                jittering,
                lighting,
                normalize,
            ]))

        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(valdir, transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
        numberofclass = 1000

    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'pyramidnet_moex':
        model = PYRM_MOEX.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass,
                                args.bottleneck)
    else:
        raise Exception('unknown network architecture: {}'.format(args.net_type))

    model = torch.nn.DataParallel(model).cuda()

    print(model)
    print('the number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay, nesterov=True)


    cudnn.benchmark = True


    for epoch in range(0, args.epochs):

        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        err1, err5, val_loss = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5

        print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5)
        save_checkpoint({
            'epoch': epoch,
            'arch': args.net_type,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
            'best_err5': best_err5,
            'optimizer': optimizer.state_dict(),
        }, is_best)

    f = open('train_moex.txt', 'a+')
    f.write('lam = ' + str(args.lam) + ': Best accuracy (top-1 and 5 error):' + str(best_err1) + ', ' + str(best_err5))
    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
    f.close()
コード例 #2
0
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()

    if args.dataset.startswith('cifar'):
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        autoaug = args.autoaug
        if autoaug:
            print('augmentation: %s' % autoaug)
            if autoaug == 'fa_reduced_cifar10':
                transform_train.transforms.insert(
                    0, Augmentation(fa_reduced_cifar10()))
            elif autoaug == 'fa_reduced_imagenet':
                transform_train.transforms.insert(
                    0, Augmentation(fa_reduced_imagenet()))
            elif autoaug == 'autoaug_cifar10':
                transform_train.transforms.insert(
                    0, Augmentation(autoaug_paper_cifar10()))
            elif autoaug == 'autoaug_extend':
                transform_train.transforms.insert(
                    0, Augmentation(autoaug_policy()))
            elif autoaug in ['default', 'inception', 'inception320']:
                pass
            else:
                raise ValueError('not found augmentations. %s' %
                                 C.get()['aug'])

        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        if args.dataset == 'cifar100':
            ds_train = datasets.CIFAR100(args.cifarpath,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
            if args.cv >= 0:
                sss = StratifiedShuffleSplit(n_splits=5,
                                             test_size=0.2,
                                             random_state=0)
                sss = sss.split(list(range(len(ds_train))), ds_train.targets)
                for _ in range(args.cv + 1):
                    train_idx, valid_idx = next(sss)
                ds_valid = Subset(ds_train, valid_idx)
                ds_train = Subset(ds_train, train_idx)
            else:
                ds_valid = Subset(ds_train, [])
            ds_test = datasets.CIFAR100(args.cifarpath,
                                        train=False,
                                        transform=transform_test)

            train_loader = torch.utils.data.DataLoader(
                CutMix(ds_train,
                       100,
                       beta=args.cutmix_beta,
                       prob=args.cutmix_prob,
                       num_mix=args.cutmix_num),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            tval_loader = torch.utils.data.DataLoader(
                ds_valid,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                ds_test,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            ds_train = datasets.CIFAR10(args.cifarpath,
                                        train=True,
                                        download=True,
                                        transform=transform_train)
            if args.cv >= 0:
                sss = StratifiedShuffleSplit(n_splits=5,
                                             test_size=0.2,
                                             random_state=0)
                sss = sss.split(list(range(len(ds_train))), ds_train.targets)
                for _ in range(args.cv + 1):
                    train_idx, valid_idx = next(sss)
                ds_valid = Subset(ds_train, valid_idx)
                ds_train = Subset(ds_train, train_idx)
            else:
                ds_valid = Subset(ds_train, [])

            train_loader = torch.utils.data.DataLoader(
                CutMix(ds_train,
                       10,
                       beta=args.cutmix_beta,
                       prob=args.cutmix_prob,
                       num_mix=args.cutmix_num),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            tval_loader = torch.utils.data.DataLoader(
                ds_valid,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10(args.cifarpath,
                                 train=False,
                                 transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            numberofclass = 10
        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))

    elif args.dataset == 'imagenet':
        traindir = os.path.join(args.imagenetpath, 'train')
        valdir = os.path.join(args.imagenetpath, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        jittering = utils.ColorJitter(brightness=0.4,
                                      contrast=0.4,
                                      saturation=0.4)
        lighting = utils.Lighting(alphastd=0.1,
                                  eigval=[0.2175, 0.0188, 0.0045],
                                  eigvec=[[-0.5675, 0.7192, 0.4009],
                                          [-0.5808, -0.0045, -0.8140],
                                          [-0.5836, -0.6948, 0.4203]])

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            jittering,
            lighting,
            normalize,
        ])

        autoaug = args.autoaug
        if autoaug:
            print('augmentation: %s' % autoaug)
            if autoaug == 'fa_reduced_cifar10':
                transform_train.transforms.insert(
                    0, Augmentation(fa_reduced_cifar10()))
            elif autoaug == 'fa_reduced_imagenet':
                transform_train.transforms.insert(
                    0, Augmentation(fa_reduced_imagenet()))

            elif autoaug == 'autoaug_cifar10':
                transform_train.transforms.insert(
                    0, Augmentation(autoaug_paper_cifar10()))
            elif autoaug == 'autoaug_extend':
                transform_train.transforms.insert(
                    0, Augmentation(autoaug_policy()))
            elif autoaug in ['default', 'inception', 'inception320']:
                pass
            else:
                raise ValueError('not found augmentations. %s' %
                                 C.get()['aug'])

        train_dataset = datasets.ImageFolder(traindir, transform_train)
        if args.cv >= 0:
            sss = StratifiedShuffleSplit(n_splits=5,
                                         test_size=0.2,
                                         random_state=0)
            sss = sss.split(list(range(len(train_dataset))),
                            train_dataset.targets)
            for _ in range(args.cv + 1):
                train_idx, valid_idx = next(sss)
            valid_dataset = Subset(train_dataset, valid_idx)
            train_dataset = Subset(train_dataset, train_idx)
        else:
            valid_dataset = Subset(train_dataset, [])

        train_dataset = CutMix(train_dataset,
                               1000,
                               beta=args.cutmix_beta,
                               prob=args.cutmix_prob,
                               num_mix=args.cutmix_num)
        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)
        tval_loader = torch.utils.data.DataLoader(valid_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=True)
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        numberofclass = 1000
    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass, True)
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha,
                                numberofclass, True)
    elif 'wresnet' in args.net_type:
        model = WRN(args.depth,
                    args.alpha,
                    dropout_rate=0.0,
                    num_classes=numberofclass)
    else:
        raise ValueError('unknown network architecture: {}'.format(
            args.net_type))

    model = torch.nn.DataParallel(model).cuda()
    print('the number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    criterion = CutMixCrossEntropyLoss(True)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=1e-4,
                                nesterov=True)
    cudnn.benchmark = True

    for epoch in range(0, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        model.train()
        err1, err5, train_loss = run_epoch(train_loader, model, criterion,
                                           optimizer, epoch, 'train')
        train_err1 = err1
        err1, err5, train_loss = run_epoch(tval_loader, model, criterion, None,
                                           epoch, 'train-val')

        # evaluate on validation set
        model.eval()
        err1, err5, val_loss = run_epoch(val_loader, model, criterion, None,
                                         epoch, 'valid')

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5
            print('Current Best (top-1 and 5 error):', best_err1, best_err5)

        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.net_type,
                'state_dict': model.state_dict(),
                'best_err1': best_err1,
                'best_err5': best_err5,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            filename='checkpoint_e%d_top1_%.3f_%.3f.pth' %
            (epoch, train_err1, err1))

    print('Best(top-1 and 5 error):', best_err1, best_err5)
コード例 #3
0
ファイル: train.py プロジェクト: snu-mllab/PuzzleMix
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()

    if args.seed >= 0:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True

    # Save path
    args.expname += args.method
    if args.transport:
        args.expname += '_tp'
    args.expname += '_prob_' + str(args.mixup_prob)
    if args.clean_lam > 0:
        args.expname += '_clean_' + str(args.clean_lam)
    if args.seed >= 0:
        args.expname += '_seed' + str(args.seed)
    print("Model is saved at {}".format(args.expname))

    # Dataset and loader
    if args.dataset.startswith('cifar'):
        mean = [x / 255.0 for x in [125.3, 123.0, 113.9]]
        std = [x / 255.0 for x in [63.0, 62.1, 66.7]]
        normalize = transforms.Normalize(mean=mean, std=std)

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=args.padding),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        if args.dataset == 'cifar100':
            train_loader = torch.utils.data.DataLoader(datasets.CIFAR100('~/Datasets/cifar100/',
                                                                         train=True,
                                                                         download=True,
                                                                         transform=transform_train),
                                                       batch_size=args.batch_size,
                                                       shuffle=True,
                                                       num_workers=args.workers,
                                                       pin_memory=True)
            val_loader = torch.utils.data.DataLoader(datasets.CIFAR100('~/Datasets/cifar100/',
                                                                       train=False,
                                                                       transform=transform_test),
                                                     batch_size=args.batch_size // 4,
                                                     shuffle=True,
                                                     num_workers=args.workers,
                                                     pin_memory=True)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('../data',
                                                                        train=True,
                                                                        download=True,
                                                                        transform=transform_train),
                                                       batch_size=args.batch_size,
                                                       shuffle=True,
                                                       num_workers=args.workers,
                                                       pin_memory=True)
            val_loader = torch.utils.data.DataLoader(datasets.CIFAR10('../data',
                                                                      train=False,
                                                                      transform=transform_test),
                                                     batch_size=args.batch_size,
                                                     shuffle=True,
                                                     num_workers=args.workers,
                                                     pin_memory=True)
            numberofclass = 10
        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))

    elif args.dataset == 'imagenet':
        traindir = os.path.join('/data/readonly/ImageNet-Fast/imagenet/train')
        valdir = os.path.join('/data/readonly/ImageNet-Fast/imagenet/val')
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        normalize = transforms.Normalize(mean=mean, std=std)
        jittering = utils.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4)
        lighting = utils.Lighting(alphastd=0.1,
                                  eigval=[0.2175, 0.0188, 0.0045],
                                  eigvec=[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140],
                                          [-0.5836, -0.6948, 0.4203]])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                jittering,
                lighting,
                normalize,
            ]))
        train_sampler = None

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=(train_sampler is None),
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=train_sampler)
        val_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(valdir, val_transform),
                                                 batch_size=args.batch_size // 4,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        numberofclass = 1000
        args.neigh_size = min(args.neigh_size, 2)

    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    # Model
    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass, args.bottleneck)  # for ResNet
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass,
                                args.bottleneck)
    else:
        raise Exception('unknown network architecture: {}'.format(args.net_type))

    pretrained = "runs/{}/{}".format(args.expname, 'checkpoint.pth.tar')
    if os.path.isfile(pretrained):
        print("=> loading checkpoint '{}'".format(pretrained))
        checkpoint = torch.load(pretrained)
        checkpoint['state_dict'] = dict(
            (key[7:], value) for (key, value) in checkpoint['state_dict'].items())
        model.load_state_dict(checkpoint['state_dict'])
        cur_epoch = checkpoint['epoch'] + 1
        best_err1 = checkpoint['best_err1']
        print("=> loaded checkpoint '{}'(epoch: {}, best err1: {}%)".format(
            pretrained, cur_epoch, checkpoint['best_err1']))
    else:
        cur_epoch = 0
        print("=> no checkpoint found at '{}'".format(pretrained))

    model = torch.nn.DataParallel(model).cuda()
    print('the number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    criterion_batch = nn.CrossEntropyLoss(reduction='none').cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    if os.path.isfile(pretrained):
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("optimizer is loaded!")

    mean_torch = torch.tensor(mean, dtype=torch.float32).reshape(1, 3, 1, 1).cuda()
    std_torch = torch.tensor(std, dtype=torch.float32).reshape(1, 3, 1, 1).cuda()
    if args.mp > 0:
        mp = Pool(args.mp)
    else:
        mp = None

    # Start training and validation
    for epoch in range(cur_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_loss = train(train_loader, model, criterion, criterion_batch, optimizer, epoch,
                           mean_torch, std_torch, mp)
        # evaluate on validation set
        err1, err5, val_loss = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5

        print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5)
        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.net_type,
                'state_dict': model.state_dict(),
                'best_err1': best_err1,
                'best_err5': best_err5,
                'optimizer': optimizer.state_dict(),
            }, is_best)

    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
コード例 #4
0
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()

    if args.dataset == 'imagenet':
        traindir = os.path.join('~/dataset/tiny-imagenet/train')
        valdir = os.path.join('~/dataset/tiny-imagenet/val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        jittering = utils.ColorJitter(brightness=0.4,
                                      contrast=0.4,
                                      saturation=0.4)
        lighting = utils.Lighting(alphastd=0.1,
                                  eigval=[0.2175, 0.0188, 0.0045],
                                  eigvec=[[-0.5675, 0.7192, 0.4009],
                                          [-0.5808, -0.0045, -0.8140],
                                          [-0.5836, -0.6948, 0.4203]])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                jittering,
                lighting,
                normalize,
            ]))

        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        numberofclass = 200
    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass,
                          args.bottleneck)  # for ResNet
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha,
                                numberofclass, args.bottleneck)
    else:
        raise Exception('unknown network architecture: {}'.format(
            args.net_type))

    model = torch.nn.DataParallel(model).cuda()

    print(model)
    print('the number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    cudnn.benchmark = True

    for epoch in range(0, args.epochs):

        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        err1, err5, val_loss = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5

        print('Current best accuracy (top-1 and 5 error):', best_err1,
              best_err5)
        f = open("best_accuracy.txt", "a+")
        f.write('best acc - top1: %.4f, top5: %.4f at iteration: %d\r\n' %
                (best_err1, best_err5, epoch))
        f.close()
        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.net_type,
                'state_dict': model.state_dict(),
                'best_err1': best_err1,
                'best_err5': best_err5,
                'optimizer': optimizer.state_dict(),
            }, is_best)

    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)

    f = open("best_accuracy.txt", "a+")
    f.write('Final best accuracy - top1: %.4f, top5: %.4f\r\n' %
            (best_err1, best_err5))
    f.close()
def main():
    global args, best_err1, best_err5, numberofclass
    args = parser.parse_args()

    assert args.method in ['ce', 'ols', 'sce', 'ls', 'gce', 'jo', 'bootsoft', 'boothard', 'forward', 'backward', 'disturb'], \
        "method must be the one of 'ce', 'sce', 'ls', 'gce', 'jo', 'bootsoft', 'boothard', 'forward', 'backward', 'disturb' "

    args.gpu = 0
    args.world_size = 1

    print(args)
    log_dir = '%s/runs/record_dir/%s/' % (args.save_dir, args.expname)
    writer = SummaryWriter(log_dir=log_dir)

    if args.seed is not None:
        print('set the same seed for all.....')
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    if args.dataset.startswith('cifar'):
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        if args.dataset == 'cifar100':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('./data',
                                  train=True,
                                  download=True,
                                  transform=transform_train),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=False)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('./data',
                                  train=False,
                                  transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=False)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('./data',
                                 train=True,
                                 download=True,
                                 transform=transform_train),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=False)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('./data',
                                 train=False,
                                 transform=transform_test),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=False)
            numberofclass = 10
        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))

    elif args.dataset == 'imagenet':
        traindir = os.path.join('./data/ILSVRC1/train')
        valdir = os.path.join('./data/ILSVRC1/val1')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        jittering = utils.ColorJitter(brightness=0.4,
                                      contrast=0.4,
                                      saturation=0.4)
        lighting = utils.Lighting(alphastd=0.1,
                                  eigval=[0.2175, 0.0188, 0.0045],
                                  eigvec=[[-0.5675, 0.7192, 0.4009],
                                          [-0.5808, -0.0045, -0.8140],
                                          [-0.5836, -0.6948, 0.4203]])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                jittering,
                lighting,
                normalize,
            ]))

        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=False,
            sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=False)
        numberofclass = 1000

    print("=> creating model '{}'".format(args.net_type))
    # define loss function (criterion) and optimizer
    solver = Solver()

    solver.model = solver.model.cuda()
    print('the number of model parameters: {}'.format(
        sum([p.data.nelement() for p in solver.model.parameters()])))
    cudnn.benchmark = True

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_err1 = checkpoint['best_err1']
            solver.model.load_state_dict(checkpoint['state_dict'])
            solver.optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

    for epoch in range(args.start_epoch, args.epochs):
        print('current os time = ',
              time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        adjust_learning_rate(solver.optimizer, epoch)
        # train for one epoch
        train_loss = solver.train(train_loader, epoch)
        # evaluate on validation set
        err1, err5, val_loss = solver.validate(val_loader, epoch)

        writer.add_scalar('training loss', train_loss, epoch)
        writer.add_scalar('testing loss', val_loss, epoch)
        writer.add_scalar('top1 error', err1, epoch)
        writer.add_scalar('top5 error', err5, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5

        print('Current best accuracy (top-1 and 5 error):', best_err1,
              best_err5)
        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.net_type,
                'state_dict': solver.model.state_dict(),
                'best_err1': best_err1,
                'best_err5': best_err5,
                'optimizer': solver.optimizer.state_dict(),
            }, is_best)

    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
    print('method = {}, expname = {}'.format(args.method, args.expname))
    loss_dir = "%s/runs/record_dir/%s/" % (args.save_dir, args.expname)
    writer.export_scalars_to_json(loss_dir + 'loss.json')
    writer.close()
コード例 #6
0
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()

    if args.dataset.startswith('cifar'):
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        if args.dataset == 'cifar100':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data',
                                  train=True,
                                  download=True,
                                  transform=transform_train),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data',
                                  train=False,
                                  transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            train_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data',
                                 train=True,
                                 download=True,
                                 transform=transform_train),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data',
                                 train=False,
                                 transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            numberofclass = 10
        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))

    elif args.dataset == 'stl10':
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        transform_train = transforms.Compose([
            transforms.RandomCrop(96),
            transforms.Resize(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose(
            [transforms.Resize(224),
             transforms.ToTensor(), normalize])
        train_loader = torch.utils.data.DataLoader(
            myDataSet('../data',
                      split='train',
                      download=True,
                      transform=transform_train),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True)
        val_loader = torch.utils.data.DataLoader(myDataSet(
            '../data', split='test', transform=transform_test),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        numberofclass = 10

    elif args.dataset == 'caltech101':
        image_transforms = {
            # Train uses data augmentation]
            'train':
            transforms.Compose([
                #transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
                #transforms.RandomRotation(degrees=15),
                #transforms.ColorJitter(),
                transforms.CenterCrop(size=224),  # Image net standards
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.485, 0.456, 0.406],
                    [0.229, 0.224, 0.225])  # Imagenet standards
            ]),
            # Validation does not use augmentation
            'valid':
            transforms.Compose([
                #transforms.Resize(size=256),
                transforms.CenterCrop(size=224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ]),
        }

        # Dataloader iterators, make sure to shuffle
        data = datasets.ImageFolder(root='101_ObjectCategories',
                                    transform=image_transforms['train'])
        train_set, val_set = torch.utils.data.random_split(data, [7000, 2144])
        train_loader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  shuffle=True),
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=True)
        numberofclass = 102

    elif args.dataset == 'imagenet':
        traindir = os.path.join('/home/data/ILSVRC/train')
        valdir = os.path.join('/home/data/ILSVRC/val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        jittering = utils.ColorJitter(brightness=0.4,
                                      contrast=0.4,
                                      saturation=0.4)
        lighting = utils.Lighting(alphastd=0.1,
                                  eigval=[0.2175, 0.0188, 0.0045],
                                  eigvec=[[-0.5675, 0.7192, 0.4009],
                                          [-0.5808, -0.0045, -0.8140],
                                          [-0.5836, -0.6948, 0.4203]])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                jittering,
                lighting,
                normalize,
            ]))

        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        numberofclass = 1000

    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass,
                          args.bottleneck)  # for ResNet
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha,
                                numberofclass, args.bottleneck)
    else:
        raise Exception('unknown network architecture: {}'.format(
            args.net_type))

    model = torch.nn.DataParallel(model).cuda()

    print(model)
    print('the number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    cudnn.benchmark = True

    for epoch in range(0, args.epochs):

        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        err1, err5, val_loss = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5

        print('Current best accuracy (top-1 and 5 error):', best_err1,
              best_err5)
        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.net_type,
                'state_dict': model.state_dict(),
                'best_err1': best_err1,
                'best_err5': best_err5,
                'optimizer': optimizer.state_dict(),
            }, is_best)

    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)