Exemplo n.º 1
0
def attack_loader(args, net):

    # Gradient Clamping based Attack
    if args.attack == "pgd":
        return torchattacks.PGD(model=net,
                                eps=args.eps,
                                alpha=args.eps / args.steps * 2.3,
                                steps=args.steps,
                                random_start=True)

    elif args.attack == "auto":
        return torchattacks.APGD(model=net, eps=args.eps)

    elif args.attack == "fab":
        return torchattacks.FAB(model=net,
                                eps=args.eps,
                                n_classes=args.n_classes)

    elif args.attack == "cw":
        return torchattacks.CW(model=net, c=0.1, lr=0.1, steps=200)

    elif args.attack == "fgsm":
        return torchattacks.FGSM(model=net, eps=args.eps)

    elif args.attack == "bim":
        return torchattacks.BIM(model=net, eps=args.eps, alpha=1 / 255)

    elif args.attack == "deepfool":
        return torchattacks.DeepFool(model=net, steps=10)

    elif args.attack == "sparse":
        return torchattacks.SparseFool(model=net)

    elif args.attack == "gn":
        return torchattacks.GN(model=net, sigma=args.eps)
Exemplo n.º 2
0
    def eval_(self):
        self.model.eval()
        adv_correct = 0
        nat_correct = 0
        total = 0

        if self.args.test_attacker == "PGD":
            self.test_attacker = PGD_attack(
                self.model,
                self.args.epsilon,
                self.args.alpha,
                self.args.attack_steps,
                random_start=self.args.random_start)
        elif self.args.test_attacker == "PGD_mod":
            self.test_attacker = GD(self.model,
                                    self.args.epsilon,
                                    self.args.alpha,
                                    self.args.attack_steps,
                                    random_start=self.args.random_start)
        elif self.args.train_attacker == "PGDL2":
            self.test_attacker = torchattacks.PGDL2(
                self.model,
                self.args.epsilon,
                self.args.alpha,
                self.args.attack_steps,
                random_start=self.args.random_start)
        elif self.args.train_attacker == "CW":
            self.test_attacker = torchattacks.CW(self.model)

        for i, (image, label) in enumerate(self.testloader):
            if torch.cuda.is_available():
                image = image.cuda()
                label = label.cuda()
            image, label = Variable(image), Variable(label)

            adv_image = self.test_attacker(image, label)

            nat_logits = self.model(image)
            adv_logits = self.model(adv_image)

            _, nat_pred = torch.max(nat_logits, dim=1)
            _, adv_pred = torch.max(adv_logits, dim=1)

            nat_correct += (nat_pred == label).sum()
            adv_correct += (adv_pred == label).sum()
            total += label.size(0)

        nat_acc = float(nat_correct) / total
        adv_acc = float(adv_correct) / total

        self._log('Natural Accuracy: {:.3f}'.format(nat_acc))
        self._log('Adv Accuracy: {:.3f}'.format(adv_acc))
        return nat_acc, adv_acc
def get_atk(model, atk_name, eps, steps):

    if atk_name == 'fgsm':
        return torchattacks.FGSM(model, eps=eps)
    elif atk_name == 'bim':
        return torchattacks.BIM(model,
                                eps=eps,
                                steps=steps,
                                alpha=eps / (steps * .5))
    elif atk_name == 'deepfool':
        return torchattacks.DeepFool(model, steps=steps)
    elif atk_name == 'cw':
        return torchattacks.CW(model)
    elif atk_name == 'pgd':
        return torchattacks.PGD(model,
                                eps=eps,
                                steps=steps,
                                alpha=eps / (steps * .5))
    elif atk_name == 'rfgsm':
        return torchattacks.RFGSM(model, eps=eps, alpha=eps)
    elif atk_name == 'auto-attack':
        return torchattacks.AutoAttack(model, eps=eps)
    elif atk_name == 'mifgsm':
        return torchattacks.MIFGSM(model, eps=eps, steps=steps)
    elif atk_name == 'square':
        return torchattacks.Square(model, eps=eps)
    elif atk_name == 'fab':
        return torchattacks.FAB(model, eps=eps)
    elif atk_name == 'one-pixel':
        return torchattacks.OnePixel(model)
    elif atk_name == 'gn':
        return torchattacks.GN(model, sigma=eps)
    elif atk_name == 'apgd':
        return torchattacks.APGD(model, eps=eps)
    elif atk_name == 'eotpgd':
        return torchattacks.EOTPGD(model,
                                   eps=eps,
                                   steps=steps,
                                   alpha=eps / (steps * .5))
    elif atk_name == 'pgddlr':
        return torchattacks.PGDDLR(model,
                                   eps=eps,
                                   steps=steps,
                                   alpha=eps / (steps * .5))
    elif atk_name == 'ffgsm':
        return torchattacks.FFGSM(model, eps=eps, alpha=eps)
    elif atk_name == 'sparsefool':
        return torchattacks.SparseFool(model)

    else:
        print("Attack not valid")
        sys.exit(-1)
