shuffle=True, num_workers=args.prefetch, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_bs, shuffle=False, num_workers=args.prefetch, pin_memory=True) # Init checkpoints if not os.path.isdir(args.save): os.makedirs(args.save) # Init model, criterion, and optimizer net = wrn.WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate) print(net) if args.ngpu > 1: net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) if args.ngpu > 0: net.cuda() torch.manual_seed(1) if args.ngpu > 0: torch.cuda.manual_seed(1) optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'],
def main(): global args, best_prec1 args = parser.parse_args() if args.tensorboard: configure("runs/%s" % (args.name)) # Data loading code normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) if args.augment: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([transforms.ToTensor(), normalize]) kwargs = {'num_workers': 1, 'pin_memory': True} assert (args.dataset == 'cifar10' or args.dataset == 'cifar100') train_loader = torch.utils.data.DataLoader( datasets.__dict__[args.dataset.upper()]('../data', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.__dict__[args.dataset.upper()]('../data', train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs) # create model model = wrn.WideResNet(args.layers, args.dataset == 'cifar10' and 10 or 100, args.widen_factor, dropRate=args.droprate) # get the number of model parameters print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) # for training on multiple GPUs. # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use # model = torch.nn.DataParallel(model).cuda() model = model.cuda() # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # define loss function (criterion) and pptimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set prec1 = validate(val_loader, model, criterion, epoch) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best) print 'Best accuracy: ', best_prec1
def __init__(self, root='', train=True, meta=True, num_meta=1000, corruption_prob=0, corruption_type='unif', transform=None, target_transform=None, download=False, seed=1): self.root = root self.transform = transform self.target_transform = target_transform self.train = train # training set or test set self.meta = meta self.corruption_prob = corruption_prob self.num_meta = num_meta if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') # now load the picked numpy arrays if self.train: self.train_data = [] self.train_labels = [] self.train_coarse_labels = [] for fentry in self.train_list: f = fentry[0] file = os.path.join(root, self.base_folder, f) fo = open(file, 'rb') if sys.version_info[0] == 2: entry = pickle.load(fo) else: entry = pickle.load(fo, encoding='latin1') self.train_data.append(entry['data']) if 'labels' in entry: self.train_labels += entry['labels'] img_num_list = [int(self.num_meta/10)] * 10 num_classes = 10 else: self.train_labels += entry['fine_labels'] self.train_coarse_labels += entry['coarse_labels'] img_num_list = [int(self.num_meta/100)] * 100 num_classes = 100 fo.close() self.train_data = np.concatenate(self.train_data) self.train_data = self.train_data.reshape((50000, 3, 32, 32)) self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC data_list_val = {} for j in range(num_classes): data_list_val[j] = [i for i, label in enumerate(self.train_labels) if label == j] idx_to_meta = [] idx_to_train = [] print(img_num_list) for cls_idx, img_id_list in data_list_val.items(): np.random.shuffle(img_id_list) img_num = img_num_list[int(cls_idx)] idx_to_meta.extend(img_id_list[:img_num]) idx_to_train.extend(img_id_list[img_num:]) if meta is True: self.train_data = self.train_data[idx_to_meta] self.train_labels = list(np.array(self.train_labels)[idx_to_meta]) else: self.train_data = self.train_data[idx_to_train] self.train_labels = list(np.array(self.train_labels)[idx_to_train]) if corruption_type == 'hierarchical': self.train_coarse_labels = list(np.array(self.train_coarse_labels)[idx_to_meta]) if corruption_type == 'unif': C = uniform_mix_C(self.corruption_prob, num_classes) print(C) self.C = C elif corruption_type == 'flip': C = flip_labels_C(self.corruption_prob, num_classes) print(C) self.C = C elif corruption_type == 'flip2': C = flip_labels_C_two(self.corruption_prob, num_classes) print(C) self.C = C elif corruption_type == 'hierarchical': assert num_classes == 100, 'You must use CIFAR-100 with the hierarchical corruption.' coarse_fine = [] for i in range(20): coarse_fine.append(set()) for i in range(len(self.train_labels)): coarse_fine[self.train_coarse_labels[i]].add(self.train_labels[i]) for i in range(20): coarse_fine[i] = list(coarse_fine[i]) C = np.eye(num_classes) * (1 - corruption_prob) for i in range(20): tmp = np.copy(coarse_fine[i]) for j in range(len(tmp)): tmp2 = np.delete(np.copy(tmp), j) C[tmp[j], tmp2] += corruption_prob * 1/len(tmp2) self.C = C print(C) elif corruption_type == 'clabels': net = wrn.WideResNet(40, num_classes, 2, dropRate=0.3).cuda() model_name = './cifar{}_labeler'.format(num_classes) net.load_state_dict(torch.load(model_name)) net.eval() else: assert False, "Invalid corruption type '{}' given. Must be in {'unif', 'flip', 'hierarchical'}".format(corruption_type) np.random.seed(seed) if corruption_type == 'clabels': mean = [x / 255 for x in [125.3, 123.0, 113.9]] std = [x / 255 for x in [63.0, 62.1, 66.7]] test_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean, std)]) # obtain sampling probabilities sampling_probs = [] print('Starting labeling') for i in range((len(self.train_labels) // 64) + 1): current = self.train_data[i*64:(i+1)*64] current = [Image.fromarray(current[i]) for i in range(len(current))] current = torch.cat([test_transform(current[i]).unsqueeze(0) for i in range(len(current))], dim=0) data = V(current).cuda() logits = net(data) smax = F.softmax(logits / 5) # temperature of 1 sampling_probs.append(smax.data.cpu().numpy()) sampling_probs = np.concatenate(sampling_probs, 0) print('Finished labeling 1') new_labeling_correct = 0 argmax_labeling_correct = 0 for i in range(len(self.train_labels)): old_label = self.train_labels[i] new_label = np.random.choice(num_classes, p=sampling_probs[i]) self.train_labels[i] = new_label if old_label == new_label: new_labeling_correct += 1 if old_label == np.argmax(sampling_probs[i]): argmax_labeling_correct += 1 print('Finished labeling 2') print('New labeling accuracy:', new_labeling_correct / len(self.train_labels)) print('Argmax labeling accuracy:', argmax_labeling_correct / len(self.train_labels)) else: for i in range(len(self.train_labels)): self.train_labels[i] = np.random.choice(num_classes, p=C[self.train_labels[i]]) self.corruption_matrix = C else: f = self.test_list[0][0] file = os.path.join(root, self.base_folder, f) fo = open(file, 'rb') if sys.version_info[0] == 2: entry = pickle.load(fo) else: entry = pickle.load(fo, encoding='latin1') self.test_data = entry['data'] if 'labels' in entry: self.test_labels = entry['labels'] else: self.test_labels = entry['fine_labels'] fo.close() self.test_data = self.test_data.reshape((10000, 3, 32, 32)) self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
def main(): global args, best_prec1, best_test_prec1 global acc1_tr, losses_tr global losses_cl_tr global acc1_val, losses_val, losses_et_val global acc1_test, losses_test, losses_et_test global weights_cl args = parser.parse_args() print(args) if args.dataset == 'svhn': drop_rate = 0.3 widen_factor = 3 else: drop_rate = 0.3 widen_factor = 3 # create model if args.arch == 'preresnet': print("Model: %s" % args.arch) model = preresnet_cifar.resnet(depth=32, num_classes=args.num_classes) elif args.arch == 'wideresnet': print("Model: %s" % args.arch) model = wideresnet.WideResNet(28, args.num_classes, widen_factor=widen_factor, dropRate=drop_rate, leakyRate=0.1) else: assert (False) if args.model == 'mt': import copy model_teacher = copy.deepcopy(model) model_teacher = torch.nn.DataParallel(model_teacher).cuda() model = torch.nn.DataParallel(model).cuda() print(model) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) if args.model == 'mt': model_teacher.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.optim == 'sgd' or args.optim == 'adam': pass else: print('Not Implemented Optimizer') assert (False) ckpt_dir = args.ckpt + '_' + args.dataset + '_' + args.arch + '_' + args.model + '_' + args.optim ckpt_dir = ckpt_dir + '_e%d' % (args.epochs) if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) print(ckpt_dir) cudnn.benchmark = True # Data loading code if args.dataset == 'cifar10': dataloader = cifar.CIFAR10 num_classes = 10 data_dir = '/tmp/' normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=2), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([ transforms.ToTensor(), normalize, ]) elif args.dataset == 'cifar10_zca': dataloader = cifar_zca.CIFAR10 num_classes = 10 data_dir = 'cifar10_zca/cifar10_gcn_zca_v2.npz' # transform is implemented inside zca dataloader transform_train = transforms.Compose([ transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) elif args.dataset == 'svhn': dataloader = svhn.SVHN num_classes = 10 data_dir = '/tmp/' normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=2), transforms.ToTensor(), normalize, ]) transform_test = transforms.Compose([ transforms.ToTensor(), normalize, ]) labelset = dataloader(root=data_dir, split='label', download=True, transform=transform_train, boundary=args.boundary) unlabelset = dataloader(root=data_dir, split='unlabel', download=True, transform=transform_train, boundary=args.boundary) batch_size_label = args.batch_size // 2 batch_size_unlabel = args.batch_size // 2 if args.model == 'baseline': batch_size_label = args.batch_size label_loader = data.DataLoader(labelset, batch_size=batch_size_label, shuffle=True, num_workers=args.workers, pin_memory=True) label_iter = iter(label_loader) unlabel_loader = data.DataLoader(unlabelset, batch_size=batch_size_unlabel, shuffle=True, num_workers=args.workers, pin_memory=True) unlabel_iter = iter(unlabel_loader) print("Batch size (label): ", batch_size_label) print("Batch size (unlabel): ", batch_size_unlabel) validset = dataloader(root=data_dir, split='valid', download=True, transform=transform_test, boundary=args.boundary) val_loader = data.DataLoader(validset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) testset = dataloader(root=data_dir, split='test', download=True, transform=transform_test) test_loader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # deifine loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss(size_average=False).cuda() criterion_mse = nn.MSELoss(size_average=False).cuda() criterion_kl = nn.KLDivLoss(size_average=False).cuda() criterion_l1 = nn.L1Loss(size_average=False).cuda() criterions = (criterion, criterion_mse, criterion_kl, criterion_l1) if args.optim == 'adam': print('Using Adam optimizer') optimizer = torch.optim.Adam(model.parameters(), args.lr, betas=(0.9, 0.999), weight_decay=args.weight_decay) elif args.optim == 'sgd': print('Using SGD optimizer') optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) for epoch in range(args.start_epoch, args.epochs): if args.optim == 'adam': print('Learning rate schedule for Adam') lr = adjust_learning_rate_adam(optimizer, epoch) elif args.optim == 'sgd': print('Learning rate schedule for SGD') lr = adjust_learning_rate(optimizer, epoch) # train for one epoch if args.model == 'baseline': print('Supervised Training') for i in range( 10 ): #baseline repeat 10 times since small number of training set prec1_tr, loss_tr = train_sup(label_loader, model, criterions, optimizer, epoch, args) weight_cl = 0.0 elif args.model == 'pi': print('Pi model') prec1_tr, loss_tr, loss_cl_tr, weight_cl = train_pi( label_loader, unlabel_loader, model, criterions, optimizer, epoch, args) elif args.model == 'mt': print('Mean Teacher model') prec1_tr, loss_tr, loss_cl_tr, prec1_t_tr, weight_cl = train_mt( label_loader, unlabel_loader, model, model_teacher, criterions, optimizer, epoch, args) else: print("Not Implemented ", args.model) assert (False) # evaluate on validation set prec1_val, loss_val = validate(val_loader, model, criterions, args, 'valid') prec1_test, loss_test = validate(test_loader, model, criterions, args, 'test') if args.model == 'mt': prec1_t_val, loss_t_val = validate(val_loader, model_teacher, criterions, args, 'valid') prec1_t_test, loss_t_test = validate(test_loader, model_teacher, criterions, args, 'test') # append values acc1_tr.append(prec1_tr) losses_tr.append(loss_tr) acc1_val.append(prec1_val) losses_val.append(loss_val) acc1_test.append(prec1_test) losses_test.append(loss_test) if args.model != 'baseline': losses_cl_tr.append(loss_cl_tr) if args.model == 'mt': acc1_t_tr.append(prec1_t_tr) acc1_t_val.append(prec1_t_val) acc1_t_test.append(prec1_t_test) weights_cl.append(weight_cl) learning_rate.append(lr) # remember best prec@1 and save checkpoint if args.model == 'mt': is_best = prec1_t_val > best_prec1 if is_best: best_test_prec1_t = prec1_t_test best_test_prec1 = prec1_test print("Best test precision: %.3f" % best_test_prec1_t) best_prec1 = max(prec1_t_val, best_prec1) dict_checkpoint = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'best_test_prec1': best_test_prec1, 'acc1_tr': acc1_tr, 'losses_tr': losses_tr, 'losses_cl_tr': losses_cl_tr, 'acc1_val': acc1_val, 'losses_val': losses_val, 'acc1_test': acc1_test, 'losses_test': losses_test, 'acc1_t_tr': acc1_t_tr, 'acc1_t_val': acc1_t_val, 'acc1_t_test': acc1_t_test, 'state_dict_teacher': model_teacher.state_dict(), 'best_test_prec1_t': best_test_prec1_t, 'weights_cl': weights_cl, 'learning_rate': learning_rate, } else: is_best = prec1_val > best_prec1 if is_best: best_test_prec1 = prec1_test print("Best test precision: %.3f" % best_test_prec1) best_prec1 = max(prec1_val, best_prec1) dict_checkpoint = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'best_test_prec1': best_test_prec1, 'acc1_tr': acc1_tr, 'losses_tr': losses_tr, 'losses_cl_tr': losses_cl_tr, 'acc1_val': acc1_val, 'losses_val': losses_val, 'acc1_test': acc1_test, 'losses_test': losses_test, 'weights_cl': weights_cl, 'learning_rate': learning_rate, } save_checkpoint(dict_checkpoint, is_best, args.arch.lower() + str(args.boundary), dirname=ckpt_dir)