def get_teacher(p):
    # Get backbone
    if p['backbone'] == 'resnet18':
        if p['train_db_name'] in [
                'cifar-10', 'cifar-10-d', 'cifar-10-f', 'cifar-20',
                'cifar-20-d', 'cifar-20-f'
        ]:
            from models.resnet_cifar import resnet18
            backbone = resnet18()

        elif p['train_db_name'] in ['stl-10', 'stl-10-d', 'stl-10-f']:
            from models.resnet_stl import resnet18
            backbone = resnet18()

        elif p['train_db_name'] == 'pascal-pretrained':
            from models.resnet_pascal import resnet18
            backbone = resnet18()

        else:
            raise NotImplementedError
    elif p['backbone'] == 'resnet50':
        if 'imagenet' in p['train_db_name']:
            from models.resnet_wider import resnet50x1
            backbone = resnet50x1()
    else:
        raise ValueError('Invalid backbone {}'.format(p['backbone']))

    from models.models import ClusteringModel
    if p['teacher'] == 'selflabel':
        assert (p['num_heads'] == 1)
    teacher = ClusteringModel(backbone, p['num_classes'], p['num_heads'])

    return teacher
def get_model(p, pretrain_path=None):
    # Get backbone
    if p['backbone'] == 'resnet18':
        if p['train_db_name'] in ['cifar-10', 'cifar-20']:
            from models.resnet_cifar import resnet18
            backbone = resnet18()

        elif p['train_db_name'] == 'stl-10':
            from models.resnet_stl import resnet18
            backbone = resnet18()

        elif p['train_db_name'] in ['sewer']:
            from models.resnet_cifar import resnet18
            backbone = resnet18()

        else:
            raise NotImplementedError

    elif p['backbone'] == 'resnet50':
        if 'imagenet' in p['train_db_name']:
            from models.resnet import resnet50
            backbone = resnet50()

        else:
            raise NotImplementedError

    else:
        raise ValueError('Invalid backbone {}'.format(p['backbone']))

    # Setup
    if p['setup'] in ['simclr', 'moco']:
        from models.models import ContrastiveModel
        model = ContrastiveModel(backbone, **p['model_kwargs'])

    elif p['setup'] in ['scan', 'selflabel']:
        from models.models import ClusteringModel
        if p['setup'] == 'selflabel':
            assert (p['num_heads'] == 1)
        model = ClusteringModel(backbone, p['num_classes'], p['num_heads'])

    else:
        raise ValueError('Invalid setup {}'.format(p['setup']))

    # Load pretrained weights
    if pretrain_path is not None and os.path.exists(pretrain_path):
        state = torch.load(pretrain_path, map_location='cpu')

        if p['setup'] == 'scan':  # Weights are supposed to be transfered from contrastive training
            missing = model.load_state_dict(state, strict=False)
            assert (set(missing[1]) == {
                'contrastive_head.0.weight', 'contrastive_head.0.bias',
                'contrastive_head.2.weight', 'contrastive_head.2.bias'
            } or set(missing[1])
                    == {'contrastive_head.weight', 'contrastive_head.bias'})

        elif p['setup'] == 'selflabel':  # Weights are supposed to be transfered from scan
            # We only continue with the best head (pop all heads first, then copy back the best head)
            model_state = state['model']
            all_heads = [k for k in model_state.keys() if 'cluster_head' in k]
            best_head_weight = model_state['cluster_head.%d.weight' %
                                           (state['head'])]
            best_head_bias = model_state['cluster_head.%d.bias' %
                                         (state['head'])]
            for k in all_heads:
                model_state.pop(k)

            model_state['cluster_head.0.weight'] = best_head_weight
            model_state['cluster_head.0.bias'] = best_head_bias
            missing = model.load_state_dict(model_state, strict=True)

        else:
            raise NotImplementedError

    elif pretrain_path is not None and not os.path.exists(pretrain_path):
        raise ValueError(
            'Path with pre-trained weights does not exist {}'.format(
                pretrain_path))

    else:
        pass

    return model
Beispiel #3
0
train_loader = torch.utils.data.DataLoader(dataset=train_datasets,
                                           batch_size=args.batch_size,
                                           shuffle=True)

test_datasets = torchvision.datasets.CIFAR10(root=args.input_path,
                                             transform=transform_test,
                                             download=True,
                                             train=False)

test_loader = torch.utils.data.DataLoader(dataset=test_datasets,
                                          batch_size=args.batch_size,
                                          shuffle=True)

# Define Network
model = resnet_cifar.resnet18(pretrained=False)

# Load pre-trained weights
model_dict = model.state_dict()
pretrained_ae_model = torch.load(args.checkpoint_path)
model_key = []
model_value = []
pretrained_ae_key = []
pretrained_ae_value = []
for k, v in model_dict.items():
    print(k)
    model_key.append(k)
    model_value.append(v)
for k, v in pretrained_ae_model.items():
    print(k)
    pretrained_ae_key.append(k)
