コード例 #1
0
ファイル: advGAN.py プロジェクト: dssrgu/atk_def_gan
    def __init__(self,
                 device,
                 model,
                 model_num_labels,
                 image_nc,
                 box_min,
                 box_max,
                 eps,
                 pgd_iter,
                 models_path,
                 out_path,
                 model_name,
                 writer,
                 E_lr,
                 defG_lr):
        output_nc = image_nc
        self.device = device
        self.model_num_labels = model_num_labels
        self.model = model
        self.input_nc = image_nc
        self.output_nc = output_nc
        self.box_min = box_min
        self.box_max = box_max
        self.eps = eps
        self.pgd_iter = pgd_iter
        self.models_path = models_path
        self.out_path = out_path
        self.model_name = model_name
        self.writer = writer
        self.E_lr = E_lr
        self.defG_lr = defG_lr

        self.en_input_nc = image_nc
        self.E = models.Encoder(image_nc).to(device)
        self.defG = models.Generator(adv=False).to(device)
        self.pgd = PGD(self.model, self.E, self.defG, self.device, self.eps)

        # initialize all weights
        self.E.apply(weights_init)
        self.defG.apply(weights_init)

        # initialize optimizers
        self.optimizer_E = torch.optim.Adam(self.E.parameters(),
                                            lr=self.E_lr)
        self.optimizer_defG = torch.optim.Adam(self.defG.parameters(),
                                               lr=self.defG_lr)
コード例 #2
0
ファイル: advGAN.py プロジェクト: dssrgu/atk_def_gan
class AdvGAN_Attack:
    def __init__(self,
                 device,
                 model,
                 model_num_labels,
                 image_nc,
                 box_min,
                 box_max,
                 eps,
                 pgd_iter,
                 models_path,
                 out_path,
                 model_name,
                 writer,
                 E_lr,
                 defG_lr):
        output_nc = image_nc
        self.device = device
        self.model_num_labels = model_num_labels
        self.model = model
        self.input_nc = image_nc
        self.output_nc = output_nc
        self.box_min = box_min
        self.box_max = box_max
        self.eps = eps
        self.pgd_iter = pgd_iter
        self.models_path = models_path
        self.out_path = out_path
        self.model_name = model_name
        self.writer = writer
        self.E_lr = E_lr
        self.defG_lr = defG_lr

        self.en_input_nc = image_nc
        self.E = models.Encoder(image_nc).to(device)
        self.defG = models.Generator(adv=False).to(device)
        self.pgd = PGD(self.model, self.E, self.defG, self.device, self.eps)

        # initialize all weights
        self.E.apply(weights_init)
        self.defG.apply(weights_init)

        # initialize optimizers
        self.optimizer_E = torch.optim.Adam(self.E.parameters(),
                                            lr=self.E_lr)
        self.optimizer_defG = torch.optim.Adam(self.defG.parameters(),
                                               lr=self.defG_lr)

    # generate images for training
    def gen_images(self, x, labels):

        # pgd image
        pgd_images = self.pgd.perturb(x, labels, itr=1)

        # def(pgd) image
        def_pgd_images = self.defG(self.E(pgd_images)) + pgd_images
        def_pgd_images = torch.clamp(def_pgd_images, self.box_min, self.box_max)

        # make def(nat) image
        def_images = self.defG(self.E(x)) + x
        def_images = torch.clamp(def_images, self.box_min, self.box_max)

        return pgd_images, def_pgd_images, def_images

    # performance tester
    def test(self):

        self.E.eval()
        self.defG.eval()

        test_full(self.device, self.model, self.E, self.defG, self.eps,
                  self.out_path, self.model_name, label_count=True, save_img=True)

        self.E.train()
        self.defG.train()

    # train single batch
    def train_batch(self, x, labels):

        # optimize E, def
        for i in range(1):

            # clear grad
            self.optimizer_E.zero_grad()
            self.optimizer_defG.zero_grad()

            pgd_images, def_pgd_images, def_images = self.gen_images(x, labels)

            # def(pgd) loss
            logits_def_pgd = self.model(def_pgd_images)
            loss_def_pgd = F.cross_entropy(logits_def_pgd, labels)

            # def(nat) loss
            logits_def = self.model(def_images)
            loss_def = F.cross_entropy(logits_def, labels)

            # backprop
            loss = loss_def + loss_def_pgd

            loss.backward()

            self.optimizer_E.step()
            self.optimizer_defG.step()

        # pgd performance check

        self.E.eval()
        self.defG.eval()

        pgd_acc_li = []
        pgd_nat_acc_li = []

        pgd_nat_images = self.pgd.perturb(x, labels, itr=0)

        # obsufcated check
        pgd_nat_images = self.defG(self.E(pgd_nat_images)) + pgd_nat_images
        pgd_nat_images = torch.clamp(pgd_nat_images, self.box_min, self.box_max)

        pred = torch.argmax(self.model(pgd_images), 1)
        num_correct = torch.sum(pred == labels, 0)
        pgd_acc = num_correct.item()/len(labels)

        pgd_acc_li.append(pgd_acc)

        pred = torch.argmax(self.model(pgd_nat_images), 1)
        num_correct = torch.sum(pred == labels, 0)
        pgd_nat_acc = num_correct.item()/len(labels)

        pgd_nat_acc_li.append(pgd_nat_acc)

        self.E.train()
        self.defG.train()

        return pgd_acc_li, pgd_nat_acc_li, torch.sum(loss).item()

    # main training function
    def train(self, train_dataloader, epochs):
        for epoch in range(1, epochs+1):

            if epoch == 50:
                self.optimizer_E = torch.optim.Adam(self.E.parameters(),
                                                    lr=self.E_lr/10)
                self.optimizer_defG = torch.optim.Adam(self.defG.parameters(),
                                                       lr=self.defG_lr/10)
            if epoch == 80:
                self.optimizer_E = torch.optim.Adam(self.E.parameters(),
                                                    lr=self.E_lr/100)
                self.optimizer_defG = torch.optim.Adam(self.defG.parameters(),
                                                       lr=self.defG_lr/100)

            loss_sum = 0
            pgd_acc_li_sum = []
            pgd_nat_acc_li_sum = []

            for i, data in enumerate(train_dataloader, start=0):
                images, labels = data
                images, labels = images.to(self.device), labels.to(self.device)

                pgd_acc_li_batch, pgd_nat_acc_li_batch, loss_batch = \
                    self.train_batch(images, labels)
                loss_sum += loss_batch
                pgd_acc_li_sum.append(pgd_acc_li_batch)
                pgd_nat_acc_li_sum.append(pgd_nat_acc_li_batch)

            # print statistics
            num_batch = len(train_dataloader)
            print("epoch %d:\nloss_E: %.5f" %
                  (epoch, loss_sum/num_batch))

            pgd_acc_li_sum = np.mean(np.array(pgd_acc_li_sum), axis=0)
            for idx in range(len(self.pgd_iter)):
                print("pgd iter %d acc.: %.5f" % (self.pgd_iter[idx], pgd_acc_li_sum[idx]))
            
            pgd_nat_acc_li_sum = np.mean(np.array(pgd_nat_acc_li_sum), axis=0)
            for idx in range(len(self.pgd_iter)):
                print("pgd nat iter %d acc.: %.5f" % (self.pgd_iter[idx], pgd_nat_acc_li_sum[idx]))
            print()

            # write to tensorboard
            if self.writer:
                self.writer.add_scalar('loss', loss_sum/num_batch, epoch)
                for idx in range(len(self.pgd_iter)):
                    self.writer.add_scalar('pgd_acc_%d' % (self.pgd_iter[idx]), pgd_acc_li_sum[idx], epoch)
                    self.writer.add_scalar('pgd_nat_acc_%d' % (self.pgd_iter[idx]), pgd_nat_acc_li_sum[idx], epoch)

            # save generator
            if epoch%20==0:
                E_file_name = self.models_path + self.model_name + 'E_epoch_' + str(epoch) + '.pth'
                defG_file_name = self.models_path + self.model_name + 'defG_epoch_' + str(epoch) + '.pth'
                torch.save(self.E.state_dict(), E_file_name)
                torch.save(self.defG.state_dict(), defG_file_name)

        if self.writer:
            self.writer.close()

        # test performance
        self.test()
