def main(): torch.manual_seed(args.seed) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu use_gpu = torch.cuda.is_available() if args.use_cpu: use_gpu = False sys.stdout = Logger( osp.join(args.save_dir, 'log_' + 'CIFAR-10_PC_Loss' + '.txt')) if use_gpu: print("Currently using GPU: {}".format(args.gpu)) cudnn.benchmark = True torch.cuda.manual_seed_all(args.seed) else: print("Currently using CPU") # Data Load num_classes = 10 print('==> Preparing dataset') transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch, pin_memory=True, shuffle=True, num_workers=args.workers) testset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch, pin_memory=True, shuffle=False, num_workers=args.workers) # Loading the Model model = resnet(num_classes=num_classes, depth=110) if True: model = nn.DataParallel(model).cuda() criterion_xent = nn.CrossEntropyLoss() criterion_prox_1024 = Proximity(num_classes=num_classes, feat_dim=1024, use_gpu=use_gpu) criterion_prox_256 = Proximity(num_classes=num_classes, feat_dim=256, use_gpu=use_gpu) criterion_conprox_1024 = Con_Proximity(num_classes=num_classes, feat_dim=1024, use_gpu=use_gpu) criterion_conprox_256 = Con_Proximity(num_classes=num_classes, feat_dim=256, use_gpu=use_gpu) optimizer_model = torch.optim.SGD(model.parameters(), lr=args.lr_model, weight_decay=1e-04, momentum=0.9) optimizer_prox_1024 = torch.optim.SGD(criterion_prox_1024.parameters(), lr=args.lr_prox) optimizer_prox_256 = torch.optim.SGD(criterion_prox_256.parameters(), lr=args.lr_prox) optimizer_conprox_1024 = torch.optim.SGD( criterion_conprox_1024.parameters(), lr=args.lr_conprox) optimizer_conprox_256 = torch.optim.SGD(criterion_conprox_256.parameters(), lr=args.lr_conprox) filename = 'Models_Softmax/CIFAR10_Softmax.pth.tar' checkpoint = torch.load(filename) model.load_state_dict(checkpoint['state_dict']) optimizer_model.load_state_dict = checkpoint['optimizer_model'] start_time = time.time() for epoch in range(args.max_epoch): adjust_learning_rate(optimizer_model, epoch) adjust_learning_rate_prox(optimizer_prox_1024, epoch) adjust_learning_rate_prox(optimizer_prox_256, epoch) adjust_learning_rate_conprox(optimizer_conprox_1024, epoch) adjust_learning_rate_conprox(optimizer_conprox_256, epoch) print("==> Epoch {}/{}".format(epoch + 1, args.max_epoch)) train(model, criterion_xent, criterion_prox_1024, criterion_prox_256, criterion_conprox_1024, criterion_conprox_256, optimizer_model, optimizer_prox_1024, optimizer_prox_256, optimizer_conprox_1024, optimizer_conprox_256, trainloader, use_gpu, num_classes, epoch) if args.eval_freq > 0 and (epoch + 1) % args.eval_freq == 0 or ( epoch + 1) == args.max_epoch: print("==> Test") #Tests after every 10 epochs acc, err = test(model, testloader, use_gpu, num_classes, epoch) print("Accuracy (%): {}\t Error rate (%): {}".format(acc, err)) state_ = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer_model': optimizer_model.state_dict(), 'optimizer_prox_1024': optimizer_prox_1024.state_dict(), 'optimizer_prox_256': optimizer_prox_256.state_dict(), 'optimizer_conprox_1024': optimizer_conprox_1024.state_dict(), 'optimizer_conprox_256': optimizer_conprox_256.state_dict(), } torch.save(state_, 'Models_PCL/CIFAR10_PCL.pth.tar') elapsed = round(time.time() - start_time) elapsed = str(datetime.timedelta(seconds=elapsed)) print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
def main(): # init model, ResNet18() can be also used here for training # model = WideResNet().to(device) if args.network == 'smallCNN': model = SmallCNN().to(device) elif args.network == 'wideResNet': model = WideResNet().to(device) elif args.network == 'resnet': model = ResNet().to(device) else: model = VGG(args.network, num_classes=10).to(device) sys.stdout = Logger(os.path.join(args.log_dir, args.log_file)) print(model) criterion_prox = Proximity(10, args.feat_size, True) criterion_conprox = Con_Proximity(10, args.feat_size, True) optimizer_prox = optim.SGD(criterion_prox.parameters(), lr=args.lr_prox) optimizer_conprox = optim.SGD(criterion_conprox.parameters(), lr=args.lr_conprox) optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.fine_tune: base_dir = args.base_dir state_dict = torch.load("{}/{}_ep{}.pt".format(base_dir, args.base_model, args.checkpoint)) opt = torch.load("{}/opt-{}_ep{}.tar".format(base_dir, args.base_model, args.checkpoint)) model.load_state_dict(state_dict) optimizer.load_state_dict(opt) natural_acc = [] robust_acc = [] for epoch in range(1, args.epochs + 1): # adjust learning rate for SGD adjust_learning_rate(optimizer, epoch) adjust_learning_rate(optimizer_prox, epoch) adjust_learning_rate(optimizer_conprox, epoch) start_time = time.time() # adversarial training train(model, device, train_loader, optimizer, criterion_prox, optimizer_prox, criterion_conprox, optimizer_conprox, epoch) # evaluation on natural examples print('================================================================') print("Current time: {}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))) # eval_train(model, device, train_loader) # eval_test(model, device, test_loader) natural_err_total, robust_err_total = eval_adv_test_whitebox(model, device, test_loader) with open(os.path.join(stats_dir, '{}.txt'.format(args.save_model)), "a") as f: f.write("{} {} {}\n".format(epoch, natural_err_total, robust_err_total)) print('using time:', datetime.timedelta(seconds=round(time.time() - start_time))) natural_acc.append(natural_err_total) robust_acc.append(robust_err_total) file_name = os.path.join(stats_dir, '{}_stat{}.npy'.format(args.save_model, epoch)) # np.save(file_name, np.stack((np.array(self.train_loss), np.array(self.test_loss), # np.array(self.train_acc), np.array(self.test_acc), # np.array(self.elasticity), np.array(self.x_grads), # np.array(self.fgsms), np.array(self.pgds), # np.array(self.cws)))) np.save(file_name, np.stack((np.array(natural_acc), np.array(robust_acc)))) # save checkpoint if epoch % args.save_freq == 0: torch.save(model.state_dict(), os.path.join(model_dir, '{}_ep{}.pt'.format(args.save_model, epoch))) torch.save(optimizer.state_dict(), os.path.join(model_dir, 'opt-{}_ep{}.tar'.format(args.save_model, epoch))) print("Ep{}: Model saved as {}.".format(epoch, args.save_model)) print('================================================================')