Exemplo n.º 4
0
def load_attack(model, attack: str):
    import torchattacks
    if attack == 'PGD':
        return torchattacks.PGD(model, eps=2 / 255, alpha=2 / 255, steps=7)
    elif attack == 'CW':
        return torchattacks.CW(model,
                               targeted=False,
                               c=1,
                               kappa=0,
                               steps=1000,
                               lr=0.01)
    elif attack == 'BIM':
        return torchattacks.BIM(model, eps=4 / 255, alpha=1 / 255, steps=0)
    elif attack == 'FGSM':
        return torchattacks.FGSM(model, eps=1 / 255)
    else:
        raise NotImplementedError()
Exemplo n.º 5
0
def train_epoch(model, loader, optimizer):
    model.train()
    train_loss = []
    bar = tqdm(loader)
    for i, (data, target, face_name, df_method) in enumerate(bar):

        optimizer.zero_grad()

        if args.use_meta:
            data, meta = data
            data, meta, target = data.to(device), meta.to(device), target.to(
                device)
            logits = model(data, meta)
        else:

            # attack 추가
            method = {
                '0_PGD': [20, 70, 2],
                '1_APGD': [20, 70, 2],
                '2_FGSM': [2, 8],
                '3_FFGSM': [4, 7, 10],
                '4_MIFGSM': [3, 6],
                '5_RFGSM': [4, 7, 8],
                '6_BIM': [4, 10, 1],
                '7_CW': [1e-4, 2e-4]
            }

            # 1. original data save
            # img_o
            # 2. small sized data
            # img_s = scaling(image_o, scaling_factor=0.5)

            # out_attack = attack(small_data, target~~~)
            # img_gen = normalize ( scaling ((out_attack - small_data), 1/scaling_factor) + img_o)

            for eps in range(2):
                globals()['atk{}'.format(0)] = torchattacks.PGD(
                    model,
                    eps=method['0_PGD'][eps] / 255,
                    alpha=method['0_PGD'][-1] / 255,
                    steps=4)
                globals()['atk{}'.format(1)] = torchattacks.APGD(
                    model,
                    eps=method['1_APGD'][eps] / 255,
                    alpha=method['1_APGD'][-1] / 255,
                    steps=4)
                globals()['atk{}'.format(2)] = torchattacks.FGSM(
                    model, eps=method['2_FGSM'][eps] / 255)
                globals()['atk{}'.format(3)] = torchattacks.FFGSM(
                    model,
                    eps=method['3_FFGSM'][eps] / 255,
                    alpha=method['3_FFGSM'][-1] / 255)
                globals()['atk{}'.format(4)] = torchattacks.MIFGSM(
                    model, eps=method['4_MIFGSM'][eps] / 255, steps=4)
                globals()['atk{}'.format(5)] = torchattacks.RFGSM(
                    model,
                    eps=method['5_RFGSM'][eps] / 255,
                    alpha=method['5_RFGSM'][-1] / 255,
                    steps=4)
                globals()['atk{}'.format(6)] = torchattacks.BIM(
                    model,
                    eps=method['6_BIM'][eps] / 255,
                    alpha=method['6_BIM'][-1] / 255)
                globals()['atk{}'.format(7)] = torchattacks.CW(
                    model, c=method['7_CW'][eps], steps=10)

                for count in range(8):
                    # globals()['data_atk{}'.format(i)]
                    globals()['data_atk{}'.format(count)] = globals()[
                        'atk{}'.format(count)](data, (target + 1) % 2)
                    globals()['data_atk{}'.format(count)], target = (
                        globals()['data_atk{}'.format(count)]
                    ).to(device), target.to(device)
                    logits = model(globals()['data_atk{}'.format(count)])
                    globals()['data_atk{}'.format(count)] = (
                        globals()['data_atk{}'.format(count)]).cpu().numpy()

                method_keys = list(method.keys())

                bat_size = args.batch_size
                for j in range(bat_size):
                    for save_cnt in range(8):
                        globals()['im{}'.format(save_cnt)] = (globals()[
                            'data_atk{}'.format(save_cnt)])[j, :, :, :]
                        # imsave(
                        #     f"./confirm_attack2img/AE-classification/{method_keys[save_cnt]}/"
                        #     f"{target[j]}_{face_name[j]}_{i * bat_size + j}_{method_keys[save_cnt]}_eps{method[method_keys[save_cnt]][eps]}_wsbs.png",
                        #     np.transpose(globals()['im{}'.format(save_cnt)], (1, 2, 0)))
                        imsave(
                            f"./confirm_attack2img/AE-classification/train/"
                            f"{target[j]}_{face_name[j]}_{df_method[j]}_{i * bat_size + j}_{method_keys[save_cnt]}_eps{method[method_keys[save_cnt]][eps]}_wsbs.png",
                            np.transpose(globals()['im{}'.format(save_cnt)],
                                         (1, 2, 0)))

        loss = criterion(logits, target)

        if not args.use_amp:
            loss.backward()
        # else:
        #     with amp.scale_loss(loss, optimizer) as scaled_loss:
        #         scaled_loss.backward()

        if args.image_size in [896, 576]:
            # 그라디언트가 너무 크면 값을 0.5로 잘라준다 (max_grad_norm=0.5)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

        # gradient accumulation (메모리 부족할때)
        if args.accumulation_step:
            if (i + 1) % args.accumulation_step == 0:
                optimizer.step()
                # optimizer.zero_grad()
        else:
            optimizer.step()
            # optimizer.zero_grad()

        loss_np = loss.detach().cpu().numpy()
        train_loss.append(loss_np)
        smooth_loss = sum(train_loss[-100:]) / min(len(train_loss), 100)
        bar.set_description('loss: %.5f, smooth_loss: %.5f' %
                            (loss_np, smooth_loss))

    train_loss = np.mean(train_loss)
    return train_loss