コード例 #3
0
def tester(dataset,
           dataloader,
           device,
           target_model,
           E,
           defG,
           advG,
           eps,
           out_path,
           model_name,
           label_count=True,
           save_img=False):

    # load PGD
    pgd = PGD(target_model, E, defG, device, eps)

    num_correct_adv = 0
    num_correct_pgd = 0
    num_correct_def_adv = 0
    num_correct_def = 0
    num_correct_def_pgd = 0
    num_correct_def_pgd_nat = 0
    num_correct = 0

    test_img_full = []
    adv_img_full = []
    pgd_img_full = []
    pgd_nat_img_full = []
    def_img_full = []
    def_adv_img_full = []
    def_pgd_img_full = []
    def_pgd_nat_img_full = []

    pred_adv_full = []
    pred_pgd_full = []
    pred_def_pgd_full = []
    pred_def_pgd_nat_full = []

    for i, data in enumerate(dataloader, 0):

        # load images
        test_img, test_label = data
        test_img, test_label = test_img.to(device), test_label.to(device)

        target_labels = torch.randint_like(test_label, 0, 10)
        target_one_hot = torch.eye(10, device=device)[target_labels]
        target_one_hot = target_one_hot.view(-1, 10, 1, 1)

        # prep images
        x_encoded = E(test_img)

        adv_noise = advG(x_encoded, target_one_hot)
        adv_img = adv_noise * eps + test_img
        adv_img = torch.clamp(adv_img, 0, 1)

        def_adv_noise = defG(E(adv_img))
        def_adv_img = def_adv_noise + adv_img
        def_adv_img = torch.clamp(def_adv_img, 0, 1)

        def_noise = defG(E(test_img))
        def_img = def_noise + test_img
        def_img = torch.clamp(def_img, 0, 1)

        pgd_img = pgd.perturb(test_img, test_label, itr=1)

        def_pgd_noise = defG(E(pgd_img))
        def_pgd_img = def_pgd_noise + pgd_img
        def_pgd_img = torch.clamp(def_pgd_img, 0, 1)

        pgd_nat_img = pgd.perturb(test_img, test_label, itr=0)

        def_pgd_nat_noise = defG(E(pgd_nat_img))
        def_pgd_nat_img = def_pgd_nat_noise + pgd_nat_img
        def_pgd_nat_img = torch.clamp(def_pgd_nat_img, 0, 1)

        # calculate acc.
        pred = torch.argmax(normalized_eval(test_img, target_model), 1)
        pred_adv = torch.argmax(normalized_eval(adv_img, target_model), 1)
        pred_pgd = torch.argmax(normalized_eval(pgd_nat_img, target_model), 1)
        pred_def_adv = torch.argmax(normalized_eval(def_adv_img, target_model),
                                    1)
        pred_def = torch.argmax(normalized_eval(def_img, target_model), 1)
        pred_def_pgd = torch.argmax(normalized_eval(def_pgd_img, target_model),
                                    1)
        pred_def_pgd_nat = torch.argmax(
            normalized_eval(def_pgd_nat_img, target_model), 1)

        num_correct += torch.sum(pred == test_label, 0)
        num_correct_adv += torch.sum(pred_adv == test_label, 0)
        num_correct_pgd += torch.sum(pred_pgd == test_label, 0)
        num_correct_def_adv += torch.sum(pred_def_adv == test_label, 0)
        num_correct_def += torch.sum(pred_def == test_label, 0)
        num_correct_def_pgd += torch.sum(pred_def_pgd == test_label, 0)
        num_correct_def_pgd_nat += torch.sum(pred_def_pgd_nat == test_label, 0)
        '''
        l_one = np.mean(np.abs(adv_noise.cpu().detach().numpy()))
        np.save('./out/noise/adv_noise.npy', adv_noise.cpu().detach().numpy())
        print('l-one of adv noise:%f' % (l_one))
        l_one = np.mean(np.abs(def_noise.cpu().detach().numpy()))
        np.save('./out/noise/def_noise.npy', def_noise.cpu().detach().numpy())
        print('l-one of def noise:%f' % (l_one))
        l_one = np.mean(np.abs(def_adv_noise.cpu().detach().numpy()))
        np.save('./out/noise/def_adv_noise.npy', def_adv_noise.cpu().detach().numpy())
        print('l-one of def(adv) noise:%f' % (l_one))
        l_one = np.mean(np.abs(def_pgd_noise.cpu().detach().numpy()))
        np.save('./out/noise/def_pgd_noise.npy', def_pgd_noise.cpu().detach().numpy())
        print('l-one of def(pgd) noise:%f' % (l_one))
        l_one = np.mean(np.abs(def_pgd_nat_noise.cpu().detach().numpy()))
        np.save('./out/noise/def_pgd_nat_noise.npy', def_pgd_nat_noise.cpu().detach().numpy())
        print('l-one of def(pgd_nat) noise:%f' % (l_one))
        exit()
        '''

        if label_count:
            pred_adv_full.append(pred_adv)
            pred_pgd_full.append(pred_pgd)
            pred_def_pgd_full.append(pred_def_pgd)
            pred_def_pgd_nat_full.append(pred_def_pgd_nat)

        if save_img and i < 1:
            test_img_full.append(test_img)
            adv_img_full.append(adv_img)
            pgd_img_full.append(pgd_img)
            pgd_nat_img_full.append(pgd_nat_img)
            def_img_full.append(def_img)
            def_adv_img_full.append(def_adv_img)
            def_pgd_img_full.append(def_pgd_img)
            def_pgd_nat_img_full.append(def_pgd_nat_img)

    print('num_correct(nat): ', num_correct.item())
    print('num_correct(adv): ', num_correct_adv.item())
    print('num_correct(pgd): ', num_correct_pgd.item())
    print('num_correct(def(adv)): ', num_correct_def_adv.item())
    print('num_correct(def(nat)): ', num_correct_def.item())
    print('num_correct(def(pgd)): ', num_correct_def_pgd.item())
    print('num_correct(def(pgd_nat)): ', num_correct_def_pgd_nat.item())
    print()

    print('accuracy of nat imgs: %f' % (num_correct.item() / len(dataset)))
    print('accuracy of adv imgs: %f' % (num_correct_adv.item() / len(dataset)))
    print('accuracy of pgd imgs: %f' % (num_correct_pgd.item() / len(dataset)))
    print('accuracy of def(adv) imgs: %f' %
          (num_correct_def_adv.item() / len(dataset)))
    print('accuracy of def(nat) imgs: %f' %
          (num_correct_def.item() / len(dataset)))
    print('accuracy of def(pgd) imgs: %f' %
          (num_correct_def_pgd.item() / len(dataset)))
    print('accuracy of def(pgd_nat) imgs: %f' %
          (num_correct_def_pgd_nat.item() / len(dataset)))
    print()

    l_inf = np.amax(
        np.abs(adv_img.cpu().detach().numpy() -
               test_img.cpu().detach().numpy()))
    print('l-inf of adv imgs:%f' % (l_inf))
    l_inf = np.amax(
        np.abs(def_img.cpu().detach().numpy() -
               test_img.cpu().detach().numpy()))
    print('l-inf of def imgs:%f' % (l_inf))
    l_inf = np.amax(
        np.abs(def_adv_img.cpu().detach().numpy() -
               test_img.cpu().detach().numpy()))
    print('l-inf of def(adv) imgs:%f' % (l_inf))
    l_inf = np.amax(
        np.abs(def_pgd_img.cpu().detach().numpy() -
               test_img.cpu().detach().numpy()))
    print('l-inf of def(pgd) imgs:%f' % (l_inf))
    l_inf = np.amax(
        np.abs(def_pgd_nat_img.cpu().detach().numpy() -
               test_img.cpu().detach().numpy()))
    print('l-inf of def(pgd_nat) imgs:%f' % (l_inf))

    print()

    l_one = np.mean(np.abs(adv_noise.cpu().detach().numpy()))
    np.save('./out/noise/adv_noise.npy', adv_noise.cpu().detach().numpy())
    print('l-one of adv noise:%f' % (l_one))
    l_one = np.mean(np.abs(def_noise.cpu().detach().numpy()))
    np.save('./out/noise/def_noise.npy', def_noise.cpu().detach().numpy())
    print('l-one of def noise:%f' % (l_one))
    l_one = np.mean(np.abs(def_adv_noise.cpu().detach().numpy()))
    np.save('./out/noise/def_adv_noise.npy',
            def_adv_noise.cpu().detach().numpy())
    print('l-one of def(adv) noise:%f' % (l_one))
    l_one = np.mean(np.abs(def_pgd_noise.cpu().detach().numpy()))
    np.save('./out/noise/def_pgd_noise.npy',
            def_pgd_noise.cpu().detach().numpy())
    print('l-one of def(pgd) noise:%f' % (l_one))
    l_one = np.mean(np.abs(def_pgd_nat_noise.cpu().detach().numpy()))
    np.save('./out/noise/def_pgd_nat_noise.npy',
            def_pgd_nat_noise.cpu().detach().numpy())
    print('l-one of def(pgd_nat) noise:%f' % (l_one))

    print()

    if label_count:
        pred_adv_full = torch.cat(pred_adv_full)
        preds = pred_adv_full.cpu().detach().numpy()
        print('label counts in adv imgs:')
        print(np.unique(preds, return_counts=True))

        pred_pgd_full = torch.cat(pred_pgd_full)
        preds = pred_pgd_full.cpu().detach().numpy()
        print('label counts in pgd imgs:')
        print(np.unique(preds, return_counts=True))

        pred_def_pgd_full = torch.cat(pred_def_pgd_full)
        preds = pred_def_pgd_full.cpu().detach().numpy()
        print('label counts in def_pgd imgs:')
        print(np.unique(preds, return_counts=True))

        pred_def_pgd_nat_full = torch.cat(pred_def_pgd_nat_full)
        preds = pred_def_pgd_nat_full.cpu().detach().numpy()
        print('label counts in def_pgd_nat imgs:')
        print(np.unique(preds, return_counts=True))

        print()

    if save_img:
        test_img_full = torch.cat(test_img_full)
        adv_img_full = torch.cat(adv_img_full)
        pgd_img_full = torch.cat(pgd_img_full)
        pgd_nat_img_full = torch.cat(pgd_nat_img_full)
        def_img_full = torch.cat(def_img_full)
        def_adv_img_full = torch.cat(def_adv_img_full)
        def_pgd_img_full = torch.cat(def_pgd_img_full)
        def_pgd_nat_img_full = torch.cat(def_pgd_nat_img_full)

        test_grid = make_grid(test_img_full)
        adv_grid = make_grid(adv_img_full)
        pgd_grid = make_grid(pgd_img_full)
        pgd_nat_grid = make_grid(pgd_nat_img_full)
        def_grid = make_grid(def_img_full)
        def_adv_grid = make_grid(def_adv_img_full)
        def_pgd_grid = make_grid(def_pgd_img_full)
        def_pgd_nat_grid = make_grid(def_pgd_nat_img_full)

        if not os.path.exists(out_path + model_name):
            os.makedirs(out_path + model_name)

        save_image(test_grid, out_path + model_name + 'test_grid.png')
        save_image(adv_grid, out_path + model_name + 'adv_grid.png')
        save_image(pgd_grid, out_path + model_name + 'pgd_grid.png')
        save_image(pgd_nat_grid, out_path + model_name + 'pgd_nat_grid.png')
        save_image(def_grid, out_path + model_name + 'def_grid.png')
        save_image(def_adv_grid, out_path + model_name + 'def_adv_grid.png')
        save_image(def_pgd_grid, out_path + model_name + 'def_pgd_grid.png')
        save_image(def_pgd_nat_grid,
                   out_path + model_name + 'def_pgd_nat_grid.png')

        print('images saved')
