Esempio n. 1
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='cuda', type=str)
    parser.add_argument('--pgd_type', default=None, type=str)
    args = parser.parse_args()

    # settings
    device = args.device
    weight_path = './weights/vgg16_e086_90.62.pth'

    # classification model
    net = VGG('VGG16').to(device)
    state_dict = torch.load(weight_path, map_location=device)
    net.load_state_dict(state_dict)
    net.eval()

    # test dataset
    test_dataloader = get_test_loader(batch_size=8)

    # PGD instance
    if args.pgd_type == 'linf':
        PGD = PGD_Linf(model=net, epsilon=8 * 4 / 255)
    elif args.pgd_type == 'l2':
        PGD = PGD_L2(model=net, epsilon=40 * 4 / 255)

    # PGD examples
    for images, labels in test_dataloader:

        images = images.to(device)
        labels = labels.to(device)
        images_adv = PGD.perturb(images, labels)

        outputs = net(images)
        outputs_adv = net(images_adv)

        for image, image_adv, output, output_adv in zip(
                images, images_adv, outputs, outputs_adv):

            img = recover_image(image)
            soft_label = F.softmax(output, dim=0).cpu().detach().numpy()

            img_adv = recover_image(image_adv)
            soft_label_adv = F.softmax(output_adv,
                                       dim=0).cpu().detach().numpy()

            l2_dist = torch.norm(image - image_adv, 2).item()
            linf_dist = torch.norm(image - image_adv, float('inf')).item()
            print('%s -> %s' % (IND2CLASS[np.argmax(soft_label)],
                                IND2CLASS[np.argmax(soft_label_adv)]))
            print('l2   dist = %.4f' % l2_dist)
            print('linf dist = %.4f' % linf_dist)
            print()

            plot_comparison(img, img_adv, soft_label, soft_label_adv)
            plt.show()

        break
Esempio n. 2
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='cuda', type=str)
    parser.add_argument('--epochs', default=200, type=int)
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--lr_decay', default=20, type=int)
    parser.add_argument('--pgd_type', default=None, type=str)
    parser.add_argument('--pgd_epsilon', default=8, type=int)
    parser.add_argument('--pgd_steps', default=10, type=int)
    parser.add_argument('--pgd_label', default=0, type=int)
    args = parser.parse_args()

    config = dict()
    config['device'] = args.device
    config['num_epoch'] = args.epochs
    config['batch_size'] = args.batch_size
    config['learning_rate'] = args.lr
    config['lr_decay'] = args.lr_decay
    config['pgd_type'] = args.pgd_type
    config['pgd_epsilon'] = args.pgd_epsilon
    config['pgd_steps'] = args.pgd_steps
    config['pgd_label'] = args.pgd_label

    # CIFAR-10 dataset (40000 + 10000)
    train_loader, valid_loader = get_train_valid_loader(
        batch_size=config['batch_size'])

    # classification network
    net = VGG('VGG16').to(device=config['device'])

    # train settings
    learning_rate = config['learning_rate']
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=learning_rate,
                          momentum=0.9,
                          weight_decay=5e-4)
    best_valid_acc1 = 0

    output_path = './train_pgd_{:%Y-%m-%d-%H-%M-%S}/'.format(datetime.now())
    log_file = output_path + 'train_log.txt'
    os.mkdir(output_path)

    for epoch_idx in range(1, config['num_epoch'] + 1):

        # learning rate scheduling
        if epoch_idx % config['lr_decay'] == 0:
            learning_rate *= 0.5
            optimizer = optim.SGD(net.parameters(),
                                  lr=learning_rate,
                                  momentum=0.9,
                                  weight_decay=5e-4)

        # train & valid
        if config['pgd_type'] == 'l2':
            PGD = PGD_L2(model=net,
                         epsilon=config['pgd_epsilon'] * 4 / 255,
                         num_steps=config['pgd_steps'])
            _ = train(train_loader,
                      net,
                      criterion,
                      log_file,
                      optimizer,
                      epoch_idx,
                      PGD=PGD,
                      config=config)
        elif config['pgd_type'] == 'linf':
            PGD = PGD_Linf(model=net,
                           epsilon=config['pgd_epsilon'] * 4 / 255,
                           num_steps=config['pgd_steps'])
            _ = train(train_loader,
                      net,
                      criterion,
                      log_file,
                      optimizer,
                      epoch_idx,
                      PGD=PGD,
                      config=config)
        elif config['pgd_type'] == 'fgsm':
            FG = FGSM(model=net, num_steps=config['pgd_epsilon'])
            _ = train(train_loader,
                      net,
                      criterion,
                      log_file,
                      optimizer,
                      epoch_idx,
                      PGD=FG,
                      config=config)
        else:
            _ = train(train_loader,
                      net,
                      criterion,
                      log_file,
                      optimizer,
                      epoch_idx,
                      PGD=None,
                      config=config)
        valid_acc1 = valid(valid_loader,
                           net,
                           criterion,
                           log_file,
                           config=config)

        # save best
        if valid_acc1 > best_valid_acc1:
            best_valid_acc1 = valid_acc1
            file_name = output_path + 'vgg16_e%03d_%.2f.pth' % (
                epoch_idx, best_valid_acc1)
            torch.save(net.state_dict(), file_name)
            print('epoch=%003d, acc=%.4f saved.\n' %
                  (epoch_idx, best_valid_acc1))