Exemplo n.º 6
0
def train_epoch(model, loader, optimizer):
    model.train()
    train_loss = []
    bar = tqdm(loader)
    for i, (data, target, face_name) in enumerate(bar):

        optimizer.zero_grad()

        if args.use_meta:
            data, meta = data
            data, meta, target = data.to(device), meta.to(device), target.to(device)
            logits = model(data, meta)
        else:

            # attack 추가
            method = {
                '1_PGD': [20,70,2],
                '2_APGD':[20,70,2],
                '3_FGSM': [2,8],
                '4_FFGSM': [4,7,10],
                '5_MIFGSM': [3,6],
                '6_RFGSM': [4,7,8],
                '7_BIM':[4,10,1],
                '8_CW':[1e-4, 2e-4]}

            #TODO: dataset에 original image와 attacked image모두 만들기
            # 1. original data save         img_o
            # 2. small sized data           img_s = scaling(image_o, scaling_factor=0.5)
            # out_attack = attack(small_data, target~~~)
            # img_gen = normalize ( scaling ((out_attack - small_data), 1/scaling_factor) + img_o)


            scaling_factor = 0.5
            img_origin = np.transpose(data.cpu().numpy()[i, :, :, :], (1, 2, 0))
            img_small = cv2.resize(img_origin, dsize=(0, 0),
                                    fx=scaling_factor,
                                    fy=scaling_factor)   #,interpolation=cv2.INTER_AREA


            for eps in range(2):
                globals()['atk{}'.format(1)] = torchattacks.PGD(model, eps=method['1_PGD'][eps] / 255, alpha=method['1_PGD'][-1] / 255, steps=4)
                globals()['atk{}'.format(2)] = torchattacks.APGD(model, eps=method['2_APGD'][eps] / 255, alpha=method['2_APGD'][-1] / 255, steps=4)
                globals()['atk{}'.format(3)] = torchattacks.FGSM(model, eps=method['3_FGSM'][eps] / 255)
                globals()['atk{}'.format(4)] = torchattacks.FFGSM(model, eps=method['4_FFGSM'][eps] / 255, alpha=method['4_FFGSM'][-1] / 255)
                globals()['atk{}'.format(5)] = torchattacks.MIFGSM(model, eps=method['5_MIFGSM'][eps] / 255, steps=4)
                globals()['atk{}'.format(6)] = torchattacks.RFGSM(model, eps=method['6_RFGSM'][eps] / 255, alpha=method['6_RFGSM'][-1] / 255, steps=4)
                globals()['atk{}'.format(7)] = torchattacks.BIM(model, eps=method['7_BIM'][eps] / 255, alpha=method['7_BIM'][-1] / 255)
                globals()['atk{}'.format(8)] = torchattacks.CW(model, c= method['8_CW'][eps], steps=10)

                for count in range(1,9):
                    # regularization
                    # torch.clamp(images + delta, min=0, max=1).detach()
                    # torch.from_numpy(img_small)
                    out_attack = globals()['atk{}'.format(count)](torch.from_numpy(img_small), (target + 1) % 2)
                    img_gen = torch.clamp(cv2.resize(out_attack-img_small,dsize=(0,0),fx=1/scaling_factor, fy=1/scaling_factor), min=0, max=1).detach() + img_origin

                    globals()['data_atk{}'.format(count)] = torch.from_numpy(img_gen)
                    globals()['data_atk{}'.format(count)], target = (globals()['data_atk{}'.format(count)]).to(device), target.to(device)
                    logits = model(globals()['data_atk{}'.format(count)])
                    globals()['data_atk{}'.format(count)] = (globals()['data_atk{}'.format(count)]).cpu().numpy()


                method_keys = list(method.keys())

                bat_size = args.batch_size
                for j in range(bat_size):
                    # save original image
                    # im0 = data.cpu().numpy()[j, :, :, :]
                    # imsave(
                    #     f"./confirm_attack2img/AE-real_fake/0_original/"
                    #     f"{target[j]}_{face_name[j]}_{i * bat_size + j}_0_wsbs.png",
                    #     np.transpose(im0, (1, 2, 0)))

                    # save attacked image
                    for save_cnt in range(1,9):
                        globals()['im{}'.format(save_cnt)] = (globals()['data_atk{}'.format(save_cnt)])[j, :, :, :]
                        # imsave(
                        #     f"./confirm_attack2img/AE-real_fake/{method_keys[save_cnt]}/"
                        #     f"{target[j]}_{face_name[j]}_{i * bat_size + j}_{method_keys[save_cnt]}_eps{method[method_keys[save_cnt]][eps]}_wsbs.png",
                        #     np.transpose(globals()['im{}'.format(save_cnt)], (1, 2, 0)))
                        imsave(
                            f"./confirm_attack2img/AE-real_fake/train/"
                            f"{target[j]}_{face_name[j]}_{i * bat_size + j}_{method_keys[save_cnt]}_eps{method[method_keys[save_cnt]][eps]}_wsbs.png",
                            np.transpose(globals()['im{}'.format(save_cnt)], (1, 2, 0)))


        loss = criterion(logits, target)

        if not args.use_amp:
            loss.backward()
        # else:
        #     with amp.scale_loss(loss, optimizer) as scaled_loss:
        #         scaled_loss.backward()

        if args.image_size in [896, 576]:
            # 그라디언트가 너무 크면 값을 0.5로 잘라준다 (max_grad_norm=0.5)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

        # gradient accumulation (메모리 부족할때)
        if args.accumulation_step:
            if (i + 1) % args.accumulation_step == 0:
                optimizer.step()
                # optimizer.zero_grad()
        else:
            optimizer.step()
            # optimizer.zero_grad()

        loss_np = loss.detach().cpu().numpy()
        train_loss.append(loss_np)
        smooth_loss = sum(train_loss[-100:]) / min(len(train_loss), 100)
        bar.set_description('loss: %.5f, smooth_loss: %.5f' % (loss_np, smooth_loss))

    train_loss = np.mean(train_loss)
    return train_loss
