示例#1
0
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()

    if args.dataset.startswith('cifar'):
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        if args.dataset == 'cifar100':
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR100('../data',
                                  train=False,
                                  transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10('../data',
                                 train=False,
                                 transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            numberofclass = 10
        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))

    elif args.dataset == 'imagenet':
        valdir = os.path.join(
            '/data_large/readonly/ImageNet-Fast/imagenet/val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        val_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(), normalize
        ])
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir, val_transform),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        numberofclass = 1000

    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass,
                          args.bottleneck)  # for ResNet
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha,
                                numberofclass, args.bottleneck)
    else:
        raise Exception('unknown network architecture: {}'.format(
            args.net_type))

    model = torch.nn.DataParallel(model).cuda()

    if os.path.isfile(args.pretrained):
        print("=> loading checkpoint '{}'".format(args.pretrained))
        checkpoint = torch.load(args.pretrained)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}'(best err1: {}%)".format(
            args.pretrained, checkpoint['best_err1']))
    else:
        raise Exception("=> no checkpoint found at '{}'".format(
            args.pretrained))

    print('the number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    cudnn.benchmark = True

    # evaluate on validation set
    err1, err5, val_loss = validate(val_loader, model, criterion)
    print('Accuracy (top-1 and 5 error):', err1, err5)
示例#2
0
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()

    if args.seed >= 0:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True

    # Save path
    args.expname += args.method
    if args.transport:
        args.expname += '_tp'
    args.expname += '_prob_' + str(args.mixup_prob)
    if args.clean_lam > 0:
        args.expname += '_clean_' + str(args.clean_lam)
    if args.seed >= 0:
        args.expname += '_seed' + str(args.seed)
    print("Model is saved at {}".format(args.expname))

    # Dataset and loader
    if args.dataset.startswith('cifar'):
        mean = [x / 255.0 for x in [125.3, 123.0, 113.9]]
        std = [x / 255.0 for x in [63.0, 62.1, 66.7]]
        normalize = transforms.Normalize(mean=mean, std=std)

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=args.padding),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        if args.dataset == 'cifar100':
            train_loader = torch.utils.data.DataLoader(datasets.CIFAR100('~/Datasets/cifar100/',
                                                                         train=True,
                                                                         download=True,
                                                                         transform=transform_train),
                                                       batch_size=args.batch_size,
                                                       shuffle=True,
                                                       num_workers=args.workers,
                                                       pin_memory=True)
            val_loader = torch.utils.data.DataLoader(datasets.CIFAR100('~/Datasets/cifar100/',
                                                                       train=False,
                                                                       transform=transform_test),
                                                     batch_size=args.batch_size // 4,
                                                     shuffle=True,
                                                     num_workers=args.workers,
                                                     pin_memory=True)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('../data',
                                                                        train=True,
                                                                        download=True,
                                                                        transform=transform_train),
                                                       batch_size=args.batch_size,
                                                       shuffle=True,
                                                       num_workers=args.workers,
                                                       pin_memory=True)
            val_loader = torch.utils.data.DataLoader(datasets.CIFAR10('../data',
                                                                      train=False,
                                                                      transform=transform_test),
                                                     batch_size=args.batch_size,
                                                     shuffle=True,
                                                     num_workers=args.workers,
                                                     pin_memory=True)
            numberofclass = 10
        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))

    elif args.dataset == 'imagenet':
        traindir = os.path.join('/data/readonly/ImageNet-Fast/imagenet/train')
        valdir = os.path.join('/data/readonly/ImageNet-Fast/imagenet/val')
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        normalize = transforms.Normalize(mean=mean, std=std)
        jittering = utils.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4)
        lighting = utils.Lighting(alphastd=0.1,
                                  eigval=[0.2175, 0.0188, 0.0045],
                                  eigvec=[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140],
                                          [-0.5836, -0.6948, 0.4203]])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                jittering,
                lighting,
                normalize,
            ]))
        train_sampler = None

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=(train_sampler is None),
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   sampler=train_sampler)
        val_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(valdir, val_transform),
                                                 batch_size=args.batch_size // 4,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        numberofclass = 1000
        args.neigh_size = min(args.neigh_size, 2)

    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    # Model
    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass, args.bottleneck)  # for ResNet
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass,
                                args.bottleneck)
    else:
        raise Exception('unknown network architecture: {}'.format(args.net_type))

    pretrained = "runs/{}/{}".format(args.expname, 'checkpoint.pth.tar')
    if os.path.isfile(pretrained):
        print("=> loading checkpoint '{}'".format(pretrained))
        checkpoint = torch.load(pretrained)
        checkpoint['state_dict'] = dict(
            (key[7:], value) for (key, value) in checkpoint['state_dict'].items())
        model.load_state_dict(checkpoint['state_dict'])
        cur_epoch = checkpoint['epoch'] + 1
        best_err1 = checkpoint['best_err1']
        print("=> loaded checkpoint '{}'(epoch: {}, best err1: {}%)".format(
            pretrained, cur_epoch, checkpoint['best_err1']))
    else:
        cur_epoch = 0
        print("=> no checkpoint found at '{}'".format(pretrained))

    model = torch.nn.DataParallel(model).cuda()
    print('the number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    criterion_batch = nn.CrossEntropyLoss(reduction='none').cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    if os.path.isfile(pretrained):
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("optimizer is loaded!")

    mean_torch = torch.tensor(mean, dtype=torch.float32).reshape(1, 3, 1, 1).cuda()
    std_torch = torch.tensor(std, dtype=torch.float32).reshape(1, 3, 1, 1).cuda()
    if args.mp > 0:
        mp = Pool(args.mp)
    else:
        mp = None

    # Start training and validation
    for epoch in range(cur_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_loss = train(train_loader, model, criterion, criterion_batch, optimizer, epoch,
                           mean_torch, std_torch, mp)
        # evaluate on validation set
        err1, err5, val_loss = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5

        print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5)
        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.net_type,
                'state_dict': model.state_dict(),
                'best_err1': best_err1,
                'best_err5': best_err5,
                'optimizer': optimizer.state_dict(),
            }, is_best)

    print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
    def __init__(self):
        super(Solver, self).__init__()
        global numberofclass

        #define the network
        if args.net_type == 'resnet':
            self.model = RN.ResNet(dataset=args.dataset,
                                   depth=args.depth,
                                   num_classes=numberofclass,
                                   bottleneck=args.bottleneck)

        elif args.net_type == 'pyramidnet':
            self.model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha,
                                         numberofclass, args.bottleneck)

        elif args.net_type == 'wideresnet':
            self.model = WR.WideResNet(depth=args.depth,
                                       num_classes=numberofclass,
                                       widen_factor=args.width)

        elif args.net_type == 'vggnet':
            self.model = VGG.vgg16(num_classes=numberofclass)

        elif args.net_type == 'mobilenet':
            self.model = MN.mobile_half(num_classes=numberofclass)

        elif args.net_type == 'shufflenet':
            self.model = SN.ShuffleV2(num_classes=numberofclass)

        elif args.net_type == 'densenet':
            self.model = DN.densenet_cifar(num_classes=numberofclass)

        elif args.net_type == 'resnext-2':
            self.model = ResNeXt29_2x64d(num_classes=numberofclass)
        elif args.net_type == 'resnext-4':
            self.model = ResNeXt29_4x64d(num_classes=numberofclass)
        elif args.net_type == 'resnext-32':
            self.model = ResNeXt29_32x4d(num_classes=numberofclass)

        elif args.net_type == 'imagenetresnet18':
            self.model = multi_resnet18_kd(num_classes=numberofclass)
        elif args.net_type == 'imagenetresnet34':
            self.model = multi_resnet34_kd(num_classes=numberofclass)
        elif args.net_type == 'imagenetresnet50':
            self.model = multi_resnet50_kd(num_classes=numberofclass)
        elif args.net_type == 'imagenetresnet101':
            self.model = multi_resnet101_kd(num_classes=numberofclass)
        elif args.net_type == 'imagenetresnet152':
            self.model = multi_resnet152_kd(num_classes=numberofclass)
        else:
            raise Exception('unknown network architecture: {}'.format(
                args.net_type))

        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay,
                                         nesterov=True)
        self.loss_lams = torch.zeros(numberofclass,
                                     numberofclass,
                                     dtype=torch.float32).cuda()
        self.loss_lams.requires_grad = False
        #define the loss function
        if args.method == 'ce':
            self.criterion = nn.CrossEntropyLoss()
        elif args.method == 'sce':
            if args.dataset == 'cifar10':
                self.criterion = SCELoss(alpha=0.1,
                                         beta=1.0,
                                         num_classes=numberofclass)
            else:
                self.criterion = SCELoss(alpha=6.0,
                                         beta=0.1,
                                         num_classes=numberofclass)
        elif args.method == 'ls':
            self.criterion = label_smooth(num_classes=numberofclass)
        elif args.method == 'gce':
            self.criterion = generalized_cross_entropy(
                num_classes=numberofclass)
        elif args.method == 'jo':
            self.criterion = joint_optimization(num_classes=numberofclass)
        elif args.method == 'bootsoft':
            self.criterion = boot_soft(num_classes=numberofclass)
        elif args.method == 'boothard':
            self.criterion = boot_hard(num_classes=numberofclass)
        elif args.method == 'forward':
            self.criterion = Forward(num_classes=numberofclass)
        elif args.method == 'backward':
            self.criterion = Backward(num_classes=numberofclass)
        elif args.method == 'disturb':
            self.criterion = DisturbLabel(num_classes=numberofclass)
        elif args.method == 'ols':
            self.criterion = nn.CrossEntropyLoss()
        self.criterion = self.criterion.cuda()