Esempio n. 3
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='cuda', type=str)
    parser.add_argument('--weight', default=None, type=str)
    parser.add_argument('--attack', default=None, type=str)
    parser.add_argument('--epsilon', default=8, type=int)
    parser.add_argument('--steps', default=10, type=int)
    args = parser.parse_args()

    if args.attack not in [None, 'fgsm', 'l2_pgd', 'linf_pgd']:
        print('--attack [fgsm or l2_pgd or linf_pgd]')
        exit()

    config = dict()
    config['device'] = args.device
    config['weight'] = args.weight
    config['attack'] = args.attack
    config['epsilon'] = args.epsilon
    config['steps'] = args.steps

    # CIFAR-10 dataset (10000)
    test_loader = get_test_loader(batch_size=32)

    # classification network
    net = VGG('VGG16').to(device=config['device'])
    if '.pth' in config['weight']:
        print(config['weight'])
        state_dict = torch.load(config['weight'],
                                map_location=config['device'])
        net.load_state_dict(state_dict)
        net.eval()

        # test
        if config['attack'] == None:
            _ = test_accuracy(test_loader, net, config, attack=None)
        elif config['attack'] == 'fgsm':
            attack_FGSM = FGSM(model=net, num_steps=config['epsilon'])
            _ = test_accuracy(test_loader, net, config, attack=attack_FGSM)
        elif config['attack'] == 'linf_pgd':
            attack_PGD_Linf = PGD_Linf(model=net,
                                       epsilon=config['epsilon'] * 4 / 255,
                                       num_steps=config['steps'])
            _ = test_accuracy(test_loader, net, config, attack=attack_PGD_Linf)
        elif config['attack'] == 'l2_pgd':
            attack_PGD_L2 = PGD_L2(model=net,
                                   epsilon=config['epsilon'] * 4 / 255,
                                   num_steps=config['steps'])
            _ = test_accuracy(test_loader, net, config, attack=attack_PGD_L2)

    else:
        weights = sorted(os.listdir(config['weight']))
        acc1_list = []
        for weight in weights:
            if '.pth' not in weight:
                continue
            print(weight)
            weight_path = os.path.join(config['weight'], weight)
            state_dict = torch.load(weight_path, map_location=config['device'])
            net.load_state_dict(state_dict)
            net.eval()

            # test
            if config['attack'] == None:
                acc1 = test_accuracy(test_loader, net, config, attack=None)
            elif config['attack'] == 'fgsm':
                attack_FGSM = FGSM(model=net, num_steps=config['epsilon'])
                acc1 = test_accuracy(test_loader,
                                     net,
                                     config,
                                     attack=attack_FGSM)
            elif config['attack'] == 'linf_pgd':
                attack_PGD_Linf = PGD_Linf(model=net,
                                           epsilon=config['epsilon'] * 4 / 255,
                                           num_steps=config['steps'])
                acc1 = test_accuracy(test_loader,
                                     net,
                                     config,
                                     attack=attack_PGD_Linf)
            elif config['attack'] == 'l2_pgd':
                attack_PGD_L2 = PGD_L2(model=net,
                                       epsilon=config['epsilon'] * 4 / 255,
                                       num_steps=config['steps'])
                acc1 = test_accuracy(test_loader,
                                     net,
                                     config,
                                     attack=attack_PGD_L2)

            acc1_list.append(acc1.cpu().item())
        print(acc1_list)