Exemplo n.º 7
0
def main(args):
    classifier = torch.load('./saved/resnet50_cifar10.pth')
    model = blackbox(classifier)
    train_set, test_set = datasets.CIFAR10(root='/home/taejoon/data/CIFAR10',
                                           normalize=False)
    num_classes = 10
    batch_size = 64

    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")

    if args.attack == 'FGSM':
        attack = torchattacks.FGSM(model, eps=float(Fraction(args.eps)))
    elif args.attack == 'PGD':
        attack = torchattacks.PGD(model, eps=float(Fraction(args.eps)))
    elif args.attack == 'DeepFool':
        attack = torchattacks.DeepFool(model)
    elif args.attack == 'CW':
        attack = torchattacks.CW(model, kappa=50)
    elif args.attack == 'EOT':
        attack = torchattacks.APGD(model, eps=float(Fraction(args.eps)))
    # print (dict(torchattacks))

    model = model.to(device)
    model.eval()
    criterion = nn.CrossEntropyLoss()

    total_loss = 0
    clean_acc = 0
    adv_acc = 0
    cp_acc = 0
    tucker_acc = 0
    cp_rank = int(args.cp_rank)
    tucker_rank = [3, int(args.tucker_rank), int(args.tucker_rank)]

    for i, (images, labels) in enumerate(test_loader):

        images = images.to(device)
        labels = labels.to(device)

        preds = model(images)
        _, prediction = preds.max(dim=1, keepdim=False)
        clean_acc += (prediction == labels).sum()
        adversarial_images = attack(images, labels)

        preds = model(adversarial_images)
        _, prediction = preds.max(dim=1, keepdim=False)
        adv_acc += (prediction == labels).sum()

        adversarial_images = adversarial_images.detach().cpu().numpy()
        # Cp_reconstructions = np.zeros_like(adversarial_images)
        Tucker_reconstructions = np.zeros_like(adversarial_images)

        for j, adv in enumerate(adversarial_images):
            # factors = tl.decomposition.parafac(adv,rank = cp_rank,init = 'random',tol = 1e-4,random_state = np.random.RandomState())
            # cp_reconstruction = tl.kruskal_to_tensor(factors)
            # Cp_reconstructions[j] = cp_reconstruction

            core, tucker_factors = tucker(adv,
                                          ranks=tucker_rank,
                                          init='random',
                                          tol=1e-4,
                                          random_state=np.random.RandomState())
            tucker_reconstruction = tl.tucker_to_tensor((core, tucker_factors))
            Tucker_reconstructions[j] = tucker_reconstruction

        # Cp_reconstructions = torch.from_numpy(Cp_reconstructions).to(device,dtype = torch.float)
        # preds = model(Cp_reconstructions)
        # _,prediction = preds.max(dim = 1,keepdim = False)

        # cp_acc += (prediction == labels).sum()

        Tucker_reconstructions = torch.from_numpy(Tucker_reconstructions).to(
            device, dtype=torch.float)
        preds = model(Tucker_reconstructions)
        _, prediction = preds.max(dim=1, keepdim=False)

        tucker_acc += (prediction == labels).sum()
        # if (i%25==0):
        #     print ("# data %f, clean_acc %f, adv_acc %f, cp_acc %f, tucker_acc %f" %((i+1)*batch_size,clean_acc,adv_acc,cp_acc,tucker_acc))
        if (i % 5 == 0):
            print("# data %f, clean_acc %f, adv_acc %f, tucker_acc %f" %
                  ((i + 1) * batch_size, clean_acc, adv_acc, tucker_acc))

    clean_acc = float(clean_acc) / len(test_set)
    adv_acc = float(adv_acc) / len(test_set)
    # cp_acc = float(cp_acc)/len(test_set)
    tucker_acc = float(tucker_acc) / len(test_set)

    print("loss", total_loss, "adv acc", adv_acc, "tucker acc", tucker_acc)
    f = open("./result/20200706.txt", 'a')

    data = "Attack Method : " + args.attack + "\n"
    data += "epsilon : " + args.eps + "\n"
    data += "cp_rank : " + str(cp_rank) + ", tucker_rank : " + str(
        tucker_rank) + "\n"
    # data += "clean : " + str(clean_acc) + " adv_acc : " + str(adv_acc) + " cp_acc : " + str(cp_acc) + " tucker_acc : " + str(tucker_acc) + "\n"
    data += "clean : " + str(clean_acc) + " adv_acc : " + str(
        adv_acc) + " cp_acc : " + "None" + " tucker_acc : " + str(
            tucker_acc) + "\n"
    data += "=" * 50
    f.write(data + '\n')
