def main(): global args, best_err1, best_err5 args = parser.parse_args() if args.dataset.startswith('cifar'): normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([transforms.ToTensor(), normalize]) if args.dataset == 'cifar100': val_loader = torch.utils.data.DataLoader( datasets.CIFAR100('../data', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) numberofclass = 100 elif args.dataset == 'cifar10': val_loader = torch.utils.data.DataLoader( datasets.CIFAR10('../data', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) numberofclass = 10 else: raise Exception('unknown dataset: {}'.format(args.dataset)) elif args.dataset == 'imagenet': valdir = os.path.join( '/data_large/readonly/ImageNet-Fast/imagenet/val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) val_loader = torch.utils.data.DataLoader(datasets.ImageFolder( valdir, val_transform), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) numberofclass = 1000 else: raise Exception('unknown dataset: {}'.format(args.dataset)) print("=> creating model '{}'".format(args.net_type)) if args.net_type == 'resnet': model = RN.ResNet(args.dataset, args.depth, numberofclass, args.bottleneck) # for ResNet elif args.net_type == 'pyramidnet': model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass, args.bottleneck) else: raise Exception('unknown network architecture: {}'.format( args.net_type)) model = torch.nn.DataParallel(model).cuda() if os.path.isfile(args.pretrained): print("=> loading checkpoint '{}'".format(args.pretrained)) checkpoint = torch.load(args.pretrained) model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}'(best err1: {}%)".format( args.pretrained, checkpoint['best_err1'])) else: raise Exception("=> no checkpoint found at '{}'".format( args.pretrained)) print('the number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() cudnn.benchmark = True # evaluate on validation set err1, err5, val_loss = validate(val_loader, model, criterion) print('Accuracy (top-1 and 5 error):', err1, err5)
def main(): global args, best_err1, best_err5 args = parser.parse_args() if args.seed >= 0: np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) cudnn.benchmark = True # Save path args.expname += args.method if args.transport: args.expname += '_tp' args.expname += '_prob_' + str(args.mixup_prob) if args.clean_lam > 0: args.expname += '_clean_' + str(args.clean_lam) if args.seed >= 0: args.expname += '_seed' + str(args.seed) print("Model is saved at {}".format(args.expname)) # Dataset and loader if args.dataset.startswith('cifar'): mean = [x / 255.0 for x in [125.3, 123.0, 113.9]] std = [x / 255.0 for x in [63.0, 62.1, 66.7]] normalize = transforms.Normalize(mean=mean, std=std) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=args.padding), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([transforms.ToTensor(), normalize]) if args.dataset == 'cifar100': train_loader = torch.utils.data.DataLoader(datasets.CIFAR100('~/Datasets/cifar100/', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(datasets.CIFAR100('~/Datasets/cifar100/', train=False, transform=transform_test), batch_size=args.batch_size // 4, shuffle=True, num_workers=args.workers, pin_memory=True) numberofclass = 100 elif args.dataset == 'cifar10': train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('../data', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(datasets.CIFAR10('../data', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) numberofclass = 10 else: raise Exception('unknown dataset: {}'.format(args.dataset)) elif args.dataset == 'imagenet': traindir = os.path.join('/data/readonly/ImageNet-Fast/imagenet/train') valdir = os.path.join('/data/readonly/ImageNet-Fast/imagenet/val') mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] normalize = transforms.Normalize(mean=mean, std=std) jittering = utils.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4) lighting = utils.Lighting(alphastd=0.1, eigval=[0.2175, 0.0188, 0.0045], eigvec=[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]) train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), jittering, lighting, normalize, ])) train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]) val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(valdir, val_transform), batch_size=args.batch_size // 4, shuffle=False, num_workers=args.workers, pin_memory=True) numberofclass = 1000 args.neigh_size = min(args.neigh_size, 2) else: raise Exception('unknown dataset: {}'.format(args.dataset)) # Model print("=> creating model '{}'".format(args.net_type)) if args.net_type == 'resnet': model = RN.ResNet(args.dataset, args.depth, numberofclass, args.bottleneck) # for ResNet elif args.net_type == 'pyramidnet': model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass, args.bottleneck) else: raise Exception('unknown network architecture: {}'.format(args.net_type)) pretrained = "runs/{}/{}".format(args.expname, 'checkpoint.pth.tar') if os.path.isfile(pretrained): print("=> loading checkpoint '{}'".format(pretrained)) checkpoint = torch.load(pretrained) checkpoint['state_dict'] = dict( (key[7:], value) for (key, value) in checkpoint['state_dict'].items()) model.load_state_dict(checkpoint['state_dict']) cur_epoch = checkpoint['epoch'] + 1 best_err1 = checkpoint['best_err1'] print("=> loaded checkpoint '{}'(epoch: {}, best err1: {}%)".format( pretrained, cur_epoch, checkpoint['best_err1'])) else: cur_epoch = 0 print("=> no checkpoint found at '{}'".format(pretrained)) model = torch.nn.DataParallel(model).cuda() print('the number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() criterion_batch = nn.CrossEntropyLoss(reduction='none').cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) if os.path.isfile(pretrained): optimizer.load_state_dict(checkpoint['optimizer']) print("optimizer is loaded!") mean_torch = torch.tensor(mean, dtype=torch.float32).reshape(1, 3, 1, 1).cuda() std_torch = torch.tensor(std, dtype=torch.float32).reshape(1, 3, 1, 1).cuda() if args.mp > 0: mp = Pool(args.mp) else: mp = None # Start training and validation for epoch in range(cur_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) # train for one epoch train_loss = train(train_loader, model, criterion, criterion_batch, optimizer, epoch, mean_torch, std_torch, mp) # evaluate on validation set err1, err5, val_loss = validate(val_loader, model, criterion, epoch) # remember best prec@1 and save checkpoint is_best = err1 <= best_err1 best_err1 = min(err1, best_err1) if is_best: best_err5 = err5 print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5) save_checkpoint( { 'epoch': epoch, 'arch': args.net_type, 'state_dict': model.state_dict(), 'best_err1': best_err1, 'best_err5': best_err5, 'optimizer': optimizer.state_dict(), }, is_best) print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
def __init__(self): super(Solver, self).__init__() global numberofclass #define the network if args.net_type == 'resnet': self.model = RN.ResNet(dataset=args.dataset, depth=args.depth, num_classes=numberofclass, bottleneck=args.bottleneck) elif args.net_type == 'pyramidnet': self.model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass, args.bottleneck) elif args.net_type == 'wideresnet': self.model = WR.WideResNet(depth=args.depth, num_classes=numberofclass, widen_factor=args.width) elif args.net_type == 'vggnet': self.model = VGG.vgg16(num_classes=numberofclass) elif args.net_type == 'mobilenet': self.model = MN.mobile_half(num_classes=numberofclass) elif args.net_type == 'shufflenet': self.model = SN.ShuffleV2(num_classes=numberofclass) elif args.net_type == 'densenet': self.model = DN.densenet_cifar(num_classes=numberofclass) elif args.net_type == 'resnext-2': self.model = ResNeXt29_2x64d(num_classes=numberofclass) elif args.net_type == 'resnext-4': self.model = ResNeXt29_4x64d(num_classes=numberofclass) elif args.net_type == 'resnext-32': self.model = ResNeXt29_32x4d(num_classes=numberofclass) elif args.net_type == 'imagenetresnet18': self.model = multi_resnet18_kd(num_classes=numberofclass) elif args.net_type == 'imagenetresnet34': self.model = multi_resnet34_kd(num_classes=numberofclass) elif args.net_type == 'imagenetresnet50': self.model = multi_resnet50_kd(num_classes=numberofclass) elif args.net_type == 'imagenetresnet101': self.model = multi_resnet101_kd(num_classes=numberofclass) elif args.net_type == 'imagenetresnet152': self.model = multi_resnet152_kd(num_classes=numberofclass) else: raise Exception('unknown network architecture: {}'.format( args.net_type)) self.optimizer = torch.optim.SGD(self.model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) self.loss_lams = torch.zeros(numberofclass, numberofclass, dtype=torch.float32).cuda() self.loss_lams.requires_grad = False #define the loss function if args.method == 'ce': self.criterion = nn.CrossEntropyLoss() elif args.method == 'sce': if args.dataset == 'cifar10': self.criterion = SCELoss(alpha=0.1, beta=1.0, num_classes=numberofclass) else: self.criterion = SCELoss(alpha=6.0, beta=0.1, num_classes=numberofclass) elif args.method == 'ls': self.criterion = label_smooth(num_classes=numberofclass) elif args.method == 'gce': self.criterion = generalized_cross_entropy( num_classes=numberofclass) elif args.method == 'jo': self.criterion = joint_optimization(num_classes=numberofclass) elif args.method == 'bootsoft': self.criterion = boot_soft(num_classes=numberofclass) elif args.method == 'boothard': self.criterion = boot_hard(num_classes=numberofclass) elif args.method == 'forward': self.criterion = Forward(num_classes=numberofclass) elif args.method == 'backward': self.criterion = Backward(num_classes=numberofclass) elif args.method == 'disturb': self.criterion = DisturbLabel(num_classes=numberofclass) elif args.method == 'ols': self.criterion = nn.CrossEntropyLoss() self.criterion = self.criterion.cuda()