Esempio n. 4
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='cuda', type=str)
    parser.add_argument('--weight', default=None, type=str)
    parser.add_argument('--atn_sample', default=0.1, type=float)
    parser.add_argument('--atn_epoch', default=10, type=int)
    parser.add_argument('--atn_batch_size', default=32, type=int)
    parser.add_argument('--atn_weight', default=None, type=str)
    parser.add_argument('--atn_lr', default=1e-4, type=float)
    parser.add_argument('--atn_epsilon', default=8, type=int)
    args = parser.parse_args()

    config = dict()
    config['device'] = args.device
    config['weight'] = args.weight
    config['atn_sample'] = args.atn_sample
    config['atn_epoch'] = args.atn_epoch
    config['atn_batch_size'] = args.atn_batch_size
    config['atn_weight'] = args.atn_weight
    config['atn_lr'] = args.atn_lr
    config['atn_epsilon'] = args.atn_epsilon

    # CIFAR-10 dataset (10000)
    test_loader = get_test_loader(batch_size=32)

    # classification network
    net = VGG('VGG16').to(device=config['device'])

    weights = sorted(os.listdir(config['weight']))
    acc1_list = []
    for weight in weights:
        if '.pth' not in weight:
            continue
        print(weight)
        weight_path = os.path.join(config['weight'], weight)
        state_dict = torch.load(weight_path, map_location=config['device'])
        net.load_state_dict(state_dict)
        net.eval()

        # train ATN
        test_loader = get_test_loader(batch_size=config['atn_batch_size'])
        atn = P_ATN(model=net,
                    epsilon=config['atn_epsilon'] * 4 / 255,
                    weight=config['atn_weight'],
                    device=config['device'])

        for epoch_idx_atn in range(1, config['atn_epoch'] + 1):
            losses = []
            lossXs = []
            lossYs = []
            l2_lst = []
            for batch_idx, (images, labels) in enumerate(test_loader):
                if batch_idx == int(len(test_loader) * config['atn_sample']):
                    break
                loss, lossX, lossY, l2_dist = atn.train(
                    images, labels, learning_rate=config['atn_lr'])
                losses.append(loss)
                lossXs.append(lossX)
                lossYs.append(lossY)
                l2_lst.append(l2_dist)
            avg_loss = sum(losses) / len(losses)
            avg_lossX = sum(lossXs) / len(lossXs)
            avg_lossY = sum(lossYs) / len(lossYs)
            avg_l2 = sum(l2_lst) / len(l2_lst)
            print('[%3d / %3d] Avg.Loss: %.4f(%.4f, %.4f)\tAvg.L2-dist: %.4f' %
                  (epoch_idx_atn, config['atn_epoch'], avg_loss, avg_lossX,
                   avg_lossY, avg_l2))

        acc1 = test_accuracy(test_loader, net, config, attack=atn)
        acc1_list.append(acc1.cpu().item())

    print(acc1_list)