Exemplo n.º 8
0
    def train(self):
        adv_losses = AverageMeter()
        nat_losses = AverageMeter()

        best_adv_acc = 0
        best_nat_acc = 0

        if self.epoch < 80:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = 0.1

        if self.epoch >= 80 and self.epoch < 120:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = 0.01

        if self.epoch >= 120:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = 0.001

        while self.epoch < self.args.epochs:
            self.model.train()
            if self.args.train_attacker == "PGD":
                self.train_attacker = PGD_attack(
                    self.model,
                    self.args.epsilon,
                    self.args.alpha,
                    self.args.attack_steps,
                    random_start=self.args.random_start)
            elif self.args.train_attacker == "PGD_mod":
                self.train_attacker = GD(self.model,
                                         self.args.epsilon,
                                         self.args.alpha,
                                         self.args.attack_steps,
                                         random_start=self.args.random_start)

            elif self.args.train_attacker == "PGDL2":
                self.train_attacker = torchattacks.FAB(
                    self.model,
                    self.args.epsilon,
                    self.args.alpha,
                    self.args.attack_steps,
                    random_start=self.args.random_start)

            elif self.args.train_attacker == "CW":
                self.train_attacker = torchattacks.CW(self.model)

            total = 0
            adv_correct = 0
            nat_correct = 0

            for i, (image, label) in enumerate(self.trainloader):
                if torch.cuda.is_available():
                    image = image.cuda()
                    label = label.cuda()

                self.optimizer.zero_grad()
                image, label = Variable(image), Variable(label)
                x_adv = self.train_attacker(image, label)

                #compute output
                adv_logits = self.model(x_adv)
                nat_logits = self.model(image)
                adv_loss = self.criterion(adv_logits, label)
                nat_loss = self.criterion(nat_logits, label)
                # loss = 0.5*adv_loss+0.5*nat_loss
                loss = adv_loss
                loss.backward()
                self.optimizer.step()

                #checking for attack-success acc
                _, adv_pred = torch.max(adv_logits, dim=1)
                adv_correct += (adv_pred == label).sum()
                total += label.size(0)

                #checking for natural-success acc
                _, nat_pred = torch.max(nat_logits, dim=1)
                nat_correct += (nat_pred == label).sum()

                adv_losses.update(adv_loss.data.item(), x_adv.size(0))
                nat_losses.update(nat_loss.data.item(), image.size(0))

            self.epoch += 1

            nat_acc = float(nat_correct) / total
            adv_acc = float(adv_correct) / total
            mess = "{}th Epoch, nat Acc: {:.3f}, adv Acc: {:.3f}, Loss: {:.3f}".format(
                self.epoch, nat_acc, adv_acc, loss.item())
            self._log(mess)
            self._save_checkpoint('checkpoint.pth')

            # Evaluation
            nat_acc, adv_acc = self.eval_()

            if nat_acc + adv_acc > best_adv_acc + best_nat_acc:
                best_adv_acc = adv_acc
                best_nat_acc = nat_acc
                self._save_checkpoint('best_checkpoint.pth')
                self._log('Best Test Accuracy: {:.3f}/{:.3f}'.format(
                    best_adv_acc, best_nat_acc))
        self._log('=======Best Test Accuracy: {:.3f}/{:.3f}======'.format(
            best_adv_acc, best_nat_acc))