def train(opt):
    # set device to cpu/gpu
    if opt.use_gpu:
        device = torch.device("cuda", opt.gpu_id)
    else:
        device = torch.device("cpu")

    # Data transformations for data augmentation
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.RandomErasing(),
    ])
    transform_val = transforms.Compose([
        transforms.ToTensor(),
    ])

    # get CIFAR10/CIFAR100 train/val set
    if opt.dataset == "CIFAR10":
        alp_lambda = 0.5
        margin = 0.03
        lambda_loss = [2, 0.001]
        train_set = CIFAR10(root="./data",
                            train=True,
                            download=True,
                            transform=transform_train)
        val_set = CIFAR10(root="./data",
                          train=True,
                          download=True,
                          transform=transform_val)
    else:
        alp_lambda = 0.5
        margin = 0.03
        lambda_loss = [2, 0.001]
        train_set = CIFAR100(root="./data",
                             train=True,
                             download=True,
                             transform=transform_train)
        val_set = CIFAR100(root="./data",
                           train=True,
                           download=True,
                           transform=transform_val)
    num_classes = np.unique(train_set.targets).shape[0]

    # set stratified train/val split
    idx = list(range(len(train_set.targets)))
    train_idx, val_idx, _, _ = train_test_split(idx,
                                                train_set.targets,
                                                test_size=opt.val_split,
                                                random_state=42)

    # get train/val samplers
    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)

    # get train/val dataloaders
    train_loader = DataLoader(train_set,
                              sampler=train_sampler,
                              batch_size=opt.batch_size,
                              num_workers=opt.num_workers)
    val_loader = DataLoader(val_set,
                            sampler=val_sampler,
                            batch_size=opt.batch_size,
                            num_workers=opt.num_workers)

    data_loaders = {"train": train_loader, "val": val_loader}

    print(
        "Dataset -- {}, Metric -- {}, Train Mode -- {}, Backbone -- {}".format(
            opt.dataset, opt.metric, opt.train_mode, opt.backbone))
    print("Train iteration batch size: {}".format(opt.batch_size))
    print("Train iterations per epoch: {}".format(len(train_loader)))

    # get backbone model
    if opt.backbone == "resnet18":
        model = resnet18(pretrained=False)
    else:
        model = resnet34(pretrained=False)

    # set metric loss function
    model.fc = Softmax(model.fc.in_features, num_classes)

    model.to(device)
    if opt.use_gpu:
        model = DataParallel(model).to(device)

    criterion = CrossEntropyLoss()
    mse_criterion = MSELoss()

    # set optimizer and LR scheduler
    if opt.optimizer == "sgd":
        optimizer = SGD([{
            "params": model.parameters()
        }],
                        lr=opt.lr,
                        weight_decay=opt.weight_decay,
                        momentum=0.9)
    else:
        optimizer = Adam([{
            "params": model.parameters()
        }],
                         lr=opt.lr,
                         weight_decay=opt.weight_decay)
    if opt.scheduler == "decay":
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=opt.lr_step,
                                        gamma=opt.lr_decay)
    else:
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   factor=0.1,
                                                   patience=10)

    # train/val loop
    for epoch in range(opt.epoch):
        for phase in ["train", "val"]:
            total_examples, total_correct, total_loss = 0, 0, 0

            if phase == "train":
                model.train()
            else:
                model.eval()

            start_time = time.time()
            for ii, data in enumerate(data_loaders[phase]):
                # load data batch to device
                images, labels = data
                images = images.to(device)
                labels = labels.to(device).long()

                # perform adversarial attack update to images
                if opt.train_mode == "at" or opt.train_mode == "alp":
                    adv_images = pgd(model, images, labels, 8. / 255, 2. / 255,
                                     7)
                else:
                    pass

                # at train mode
                if opt.train_mode == "at":
                    # get feature embedding from resnet
                    features, _ = model(images, labels)
                    adv_features, adv_predictions = model(adv_images, labels)

                    # get triplet loss (margin 0.03 CIFAR10/100, 0.01 TinyImageNet from paper)
                    tpl_loss, _, mask = batch_all_triplet_loss(
                        labels, features, margin, adv_embeddings=adv_features)

                    # get adv anchor and clean pos/neg feature norm loss
                    norm = features.mm(features.t()).diag()
                    adv_norm = adv_features.mm(adv_features.t()).diag()
                    norm_loss = adv_norm[mask.nonzero()[0]] + norm[
                        mask.nonzero()[1]] + norm[mask.nonzero()[2]]
                    norm_loss = torch.sum(norm_loss) / \
                        (mask.nonzero()[0].shape[0] + mask.nonzero()
                         [1].shape[0] + mask.nonzero()[2].shape[0])

                    # get cross-entropy loss (only adv considering anchor examples)
                    adv_anchor_predictions = adv_predictions[np.unique(
                        mask.nonzero()[0])]
                    anchor_labels = labels[np.unique(mask.nonzero()[0])]
                    ce_loss = criterion(adv_anchor_predictions, anchor_labels)

                    # combine cross-entropy loss, triplet loss and feature norm loss using lambda weights
                    loss = ce_loss + lambda_loss[0] * \
                        tpl_loss + lambda_loss[1] * norm_loss
                    optimizer.zero_grad()

                    # for result accumulation
                    predictions = adv_anchor_predictions
                    labels = anchor_labels

                # alp train mode
                elif opt.train_mode == "alp":
                    # get feature embedding from resnet
                    features, predictions = model(images, labels)
                    adv_features, adv_predictions = model(adv_images, labels)

                    # get triplet loss (margin 0.03 CIFAR10/100, 0.01 TinyImageNet from paper)
                    tpl_loss, _, mask = batch_all_triplet_loss(
                        labels, features, margin, adv_embeddings=adv_features)

                    # get adv anchor and clean pos/neg feature norm loss
                    norm = features.mm(features.t()).diag()
                    adv_norm = adv_features.mm(adv_features.t()).diag()
                    norm_loss = adv_norm[mask.nonzero()[0]] + norm[
                        mask.nonzero()[1]] + norm[mask.nonzero()[2]]
                    norm_loss = torch.sum(norm_loss) / \
                        (mask.nonzero()[0].shape[0] + mask.nonzero()
                         [1].shape[0] + mask.nonzero()[2].shape[0])

                    # get cross-entropy loss (only considering adv anchor examples)
                    anchor_predictions = predictions[np.unique(
                        mask.nonzero()[0])]
                    adv_anchor_predictions = adv_predictions[np.unique(
                        mask.nonzero()[0])]
                    anchor_labels = labels[np.unique(mask.nonzero()[0])]
                    ce_loss = criterion(adv_anchor_predictions, anchor_labels)

                    # get alp loss
                    alp_loss = mse_criterion(adv_anchor_predictions,
                                             anchor_predictions)

                    # combine cross-entropy loss, triplet loss and feature norm loss using lambda weights
                    loss = ce_loss + lambda_loss[0] * \
                        tpl_loss + lambda_loss[1] * norm_loss
                    # combine loss with alp loss
                    loss = loss + alp_lambda * alp_loss
                    optimizer.zero_grad()

                    # for result accumulation
                    predictions = adv_anchor_predictions
                    labels = anchor_labels

                # clean train mode
                else:
                    # get feature embedding from resnet
                    features, predictions = model(images, labels)

                    # get triplet loss (margin 0.03 CIFAR10/100, 0.01 TinyImageNet from paper)
                    tpl_loss, _, mask = batch_all_triplet_loss(
                        labels, features, margin)

                    # get feature norm loss
                    norm = features.mm(features.t()).diag()
                    norm_loss = norm[mask.nonzero()[0]] + \
                        norm[mask.nonzero()[1]] + norm[mask.nonzero()[2]]
                    norm_loss = torch.sum(norm_loss) / \
                        (mask.nonzero()[0].shape[0] + mask.nonzero()
                         [1].shape[0] + mask.nonzero()[2].shape[0])

                    # get cross-entropy loss (only considering anchor examples)
                    anchor_predictions = predictions[np.unique(
                        mask.nonzero()[0])]
                    anchor_labels = labels[np.unique(mask.nonzero()[0])]
                    ce_loss = criterion(anchor_predictions, anchor_labels)

                    # combine cross-entropy loss, triplet loss and feature norm loss using lambda weights
                    loss = ce_loss + lambda_loss[0] * \
                        tpl_loss + lambda_loss[1] * norm_loss
                    optimizer.zero_grad()

                    # for result accumulation
                    predictions = anchor_predictions
                    labels = anchor_labels

                # only take step if in train phase
                if phase == "train":
                    loss.backward()
                    optimizer.step()

                # accumulate train or val results
                predictions = torch.argmax(predictions, 1)
                total_examples += predictions.size(0)
                total_correct += predictions.eq(labels).sum().item()
                total_loss += loss.item()

                # print accumulated train/val results at end of epoch
                if ii == len(data_loaders[phase]) - 1:
                    end_time = time.time()
                    acc = total_correct / total_examples
                    loss = total_loss / len(data_loaders[phase])
                    print(
                        "{}: Epoch -- {} Loss -- {:.6f} Acc -- {:.6f} Time -- {:.6f}sec"
                        .format(phase, epoch, loss, acc,
                                end_time - start_time))

                    if phase == "train":
                        loss = total_loss / len(data_loaders[phase])
                        scheduler.step(loss)
                    else:
                        print("")

    # save model after training for opt.epoch
    save_model(model, opt.dataset, opt.metric, opt.train_mode, opt.backbone)