Esempio n. 5
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='cuda', type=str)
    parser.add_argument('--atn_epoch', default=10, type=int)
    parser.add_argument('--atn_batch_size', default=32, type=int)
    parser.add_argument('--atn_sample', default=0.1, type=float)
    parser.add_argument('--atn_weight', default=None, type=str)
    parser.add_argument('--atn_lr', default=1e-4, type=float)
    args = parser.parse_args()

    # settings
    config = dict()
    config['device'] = args.device
    config['atn_epoch'] = args.atn_epoch
    config['atn_batch_size'] = args.atn_batch_size
    config['atn_sample'] = args.atn_sample
    config['atn_weight'] = args.atn_weight
    config['atn_lr'] = args.atn_lr
    weight_path = './weights/vgg16_e086_90.62.pth'

    # classification model
    net = VGG('VGG16').to(config['device'])
    state_dict = torch.load(weight_path, map_location=config['device'])
    net.load_state_dict(state_dict)

    # train dataloader for testing
    atn_train_loader, _ = get_train_valid_loader(batch_size=config['atn_batch_size'], atn=int(config['atn_sample'] * 40000))
    train_loader, _ = get_train_valid_loader(batch_size=config['atn_batch_size'])

    # train ATN (from scratch or not)
    for eps in range(2, 17, 1):
        print('epsilon = %d' % (eps))
        atn = P_ATN(model=net,
                    epsilon=eps*4/255,
                    weight=config['atn_weight'],
                    device=config['device'])

        for epoch_idx in range(1, config['atn_epoch'] + 1):
            losses = []
            lossXs = []
            lossYs = []
            l2_lst = []
            for batch_idx, (images, labels) in enumerate(atn_train_loader):
                loss, lossX, lossY, l2_dist = atn.train(images, labels, learning_rate=config['atn_lr'])
                losses.append(loss)
                lossXs.append(lossX)
                lossYs.append(lossY)
                l2_lst.append(l2_dist)
            avg_loss = sum(losses) / len(losses)
            avg_lossX = sum(lossXs) / len(lossXs)
            avg_lossY = sum(lossYs) / len(lossYs)
            avg_l2 = sum(l2_lst) / len(l2_lst)
            # print('[%3d / %3d] Avg.Loss: %.4f(%.4f, %.4f)\tAvg.L2-dist: %.4f' % (epoch_idx, config['atn_epoch'], avg_loss, avg_lossX, avg_lossY, avg_l2))

        # ATN examples
        corr = 0
        corr_adv = 0
        l2_lst = []
        linf_lst = []

        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(train_loader, start=1):
                images = images.to(config['device'])
                labels = labels.to(config['device'])
                images_adv = atn.perturb(images)
                outputs = net(images)
                outputs_adv = net(images_adv)
                _, preds = outputs.max(1)
                _, preds_adv = outputs_adv.max(1)
                corr += preds.eq(labels).sum().item()
                corr_adv += preds_adv.eq(labels).sum().item()
            print('[%5d/%5d] corr:%5d\tcorr_adv:%5d' % (batch_idx, len(train_loader), corr, corr_adv))
Esempio n. 6
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='cuda', type=str)
    args = parser.parse_args()

    # settings
    device = args.device
    weight_path = './weights/vgg16_e086_90.62.pth'

    # classification model
    net = VGG('VGG16').to(device)
    state_dict = torch.load(weight_path, map_location=device)
    net.load_state_dict(state_dict)
    net.eval()

    # test dataset
    test_dataloader = get_test_loader(batch_size=8)

    # train ATN
    atn = P_ATN(model=net, epsilon=8 * 4 / 255, weight=None, device=device)
    for epoch_idx in range(3):
        print(epoch_idx)
        for batch_idx, (images, labels) in enumerate(test_dataloader):
            if batch_idx == 0:
                continue
            if batch_idx == 9:
                break
            _ = atn.train(images, labels)

    # ATN examples
    for images, labels in test_dataloader:

        images = images.to(device)
        images_adv = atn.perturb(images)

        outputs = net(images)
        outputs_adv = net(images_adv)

        for image, image_adv, output, output_adv in zip(
                images, images_adv, outputs, outputs_adv):

            img = recover_image(image)
            soft_label = F.softmax(output, dim=0).cpu().detach().numpy()

            img_adv = recover_image(image_adv)
            soft_label_adv = F.softmax(output_adv,
                                       dim=0).cpu().detach().numpy()

            l2_dist = torch.norm(image - image_adv, 2).item()
            linf_dist = torch.norm(image - image_adv, float('inf')).item()
            print('%s -> %s' % (IND2CLASS[np.argmax(soft_label)],
                                IND2CLASS[np.argmax(soft_label_adv)]))
            print('l2   dist = %.4f' % l2_dist)
            print('linf dist = %.4f' % linf_dist)
            print()

            plot_comparison(img, img_adv, soft_label, soft_label_adv)
            plt.show()

        break
