Пример #1
0
def main():

    model = nets[args.model]().to(device)

    bind_lip(model, verbose=False)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=wd)   
    
    natural_acc = []
    robust_acc = []
    
    start_time = time.time()

    for epoch in range(1, args.epochs + 1):
        print(f'Epoch: {epoch:3d}', end='  ')
        # adjust learning rate for SGD
        adjust_learning_rate(optimizer, epoch)

        train(args, model, device, train_loader, optimizer, epoch)

        test(model, test_loader)

        recorder.step()

        if epoch == args.epochs:
            adversary = AutoAttack(model, norm='Linf', eps=1/255, version='standard')
            adversary.attacks_to_run = ['apgd-ce', 'apgd-t']

            model.eval()
            with torch.no_grad():
                for batch_idx, (inputs, targets) in enumerate(testloader):
                    inputs, targets = inputs.to(device), targets.to(device)
                    x_adv, robust_accuracy = adversary.run_standard_evaluation(inputs, targets, bs=128)
                    recorder.record('robust_accuracy', robust_accuracy)
                    break
            torch.save(model.state_dict(),
                       os.path.join(model_dir, f'{name}.pt'))

    used_time = (time.time() - start_time) / 3600
    print(f'Used time: {used_time:.2f} h')

    with open(f'{log_dir}/{name}_record.pkl', 'wb') as file:
        pickle.dump(recorder, file)

    # recorder.draw('lip_sum')
    recorder.draw_many('lip')
    recorder.draw('train_acc')
    recorder.draw('test_acc')

    clean_acc = recorder.test_acc[-1][1]
    write2text(log_dir, 'log', name, clean_acc=clean_acc, robust_accuracy=robust_accuracy)
def main():
    # init model, ResNet18() can be also used here for training
    if args.loss == 'alp' or args.loss == 'trades' or args.loss == 'pgd':
        print("normalize False")
        model = nets[args.model]().to(device)
    else:
        print("normalize True")
        model = nets[args.model](use_FNandWN=True).to(device)

    bind_lip(model, norm='1-norm', verbose=False)

    recorder = Recorder(f'{name}')

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    for epoch in range(1, args.epochs + 1):
        # adjust learning rate for SGD
        print(f'Epoch: {epoch:3d}', end='  ')
        adjust_learning_rate(optimizer, epoch)

        # adversarial training
        train(args, model, device, train_loader, optimizer, recorder, epoch)

        # evaluation on natural examples
        # print('==============')
        # eval_train(model, device, train_loader, recorder)
        eval_test(model, device, test_loader, recorder)
        # print('==============')

        # save checkpoint
        if (epoch >= args.start_freq) and (epoch % args.save_freq == 0):
            torch.save(model.state_dict(),
                       os.path.join(model_dir, f'{name}-epoch{epoch}.pt'))

        recorder.step()

    torch.save(model.state_dict(), os.path.join(model_dir, f'{name}.pt'))

    with open(f'{log_dir}/{name}_record.pkl', 'wb') as file:
        pickle.dump(recorder, file)

    recorder.draw('lip_sum')
    recorder.draw_many('lip')

    recorder.draw('train_acc')
    recorder.draw('test_acc')

    adversary = AutoAttack(model,
                           norm='Linf',
                           eps=8 / 255,
                           version='standard',
                           verbose=False)
    adversary.attacks_to_run = ['apgd-ce']

    model.eval()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            # print(inputs.max(), inputs.min())

            x_adv, robust_accuracy = adversary.run_standard_evaluation(inputs,
                                                                       targets,
                                                                       bs=128)
            print(f'robust_accuracy: {robust_accuracy}')
            break

    recorder.record('robust_accuracy', robust_accuracy)

    with open(f'{log_dir}/{name}_record.pkl', 'wb') as file:
        pickle.dump(recorder, file)
def main():
    model = nets[args.model]().to(device)

    bind_lip(model, norm='1-norm', verbose=False)

    recorder = Recorder(f'{name}', img_dir)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    for epoch in range(1, args.epochs + 1):
        # adjust learning rate for SGD
        print(f'Epoch: {epoch:3d}', end='  ')
        adjust_learning_rate(optimizer, epoch)

        train(args, model, device, train_loader, optimizer, recorder, epoch)

        eval_test(model, device, test_loader, recorder)

        # save checkpoint
        # if (epoch >= args.start_freq) and (epoch % args.save_freq == 0):
        #     torch.save(model.state_dict(),
        #                os.path.join(model_dir, f'{name}-epoch{epoch}.pt'))

        recorder.step()

    torch.save(model.state_dict(), os.path.join(model_dir, f'{name}.pt'))

    with open(f'{log_dir}/{name}_record.pkl', 'wb') as file:
        pickle.dump(recorder, file)

    # recorder.draw('lip_sum')
    recorder.draw_many('lip')

    recorder.draw('train_acc')
    recorder.draw('test_acc')

    adversary = AutoAttack(model,
                           norm='Linf',
                           eps=8 / 255,
                           version='standard',
                           verbose=True)
    adversary.attacks_to_run = ['apgd-ce', 'apgd-t']

    model.eval()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            # print(inputs.max(), inputs.min())

            x_adv, robust_accuracy = adversary.run_standard_evaluation(inputs,
                                                                       targets,
                                                                       bs=128)
            print(f'robust_accuracy: {robust_accuracy}')
            break

    recorder.record('robust_accuracy', robust_accuracy)

    with open(f'{log_dir}/{name}_record.pkl', 'wb') as file:
        pickle.dump(recorder, file)

    clean_acc = recorder.test_acc[-1][1]
    write2text(log_dir,
               'log',
               name,
               clean_acc=clean_acc,
               robust_accuracy=robust_accuracy)