def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    sys.stdout = Logger(
        osp.join(args.save_dir, 'log_' + 'CIFAR-10_PC_Loss' + '.txt'))

    if use_gpu:
        print("Currently using GPU: {}".format(args.gpu))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU")

    # Data Load
    num_classes = 10
    print('==> Preparing dataset')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data/cifar10',
                                            train=True,
                                            download=True,
                                            transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.train_batch,
                                              pin_memory=True,
                                              shuffle=True,
                                              num_workers=args.workers)

    testset = torchvision.datasets.CIFAR10(root='./data/cifar10',
                                           train=False,
                                           download=True,
                                           transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.test_batch,
                                             pin_memory=True,
                                             shuffle=False,
                                             num_workers=args.workers)

    # Loading the Model
    model = resnet(num_classes=num_classes, depth=110)

    if True:
        model = nn.DataParallel(model).cuda()

    criterion_xent = nn.CrossEntropyLoss()
    criterion_prox_1024 = Proximity(num_classes=num_classes,
                                    feat_dim=1024,
                                    use_gpu=use_gpu)
    criterion_prox_256 = Proximity(num_classes=num_classes,
                                   feat_dim=256,
                                   use_gpu=use_gpu)

    criterion_conprox_1024 = Con_Proximity(num_classes=num_classes,
                                           feat_dim=1024,
                                           use_gpu=use_gpu)
    criterion_conprox_256 = Con_Proximity(num_classes=num_classes,
                                          feat_dim=256,
                                          use_gpu=use_gpu)

    optimizer_model = torch.optim.SGD(model.parameters(),
                                      lr=args.lr_model,
                                      weight_decay=1e-04,
                                      momentum=0.9)

    optimizer_prox_1024 = torch.optim.SGD(criterion_prox_1024.parameters(),
                                          lr=args.lr_prox)
    optimizer_prox_256 = torch.optim.SGD(criterion_prox_256.parameters(),
                                         lr=args.lr_prox)

    optimizer_conprox_1024 = torch.optim.SGD(
        criterion_conprox_1024.parameters(), lr=args.lr_conprox)
    optimizer_conprox_256 = torch.optim.SGD(criterion_conprox_256.parameters(),
                                            lr=args.lr_conprox)

    filename = 'Models_Softmax/CIFAR10_Softmax.pth.tar'
    checkpoint = torch.load(filename)

    model.load_state_dict(checkpoint['state_dict'])
    optimizer_model.load_state_dict = checkpoint['optimizer_model']

    start_time = time.time()

    for epoch in range(args.max_epoch):

        adjust_learning_rate(optimizer_model, epoch)
        adjust_learning_rate_prox(optimizer_prox_1024, epoch)
        adjust_learning_rate_prox(optimizer_prox_256, epoch)

        adjust_learning_rate_conprox(optimizer_conprox_1024, epoch)
        adjust_learning_rate_conprox(optimizer_conprox_256, epoch)

        print("==> Epoch {}/{}".format(epoch + 1, args.max_epoch))
        train(model, criterion_xent, criterion_prox_1024, criterion_prox_256,
              criterion_conprox_1024, criterion_conprox_256, optimizer_model,
              optimizer_prox_1024, optimizer_prox_256, optimizer_conprox_1024,
              optimizer_conprox_256, trainloader, use_gpu, num_classes, epoch)

        if args.eval_freq > 0 and (epoch + 1) % args.eval_freq == 0 or (
                epoch + 1) == args.max_epoch:
            print("==> Test")  #Tests after every 10 epochs
            acc, err = test(model, testloader, use_gpu, num_classes, epoch)
            print("Accuracy (%): {}\t Error rate (%): {}".format(acc, err))

            state_ = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer_model': optimizer_model.state_dict(),
                'optimizer_prox_1024': optimizer_prox_1024.state_dict(),
                'optimizer_prox_256': optimizer_prox_256.state_dict(),
                'optimizer_conprox_1024': optimizer_conprox_1024.state_dict(),
                'optimizer_conprox_256': optimizer_conprox_256.state_dict(),
            }

            torch.save(state_, 'Models_PCL/CIFAR10_PCL.pth.tar')

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
示例#2
0
def main():
    # init model, ResNet18() can be also used here for training
    # model = WideResNet().to(device)
    if args.network == 'smallCNN':
        model = SmallCNN().to(device)
    elif args.network == 'wideResNet':
        model = WideResNet().to(device)
    elif args.network == 'resnet':
        model = ResNet().to(device)
    else:
        model = VGG(args.network, num_classes=10).to(device)
    sys.stdout = Logger(os.path.join(args.log_dir, args.log_file))
    print(model)
    criterion_prox = Proximity(10, args.feat_size, True)
    criterion_conprox = Con_Proximity(10, args.feat_size, True)
    optimizer_prox = optim.SGD(criterion_prox.parameters(), lr=args.lr_prox)
    optimizer_conprox = optim.SGD(criterion_conprox.parameters(), lr=args.lr_conprox)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    if args.fine_tune:
        base_dir = args.base_dir
        state_dict = torch.load("{}/{}_ep{}.pt".format(base_dir, args.base_model, args.checkpoint))
        opt = torch.load("{}/opt-{}_ep{}.tar".format(base_dir, args.base_model, args.checkpoint))
        model.load_state_dict(state_dict)
        optimizer.load_state_dict(opt)


    natural_acc = []
    robust_acc = []

    for epoch in range(1, args.epochs + 1):
        # adjust learning rate for SGD
        adjust_learning_rate(optimizer, epoch)
        adjust_learning_rate(optimizer_prox, epoch)
        adjust_learning_rate(optimizer_conprox, epoch)

        start_time = time.time()

        # adversarial training
        train(model, device, train_loader, optimizer,
              criterion_prox, optimizer_prox,
              criterion_conprox, optimizer_conprox, epoch)

        # evaluation on natural examples
        print('================================================================')
        print("Current time: {}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
        # eval_train(model, device, train_loader)
        # eval_test(model, device, test_loader)
        natural_err_total, robust_err_total = eval_adv_test_whitebox(model, device, test_loader)
        with open(os.path.join(stats_dir, '{}.txt'.format(args.save_model)), "a") as f:
            f.write("{} {} {}\n".format(epoch, natural_err_total, robust_err_total))

        print('using time:', datetime.timedelta(seconds=round(time.time() - start_time)))

        natural_acc.append(natural_err_total)
        robust_acc.append(robust_err_total)


        file_name = os.path.join(stats_dir, '{}_stat{}.npy'.format(args.save_model, epoch))
        # np.save(file_name, np.stack((np.array(self.train_loss), np.array(self.test_loss),
        #                              np.array(self.train_acc), np.array(self.test_acc),
        #                              np.array(self.elasticity), np.array(self.x_grads),
        #                              np.array(self.fgsms), np.array(self.pgds),
        #                              np.array(self.cws))))
        np.save(file_name, np.stack((np.array(natural_acc), np.array(robust_acc))))

        # save checkpoint
        if epoch % args.save_freq == 0:
            torch.save(model.state_dict(),
                       os.path.join(model_dir, '{}_ep{}.pt'.format(args.save_model, epoch)))
            torch.save(optimizer.state_dict(),
                       os.path.join(model_dir, 'opt-{}_ep{}.tar'.format(args.save_model, epoch)))
            print("Ep{}: Model saved as {}.".format(epoch, args.save_model))
        print('================================================================')