Esempio n. 7
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='cuda', type=str)
    parser.add_argument('--atn_epoch', default=10, type=int)
    parser.add_argument('--atn_batch_size', default=32, type=int)
    parser.add_argument('--atn_sample', default=0.1, type=float)
    parser.add_argument('--atn_epsilon', default=8, type=int)
    parser.add_argument('--atn_weight', default=None, type=str)
    parser.add_argument('--atn_lr', default=1e-4, type=float)
    args = parser.parse_args()

    # settings
    config = dict()
    config['device'] = args.device
    config['atn_epoch'] = args.atn_epoch
    config['atn_batch_size'] = args.atn_batch_size
    config['atn_sample'] = args.atn_sample
    config['atn_epsilon'] = args.atn_epsilon
    config['atn_weight'] = args.atn_weight
    config['atn_lr'] = args.atn_lr
    weight_path = './weights/vgg16_e086_90.62.pth'

    # classification model
    net = VGG('VGG16').to(config['device'])
    state_dict = torch.load(weight_path, map_location=config['device'])
    net.load_state_dict(state_dict)
    net.eval()

    # train dataloader for testing
    atn_train_loader, _ = get_train_valid_loader(
        batch_size=config['atn_batch_size'],
        atn=int(config['atn_sample'] * 40000))

    # train ATN (from scratch or not)
    atn = P_ATN(model=net,
                epsilon=config['atn_epsilon'] * 4 / 255,
                weight=config['atn_weight'],
                device=config['device'])

    for epoch_idx in range(1, config['atn_epoch'] + 1):
        losses = []
        lossXs = []
        lossYs = []
        l2_lst = []
        for batch_idx, (images, labels) in enumerate(atn_train_loader):
            loss, lossX, lossY, l2_dist = atn.train(
                images, labels, learning_rate=config['atn_lr'])
            losses.append(loss)
            lossXs.append(lossX)
            lossYs.append(lossY)
            l2_lst.append(l2_dist)
        avg_loss = sum(losses) / len(losses)
        avg_lossX = sum(lossXs) / len(lossXs)
        avg_lossY = sum(lossYs) / len(lossYs)
        avg_l2 = sum(l2_lst) / len(l2_lst)
        print('[%3d / %3d] Avg.Loss: %.4f(%.4f, %.4f)\tAvg.L2-dist: %.4f' %
              (epoch_idx, config['atn_epoch'], avg_loss, avg_lossX, avg_lossY,
               avg_l2))

    # ATN examples
    corr = 0
    corr_adv = 0
    l2_lst = []
    linf_lst = []
    for batch_idx, (images, labels) in enumerate(atn_train_loader, start=1):

        images = images.to(config['device'])
        images_adv = atn.perturb(images)

        outputs = net(images)
        outputs_adv = net(images_adv)

        for image, image_adv, output, output_adv, label in zip(
                images, images_adv, outputs, outputs_adv, labels):

            soft_label = F.softmax(output, dim=0).cpu().detach().numpy()
            soft_label_adv = F.softmax(output_adv,
                                       dim=0).cpu().detach().numpy()

            label = label.item()
            pred = np.argmax(soft_label)
            pred_adv = np.argmax(soft_label_adv)

            if label == pred:
                corr += 1

            if label == pred_adv:
                corr_adv += 1

            l2_dist = torch.norm(image - image_adv, 2).item()
            linf_dist = torch.norm(image - image_adv, float('inf')).item()

            l2_lst.append(l2_dist)
            linf_lst.append(linf_dist)

    a = sum(l2_lst) / len(l2_lst)
    b = sum(linf_lst) / len(linf_lst)
    print('[%5d/%5d] corr:%5d\tcorr_adv:%5d\tavg.l2:%.4f\tavg.linf:%.4f' %
          (batch_idx, len(atn_train_loader), corr, corr_adv, a, b))