Beispiel #5
0
                    type=int,
                    metavar='N',
                    help='mini-batch size (default: 128)')
parser.add_argument('--lr',
                    '--learning_rate',
                    default=0.01,
                    type=float,
                    metavar='LR',
                    help='initial learning rate')
parser.add_argument('--resume',
                    action='store_true',
                    default=True,
                    help='resume training')

args = parser.parse_args()
net = resnet18()

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
Beispiel #6
0
def adversarial_learning(best_cla_model_path, adv_example_path):
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # print(device)

    parser = argparse.ArgumentParser("Image classifical!")
    parser.add_argument("--epochs", type=int, default=200, help="Epoch default:50.")
    parser.add_argument("--image_size", type=int, default=32, help="Image Size default:28.")
    parser.add_argument("--batch_size", type=int, default=200, help="Batch_size default:256.")
    parser.add_argument("--lr", type=float, default=0.01, help="learing_rate. Default=0.01")
    parser.add_argument("--num_classes", type=int, default=10, help="num classes")
    args = parser.parse_args()

    # Load model
    model = resnet_cifar.resnet18(pretrained=False)
    model.to(device)
    # summary(model,(3,32,32))
    # print(model)

    # Load pre-trained weights
    # image_mean = torch.tensor([0.491, 0.482, 0.447]).view(1, 3, 1, 1)
    # image_std = torch.tensor([0.247, 0.243, 0.262]).view(1, 3, 1, 1)
    pretrained_cla_dict = torch.load(best_cla_model_path)
    model_dict = model.state_dict()

    pretrained_cla_weight_key = []
    pretrained_cla_weight_value = []
    model_weight_key = []
    model_weight_value = []

    for k, v in pretrained_cla_dict.items():
        # print(k)
        pretrained_cla_weight_key.append(k)
        pretrained_cla_weight_value.append(v)

    for k, v in model_dict.items():
        # print(k)
        model_weight_key.append(k)
        model_weight_value.append(v)

    new_dict = {}
    for i in range(len(model_dict)):
        new_dict[model_weight_key[i]] = pretrained_cla_weight_value[i + 2]

    model_dict.update(new_dict)
    model.load_state_dict(model_dict)
    # model = NormalizedModel(model=model, mean=image_mean, std=image_std)

    # model.load_state_dict(torch.load(best_cla_model_path))
    model.to(device)

    # criterion
    criterion = nn.CrossEntropyLoss().to(device)

    # batch_shape
    batchShape_adv = [args.batch_size, 3, args.image_size, args.image_size]
    batchShape_clean = [args.batch_size, 3, args.image_size, args.image_size]

    print("Waiting for Testing!")
    with torch.no_grad():
        # 测试clean test set
        total = 0
        correct_clean = 0
        correct_adv = 0
        for batchSize, images_clean, images_adv, labels in load_images(adv_example_path,
                                                                       batchShape_clean,
                                                                       batchShape_adv):
            model.eval()
            # print(labels[0:20])
            total += len(images_clean)
            images_clean = torch.from_numpy(images_clean).type(torch.FloatTensor).to(device)
            images_adv = torch.from_numpy(images_adv).type(torch.FloatTensor).to(device)
            labels = torch.from_numpy(labels).type(torch.LongTensor).to(device)
            model.to(device)
            # 测试clean数据集的测试集
            outputs_clean = model(images_clean)
            _, predicted_clean = torch.max(outputs_clean.data, 1)  # 取得分最高的那个类 (outputs.data的索引号)
            correct_clean += (predicted_clean == labels).sum().item()
            # 测试adv数据集的测试集
            outputs_adv = model(images_adv)
            _, predicted_adv = torch.max(outputs_adv.data, 1)  # 取得分最高的那个类 (outputs.data的索引号)
            correct_adv += (predicted_adv == labels).sum().item()
        # print(total)
        # print(correct_clean)
        # print(total_train)
        acc_clean = correct_clean / total * 100
        acc_adv = correct_adv / total * 100
        print("Clean Test Set Accuracy:%.2f%%" % acc_clean)
        print("Adv Test Set Accuracy:%.2f%%" % acc_adv)
