def main(args, dst_folder): # best_ac only record the best top1_ac for validation set. best_ac = 0.0 # os.environ['CUDA_VISIBLE_DEVICES'] = '0' if args.cuda_dev == 1: torch.cuda.set_device(1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cudnn.deterministic = True # fix the GPU to deterministic mode torch.manual_seed(args.seed) # CPU seed if device == "cuda": torch.cuda.manual_seed_all(args.seed) # GPU seed random.seed(args.seed) # python seed for image transformation np.random.seed(args.seed) if args.dataset == 'svhn': mean = [x/255 for x in[127.5,127.5,127.5]] std = [x/255 for x in[127.5,127.5,127.5]] elif args.dataset == 'cifar100': mean = [0.5071, 0.4867, 0.4408] std = [0.2675, 0.2565, 0.2761] if args.DA == "standard": transform_train = transforms.Compose([ transforms.Pad(2, padding_mode='reflect'), transforms.RandomCrop(32), #transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ]) elif args.DA == "jitter": transform_train = transforms.Compose([ transforms.Pad(2, padding_mode='reflect'), transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1), transforms.RandomCrop(32), #SVHNPolicy(), #AutoAugment(), #transforms.RandomHorizontalFlip(), transforms.ToTensor(), #Cutout(n_holes=1,length=20), transforms.Normalize(mean, std), ]) else: print("Wrong value for --DA argument.") transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std), ]) # data loader train_loader, test_loader, train_noisy_indexes = data_config(args, transform_train, transform_test, dst_folder) if args.network == "MT_Net": print("Loading MT_Net...") model = MT_Net(num_classes = args.num_classes, dropRatio = args.dropout).to(device) elif args.network == "WRN28_2_wn": print("Loading WRN28_2...") model = WRN28_2_wn(num_classes = args.num_classes, dropout = args.dropout).to(device) elif args.network == "PreactResNet18_WNdrop": print("Loading preActResNet18_WNdrop...") model = PreactResNet18_WNdrop(drop_val = args.dropout, num_classes = args.num_classes).to(device) print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) milestones = args.M if args.swa == 'True': # to install it: # pip3 install torchcontrib # git clone https://github.com/pytorch/contrib.git # cd contrib # sudo python3 setup.py install from torchcontrib.optim import SWA #base_optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4) base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4) optimizer = SWA(base_optimizer, swa_lr=args.swa_lr) else: #optimizer = RAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-4) optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) loss_train_epoch = [] loss_val_epoch = [] acc_train_per_epoch = [] acc_val_per_epoch = [] new_labels = [] exp_path = os.path.join('./', 'noise_models_{0}'.format(args.experiment_name), str(args.labeled_samples)) res_path = os.path.join('./', 'metrics_{0}'.format(args.experiment_name), str(args.labeled_samples)) if not os.path.isdir(res_path): os.makedirs(res_path) if not os.path.isdir(exp_path): os.makedirs(exp_path) cont = 0 load = False save = True if args.initial_epoch != 0: initial_epoch = args.initial_epoch load = True save = False if args.dataset_type == 'sym_noise_warmUp': load = False save = True if load: if args.loss_term == 'Reg_ep': train_type = 'C' if args.loss_term == 'MixUp_ep': train_type = 'M' if args.dropout > 0.0: train_type = train_type + 'drop' + str(int(10*args.dropout)) if args.beta == 0.0: train_type = train_type + 'noReg' path = './checkpoints/warmUp_{6}_{5}_{0}_{1}_{2}_{3}_S{4}.hdf5'.format(initial_epoch, \ args.dataset, \ args.labeled_samples, \ args.network, \ args.seed, \ args.Mixup_Alpha, \ train_type) checkpoint = torch.load(path) print("Load model in epoch " + str(checkpoint['epoch'])) print("Path loaded: ", path) model.load_state_dict(checkpoint['state_dict']) print("Relabeling the unlabeled samples...") model.eval() initial_rand_relab = args.label_noise results = np.zeros((len(train_loader.dataset), 10), dtype=np.float32) for images, images_pslab, labels, soft_labels, index in train_loader: images = images.to(device) labels = labels.to(device) soft_labels = soft_labels.to(device) outputs = model(images) prob, loss = loss_soft_reg_ep(outputs, labels, soft_labels, device, args) results[index.detach().numpy().tolist()] = prob.cpu().detach().numpy().tolist() train_loader.dataset.update_labels_randRelab(results, train_noisy_indexes, initial_rand_relab) print("Start training...") for epoch in range(1, args.epoch + 1): st = time.time() scheduler.step() # train for one epoch print(args.experiment_name, args.labeled_samples) loss_per_epoch, top_5_train_ac, top1_train_acc_original_labels, \ top1_train_ac, train_time = train_CrossEntropy_partialRelab(\ args, model, device, \ train_loader, optimizer, \ epoch, train_noisy_indexes) loss_train_epoch += [loss_per_epoch] # test if args.validation_exp == "True": loss_per_epoch, acc_val_per_epoch_i = validating(args, model, device, test_loader) else: loss_per_epoch, acc_val_per_epoch_i = testing(args, model, device, test_loader) loss_val_epoch += loss_per_epoch acc_train_per_epoch += [top1_train_ac] acc_val_per_epoch += acc_val_per_epoch_i #################################################################################################### ############################# SAVING MODELS ########################### #################################################################################################### if not os.path.exists('./checkpoints'): os.mkdir('./checkpoints') if epoch == 1: best_acc_val = acc_val_per_epoch_i[-1] snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % ( epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth')) else: if acc_val_per_epoch_i[-1] > best_acc_val: best_acc_val = acc_val_per_epoch_i[-1] if cont > 0: try: os.remove(os.path.join(exp_path, 'opt_' + snapBest + '.pth')) os.remove(os.path.join(exp_path, snapBest + '.pth')) except OSError: pass snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestAccVal_%.5f' % ( epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth')) cont += 1 if epoch == args.epoch: snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f' % ( epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth')) #### Save models for ensembles: if (epoch >= 150) and (epoch%2 == 0) and (args.save_checkpoint == "True"): print("Saving model ...") out_path = './checkpoints/ENS_{0}_{1}'.format(args.experiment_name, args.labeled_samples) if not os.path.exists(out_path): os.makedirs(out_path) torch.save(model.state_dict(), out_path + "/epoch_{0}.pth".format(epoch)) ### Saving model to load it again # cond = epoch%1 == 0 if args.dataset_type == 'sym_noise_warmUp': if args.loss_term == 'Reg_ep': train_type = 'C' if args.loss_term == 'MixUp_ep': train_type = 'M' if args.dropout > 0.0: train_type = train_type + 'drop' + str(int(10*args.dropout)) if args.beta == 0.0: train_type = train_type + 'noReg' cond = (epoch==args.epoch) name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type) save = True else: cond = (epoch==args.epoch) name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type) save = True if cond and save: print("Saving models...") path = './checkpoints/{0}_{1}_{2}_{3}_{4}_S{5}.hdf5'.format(name, epoch, args.dataset, args.labeled_samples, args.network, args.seed) save_checkpoint({ 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(), 'loss_train_epoch' : np.asarray(loss_train_epoch), 'loss_val_epoch' : np.asarray(loss_val_epoch), 'acc_train_per_epoch' : np.asarray(acc_train_per_epoch), 'acc_val_per_epoch' : np.asarray(acc_val_per_epoch), 'labels': np.asarray(train_loader.dataset.soft_labels) }, filename = path) #################################################################################################### ############################ SAVING METRICS ########################### #################################################################################################### # Save losses: np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_train.npy', np.asarray(loss_train_epoch)) np.save(res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_val.npy', np.asarray(loss_val_epoch)) # save accuracies: np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_train.npy', np.asarray(acc_train_per_epoch)) np.save(res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_val.npy', np.asarray(acc_val_per_epoch)) # save the new labels new_labels.append(train_loader.dataset.labels) np.save(res_path + '/' + str(args.labeled_samples) + '_new_labels.npy', np.asarray(new_labels)) #logging.info('Epoch: [{}|{}], train_loss: {:.3f}, top1_train_ac: {:.3f}, top1_val_ac: {:.3f}, train_time: {:.3f}'.format(epoch, args.epoch, loss_per_epoch[-1], top1_train_ac, acc_val_per_epoch_i[-1], time.time() - st)) # applying swa if args.swa == 'True': optimizer.swap_swa_sgd() optimizer.bn_update(train_loader, model, device) if args.validation_exp == "True": loss_swa, acc_val_swa = validating(args, model, device, test_loader) else: loss_swa, acc_val_swa = testing(args, model, device, test_loader) snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_noise_%d_bestValLoss_%.5f_swaAcc_%.5f' % ( epoch, loss_per_epoch[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val, acc_val_swa[0]) torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth')) # save_fig(dst_folder) print('Best ac:%f' % best_acc_val) record_result(dst_folder, best_ac)
def main(args): best_ac = 0.0 ##################### # Initializing seeds and preparing GPU if args.cuda_dev == 1: torch.cuda.set_device(1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cudnn.deterministic = True # fix the GPU to deterministic mode torch.manual_seed(args.seed) # CPU seed if device == "cuda": torch.cuda.manual_seed_all(args.seed) # GPU seed random.seed(args.seed) # python seed for image transformation np.random.seed(args.seed) ##################### if args.dataset == 'cifar10': mean = [0.4914, 0.4822, 0.4465] std = [0.2023, 0.1994, 0.2010] elif args.dataset == 'cifar100': mean = [0.5071, 0.4867, 0.4408] std = [0.2675, 0.2565, 0.2761] if args.DA == "standard": transform_train = transforms.Compose([ transforms.Pad(2, padding_mode='reflect'), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ]) elif args.DA == "jitter": transform_train = transforms.Compose([ transforms.Pad(2, padding_mode='reflect'), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ]) else: print("Wrong value for --DA argument.") transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std), ]) # data lodaer train_loader, test_loader, unlabeled_indexes = data_config( args, transform_train, transform_test) if args.network == "MT_Net": print("Loading MT_Net...") model = MT_Net(num_classes=args.num_classes, dropRatio=args.dropout).to(device) elif args.network == "WRN28_2_wn": print("Loading WRN28_2...") model = WRN28_2_wn(num_classes=args.num_classes, dropout=args.dropout).to(device) elif args.network == "PreactResNet18_WNdrop": print("Loading preActResNet18_WNdrop...") model = PreactResNet18_WNdrop(drop_val=args.dropout, num_classes=args.num_classes).to(device) print('Total params: %2.fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) milestones = args.M if args.swa == 'True': # to install it: # pip3 install torchcontrib # git clone https://github.com/pytorch/contrib.git # cd contrib # sudo python3 setup.py install from torchcontrib.optim import SWA base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) optimizer = SWA(base_optimizer, swa_lr=args.swa_lr) else: optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) loss_train_epoch = [] loss_val_epoch = [] acc_train_per_epoch = [] acc_val_per_epoch = [] exp_path = os.path.join('./', 'ssl_models_{0}'.format(args.experiment_name), str(args.labeled_samples)) res_path = os.path.join('./', 'metrics_{0}'.format(args.experiment_name), str(args.labeled_samples)) if not os.path.isdir(res_path): os.makedirs(res_path) if not os.path.isdir(exp_path): os.makedirs(exp_path) cont = 0 load = False save = True if args.load_epoch != 0: load_epoch = args.load_epoch load = True save = False if args.dataset_type == 'ssl_warmUp': load = False save = True if load: if args.loss_term == 'Reg_ep': train_type = 'C' if args.loss_term == 'MixUp_ep': train_type = 'M' if args.dropout > 0.0: train_type = train_type + 'drop' + str(int(10 * args.dropout)) path = './checkpoints/warmUp_{0}_{1}_{2}_{3}_{4}_{5}_S{6}.hdf5'.format(train_type, \ args.Mixup_Alpha, \ load_epoch, \ args.dataset, \ args.labeled_samples, \ args.network, \ args.seed) checkpoint = torch.load(path) print("Load model in epoch " + str(checkpoint['epoch'])) print("Path loaded: ", path) model.load_state_dict(checkpoint['state_dict']) print("Relabeling the unlabeled samples...") model.eval() results = np.zeros((len(train_loader.dataset), args.num_classes), dtype=np.float32) for images, images_pslab, labels, soft_labels, index in train_loader: images = images.to(device) labels = labels.to(device) soft_labels = soft_labels.to(device) outputs = model(images) prob, loss = loss_soft_reg_ep(outputs, labels, soft_labels, device, args) results[index.detach().numpy().tolist()] = prob.cpu().detach( ).numpy().tolist() train_loader.dataset.update_labels(results, unlabeled_indexes) print("Start training...") #################################################################################################### ############################### TRAINING ############################## #################################################################################################### for epoch in range(1, args.epoch + 1): st = time.time() scheduler.step() # train for one epoch print(args.experiment_name, args.labeled_samples) loss_per_epoch_train, \ top_5_train_ac, \ top1_train_ac, \ train_time = train_CrossEntropy(args, model, device, \ train_loader, optimizer, \ epoch, unlabeled_indexes) loss_train_epoch += [loss_per_epoch_train] # test if args.validation_exp == "True": loss_per_epoch_test, acc_val_per_epoch_i = validating( args, model, device, test_loader) else: loss_per_epoch_test, acc_val_per_epoch_i = testing( args, model, device, test_loader) loss_val_epoch += loss_per_epoch_test acc_train_per_epoch += [top1_train_ac] acc_val_per_epoch += acc_val_per_epoch_i #################################################################################################### ############################# SAVING MODELS ########################### #################################################################################################### if not os.path.exists('./checkpoints'): os.mkdir('./checkpoints') if epoch == 1: best_acc_val = acc_val_per_epoch_i[-1] snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestAccVal_%.5f' % ( epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth')) else: if acc_val_per_epoch_i[-1] > best_acc_val: best_acc_val = acc_val_per_epoch_i[-1] if cont > 0: try: os.remove( os.path.join(exp_path, 'opt_' + snapBest + '.pth')) os.remove(os.path.join(exp_path, snapBest + '.pth')) except OSError: pass snapBest = 'best_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestAccVal_%.5f' % ( epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapBest + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapBest + '.pth')) cont += 1 if epoch == args.epoch: snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestValLoss_%.5f' % ( epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val) torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth')) ### Saving model to load it again # cond = epoch%1 == 0 if args.dataset_type == 'ssl_warmUp': if args.loss_term == 'Reg_ep': train_type = 'C' if args.loss_term == 'MixUp_ep': train_type = 'M' if args.dropout > 0.0: train_type = train_type + 'drop' + str(int(10 * args.dropout)) cond = (epoch == args.epoch) name = 'warmUp_{1}_{0}'.format(args.Mixup_Alpha, train_type) save = True else: cond = False if cond and save: print("Saving models...") path = './checkpoints/{0}_{1}_{2}_{3}_{4}_S{5}.hdf5'.format(name, epoch, args.dataset, \ args.labeled_samples, \ args.network, \ args.seed) save_checkpoint( { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss_train_epoch': np.asarray(loss_train_epoch), 'loss_val_epoch': np.asarray(loss_val_epoch), 'acc_train_per_epoch': np.asarray(acc_train_per_epoch), 'acc_val_per_epoch': np.asarray(acc_val_per_epoch), 'labels': np.asarray(train_loader.dataset.soft_labels) }, filename=path) #################################################################################################### ############################ SAVING METRICS ########################### #################################################################################################### # Save losses: np.save( res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_train.npy', np.asarray(loss_train_epoch)) np.save( res_path + '/' + str(args.labeled_samples) + '_LOSS_epoch_val.npy', np.asarray(loss_val_epoch)) # save accuracies: np.save( res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_train.npy', np.asarray(acc_train_per_epoch)) np.save( res_path + '/' + str(args.labeled_samples) + '_accuracy_per_epoch_val.npy', np.asarray(acc_val_per_epoch)) # applying swa if args.swa == 'True': optimizer.swap_swa_sgd() optimizer.bn_update(train_loader, model, device) if args.validation_exp == "True": loss_swa, acc_val_swa = validating(args, model, device, test_loader) else: loss_swa, acc_val_swa = testing(args, model, device, test_loader) snapLast = 'last_epoch_%d_valLoss_%.5f_valAcc_%.5f_labels_%d_bestValLoss_%.5f_swaAcc_%.5f' % ( epoch, loss_per_epoch_test[-1], acc_val_per_epoch_i[-1], args.labeled_samples, best_acc_val, acc_val_swa[0]) torch.save(model.state_dict(), os.path.join(exp_path, snapLast + '.pth')) torch.save(optimizer.state_dict(), os.path.join(exp_path, 'opt_' + snapLast + '.pth')) print('Best ac:%f' % best_acc_val)