def ta_cw(x, y, model, c=0.001, kappa=0, steps=1000, lr=0.01):
    print(c)
    attack = torchattacks.CW(model, c=c, kappa=kappa, steps=steps, lr=lr)
    advs = attack(x, y)
    return advs
Exemplo n.º 10
0
val_loader = create_test_dataset(args.batch_size)

# 设置是否对抗训练
if args.at is True:
    train_attack = IPGD(net, eps=8 / 255.0, sigma=2 / 255.0, nb_iter=10, norm=np.inf, device=device,
                          drop_prob=args.drop_prob)
else:
    train_attack = None

# 设置攻击方法
if args.val_method == 'fgsm':
    print('use FGSM attack')
    val_attack = torchattacks.FGSM(net)
elif args.val_method == 'cw':
    print('use cw attack')
    val_attack = torchattacks.CW(net)
else:
    print('use pgd attack')
    val_attack = IPGD(net, eps=8 / 255.0, sigma=2 / 255.0, nb_iter=20, norm=np.inf, device=device, drop_prob=0)

# 创建终端logger
logger = create_logger('./log', 'train', 'info')

# 当前epoch
now_epoch = 0

# 断点
# if args.auto_continue:
#     args.resume = os.path.join(args.model_dir, 'last.checkpoint')
if args.resume is not None and os.path.isfile(args.resume):
    now_epoch = load_checkpoint(args.resume, net, optimizer, lr_scheduler)