def test_cifar(best_cla_model_path, best_com_model_path, adv_example_path):
    # Device configuration
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # print(device)

    parser = argparse.ArgumentParser("Image classifical!")
    parser.add_argument("--epochs", type=int, default=200, help="Epoch default:50.")
    parser.add_argument("--image_size", type=int, default=32, help="Image Size default:28.")
    parser.add_argument("--batch_size", type=int, default=200, help="Batch_size default:256.")
    parser.add_argument("--lr", type=float, default=0.01, help="learing_rate. Default=0.01")
    parser.add_argument("--num_classes", type=int, default=10, help="num classes")
    args = parser.parse_args()

    # Load model
    cla_model = resnet_cifar.resnet18(pretrained=False)
    com_model = ComDefend.ComDefend()

    # Load pre-trained weights
    cla_model.load_state_dict(torch.load(best_cla_model_path))
    com_model.load_state_dict(torch.load(best_com_model_path))

    cla_model.to(device)
    com_model.to(device)

    # batch_shape
    batchShape_adv = [args.batch_size, 3, args.image_size, args.image_size]
    batchShape_clean = [args.batch_size, 3, args.image_size, args.image_size]

    print("Waiting for Testing!")
    with torch.no_grad():
        # 测试clean test set
        total = 0
        correct_clean = 0
        correct_adv = 0
        for batchSize, images_clean, images_adv, labels in load_images_2(adv_example_path, batchShape_clean, batchShape_adv):
            cla_model.eval()
            com_model.eval()

            total += len(images_clean)
            images_clean = torch.from_numpy(images_clean).type(torch.FloatTensor).to(device)
            images_adv = torch.from_numpy(images_adv).type(torch.FloatTensor).to(device)

            com_images_clean = com_model(images_clean)
            com_images_adv = com_model(images_adv)

            labels = torch.from_numpy(labels).type(torch.LongTensor).to(device)
            cla_model.to(device)
            com_model.to(device)

            # 测试clean数据集的测试集
            outputs_clean = cla_model(com_images_clean)
            _, predicted_clean = torch.max(outputs_clean.data, 1)  # 取得分最高的那个类 (outputs.data的索引号)
            correct_clean += (predicted_clean == labels).sum().item()\

            # 测试adv数据集的测试集
            outputs_adv = cla_model(com_images_adv)
            _, predicted_adv = torch.max(outputs_adv.data, 1)  # 取得分最高的那个类 (outputs.data的索引号)
            correct_adv += (predicted_adv == labels).sum().item()

        acc_clean = correct_clean / total * 100
        acc_adv = correct_adv / total * 100
        print("Clean Test Set Accuracy:%.2f%%" % acc_clean)
        print("Adv Test Set Accuracy:%.2f%%" % acc_adv)