コード例 #4
0
        print(arg, getattr(args, arg))

    # Define what device we are using
    print("CUDA Available: ", torch.cuda.is_available())
    device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

    mnist_dataset = torchvision.datasets.MNIST('./dataset', train=True, transform=transforms.ToTensor(), download=True)
    train_dataloader = DataLoader(mnist_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1)

    # training the target model
    target_model = MNIST_target_net().to(device)
    target_model.train()
    opt_model = torch.optim.Adam(target_model.parameters(), lr=0.0001)
    epochs = args.epochs

    pgd = PGD(target_model, None, None, device, args.epsilon, 7, args.epsilon/4)
    for epoch in range(epochs):
        loss_epoch = 0
        if epoch == 20:
            opt_model = torch.optim.Adam(target_model.parameters(), lr=0.00001)

        num_corrects = 0
        adv_num_corrects = 0
        total = 0
        for i, data in enumerate(train_dataloader, 0):
            train_imgs, train_labels = data
            train_imgs, train_labels = train_imgs.to(device), train_labels.to(device)
            #print(torch.max(train_imgs))
            logits_model = target_model(train_imgs)
            adv_imgs = pgd.perturb(train_imgs, train_labels, itr=0)
            #print(torch.max(adv_imgs-train_imgs))
コード例 #5
0
class AdvGAN_Attack:
    def __init__(self, device, model, model_num_labels, image_nc, box_min,
                 box_max, eps, pgd_iter, models_path, out_path, model_name,
                 writer, E_lr, advG_lr, defG_lr):
        output_nc = image_nc
        self.device = device
        self.model_num_labels = model_num_labels
        self.model = model
        self.input_nc = image_nc
        self.output_nc = output_nc
        self.box_min = box_min
        self.box_max = box_max
        self.eps = eps
        self.pgd_iter = pgd_iter
        self.models_path = models_path
        self.out_path = out_path
        self.model_name = model_name
        self.writer = writer
        self.E_lr = E_lr
        self.advG_lr = advG_lr
        self.defG_lr = defG_lr

        self.en_input_nc = image_nc
        self.E = models.Encoder(image_nc).to(device)
        self.defG = models.Generator(adv=False).to(device)
        self.advG = models.Generator(y_dim=model_num_labels,
                                     adv=True).to(device)
        self.pgd = PGD(self.model,
                       self.E,
                       self.defG,
                       self.device,
                       self.eps,
                       step_size=self.eps / 4)

        # initialize all weights
        self.E.apply(weights_init)
        self.defG.apply(weights_init)
        self.advG.apply(weights_init)

        # initialize optimizers
        self.optimizer_E = torch.optim.Adam(self.E.parameters(), lr=self.E_lr)
        self.optimizer_defG = torch.optim.Adam(self.defG.parameters(),
                                               lr=self.defG_lr)
        self.optimizer_advG = torch.optim.Adam(self.advG.parameters(),
                                               lr=self.advG_lr)

    # generate images for training
    def gen_images(self, x, labels, adv=True, def_adv=True, def_nat=True):

        results = []

        # random target labels
        target_labels = torch.randint_like(labels, 0, self.model_num_labels)
        target_one_hot = torch.eye(self.model_num_labels,
                                   device=self.device)[target_labels]
        target_one_hot = target_one_hot.view(-1, self.model_num_labels, 1, 1)

        x_encoded = self.E(x)

        if adv or def_adv:
            # make adv image
            adv_images = self.advG(x_encoded, target_one_hot) * self.eps + x
            adv_images = torch.clamp(adv_images, self.box_min, self.box_max)

            results.append(adv_images)

        if def_adv:
            # make def(adv) image
            def_adv_images = self.defG(self.E(adv_images)) + adv_images
            def_adv_images = torch.clamp(def_adv_images, self.box_min,
                                         self.box_max)

            results.append(def_adv_images)

        if def_nat:
            # make def(nat) image
            def_images = self.defG(x_encoded) + x
            def_images = torch.clamp(def_images, self.box_min, self.box_max)

            results.append(def_images)

        results.append(target_labels)

        return results

    # performance tester
    def test(self):

        self.E.eval()
        self.defG.eval()

        test_full(self.device,
                  self.model,
                  self.E,
                  self.defG,
                  self.advG,
                  self.eps,
                  self.out_path,
                  self.model_name,
                  label_count=True,
                  save_img=True)

        self.E.train()
        self.defG.train()

    # train single batch
    def train_batch(self, x, labels, batch_num):

        # optimize E
        for i in range(1):
            # clear grad
            self.optimizer_E.zero_grad()

            adv_images, def_adv_images, def_images, target_labels = self.gen_images(
                x, labels)

            # adv loss
            logits_adv = normalized_eval(adv_images, self.model)
            loss_adv = F.cross_entropy(logits_adv, target_labels)

            # def(adv) loss
            logits_def_adv = normalized_eval(def_adv_images, self.model)
            loss_def_adv = F.cross_entropy(logits_def_adv, labels)

            # def(nat) loss
            logits_def = normalized_eval(def_images, self.model)
            loss_def = F.cross_entropy(logits_def, labels)

            # backprop
            loss_E = loss_adv + loss_def_adv + loss_def

            loss_E.backward()

            self.optimizer_E.step()

        # optimize G
        for i in range(1):
            # clear grad
            self.optimizer_advG.zero_grad()
            adv_images, def_adv_images, target_labels = self.gen_images(
                x, labels, def_nat=False)

            # adv loss
            logits_adv = normalized_eval(adv_images, self.model)
            loss_adv = F.cross_entropy(logits_adv, target_labels)

            # def(adv) loss
            logits_def_adv = normalized_eval(def_adv_images, self.model)
            loss_def_adv = F.cross_entropy(logits_def_adv, target_labels)

            # backprop
            loss_advG = loss_adv + loss_def_adv

            loss_advG.backward()

            self.optimizer_advG.step()

        # optimize defG
        for i in range(1):
            # clear grad
            self.optimizer_defG.zero_grad()

            _, def_adv_images, def_images, _ = self.gen_images(x,
                                                               labels,
                                                               adv=False)

            # def(adv) loss
            logits_def_adv = normalized_eval(def_adv_images, self.model)
            loss_def_adv = F.cross_entropy(logits_def_adv, labels)

            # def loss
            logits_def = normalized_eval(def_images, self.model)
            loss_def = F.cross_entropy(logits_def, labels)

            # backprop
            loss_defG = loss_def_adv + loss_def

            loss_defG.backward()

            self.optimizer_defG.step()

        if batch_num == 0:

            self.E.eval()
            self.advG.eval()
            self.defG.eval()

            # adv, def performance check
            adv_pred = torch.argmax(normalized_eval(adv_images, self.model), 1)
            adv_correct = torch.sum(adv_pred == labels, 0)
            adv_acc = adv_correct.item() / len(labels)

            def_adv_pred = torch.argmax(
                normalized_eval(def_adv_images, self.model), 1)
            def_adv_correct = torch.sum(def_adv_pred == labels, 0)
            def_adv_acc = def_adv_correct.item() / len(labels)

            def_pred = torch.argmax(normalized_eval(def_images, self.model), 1)
            def_correct = torch.sum(def_pred == labels, 0)
            def_acc = def_correct.item() / len(labels)

            nat_pred = torch.argmax(normalized_eval(x, self.model), 1)
            nat_correct = torch.sum(nat_pred == labels, 0)
            nat_acc = nat_correct.item() / len(labels)

            print('adv mean perturbation: %.5f' %
                  torch.abs(adv_images - x).mean().item())
            print('def_adv mean perturbation: %.5f' %
                  torch.abs(def_adv_images - x).mean().item())
            print('def mean perturbation: %.5f' %
                  torch.abs(def_images - x).mean().item())
            print()

            # pgd performance check

            pgd_acc_li = []
            pgd_nat_acc_li = []
            for itr in self.pgd_iter:

                pgd_img = self.pgd.perturb(x, labels, itr=itr)
                pgd_nat_img = self.pgd.perturb(x, labels, itr=0)

                for _ in range(itr):
                    pgd_img = self.defG(self.E(pgd_img)) + pgd_img
                    pgd_img = torch.clamp(pgd_img, self.box_min, self.box_max)

                    # obfuscated check
                    pgd_nat_img = self.defG(self.E(pgd_nat_img)) + pgd_nat_img
                    pgd_nat_img = torch.clamp(pgd_nat_img, self.box_min,
                                              self.box_max)

                pred = torch.argmax(normalized_eval(pgd_img, self.model), 1)
                num_correct = torch.sum(pred == labels, 0)
                pgd_acc = num_correct.item() / len(labels)

                pgd_acc_li.append(pgd_acc)

                pred = torch.argmax(normalized_eval(pgd_nat_img, self.model),
                                    1)
                num_correct = torch.sum(pred == labels, 0)
                pgd_nat_acc = num_correct.item() / len(labels)

                pgd_nat_acc_li.append(pgd_nat_acc)

            self.defG.train()
            self.advG.train()
            self.E.train()

        else:
            pgd_acc_li = None
            pgd_nat_acc_li = None

            adv_acc = None
            def_adv_acc = None
            def_acc = None
            nat_acc = None

        return pgd_acc_li, pgd_nat_acc_li, torch.sum(loss_E).item(), torch.sum(loss_advG).item(), \
               torch.sum(loss_defG).item(), \
               adv_acc, def_adv_acc, def_acc, nat_acc

    # main training function
    def train(self, train_dataloader, epochs):
        for epoch in range(1, epochs + 1):

            if epoch == 50:
                self.optimizer_E = torch.optim.Adam(self.E.parameters(),
                                                    lr=self.E_lr / 10)
                self.optimizer_defG = torch.optim.Adam(self.defG.parameters(),
                                                       lr=self.defG_lr / 10)
                self.optimizer_advG = torch.optim.Adam(self.advG.parameters(),
                                                       lr=self.advG_lr / 10)
            if epoch == 80:
                self.optimizer_E = torch.optim.Adam(self.E.parameters(),
                                                    lr=self.E_lr / 100)
                self.optimizer_defG = torch.optim.Adam(self.defG.parameters(),
                                                       lr=self.defG_lr / 100)
                self.optimizer_advG = torch.optim.Adam(self.advG.parameters(),
                                                       lr=self.advG_lr / 100)

            loss_E_sum = 0
            loss_defG_sum = 0
            loss_advG_sum = 0
            pgd_acc_li_sum = []
            pgd_nat_acc_li_sum = []
            nat_acc_sum = 0
            adv_acc_sum = 0
            def_adv_acc_sum = 0
            def_acc_sum = 0

            for i, data in enumerate(train_dataloader, start=0):

                images, labels = data
                images, labels = images.to(self.device), labels.to(self.device)

                pgd_acc_li_batch, pgd_nat_acc_li_batch, loss_E_batch, loss_advG_batch, loss_defG_batch, \
                adv_acc, def_adv_acc, def_acc, nat_acc = \
                    self.train_batch(images, labels, i)

                loss_E_sum += loss_E_batch
                loss_advG_sum += loss_advG_batch
                loss_defG_sum += loss_defG_batch

                if pgd_acc_li_batch:
                    pgd_acc_li_sum.append(pgd_acc_li_batch)
                    pgd_nat_acc_li_sum.append(pgd_nat_acc_li_batch)
                    nat_acc_sum += nat_acc
                    adv_acc_sum += adv_acc
                    def_adv_acc_sum += def_adv_acc
                    def_acc_sum += def_acc

            # print statistics
            num_batch = len(train_dataloader)
            print("epoch %d:\nloss_E: %.5f, loss_advG: %.5f, loss_defG: %.5f" %
                  (epoch, loss_E_sum / num_batch, loss_advG_sum / num_batch,
                   loss_defG_sum / num_batch))

            pgd_acc_li_sum = np.mean(np.array(pgd_acc_li_sum), axis=0)
            for idx in range(len(self.pgd_iter)):
                print("pgd iter %d acc.: %.5f" %
                      (self.pgd_iter[idx], pgd_acc_li_sum[idx]))

            pgd_nat_acc_li_sum = np.mean(np.array(pgd_nat_acc_li_sum), axis=0)
            for idx in range(len(self.pgd_iter)):
                print("pgd nat iter %d acc.: %.5f" %
                      (self.pgd_iter[idx], pgd_nat_acc_li_sum[idx]))

            print("nat acc.: %.5f" % (nat_acc_sum))
            print("adv acc.: %.5f" % (adv_acc_sum))
            print("def_adv acc.: %.5f" % (def_adv_acc_sum))
            print("def acc.: %.5f" % (def_acc_sum))

            print()

            # write to tensorboard
            if self.writer:
                self.writer.add_scalar('loss_E', loss_E_sum / num_batch, epoch)
                self.writer.add_scalar('loss_advG', loss_advG_sum / num_batch,
                                       epoch)
                self.writer.add_scalar('loss_defG', loss_defG_sum / num_batch,
                                       epoch)

                for idx in range(len(self.pgd_iter)):
                    self.writer.add_scalar('pgd_acc_%d' % (self.pgd_iter[idx]),
                                           pgd_acc_li_sum[idx], epoch)
                    self.writer.add_scalar(
                        'pgd_nat_acc_%d' % (self.pgd_iter[idx]),
                        pgd_nat_acc_li_sum[idx], epoch)

                self.writer.add_scalar('nat_acc', nat_acc_sum, epoch)
                self.writer.add_scalar('adv_acc', adv_acc_sum, epoch)
                self.writer.add_scalar('def_adv_acc', def_adv_acc_sum, epoch)
                self.writer.add_scalar('def_acc', def_acc_sum, epoch)

            # save generator
            if epoch % 20 == 0:
                E_file_name = self.models_path + self.model_name + 'E_epoch_' + str(
                    epoch) + '.pth'
                advG_file_name = self.models_path + self.model_name + 'advG_epoch_' + str(
                    epoch) + '.pth'
                defG_file_name = self.models_path + self.model_name + 'defG_epoch_' + str(
                    epoch) + '.pth'
                torch.save(self.E.state_dict(), E_file_name)
                torch.save(self.advG.state_dict(), advG_file_name)
                torch.save(self.defG.state_dict(), defG_file_name)

        if self.writer:
            self.writer.close()

        # test performance
        self.test()