Esempio n. 8
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='cuda', type=str)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--lr_decay', default=20, type=int)
    parser.add_argument('--atn_sample', default=0.1, type=float)
    parser.add_argument('--atn_epoch', default=10, type=int)
    parser.add_argument('--atn_batch_size', default=32, type=int)
    parser.add_argument('--atn_weight', default=None, type=str)
    parser.add_argument('--atn_lr', default=1e-4, type=float)
    parser.add_argument('--atn_epsilon', default=8, type=int)
    parser.add_argument('--atn_debug', default=1, type=int)
    args = parser.parse_args()

    config = dict()
    config['device'] = args.device
    config['num_epoch'] = args.epochs
    config['batch_size'] = args.batch_size
    config['learning_rate'] = args.lr
    config['lr_decay'] = args.lr_decay
    config['atn_sample'] = args.atn_sample
    config['atn_epoch'] = args.atn_epoch
    config['atn_batch_size'] = args.atn_batch_size
    config['atn_weight'] = args.atn_weight
    config['atn_lr'] = args.atn_lr
    config['atn_epsilon'] = args.atn_epsilon
    config['atn_debug'] = args.atn_debug

    # CIFAR-10 dataset (40000 + 10000)
    train_loader, valid_loader = get_train_valid_loader(
        batch_size=config['batch_size'])

    # classification network
    net = VGG('VGG16').to(device=config['device'])

    # train settings
    learning_rate = config['learning_rate']
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=learning_rate,
                          momentum=0.9,
                          weight_decay=5e-4)
    best_valid_acc1 = 0

    output_path = './train_atn_{:%Y-%m-%d-%H-%M-%S}/'.format(datetime.now())
    log_file = output_path + 'train_log.txt'
    os.mkdir(output_path)

    for epoch_idx in range(1, config['num_epoch'] + 1):

        # learning rate scheduling
        if epoch_idx % config['lr_decay'] == 0:
            learning_rate *= 0.5
            optimizer = optim.SGD(net.parameters(),
                                  lr=learning_rate,
                                  momentum=0.9,
                                  weight_decay=5e-4)

        # train ATN
        atn_train_loader, _ = get_train_valid_loader(
            batch_size=config['atn_batch_size'],
            atn=int(config['atn_sample'] * 40000))
        atn = P_ATN(model=net,
                    epsilon=config['atn_epsilon'] * 4 / 255,
                    weight=config['atn_weight'],
                    device=config['device'])

        for epoch_idx_atn in range(1, config['atn_epoch'] + 1):
            losses = []
            lossXs = []
            lossYs = []
            l2_lst = []
            for batch_idx, (images, labels) in enumerate(atn_train_loader):
                loss, lossX, lossY, l2_dist = atn.train(
                    images, labels, learning_rate=config['atn_lr'])
                losses.append(loss)
                lossXs.append(lossX)
                lossYs.append(lossY)
                l2_lst.append(l2_dist)
            avg_loss = sum(losses) / len(losses)
            avg_lossX = sum(lossXs) / len(lossXs)
            avg_lossY = sum(lossYs) / len(lossYs)
            avg_l2 = sum(l2_lst) / len(l2_lst)
            print('[%3d / %3d] Avg.Loss: %.4f(%.4f, %.4f)\tAvg.L2-dist: %.4f' %
                  (epoch_idx_atn, config['atn_epoch'], avg_loss, avg_lossX,
                   avg_lossY, avg_l2))

        # DEBUG
        if config['atn_debug']:
            with torch.no_grad():
                corr = 0
                corr_adv = 0
                for batch_idx, (images, labels) in enumerate(valid_loader,
                                                             start=1):
                    images = images.to(config['device'])
                    labels = labels.to(config['device'])
                    images_adv = atn.perturb(images)
                    outputs = net(images)
                    outputs_adv = net(images_adv)
                    _, preds = outputs.max(1)
                    _, preds_adv = outputs_adv.max(1)
                    corr += preds.eq(labels).sum().item()
                    corr_adv += preds_adv.eq(labels).sum().item()
                print('[%5d/%5d] corr:%5d\tcorr_adv:%5d' %
                      (batch_idx, len(valid_loader), corr, corr_adv))

        # train & valid
        _ = train(train_loader,
                  net,
                  criterion,
                  log_file,
                  optimizer,
                  epoch_idx,
                  ATN=atn,
                  config=config)
        valid_acc1 = valid(valid_loader,
                           net,
                           criterion,
                           log_file,
                           config=config)

        # save best
        if valid_acc1 > best_valid_acc1:
            best_valid_acc1 = valid_acc1
            file_name = output_path + 'vgg16_e%03d_%.2f.pth' % (
                epoch_idx, best_valid_acc1)
            torch.save(net.state_dict(), file_name)
            print('epoch=%003d, acc=%.4f saved.\n' %
                  (epoch_idx, best_valid_acc1))