def main_worker(gpu, ngpus_per_node, args): global best_acc1 args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) # create model print("=> creating model '{}'".format(args.arch)) num_classes = 100 if args.dataset == 'cifar100' else 10 use_norm = True if args.loss_type == 'LDAM' else False model = models.__dict__[args.arch](num_classes=num_classes, use_norm=use_norm) # create two optimizers - one for feature extractor and one for classifier feat_params = [] feat_params_names = [] cls_params = [] cls_params_names = [] learnable_epsilons = torch.nn.Parameter(torch.ones(num_classes)) for name, params in model.named_parameters(): if params.requires_grad: if "linear" in name: cls_params_names += [name] cls_params += [params] else: feat_params_names += [name] feat_params += [params] print("Create Feat Optimizer") print(f"\tRequires Grad:{feat_params_names}") feat_optim = torch.optim.SGD(feat_params + [learnable_epsilons], args.feat_lr, momentum=args.momentum, weight_decay=args.weight_decay) print("Create Feat Optimizer") print(f"\tRequires Grad:{cls_params_names}") cls_optim = torch.optim.SGD(cls_params, args.cls_lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume or args.evaluation: curr_store_name = args.store_name if not args.evaluation and args.pretrained: curr_store_name = os.path.join(curr_store_name, os.path.pardir) filename = '%s/%s/ckpt.best.pth.tar' % (args.root_model, curr_store_name) if os.path.isfile(filename): print("=> loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename, map_location=f"cuda:{args.gpu}") model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( filename, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(filename)) if args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: # DataParallel will divide and allocate batch_size to all available GPUs model = torch.nn.DataParallel(model).cuda() cudnn.benchmark = True # Data loading code= transform_train = transforms.Compose([ transforms.RandomCrop( 32, padding=4 ), # fill parameter needs torchvision installed from source transforms.RandomHorizontalFlip(), CIFAR10Policy(), transforms.ToTensor(), Cutout( n_holes=1, length=16 ), # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py) transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) transform_val = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) if args.dataset == 'cifar10': original_train_dataset = IMBALANCECIFAR10(root='./data', imb_type=args.imb_type, imb_factor=args.imb_factor, rand_number=args.rand_number, train=True, download=True, transform=transform_val) augmented_train_dataset = IMBALANCECIFAR10( root='./data', imb_type=args.imb_type, imb_factor=args.imb_factor, rand_number=args.rand_number, train=True, download=True, transform=transform_train if not args.evaluation else transform_val) val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_val) elif args.dataset == 'cifar100': original_train_dataset = IMBALANCECIFAR100( root='./data', imb_type=args.imb_type, imb_factor=args.imb_factor, rand_number=args.rand_number, train=True, download=True, transform=transform_val) augmented_train_dataset = IMBALANCECIFAR100( root='./data', imb_type=args.imb_type, imb_factor=args.imb_factor, rand_number=args.rand_number, train=True, download=True, transform=transform_train if not args.evaluation else transform_val) val_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_val) else: warnings.warn('Dataset is not listed') return cls_num_list = augmented_train_dataset.get_cls_num_list() args.cls_num_list = cls_num_list train_labels = np.array(augmented_train_dataset.get_targets()).astype(int) # calculate balanced weights balanced_weights = torch.tensor(class_weight.compute_class_weight( 'balanced', np.unique(train_labels), train_labels), dtype=torch.float).cuda(args.gpu) lt_weights = torch.tensor(cls_num_list).float() / max(cls_num_list) def create_sampler(args_str): if args_str is not None and "resample" in args_str: sampler_type, n_resample = args_str.split(",") return ClassAwareSampler(train_labels, num_samples_cls=int(n_resample)) return None original_train_loader = torch.utils.data.DataLoader( original_train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=args.workers, pin_memory=True) # feature extractor dataloader feat_sampler = create_sampler(args.feat_sampler) feat_train_loader = torch.utils.data.DataLoader( augmented_train_dataset, batch_size=args.batch_size, shuffle=(feat_sampler is None), num_workers=args.workers, pin_memory=True, sampler=feat_sampler) if args.evaluation: # evaluate on validation set # calculate centroids on the train _, train_features, train_targets, _ = validate(original_train_loader, model, 0, args, train_labels, flag="train", save_out=True) # validate validate(val_loader, model, 0, args, train_labels, flag="val", save_out=True, base_features=train_features, base_targets=train_targets) quit() # create losses def create_loss_list(args_str): loss_ls = [] loss_str_ls = args_str.split(",") for loss_str in loss_str_ls: c_weights = None prefix = "" if "_bal" in loss_str: c_weights = balanced_weights prefix = "Balanced " loss_str = loss_str.split("_bal")[0] if "_lt" in loss_str: c_weights = lt_weights prefix = "Longtailed " loss_str = loss_str.split("_")[0] if loss_str == "ce": print(f"{prefix}CE", end=",") loss_ls += [ nn.CrossEntropyLoss(weight=c_weights).cuda(args.gpu) ] elif loss_str == "robust_loss": print(f"{prefix}Robust Loss", end=",") loss_ls += [ DROLoss(temperature=args.temperature, base_temperature=args.temperature, class_weights=c_weights, epsilons=learnable_epsilons) ] print() return loss_ls feat_losses = create_loss_list(args.feat_loss) cls_losses = create_loss_list(args.cls_loss) # init log for training if not args.evaluation: log_training = open( os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w') log_testing = open( os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tf_writer = None best_acc1 = 0 best_acc_contrastive = 0 for epoch in range(args.start_epoch, args.epochs): print("=============== Extract Train Centroids ===============") _, train_features, train_targets, _ = validate(feat_train_loader, model, epoch, args, train_labels, log_training, tf_writer, flag="train", verbose=True) if epoch < args.epochs - args.balanced_clf_nepochs: print("=============== Train Feature Extractor ===============") freeze_layers(model, fe_bool=True, cls_bool=False) train(feat_train_loader, model, feat_losses, epoch, feat_optim, args, train_features, train_targets) else: if epoch == args.epochs - args.balanced_clf_nepochs: print( "================ Loading Best Feature Extractor =================" ) # load best model curr_store_name = args.store_name filename = '%s/%s/ckpt.best.pth.tar' % (args.root_model, curr_store_name) checkpoint = torch.load( filename, map_location=f"cuda:{args.gpu}")['state_dict'] model.load_state_dict(checkpoint) print("=============== Train Classifier ===============") freeze_layers(model, fe_bool=False, cls_bool=True) train(feat_train_loader, model, cls_losses, epoch, cls_optim, args) print("=============== Extract Train Centroids ===============") _, train_features, train_targets, _ = validate(original_train_loader, model, epoch, args, train_labels, log_training, tf_writer, flag="train", verbose=False) print("=============== Validate ===============") acc1, _, _, contrastive_acc = validate(val_loader, model, epoch, args, train_labels, log_testing, tf_writer, flag="val", base_features=train_features, base_targets=train_targets) if epoch < args.epochs - args.balanced_clf_nepochs: is_best = contrastive_acc > best_acc_contrastive best_acc_contrastive = max(contrastive_acc, best_acc_contrastive) else: is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) print( f"Best Contrastive Acc: {best_acc_contrastive}, Best Cls Acc: {best_acc1}" ) log_testing.write( f"Best Contrastive Acc: {best_acc_contrastive}, Best Cls Acc: {best_acc1}" ) log_testing.flush() save_checkpoint( args, { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1 }, is_best)
criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) # , eta_min=1e-8 print('==> Preparing data..') transform_train = transforms.Compose( [ transforms.Resize((new_image_size, new_image_size)), transforms.RandomCrop(new_image_size, padding=4), # resolution transforms.RandomHorizontalFlip(), CIFAR10Policy(), transforms.ToTensor(), Cutout(n_holes=1, length=16), # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py) transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) transform_test = transforms.Compose([ transforms.Resize((new_image_size, new_image_size)), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10( root='./data', train=False, download=True, transform=transform_test)
def main(): print(args) if not osp.exists(args.dir): os.makedirs(args.dir) if args.use_gpu: torch.cuda.set_device(args.gpu) cudnn.enabled = True cudnn.benchmark = True if args.manualSeed is None: args.manualSeed = random.randint(1, 10000) np.random.seed(args.manualSeed) labeled_size = args.label_num + args.val_num num_classes = 10 data_dir = '../cifar10_data/' normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]) # transform is implemented inside zca dataloader dataloader = cifar.CIFAR10 if args.auto: transform_train = transforms.Compose([ transforms.RandomCrop( 32, padding=4, fill=128 ), # fill parameter needs torchvision installed from source transforms.RandomHorizontalFlip(), CIFAR10Policy(), transforms.ToTensor(), Cutout( n_holes=1, length=16 ), # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py) normalize ]) else: transform_train = transforms.Compose([ transforms.RandomCrop( 32, padding=4, fill=128 ), # fill parameter needs torchvision installed from source transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) transform_test = transforms.Compose([transforms.ToTensor(), normalize]) base_dataset = datasets.CIFAR10(data_dir, train=True, download=True) train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split( base_dataset.targets, int(args.label_num / 10)) labelset = CIFAR10_labeled(data_dir, train_labeled_idxs, train=True, transform=transform_train) labelset2 = CIFAR10_labeled(data_dir, train_labeled_idxs, train=True, transform=transform_test) unlabelset = CIFAR10_labeled(data_dir, train_unlabeled_idxs, train=True, transform=transform_train) unlabelset2 = CIFAR10_labeled(data_dir, train_unlabeled_idxs, train=True, transform=transform_test) validset = CIFAR10_labeled(data_dir, val_idxs, train=True, transform=transform_test) testset = CIFAR10_labeled(data_dir, train=False, transform=transform_test) label_y = np.array(labelset.targets).astype(np.int32) unlabel_y = np.array(unlabelset.targets).astype(np.int32) unlabel_num = unlabel_y.shape[0] label_loader = torch.utils.data.DataLoader(labelset, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=True) label_loader2 = torch.utils.data.DataLoader( labelset2, batch_size=args.eval_batch_size, num_workers=args.num_workers, pin_memory=True) unlabel_loader = torch.utils.data.DataLoader( unlabelset, batch_size=args.eval_batch_size, num_workers=args.num_workers, pin_memory=True) unlabel_loader2 = torch.utils.data.DataLoader( unlabelset2, batch_size=args.eval_batch_size, num_workers=args.num_workers, pin_memory=True) validloader = torch.utils.data.DataLoader(validset, batch_size=args.eval_batch_size, num_workers=args.num_workers, pin_memory=True) testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, num_workers=args.num_workers, pin_memory=True) #initialize models model1 = create_model(args.num_classes, args.model) model2 = create_model(args.num_classes, args.model) ema_model = create_model(args.num_classes, args.model) if args.use_gpu: model1 = model1.cuda() model2 = model2.cuda() ema_model = ema_model.cuda() for param in ema_model.parameters(): param.detach_() df = pd.DataFrame() stats_path = osp.join(args.dir, 'stats.txt') '''if prop > args.scale: prop = args.scale''' optimizer1 = AdamW(model1.parameters(), lr=args.lr) if args.init1 and osp.exists(args.init1): model1.load_state_dict( torch.load(args.init1, map_location='cuda:{}'.format(args.gpu))) ema_optimizer = WeightEMA(model1, ema_model, alpha=args.ema_decay) if args.init and osp.exists(args.init): model1.load_state_dict( torch.load(args.init, map_location='cuda:{}'.format(args.gpu))) _, best_acc = evaluate(validloader, ema_model, prefix='val') best_ema_path = osp.join(args.dir, 'best_ema.pth') best_model1_path = osp.join(args.dir, 'best_model1.pth') best_model2_path = osp.join(args.dir, 'best_model2.pth') init_path = osp.join(args.dir, 'init_ema.pth') init_path1 = osp.join(args.dir, 'init1.pth') init_path2 = osp.join(args.dir, 'init2.pth') torch.save(ema_model.state_dict(), init_path) torch.save(model1.state_dict(), init_path1) torch.save(model2.state_dict(), init_path2) torch.save(ema_model.state_dict(), best_ema_path) torch.save(model1.state_dict(), best_model1_path) skip_model2 = False end_iter = False confident_indices = np.array([], dtype=np.int64) all_indices = np.arange(unlabel_num).astype(np.int64) #no_help_indices = np.array([]).astype(np.int64) pseudo_labels = np.zeros(all_indices.shape, dtype=np.int32) steps_per_epoch = len(iter(label_loader)) max_epoch = args.steps // steps_per_epoch logger = logging.getLogger('init') file_handler = logging.FileHandler(osp.join(args.dir, 'init.txt')) logger.addHandler(file_handler) logger.setLevel(logging.INFO) for epoch in range(max_epoch * 4 // 5): if args.mix: train_init_mix(label_loader, model1, optimizer1, ema_optimizer, steps_per_epoch, epoch, logger=logger) else: train_init(label_loader, model1, optimizer1, ema_optimizer, steps_per_epoch, epoch, logger=logger) if epoch % 10 == 0: val_loss, val_acc = evaluate(validloader, ema_model, logger, 'valid') if val_acc >= best_acc: best_acc = val_acc evaluate(testloader, ema_model, logger, 'test') torch.save(ema_model.state_dict(), init_path) torch.save(model1.state_dict(), init_path1) adjust_learning_rate_adam(optimizer1, args.lr * 0.2) for epoch in range(max_epoch // 5): if args.mix: train_init_mix(label_loader, model1, optimizer1, ema_optimizer, steps_per_epoch, epoch, logger=logger) else: train_init(label_loader, model1, optimizer1, ema_optimizer, steps_per_epoch, epoch, logger=logger) if epoch % 10 == 0: val_loss, val_acc = evaluate(validloader, ema_model, logger, 'valid') if val_acc >= best_acc: best_acc = val_acc evaluate(testloader, ema_model, logger, 'test') torch.save(ema_model.state_dict(), init_path) torch.save(model1.state_dict(), init_path1) ema_model.load_state_dict(torch.load(init_path)) model1.load_state_dict(torch.load(init_path1)) logger.info('init train finished') evaluate(validloader, ema_model, logger, 'valid') evaluate(testloader, ema_model, logger, 'test') for i_round in range(args.round): mask = np.zeros(all_indices.shape, dtype=bool) mask[confident_indices] = True other_indices = all_indices[~mask] optimizer2 = AdamW(model2.parameters(), lr=args.lr) logger = logging.getLogger('model2_round_{}'.format(i_round)) file_handler = logging.FileHandler( osp.join(args.dir, 'model2_round_{}.txt'.format(i_round))) logger.addHandler(file_handler) logger.setLevel(logging.INFO) if args.auto: probs = predict_probs(ema_model, unlabel_loader2) else: probs = np.zeros((unlabel_num, args.num_classes)) for i in range(args.K): probs += predict_probs(ema_model, unlabel_loader) probs /= args.K pseudo_labels[other_indices] = probs.argmax(axis=1).astype( np.int32)[other_indices] #pseudo_labels = probs.argmax(axis=1).astype(np.int32) df2 = create_basic_stats_dataframe() df2['iter'] = i_round df2['train_acc'] = accuracy_score(unlabel_y, pseudo_labels) df = df.append(df2, ignore_index=True) df.to_csv(stats_path, index=False) #phase2: train model2 unlabelset.targets = pseudo_labels.copy() trainset = ConcatDataset([labelset, unlabelset]) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size2, num_workers=args.num_workers, pin_memory=True, shuffle=True) model2.load_state_dict(torch.load(init_path2)) best_val_epoch = 0 best_model2_acc = 0 steps_per_epoch = len(iter(trainloader)) max_epoch2 = args.steps2 // steps_per_epoch for epoch in range(max_epoch2): train_model2(trainloader, model2, optimizer2, epoch, logger=logger) val_loss, val_acc = evaluate(validloader, model2, logger, 'val') if val_acc >= best_model2_acc: best_model2_acc = val_acc best_val_epoch = epoch torch.save(model2.state_dict(), best_model2_path) evaluate(testloader, model2, logger, 'test') if (epoch - best_val_epoch) * steps_per_epoch > args.stop_steps2: break df.loc[df['iter'] == i_round, 'valid_acc'] = best_model2_acc df.loc[df['iter'] == i_round, 'valid_epoch'] = best_val_epoch df.to_csv(stats_path, index=False) model2.load_state_dict(torch.load(best_model2_path)) logger.info('model2 train finished') evaluate(trainloader, model2, logger, 'train') evaluate(validloader, model2, logger, 'val') evaluate(label_loader2, model2, logger, 'reward') evaluate(testloader, model2, logger, 'test') #phase3: get confidence of unlabeled data by labeled data, split confident and unconfident data '''if args.auto: probs = predict_probs(model2,unlabel_loader2) else: probs = np.zeros((unlabel_num,args.num_classes)) for i in range(args.K): probs += predict_probs(model2, unlabel_loader) probs /= args.K''' probs = predict_probs(model2, unlabel_loader2) new_pseudo_labels = probs.argmax(axis=1) confidences = probs[all_indices, pseudo_labels] if args.schedule == 'exp': confident_num = int((len(confident_indices) + args.label_num) * (1 + args.scale)) - args.label_num elif args.schedule == 'linear': confident_num = len(confident_indices) + int( unlabel_num * args.scale) old_confident_indices = confident_indices.copy() confident_indices = np.array([], dtype=np.int64) for j in range(args.num_classes): j_cands = (pseudo_labels == j) k_size = int(min(confident_num // args.num_classes, j_cands.sum())) logger.info('class: {}, confident size: {}'.format(j, k_size)) if k_size > 0: j_idx_top = all_indices[j_cands][ confidences[j_cands].argsort()[-k_size:]] confident_indices = np.concatenate( (confident_indices, all_indices[j_idx_top])) '''new_confident_indices = np.intersect1d(new_confident_indices, np.setdiff1d(new_confident_indices, no_help_indices)) new_confident_indices = new_confident_indices[(-confidences[new_confident_indices]).argsort()] confident_indices = np.concatenate((old_confident_indices, new_confident_indices))''' acc = accuracy_score(unlabel_y[confident_indices], pseudo_labels[confident_indices]) logger.info('confident data num:{}, prop: {:4f}, acc: {:4f}'.format( len(confident_indices), len(confident_indices) / len(unlabel_y), acc)) '''if len(old_confident_indices) > 0: acc = accuracy_score(unlabel_y[old_confident_indices],pseudo_labels[old_confident_indices]) logger.info('old confident data prop: {:4f}, acc: {:4f}'.format(len(old_confident_indices)/len(unlabel_y), acc)) acc = accuracy_score(unlabel_y[new_confident_indices],pseudo_labels[new_confident_indices]) logger.info('new confident data prop: {:4f}, acc: {:4f}'.format(len(new_confident_indices)/len(unlabel_y), acc))''' #unlabelset.train_labels_ul = pseudo_labels.copy() confident_dataset = torch.utils.data.Subset(unlabelset, confident_indices) #phase4: refine model1 by confident data and reward data #train_dataset = torch.utils.data.ConcatDataset([confident_dataset,labelset]) logger = logging.getLogger('model1_round_{}'.format(i_round)) file_handler = logging.FileHandler( osp.join(args.dir, 'model1_round_{}.txt'.format(i_round))) logger.addHandler(file_handler) logger.setLevel(logging.INFO) best_val_epoch = 0 evaluate(validloader, ema_model, logger, 'valid') evaluate(testloader, ema_model, logger, 'test') optimizer1 = AdamW(model1.parameters(), lr=args.lr) confident_dataset = torch.utils.data.Subset(unlabelset, confident_indices) trainloader = torch.utils.data.DataLoader(confident_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, drop_last=True) #steps_per_epoch = len(iter(trainloader)) steps_per_epoch = 200 max_epoch1 = args.steps1 // steps_per_epoch for epoch in range(max_epoch1): '''current_num = int(cal_consistency_weight( (epoch + 1) * steps_per_epoch, init_ep=0, end_ep=args.stop_steps1//2, init_w=start_num, end_w=end_num)) current_confident_indices = confident_indices[:current_num] logger.info('current num: {}'.format(current_num))''' if args.mix: train_model1_mix(label_loader, trainloader, model1, optimizer1, ema_model, ema_optimizer, steps_per_epoch, epoch, logger=logger) else: train_model1(label_loader, trainloader, model1, optimizer1, ema_model, ema_optimizer, steps_per_epoch, epoch, logger=logger) val_loss, val_acc = evaluate(validloader, ema_model, logger, 'valid') if val_acc >= best_acc: best_acc = val_acc best_val_epoch = epoch evaluate(testloader, ema_model, logger, 'test') torch.save(model1.state_dict(), best_model1_path) torch.save(ema_model.state_dict(), best_ema_path) if (epoch - best_val_epoch) * steps_per_epoch > args.stop_steps1: break ema_model.load_state_dict(torch.load(best_ema_path)) model1.load_state_dict(torch.load(best_model1_path)) logger.info('model1 train finished') evaluate(validloader, ema_model, logger, 'valid') evaluate(testloader, ema_model, logger, 'test') '''no_help_indices = np.concatenate((no_help_indices,confident_indices[current_num:])) confident_indices = confident_indices[:current_num]''' if len(confident_indices) >= len(all_indices): break