コード例 #6
0
def tester(dataset,
           dataloader,
           device,
           target_model,
           E,
           defG,
           eps,
           out_path,
           model_name,
           label_count=True,
           save_img=False):

    # load PGD
    pgd = PGD(target_model, E, defG, device)

    num_correct_pgd = 0
    num_correct_def = 0
    num_correct_def_pgd = 0
    num_correct_def_pgd_nat = 0
    num_correct = 0

    test_img_full = []
    pgd_img_full = []
    pgd_nat_img_full = []
    def_img_full = []
    def_pgd_img_full = []
    def_pgd_nat_img_full = []

    pred_pgd_full = []
    pred_def_pgd_full = []
    pred_def_pgd_nat_full = []

    for i, data in enumerate(dataloader, 0):
        # load images
        test_img, test_label = data
        test_img, test_label = test_img.to(device), test_label.to(device)

        def_noise = defG(E(test_img))
        def_img = def_noise + test_img
        def_img = torch.clamp(def_img, 0, 1)

        pgd_img = pgd.perturb(test_img, test_label, itr=1)

        def_pgd_noise = defG(E(pgd_img))
        def_pgd_img = def_pgd_noise + pgd_img
        def_pgd_img = torch.clamp(def_pgd_img, 0, 1)

        pgd_nat_img = pgd.perturb(test_img, test_label, itr=0)

        def_pgd_nat_noise = defG(E(pgd_nat_img))
        def_pgd_nat_img = def_pgd_nat_noise + pgd_nat_img
        def_pgd_nat_img = torch.clamp(def_pgd_nat_img, 0, 1)

        # calculate acc.
        pred = torch.argmax(target_model(test_img), 1)
        pred_pgd = torch.argmax(target_model(pgd_nat_img), 1)
        pred_def = torch.argmax(target_model(def_img), 1)
        pred_def_pgd = torch.argmax(target_model(def_pgd_img), 1)
        pred_def_pgd_nat = torch.argmax(target_model(def_pgd_nat_img), 1)

        num_correct += torch.sum(pred == test_label, 0)
        num_correct_pgd += torch.sum(pred_pgd == test_label, 0)
        num_correct_def += torch.sum(pred_def == test_label, 0)
        num_correct_def_pgd += torch.sum(pred_def_pgd == test_label, 0)
        num_correct_def_pgd_nat += torch.sum(pred_def_pgd_nat == test_label, 0)

        if label_count:
            pred_pgd_full.append(pred_pgd)
            pred_def_pgd_full.append(pred_def_pgd)
            pred_def_pgd_nat_full.append(pred_def_pgd_nat)

        if save_img and i < 1:
            test_img_full.append(test_img)
            pgd_img_full.append(pgd_img)
            pgd_nat_img_full.append(pgd_nat_img)
            def_img_full.append(def_img)
            def_pgd_img_full.append(def_pgd_img)
            def_pgd_nat_img_full.append(def_pgd_nat_img)

    print('num_correct(nat): ', num_correct.item())
    print('num_correct(pgd): ', num_correct_pgd.item())
    print('num_correct(def(nat)): ', num_correct_def.item())
    print('num_correct(def(pgd)): ', num_correct_def_pgd.item())
    print('num_correct(def(pgd_nat)): ', num_correct_def_pgd_nat.item())
    print()

    print('accuracy of nat imgs: %f' % (num_correct.item() / len(dataset)))
    print('accuracy of pgd imgs: %f' % (num_correct_pgd.item() / len(dataset)))
    print('accuracy of def(nat) imgs: %f' %
          (num_correct_def.item() / len(dataset)))
    print('accuracy of def(pgd) imgs: %f' %
          (num_correct_def_pgd.item() / len(dataset)))
    print('accuracy of def(pgd_nat) imgs: %f' %
          (num_correct_def_pgd_nat.item() / len(dataset)))
    print()

    l_inf = np.amax(
        np.abs(def_img.cpu().detach().numpy() -
               test_img.cpu().detach().numpy()))
    print('l-inf of def imgs:%f' % (l_inf))
    l_inf = np.amax(
        np.abs(def_pgd_img.cpu().detach().numpy() -
               test_img.cpu().detach().numpy()))
    print('l-inf of def(pgd) imgs:%f' % (l_inf))
    l_inf = np.amax(
        np.abs(def_pgd_nat_img.cpu().detach().numpy() -
               test_img.cpu().detach().numpy()))
    print('l-inf of def(pgd_nat) imgs:%f' % (l_inf))

    print()

    if label_count:
        pred_pgd_full = torch.cat(pred_pgd_full)
        preds = pred_pgd_full.cpu().detach().numpy()
        print('label counts in pgd imgs:')
        print(np.unique(preds, return_counts=True))

        pred_def_pgd_full = torch.cat(pred_def_pgd_full)
        preds = pred_def_pgd_full.cpu().detach().numpy()
        print('label counts in def_pgd imgs:')
        print(np.unique(preds, return_counts=True))

        pred_def_pgd_nat_full = torch.cat(pred_def_pgd_nat_full)
        preds = pred_def_pgd_nat_full.cpu().detach().numpy()
        print('label counts in def_pgd_nat imgs:')
        print(np.unique(preds, return_counts=True))

        print()

    if save_img:
        test_img_full = torch.cat(test_img_full)
        pgd_img_full = torch.cat(pgd_img_full)
        pgd_nat_img_full = torch.cat(pgd_nat_img_full)
        def_img_full = torch.cat(def_img_full)
        def_pgd_img_full = torch.cat(def_pgd_img_full)
        def_pgd_nat_img_full = torch.cat(def_pgd_nat_img_full)

        test_grid = make_grid(test_img_full)
        pgd_grid = make_grid(pgd_img_full)
        pgd_nat_grid = make_grid(pgd_nat_img_full)
        def_grid = make_grid(def_img_full)
        def_pgd_grid = make_grid(def_pgd_img_full)
        def_pgd_nat_grid = make_grid(def_pgd_nat_img_full)

        if not os.path.exists(out_path + model_name):
            os.makedirs(out_path + model_name)

        save_image(test_grid, out_path + model_name + 'test_grid.png')
        save_image(pgd_grid, out_path + model_name + 'pgd_grid.png')
        save_image(pgd_nat_grid, out_path + model_name + 'pgd_nat_grid.png')
        save_image(def_grid, out_path + model_name + 'def_grid.png')
        save_image(def_pgd_grid, out_path + model_name + 'def_pgd_grid.png')
        save_image(def_pgd_nat_grid,
                   out_path + model_name + 'def_pgd_nat_grid.png')

        print('images saved')
コード例 #7
0
ファイル: train_ss.py プロジェクト: nupurkmr9/ss-ood
trainset = torchvision.datasets.CIFAR10(root='./dataset/', train=True,
                                        download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                          shuffle=True, num_workers=8 )

testset = torchvision.datasets.CIFAR10(root='./dataset/', train=False,
                                       download=True, transform=test_transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_bs,
                                         shuffle=False, num_workers=8)

# Create model
net = WideResNet(args.layers, 10 , args.widen_factor, dropRate=args.droprate)
net.rot_pred = nn.Linear(128, 4)

#create PGD adversary
adversary = PGD( epsilon=8./255., num_steps=10, step_size=2./255.)
adversary_test = PGD( epsilon=8./255., num_steps=20, step_size=1./255., attack_rotations=False)


start_epoch = 0

# Restore model if desired
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(args.load, 'cifar10' + '_' + 'wrn' +
                                  '_baseline_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
            break