def train(opt):
    # set device to cpu/gpu
    if opt.use_gpu:
        device = torch.device("cuda", opt.gpu_id)
    else:
        device = torch.device("cpu")

    # Data transformations for data augmentation
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.RandomErasing(),
    ])
    transform_val = transforms.Compose([
        transforms.ToTensor(),
    ])

    # get CIFAR10/CIFAR100 train/val set
    if opt.dataset == "CIFAR10":
        alp_lambda = 0.5
        train_set = CIFAR10(root="./data",
                            train=True,
                            download=True,
                            transform=transform_train)
        val_set = CIFAR10(root="./data",
                          train=True,
                          download=True,
                          transform=transform_val)
    else:
        alp_lambda = 0.5
        train_set = CIFAR100(root="./data",
                             train=True,
                             download=True,
                             transform=transform_train)
        val_set = CIFAR100(root="./data",
                           train=True,
                           download=True,
                           transform=transform_val)
    num_classes = np.unique(train_set.targets).shape[0]

    # set stratified train/val split
    idx = list(range(len(train_set.targets)))
    train_idx, val_idx, _, _ = train_test_split(idx,
                                                train_set.targets,
                                                test_size=opt.val_split,
                                                random_state=42)

    # get train/val samplers
    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)

    # get train/val dataloaders
    train_loader = DataLoader(train_set,
                              sampler=train_sampler,
                              batch_size=opt.batch_size,
                              num_workers=opt.num_workers)
    val_loader = DataLoader(val_set,
                            sampler=val_sampler,
                            batch_size=opt.batch_size,
                            num_workers=opt.num_workers)

    data_loaders = {"train": train_loader, "val": val_loader}

    print(
        "Dataset -- {}, Metric -- {}, Train Mode -- {}, Backbone -- {}".format(
            opt.dataset, opt.metric, opt.train_mode, opt.backbone))
    print("Train iteration batch size: {}".format(opt.batch_size))
    print("Train iterations per epoch: {}".format(len(train_loader)))

    # get backbone model
    if opt.backbone == "resnet18":
        model = resnet18(pretrained=False)
    else:
        model = resnet34(pretrained=False)

    model.fc = Softmax(model.fc.in_features, num_classes)

    model.to(device)
    if opt.use_gpu:
        model = DataParallel(model).to(device)

    criterion = CrossEntropyLoss()
    mse_criterion = MSELoss()

    # set optimizer and LR scheduler
    if opt.optimizer == "sgd":
        optimizer = SGD([{
            "params": model.parameters()
        }],
                        lr=opt.lr,
                        weight_decay=opt.weight_decay,
                        momentum=0.9)
    else:
        optimizer = Adam([{
            "params": model.parameters()
        }],
                         lr=opt.lr,
                         weight_decay=opt.weight_decay)
    if opt.scheduler == "decay":
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=opt.lr_step,
                                        gamma=opt.lr_decay)
    else:
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   factor=0.1,
                                                   patience=10)

    # train/val loop
    best_acc = 0
    for epoch in range(opt.epoch):
        for phase in ["train", "val"]:
            total_examples, total_correct, total_loss = 0, 0, 0

            if phase == "train":
                model.train()
            else:
                model.eval()

            start_time = time.time()
            for ii, data in enumerate(data_loaders[phase]):
                # load data batch to device
                images, labels = data
                images = images.to(device)
                labels = labels.to(device).long()

                # perform adversarial attack update to images
                if opt.train_mode == "at" or opt.train_mode == "alp":
                    adv_images = pgd(model, images, labels, 8. / 255, 2. / 255,
                                     7)
                else:
                    pass

                # at train mode prediction
                if opt.train_mode == "at":
                    # logits for adversarial images
                    _, adv_predictions = model(adv_images, labels)

                    # get loss
                    loss = criterion(adv_predictions, labels)
                    optimizer.zero_grad()

                    # for result accumulation
                    predictions = adv_predictions

                # alp train mode prediction
                elif opt.train_mode == "alp":
                    # logits for clean and adversarial images
                    _, predictions = model(images, labels)
                    _, adv_predictions = model(adv_images, labels)

                    # get ce loss
                    ce_loss = criterion(adv_predictions, labels)

                    # get alp loss
                    alp_loss = mse_criterion(adv_predictions, predictions)

                    # get overall loss
                    loss = ce_loss + alp_lambda * alp_loss
                    optimizer.zero_grad()

                    # for result accumulation
                    predictions = adv_predictions

                # clean train mode prediction
                else:
                    # logits for clean images
                    _, predictions = model(images, labels)

                    # get loss
                    loss = criterion(predictions, labels)
                    optimizer.zero_grad()

                # only take step if in train phase
                if phase == "train":
                    loss.backward()
                    optimizer.step()

                # accumulate train or val results
                predictions = torch.argmax(predictions, 1)
                total_examples += predictions.size(0)
                total_correct += predictions.eq(labels).sum().item()
                total_loss += loss.item()

                # print accumulated train/val results at end of epoch
                if ii == len(data_loaders[phase]) - 1:
                    end_time = time.time()
                    acc = total_correct / total_examples
                    loss = total_loss / len(data_loaders[phase])
                    print(
                        "{}: Epoch -- {} Loss -- {:.6f} Acc -- {:.6f} Time -- {:.6f}sec"
                        .format(phase, epoch, loss, acc,
                                end_time - start_time))

                    if phase == "train":
                        loss = total_loss / len(data_loaders[phase])
                        scheduler.step(loss)
                    else:
                        if acc > best_acc:
                            print("Accuracy improved. Saving model")
                            best_acc = acc
                            save_model(model, opt.dataset, opt.metric,
                                       opt.train_mode, opt.backbone)
                        print("")

    # save model after training for opt.epoch
    if opt.test_bb:
        save_model(model, opt.dataset, "bb", "", opt.backbone)
    '''
def reTrain(best_cla_model_path, best_ae_epoch, round, block, ae_training_set,
            device_used):
    # Device configuration
    device = torch.device(device_used if torch.cuda.is_available() else "cpu")

    parser = argparse.ArgumentParser("Adversarial Examples")
    # parser.add_argument("--input_path", type=str, default="D:/python_workplace/resnet-AE/inputData/cifar/cifar10/cifar-10-batches-py/",
    #                     help="image dir path default: ../inputData/mnist/.")
    parser.add_argument(
        "--input_path",
        type=str,
        default="C:/Users/WenqingLiu/cifar/cifar10/cifar-10-batches-py/",
        help="data set dir path")
    parser.add_argument(
        "--checkpoint_path_ae",
        type=str,
        default="H:/python_workplace/resnet-AE/checkpoint/" + ae_training_set +
        "/Autoencoder/ResNet18/cifar10/block_" + str(block) + "/round_" +
        str(round) + "/model/model_" + str(best_ae_epoch) + ".pth",
        help="Path to checkpoint for ae network.")
    parser.add_argument(
        "--input_dir_trainSet",
        type=str,
        default="H:/python_workplace/resnet-AE/outputData/FGSM/cifar10/" +
        ae_training_set + "/block_" + str(block) + "/round_" + str(round) +
        "/train/train.pkl",
        help="data set dir path")
    parser.add_argument(
        "--input_dir_testSet",
        type=str,
        default="H:/python_workplace/resnet-AE/outputData/FGSM/cifar10/" +
        ae_training_set + "/block_" + str(block) + "/round_" + str(round) +
        "/test/test.pkl",
        help="data set dir path")
    parser.add_argument("--num_classes",
                        type=int,
                        default=10,
                        help="num classes")
    parser.add_argument("--image_size",
                        type=int,
                        default=32,
                        help="Size of each input images.")
    parser.add_argument("--epochs",
                        type=int,
                        default=50,
                        help="Epoch default:50.")
    parser.add_argument("--batch_size",
                        type=int,
                        default=128,
                        help="Batch_size default:256.")
    parser.add_argument("--lr",
                        type=float,
                        default=0.01,
                        help="learing_rate. Default=0.0001")
    parser.add_argument(
        "--model_path",
        type=str,
        default="H:/python_workplace/resnet-AE/checkpoint/" + ae_training_set +
        "/RetrainClassification/ResNet18/cifar10/block_" + str(block) +
        "/round_" + str(round) + "/model/",
        help="Save model path")
    parser.add_argument(
        "--acc_file_path",
        type=str,
        default="H:/python_workplace/resnet-AE/checkpoint/" + ae_training_set +
        "/RetrainClassification/ResNet18/cifar10/block_" + str(block) +
        "/round_" + str(round) + "/acc.txt",
        help="Save accuracy file")
    parser.add_argument(
        "--best_acc_file_path",
        type=str,
        default="H:/python_workplace/resnet-AE/checkpoint/" + ae_training_set +
        "/RetrainClassification/ResNet18/cifar10/block_" + str(block) +
        "/round_" + str(round) + "/best_acc.txt",
        help="Save best accuracy file")
    parser.add_argument(
        "--log_file_path",
        type=str,
        default="H:/python_workplace/resnet-AE/checkpoint/" + ae_training_set +
        "/RetrainClassification/ResNet18/cifar10/block_" + str(block) +
        "/round_" + str(round) + "/log.txt",
        help="Save log file")

    args = parser.parse_args()

    # 准备数据集并预处理
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  # 先四周填充0,在吧图像随机裁剪成32*32
        transforms.ColorJitter(brightness=1, contrast=2, saturation=3,
                               hue=0),  # 给图像增加一些随机的光照
        transforms.RandomHorizontalFlip(),  # 图像一半的概率翻转,一半的概率不翻转
        transforms.ToTensor(),  # 将numpy数据类型转化为Tensor
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  # R,G,B每层的归一化用到的均值和方差
    ])

    transform_test = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  # 先四周填充0,在吧图像随机裁剪成32*32
        transforms.ColorJitter(brightness=1, contrast=2, saturation=3,
                               hue=0),  # 给图像增加一些随机的光照
        transforms.RandomHorizontalFlip(),  # 图像一半的概率翻转,一半的概率不翻转
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Load data
    train_datasets = torchvision.datasets.CIFAR10(root=args.input_path,
                                                  transform=transform_train,
                                                  download=True,
                                                  train=True)

    train_loader = torch.utils.data.DataLoader(dataset=train_datasets,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    test_datasets = torchvision.datasets.CIFAR10(root=args.input_path,
                                                 transform=transform_test,
                                                 download=True,
                                                 train=False)

    test_loader = torch.utils.data.DataLoader(dataset=test_datasets,
                                              batch_size=args.batch_size,
                                              shuffle=True)

    model = resnet_cifar.resnet18(pretrained=False)
    model.to(device)

    pretrained_ae_dict = torch.load(args.checkpoint_path_ae)
    pretrained_cla_dict = torch.load(best_cla_model_path)
    model_dict = model.state_dict()

    pretrained_cla_weight_key = []
    pretrained_cla_weight_value = []
    pretrained_ae_weight_key = []
    pretrained_ae_weight_value = []
    model_weight_key = []
    model_weight_value = []
    for k, v in pretrained_cla_dict.items():
        # print(k)
        pretrained_cla_weight_key.append(k)
        pretrained_cla_weight_value.append(v)

    for k, v in pretrained_ae_dict.items():
        # print(k)
        pretrained_ae_weight_key.append(k)
        pretrained_ae_weight_value.append(v)

    for k, v in model_dict.items():
        # print(k)
        model_weight_key.append(k)
        model_weight_value.append(v)

    new_dict = {}
    for i in range(len(model_weight_key)):
        if i < (6 + 4 * 6 * block + 6 * (block - 1)):
            new_dict[model_weight_key[i]] = pretrained_ae_weight_value[i]
        else:
            new_dict[model_weight_key[i]] = pretrained_cla_weight_value[i]

    model_dict.update(new_dict)
    model.load_state_dict(model_dict)
    model.to(device)

    batchShape_clean = [args.batch_size, 3, args.image_size, args.image_size]
    batchShape_adv = [args.batch_size, 3, args.image_size, args.image_size]

    print(f"Train numbers:{len(train_datasets)}")
    print(f"Test numbers:{len(test_datasets)}")

    # criterion
    criterion = nn.CrossEntropyLoss().to(device)

    # length_train = len(train_loader)
    best_acc_test_clean = 0
    best_acc_test_adv = 0
    best_acc_train_clean = 0
    best_acc_train_adv = 0
    best_epoch = 1
    flag_test = 0
    flag_train = 0
    print("Start Training Resnet-18 After AutoEncoder!")
    with open(args.acc_file_path, "w") as f1:
        with open(args.log_file_path, "w") as f2:
            for epoch in range(0, args.epochs):
                if epoch + 1 <= 20:
                    args.lr = 0.01
                elif epoch + 1 > 20 & epoch + 1 <= 40:
                    args.lr = 0.001
                else:
                    args.lr = 0.0001

                # Optimization
                optimizer = optim.SGD(model.parameters(),
                                      lr=args.lr,
                                      momentum=0.9,
                                      weight_decay=5e-4)

                # 每个epoch之前测试一下准确率
                print("Waiting for Testing of Test Set!")
                with torch.no_grad():
                    correct_clean_test = 0
                    correct_adv_test = 0
                    total_test = 0
                    for batchSize, images_clean, images_adv, labels in load_images(
                            args.input_dir_testSet, batchShape_clean,
                            batchShape_adv):
                        model.eval()
                        images_clean = torch.from_numpy(images_clean).type(
                            torch.FloatTensor).to(device)
                        images_adv = torch.from_numpy(images_adv).type(
                            torch.FloatTensor).to(device)
                        labels = torch.from_numpy(labels).type(
                            torch.LongTensor).to(device)
                        model.to(device)
                        total_test += batchSize
                        # 测试clean数据集的测试集
                        outputs_clean = model(images_clean)
                        _, predicted_clean = torch.max(
                            outputs_clean.data,
                            1)  # 取得分最高的那个类 (outputs.data的索引号)
                        correct_clean_test += (
                            predicted_clean == labels).sum().item()
                        # 测试adv数据集的测试集
                        outputs_adv = model(images_adv)
                        _, predicted_adv = torch.max(
                            outputs_adv.data,
                            1)  # 取得分最高的那个类 (outputs.data的索引号)
                        correct_adv_test += (
                            predicted_adv == labels).sum().item()
                    # print(total_test)
                    acc_clean_test = correct_clean_test / total_test * 100
                    acc_adv_test = correct_adv_test / total_test * 100
                    print("Clean Test Set Accuracy:%.2f%%" % acc_clean_test)
                    print("Adv Test Set Accuracy:%.2f%%" % acc_adv_test)
                    # 保存测试集准确率至acc.txt文件中
                    f1.write(
                        "Epoch=%03d,Clean Test Set Accuracy= %.2f%%,Adv Test Set Accuracy = %.2f%%"
                        % (epoch, acc_clean_test, acc_adv_test))
                    f1.write("\n")
                    f1.flush()
                    # 记录最佳测试分类准确率并写入best_acc.txt文件中并将准确率达标的模型保存
                    if acc_clean_test > best_acc_test_clean:
                        best_acc_test_clean = acc_clean_test
                    if acc_adv_test > best_acc_test_adv:
                        flag_test = 1
                        best_acc_test_adv = acc_adv_test

                print("Waiting for Testing of Train Set!")
                with torch.no_grad():
                    correct_clean_train = 0
                    correct_adv_train = 0
                    total_train = 0
                    for batchSize, images_clean, images_adv, labels in load_images(
                            args.input_dir_trainSet, batchShape_clean,
                            batchShape_adv):
                        model.eval()
                        images_clean = torch.from_numpy(images_clean).type(
                            torch.FloatTensor).to(device)
                        images_adv = torch.from_numpy(images_adv).type(
                            torch.FloatTensor).to(device)
                        labels = torch.from_numpy(labels).type(
                            torch.LongTensor).to(device)
                        model.to(device)
                        total_train += batchSize
                        # 测试clean数据集的测试集
                        outputs_clean = model(images_clean)
                        _, predicted_clean = torch.max(
                            outputs_clean.data,
                            1)  # 取得分最高的那个类 (outputs.data的索引号)
                        correct_clean_train += (
                            predicted_clean == labels).sum().item()
                        # 测试adv数据集的测试集
                        outputs_adv = model(images_adv)
                        _, predicted_adv = torch.max(
                            outputs_adv.data,
                            1)  # 取得分最高的那个类 (outputs.data的索引号)
                        correct_adv_train += (
                            predicted_adv == labels).sum().item()
                    # print(total_train)
                    acc_clean_train = correct_clean_train / total_train * 100
                    acc_adv_train = correct_adv_train / total_train * 100
                    print("Clean Train Set Accuracy:%.2f%%" % acc_clean_train)
                    print("Adv Train Set Accuracy:%.2f%%" % acc_adv_train)
                    # 保存测试集准确率至acc.txt文件中
                    f1.write(
                        "Epoch=%03d,Clean Train Set Accuracy= %.2f%%,Adv Train Set Accuracy = %.2f%%"
                        % (epoch, acc_clean_train, acc_adv_train))
                    f1.write("\n")
                    f1.flush()
                    # 记录最佳测试分类准确率并写入best_acc.txt文件中并将准确率达标的模型保存
                    if acc_clean_train > best_acc_train_clean:
                        best_acc_train_clean = acc_clean_train
                    if acc_adv_train > best_acc_train_adv:
                        flag_train = 1
                        best_acc_train_adv = acc_adv_train

                if flag_train == 1 and flag_test == 1:
                    if epoch != 0:
                        os.remove(args.model_path + "model_" +
                                  str(best_epoch) + ".pth")
                    f3 = open(args.best_acc_file_path, "w")
                    f3.write(
                        "Epoch=%03d,Clean Test Set Accuracy= %.2f%%,Adv Test Set Accuracy = %.2f%%,"
                        "Clean Train Set Accuracy= %.2f%%,Adv Train Set Accuracy = %.2f%%"
                        % (epoch, acc_clean_test, acc_adv_test,
                           acc_clean_train, acc_adv_train))
                    f3.close()

                    print("Saving model!")
                    torch.save(model.state_dict(),
                               "%s/model_%d.pth" % (args.model_path, epoch))
                    print("Model saved!")
                    best_epoch = epoch

                flag_test = 0
                flag_train = 0

                print("Epoch: %d" % (epoch + 1))
                sum_loss = 0.0
                correct = 0.0
                total = 0.0
                start = time.time()
                batch = 1
                len_batch = len(train_loader)
                for i, data in enumerate(train_loader, 0):
                    # 准备数据
                    s = time.time()
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    model.to(device)
                    model.train()
                    optimizer.zero_grad()

                    # forward + backward
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

                    # 每训练1个batch打印一次loss和准确率
                    sum_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += predicted.eq(labels.data).cpu().sum().item()
                    e = time.time()
                    # print(100.* correct / total)
                    print(
                        "[Epoch:%d/%d] | [Batch:%d/%d] | Loss: %.03f | Acc: %.2f%% | Lr: %.04f | Time: %.03fs"
                        % (epoch + 1, args.epochs, batch, len_batch,
                           sum_loss / batch, correct / total * 100, args.lr,
                           (e - s)))
                    batch += 1
                end = time.time()

                print(
                    "[Epoch:%d/%d] | Loss: %.03f | Test Acc of Clean Test Set: %.2f%% | "
                    "Test Acc of Clean Train Set: %.2f%% | Test Acc of Adv Test Set: %.2f%% | "
                    "Test Acc of Adv Train Set: %.2f%% | Train Acc: %.2f%% | Lr: %.04f | Time: %.03fs"
                    % (epoch + 1, args.epochs, sum_loss /
                       (i + 1), acc_clean_test, acc_clean_train, acc_adv_test,
                       acc_adv_train, correct / total * 100, args.lr,
                       (end - start)))
                f2.write(
                    "[Epoch:%d/%d] | Loss: %.03f | Test Acc of Clean Test Set: %.2f%% | "
                    "Test Acc of Clean Train Set: %.2f%% | Test Acc of Adv Test Set: %.2f%% | "
                    "Test Acc of Adv Train Set: %.2f%% | Train Acc: %.2f%% | Lr: %.04f | Time: %.03fs"
                    % (epoch + 1, args.epochs, sum_loss /
                       (i + 1), acc_clean_test, acc_clean_train, acc_adv_test,
                       acc_adv_train, correct / total * 100, args.lr,
                       (end - start)))
                f2.write("\n")
                f2.flush()
    return best_epoch
def adversarial_learning(best_cla_model_path):
    # Device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # print(device)

    parser = argparse.ArgumentParser("Image classifical!")
    parser.add_argument('--input_dir_trainSet', type=str,
                        default='D:/python_workplace/resnet-AE/checkpoint/Joint_Training/ResNet18/cifar10/train/train.pkl',
                        help='data set dir path')
    parser.add_argument('--input_dir_testSet', type=str,
                        default='D:/python_workplace/resnet-AE/checkpoint/Joint_Training/ResNet18/cifar10/test/test.pkl',
                        help='data set dir path')
    parser.add_argument('--epochs', type=int, default=300, help='Epoch default:50.')
    parser.add_argument('--image_size', type=int, default=32, help='Image Size default:28.')
    parser.add_argument('--batch_size', type=int, default=512, help='Batch_size default:256.')
    parser.add_argument('--lr', type=float, default=0.01, help='learing_rate. Default=0.01')
    parser.add_argument('--num_classes', type=int, default=10, help='num classes')
    parser.add_argument('--model_path', type=str,
                        default='D:/python_workplace/resnet-AE/checkpoint/AdversarialLearning/ResNet18/cifar10/model/',
                        help='Save model path')
    parser.add_argument('--acc_file_path', type=str,
                        default='D:/python_workplace/resnet-AE/checkpoint/AdversarialLearning/ResNet18/cifar10/acc.txt',
                        help='Save accuracy file')
    parser.add_argument('--best_acc_file_path', type=str,
                        default='D:/python_workplace/resnet-AE/checkpoint/'
                                'AdversarialLearning/ResNet18/cifar10/best_acc.txt',
                        help='Save best accuracy file')
    parser.add_argument('--log_file_path', type=str,
                        default='D:/python_workplace/resnet-AE/checkpoint/AdversarialLearning/ResNet18/cifar10/log.txt',
                        help='Save log file')

    args = parser.parse_args()

    # Load model
    model = resnet_cifar.resnet18(pretrained=False)
    model.to(device)
    # summary(model,(3,32,32))
    # print(model)

    # Load pre-trained weights
    model.load_state_dict(torch.load(best_cla_model_path))
    model.to(device)

    # criterion
    criterion = nn.CrossEntropyLoss().to(device)

    # batch_shape
    batch_shape = [args.batch_size, 3, args.image_size, args.image_size]

    best_acc_clean = 0  # 初始化best clean test set accuracy
    best_acc_adv = 0  # 初始化best adv test set accuracy
    best_epoch = 0  # 初始化best epoch
    time_k = time.time()
    print("Start Adversarial Training, Resnet-18!")
    with open(args.acc_file_path, "w") as f1:
        with open(args.log_file_path, "w")as f2:
            for epoch in range(0, args.epochs):
                if epoch + 1 <= 100:
                    args.lr = 0.1
                elif 100 < epoch + 1 <= 200:
                    args.lr = 0.01
                elif 200 < epoch + 1 <= 250:
                    args.lr = 0.001
                else:
                    args.lr = 0.0001

                # Optimization
                optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

                print('Epoch: %d' % (epoch + 1))
                sum_loss = 0.0
                correct = 0.0
                total = 0.0
                batchId = 1
                for batchSize, images_train, labels_train in load_train_set(args.input_dir_trainSet, batch_shape):
                    start = time.time()

                    # data prepare
                    images_train = torch.from_numpy(images_train).type(torch.FloatTensor).to(device)
                    labels_train = torch.from_numpy(labels_train).type(torch.LongTensor).to(device)

                    model.to(device)
                    model.train()
                    optimizer.zero_grad()

                    # forward + backward
                    outputs = model(images_train)
                    loss = criterion(outputs, labels_train)
                    loss.backward()
                    optimizer.step()

                    # 每训练1个batch打印一次loss和准确率
                    sum_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels_train.size(0)
                    correct += predicted.eq(labels_train.data).cpu().sum().item()
                    # print(100.* correct / total)

                    end = time.time()

                    print('[Epoch:%d/%d] | [Batch:%d/%d] | Loss: %.03f | Acc: %.2f%% | Lr: %.04f | Time: %.03fs'
                          % (epoch + 1, args.epochs, batchId, (100000 / args.batch_size) + 1, sum_loss / batchId,
                             correct / total * 100, args.lr, (end - start)))
                    f2.write('[Epoch:%d/%d] | [Batch:%d/%d] | Loss: %.03f | Acc: %.2f%% | Lr: %.4f | Time: %.3fs'
                          % (epoch + 1, args.epochs, batchId, (100000 / args.batch_size) + 1, sum_loss / batchId,
                             correct / total * 100, args.lr, (end - start)))
                    f2.write('\n')
                    f2.flush()
                    batchId += 1

                # 每训练完一个epoch测试一下准确率
                if (epoch + 1) % 50 == 0:
                    print("Waiting for Testing!")
                    with torch.no_grad():
                        # 测试clean test set
                        correct_clean = 0
                        total_clean = 0
                        for batchSize, images_test_clean, labels_test_clean in load_test_set_clean(args.input_dir_testSet,
                                                                                                   batch_shape):
                            model.eval()

                            # data prepare
                            images_test_clean = torch.from_numpy(images_test_clean).type(torch.FloatTensor).to(device)
                            labels_test_clean = torch.from_numpy(labels_test_clean).type(torch.LongTensor).to(device)

                            model.to(device)

                            outputs = model(images_test_clean)
                            # 取得分最高的那个类 (outputs.data的索引号)
                            _, predicted = torch.max(outputs.data, 1)
                            total_clean += labels_test_clean.size(0)
                            correct_clean += (predicted == labels_test_clean).sum().item()
                        print('Clean Test Set Accuracy:%.2f%%' % (correct_clean / total_clean * 100))
                        acc_clean = correct_clean / total_clean * 100

                        # 测试adv test set
                        correct_adv = 0
                        total_adv = 0
                        for batchSize, images_test_adv, labels_test_adv in load_test_set_adv(args.input_dir_testSet,
                                                                                                   batch_shape):
                            model.eval()

                            # data prepare
                            images_test_adv = torch.from_numpy(images_test_adv).type(torch.FloatTensor).to(device)
                            labels_test_adv = torch.from_numpy(labels_test_adv).type(torch.LongTensor).to(device)

                            model.to(device)

                            outputs = model(images_test_adv)
                            # 取得分最高的那个类 (outputs.data的索引号)
                            _, predicted = torch.max(outputs.data, 1)
                            total_adv += labels_test_adv.size(0)
                            correct_adv += (predicted == labels_test_adv).sum().item()
                        print('Adv Test Set Accuracy:%.2f%%' % (correct_adv / total_adv * 100))
                        acc_adv = correct_adv / total_adv * 100

                        # 保存测试集准确率至acc.txt文件中
                        f1.write("Epoch=%03d,Clean Test Set Accuracy= %.2f%%" % (epoch + 1, acc_clean))
                        f1.write('\n')
                        f1.write("Epoch=%03d,Adv Test Set Accuracy= %.2f%%" % (epoch + 1, acc_adv))
                        f1.write('\n')
                        f1.flush()
                        # 记录最佳测试分类准确率并写入best_acc.txt文件中并将准确率达标的模型保存
                        if acc_clean > best_acc_clean and acc_adv > best_acc_adv:
                            if epoch != 49:
                               os.remove(args.model_path + "model_" + str(best_epoch) + ".pth")
                            best_acc_clean = acc_clean
                            best_acc_adv = acc_adv
                            print('Saving model!')
                            torch.save(model.state_dict(), '%s/model_%d.pth' % (args.model_path, epoch + 1))
                            print('Model saved!')
                            f3 = open(args.best_acc_file_path, "w")
                            f3.write("Epoch=%d,Best Accuracy of Clean Set = %.2f%%,Best Accuracy of Adv Set = %.2f%%"
                                     % (epoch + 1, best_acc_clean, best_acc_adv))
                            f3.close()
                            best_epoch = epoch + 1
            time_j = time.time()
            print("Training Finished, Total Epoch = %d, Best Epoch = %d, Best Accuracy of Clean Set = %.2f%%, "
                  "Best Accuracy of Adv Set = %.2f%%, Total Time = %.2f" % (args.epochs, best_epoch, best_acc_clean,
                                                                            best_acc_adv, (time_j - time_k)/3600))