Beispiel #1
0
def main(opts):

    mode = opts['mode']

    "load data"
    if mode in ['wf', 'wf_ow', 'wf_kf']:
        "load data"
        print('train data: %s' % opts['train_data_path'])
        train_data = utils_wf.load_data_main(opts['train_data_path'],
                                             opts['batch_size'],
                                             shuffle=True)

    elif mode == 'shs':
        "load data"
        train_data = utils_shs.load_data_main(opts['train_data_path'],
                                              opts['batch_size'])

    else:
        print('mode not in ["wf","shs","wf_ow"], system will exit.')
        sys.exit()

    start_time = time.time()
    adv_gan = gan.advGan(opts, opts['x_box_min'], opts['x_box_max'],
                         opts['pert_box'])
    adv_gan.train(train_data)
    end_time = time.time()
    print('training time of GAN is {} hours.'.format(
        (end_time - start_time) / 3600.0))
    def testing_process(self):

        print('testing mode...')

        if self.mode == 'wf':
            "load data"
            test_data = utils_wf.load_data_main(self.opts['test_data_path'],self.opts['batch_size'])

            "load target model structure"
            params = utils_wf.params(self.opts['num_class'],self.opts['input_size'])
            target_model = models.target_model_wf(params).to(self.device)

        elif self.mode == 'shs':
            "load data"
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],self.opts['batch_size'])

            "load target model structure"
            params = utils_shs.params(self.opts['num_class'],self.opts['input_size'])
            target_model = models.target_model_shs(params).to(self.device)

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        model_name = self.model_path + '/adv_target_model_' + self.opts['Adversary'] + '.pth'
        target_model.load_state_dict(torch.load(model_name, map_location=self.device))
        target_model.eval()
        loss_criterion = nn.CrossEntropyLoss()

        test_loss = 0.
        test_correct = 0
        total_case = 0
        start_time = datetime.now()
        i = 0
        for data, label in test_data:
            i += 1
            print('testing {}'.format(i))
            data, label = data.to(self.device),label.to(self.device)
            outputs = target_model(data)
            loss = loss_criterion(outputs, label)
            test_loss += loss * len(label)
            _, predicted = torch.max(outputs, 1)
            correct = int(sum(predicted == label))
            test_correct += correct
            total_case += len(label)

            # delete caches
            del data, label, outputs, loss
            torch.cuda.empty_cache()

        accuracy = test_correct / total_case
        loss = test_loss / total_case
        print("Test Loss: {:5.2f}, Accuracy: {:6.2%}".format(loss, accuracy))

        end_time = datetime.now()
        time_diff = (end_time - start_time).seconds
        print("Time Usage: {:5.2f} mins.".format(time_diff / 60.))
    def test_peformance(self, threshold_val):

        "load data and target model"
        if self.mode == 'wf_ow':
            "load data"
            if self.opts['test_data_type'] == 'NoDef':
                test_data_Mon = utils_wf.load_data_main(
                    self.opts['test_data_Mon_path'], self.opts['batch_size'])
                test_data_UnMon = utils_wf.load_data_main(
                    self.opts['test_data_UnMon_path'], self.opts['batch_size'])
            elif self.opts['test_data_type'] == 'Def':
                test_data_Mon = utils_wf.load_data_main(
                    self.opts['adv_test_data_Mon_path'],
                    self.opts['batch_size'])
                test_data_UnMon = utils_wf.load_data_main(
                    self.opts['adv_test_data_UnMon_path'],
                    self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                params = utils_wf.params_lstm_ow_eval(self.opts['num_class'],
                                                      self.opts['input_size'],
                                                      self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        else:
            print('mode not in ["wf_ow"], system will exit.')
            sys.exit()

        if self.model_type == 'adv_target_model':
            model_name = self.model_path + '/adv_target_model_' + self.opts[
                'Adversary'] + '.pth'
        elif self.model_type == 'target_model':
            model_name = self.model_path + '/target_model.pth'
        else:
            print(
                'target model type not in ["target_model","adv_target_model"], system will exit.'
            )
            sys.exit()
        target_model.load_state_dict(
            torch.load(model_name, map_location=self.device))
        target_model.eval()

        print('target model %s' % model_name)
        print('test data path unmonitored %s ' %
              self.opts['adv_test_data_UnMon_path'])
        print('test data path monitored %s ' %
              self.opts['adv_test_data_Mon_path'])

        # ==============================================================
        "testing process..."
        TP = 0
        FP = 0
        TN = 0
        FN = 0

        "obtain monitored and Unmonitored lables"
        _, y_test_Mon = utils_wf.load_csv_data(self.opts['test_data_Mon_path'])
        _, y_test_UnMon = utils_wf.load_csv_data(
            self.opts['test_data_UnMon_path'])
        monitored_labels = torch.Tensor(y_test_Mon).long().to(self.device)
        unmonitored_labels = torch.Tensor(y_test_UnMon).long().to(self.device)

        # ==============================================================
        # Test with Monitored testing instances

        for i, data in enumerate(test_data_Mon, 0):
            test_x, test_y = data
            test_x, test_y = test_x.to(self.device), test_y.to(self.device)
            max_probs, pred_labels = torch.max(
                torch.softmax(target_model(test_x), 1), 1)

            for j, pred_label in enumerate(pred_labels):
                if pred_label in monitored_labels:  # predited as Monitored
                    if max_probs[
                            j] >= threshold_val:  # probability greater than the threshold
                        TP += 1
                    else:  # predicted as unmonitored and true lable is Monitored
                        FN += 1
                elif pred_label in unmonitored_labels:  # predicted as unmonitored and true lable is monitored
                    FN += 1

        # ==============================================================
        # Test with Unmonitored testing instances
        for i, data in enumerate(test_data_UnMon, 0):
            test_x, test_y = data
            test_x, test_y = test_x.to(self.device), test_y.to(self.device)
            max_probs, pred_labels = torch.max(
                torch.softmax(target_model(test_x), 1), 1)

            for j, pred_label in enumerate(pred_labels):
                if pred_label in unmonitored_labels:  # predited as unmonitored and true label is unmonitored
                    TN += 1
                elif pred_label in monitored_labels:  # predicted as Monitored
                    if max_probs[
                            j] >= threshold_val:  # predicted as monitored and true label is unmonitored
                        FP += 1
                    else:
                        TN += 1

        "print result"
        print("TP : ", TP)
        print("FP : ", FP)
        print("TN : ", TN)
        print("FN : ", FN)
        print("Total  : ", TP + FP + TN + FN)
        TPR = float(TP) / (TP + FN)
        print("TPR : ", TPR)
        FPR = float(FP) / (FP + TN)
        print("FPR : ", FPR)
        Precision = float(TP) / (TP + FP)
        print("Precision : ", Precision)
        Recall = float(TP) / (TP + FN)
        print("Recall : ", Recall)
        print("\n")

        self.log_file.writelines(
            "%.6f,%d,%d,%d,%d,%.6f,%.6f,%.6f,%.6f\n" %
            (threshold_val, TP, FP, TN, FN, TPR, FPR, Precision, Recall))
    def adv_train_process(self,delay=0.5):
        "Adversarial training, returns pertubed mini batch"
        "delay: parameter to decide how many epochs should be used as adv training"

        if self.opts['Adversary'] == 'WT':
            adv_train_data_path = '../data/WalkieTalkie/defended_csv/adv_train_WT.csv'
        else:
            adv_train_data_path = self.data_path + self.classifier_type + '/' + self.opts['adv_train_data_path']
        print('adv train data path: ',adv_train_data_path)

        if self.mode in ['wf','wf_ow','wf_kf']:
            "load data"
            if self.mode != 'wf_kf':
                train_path = self.data_path + self.opts['train_data_path']
                train_data = utils_wf.load_data_main(train_path,self.opts['batch_size'],shuffle=True)
                # test_data = utils_wf.load_data_main(self.data_path + self.opts['test_data_path'],self.opts['batch_size'])
            else:
                # benign_path = '../data/wf/cross_val/'
                # train_path =  benign_path + self.opts['train_data_path']
                train_path = '../data/wf/train_NoDef_burst.csv'
                train_data = utils_wf.load_data_main(train_path,self.opts['batch_size'], shuffle=True)
                # test_data = utils_wf.load_data_main(benign_path + self.opts['test_data_path'],self.opts['batch_size'])
            print('train data path: ',train_path)

            adv_train_data = utils_wf.load_data_main(adv_train_data_path,self.opts['batch_size'],shuffle=True)

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'], self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)
                target_model.train()

            elif self.classifier_type == 'lstm':
                if self.mode == 'wf_ow':
                    params = utils_wf.params_lstm_ow(self.opts['num_class'], self.opts['input_size'],self.opts['batch_size'])
                else:
                    params = utils_wf.params_lstm(self.opts['num_class'], self.opts['input_size'],self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)
                target_model.train()


        elif self.mode == 'shs':
            "load data"
            train_data = utils_shs.load_data_main(self.data_path + self.opts['train_data_path'],self.opts['batch_size'])
            test_data = utils_shs.load_data_main(self.data_path + self.opts['test_data_path'],self.opts['batch_size'])
            adv_train_data = utils_shs.load_data_main(adv_train_data_path, self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_shs.params_cnn(self.opts['num_class'], self.opts['input_size'])
                target_model = models.cnn_noNorm(params).to(self.device)
                target_model.train()

            elif self.classifier_type == 'lstm':
                params = utils_shs.params_lstm(self.opts['num_class'], self.opts['input_size'], self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)
                target_model.train()

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        loss_criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(target_model.parameters(), lr=1e-3)


        "start training"
        steps = 0
        flag = False
        start_time = datetime.now()
        # print('adversary working dir {}'.format(adv_train_data_path))
        for epoch in range(self.opts['epochs']):
            print('Starting epoch %d / %d' % (epoch + 1, self.opts['epochs']))

            if flag:
                print('{} based adversarial training...'.format(self.opts['Adversary']))

            for (x,y),(x_adv,y_adv) in zip(train_data,adv_train_data):
                steps += 1
                optimizer.zero_grad()

                x,y = x.to(self.device), y.to(self.device)
                outputs = target_model(x)
                loss = loss_criterion(outputs,y)

                "adversarial training"
                if epoch + 1 >= int((1-delay)*self.opts['epochs']):
                    x_adv, y_adv = x_adv.to(self.device), y_adv.to(self.device)
                    flag = True
                    loss_adv = loss_criterion(target_model(x_adv),y_adv)
                    loss = (loss + loss_adv) / 2

                "print results every 100 steps"
                if steps % 100 == 0:

                    end_time = datetime.now()
                    time_diff = (end_time - start_time).seconds
                    time_usage = '{:3}m{:3}s'.format(int(time_diff / 60), time_diff % 60)
                    msg = "Step {:5}, Loss:{:6.2f}, Time usage:{:9}."
                    print(msg.format(steps, loss, time_usage))

                loss.backward()
                optimizer.step()

            if epoch!=0 and epoch % 10 == 0 or epoch == self.opts['epochs']:
                if self.mode != 'wf_kf':
                    output_name = '/adv_target_model_%s.pth' % self.opts['Adversary']
                else:
                    output_name = '/adv_target_model_%s_%d.pth' % (self.opts['Adversary'], self.opts['id'])
                output_path = self.model_path + output_name
                torch.save(target_model.state_dict(), output_path)


        "test trianed model"
Beispiel #5
0
    def test_model(self):

        "define test data type/path"
        # classifier cross validation:
        if self.opts['cross_validation']:
            if self.classifier_type == 'lstm':
                adv_test_data_path = '../data/' + self.mode + '/cnn/' + self.opts[
                    'adv_test_data_path']
            elif self.classifier_type == 'cnn':
                adv_test_data_path = '../data/' + self.mode + '/lstm/' + self.opts[
                    'adv_test_data_path']
            else:
                print(
                    'classifier type should in [lstm,cnn]. System will exit!')
                sys.exit()
        # test on Walkie-Talkie defended data
        elif self.opts['Adversary'] == 'WT':
            adv_test_data_path = '../data/WalkieTalkie/defended_csv/adv_test_WT.csv'
        else:
            adv_test_data_path = self.data_path + '/' + self.opts[
                'adv_test_data_path']

        print('adv test data path:  ', adv_test_data_path)

        "load data and model"
        if self.mode == 'wf' or 'wf_ow' or 'wf_kf':
            "load data"
            if self.mode == 'wf_kf':
                test_path = '../data/wf/cross_val/' + self.opts[
                    'test_data_path']
                test_data = utils_wf.load_data_main(test_path,
                                                    self.opts['batch_size'])
            else:
                test_path = '../data/wf/' + self.opts['test_data_path']
                test_data = utils_wf.load_data_main(test_path,
                                                    self.opts['batch_size'])
            adv_test_data = utils_wf.load_data_main(adv_test_data_path,
                                                    self.opts['batch_size'])
            print('test data path: ', test_path)

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                if self.mode == 'wf_ow':
                    params = utils_wf.params_lstm_ow_eval(
                        self.opts['num_class'], self.opts['input_size'],
                        self.opts['batch_size'])
                else:
                    params = utils_wf.params_lstm_eval(self.opts['num_class'],
                                                       self.opts['input_size'],
                                                       self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        elif self.mode == 'shs':
            "load data"
            test_data = utils_shs.load_data_main(
                '../data/' + self.mode + '/' + self.opts['test_data_path'],
                self.opts['batch_size'])
            adv_test_data = utils_shs.load_data_main(adv_test_data_path,
                                                     self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_shs.params_cnn(self.opts['num_class'],
                                              self.opts['input_size'])
                target_model = models.cnn_noNorm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                params = utils_shs.params_lstm_eval(self.opts['num_class'],
                                                    self.opts['input_size'],
                                                    self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        if self.model_type == 'adv_target_model':
            if self.mode == 'wf_kf':
                model_name = self.model_path + '/adv_target_model_%s_%d.pth' % (
                    self.opts['Adversary'], self.opts['id'])
            else:
                model_name = self.model_path + '/adv_target_model_' + self.opts[
                    'Adversary'] + '.pth'
        elif self.model_type == 'target_model':
            if self.mode == 'wf_kf':
                model_name = self.model_path + '/target_model_%d.pth' % self.opts[
                    'id']
            else:
                model_name = self.model_path + '/target_model.pth'
        else:
            print(
                'target model type not in ["target_model","adv_target_model"], system will exit.'
            )
            sys.exit()
        print('model path: ', model_name)

        target_model.load_state_dict(
            torch.load(model_name, map_location=self.device))
        target_model.eval()

        "test on adversarial examples"
        correct_adv_x = 0
        correct_x = 0
        total_case = 0
        for (x, y), (adv_x, adv_y) in zip(test_data, adv_test_data):

            x, y = x.to(self.device), y.to(self.device)
            adv_x, adv_y = adv_x.to(self.device), adv_y.to(self.device)

            "prediction on original input x"
            pred_x = target_model(x)
            _, pred_x = torch.max(pred_x, 1)
            correct_x += (pred_x == y).sum()

            "predition on adv_x"
            pred_adv_x = target_model(adv_x)
            _, pred_adv_x = torch.max(pred_adv_x, 1)
            correct_adv_x += (pred_adv_x == adv_y).sum()

            total_case += len(y)

        acc_x = float(correct_x.item()) / float(total_case)
        acc_adv = float(correct_adv_x.item()) / float(total_case)

        print('*' * 30)
        print('"{}" with {} against {}.'.format(self.mode,
                                                self.opts['Adversary'],
                                                self.model_type))
        print('correct test after attack is {}'.format(correct_adv_x.item()))
        print('total test instances is {}'.format(total_case))
        print(
            'accuracy of test after {} attack : correct/total= {:.5f}'.format(
                self.opts['Adversary'], acc_adv))
        print('success rate of the attack is : {}'.format(1 - acc_adv))
        print('accucary of the model without being attacked is {:.5f}'.format(
            acc_x))
        print('\n')
    def adv_train_process(self, delay=0.5):

        "Adversarial training, returns pertubed mini batch"
        "delay: parameter to decide how many epochs should be used as adv training"

        if self.mode == 'wf':
            "load data"
            train_data = utils_wf.load_data_main(self.opts['train_data_path'],
                                                 self.opts['batch_size'])
            test_data = utils_wf.load_data_main(self.opts['test_data_path'],
                                                self.opts['batch_size'])

            "load target model structure"
            params = utils_wf.params(self.opts['num_class'],
                                     self.opts['input_size'])
            target_model = models.target_model_wf(params).to(self.device)
            target_model.train()

        elif self.mode == 'shs':
            "load data"
            train_data = utils_shs.load_data_main(self.opts['train_data_path'],
                                                  self.opts['batch_size'])
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            params = utils_shs.params(self.opts['num_class'],
                                      self.opts['input_size'])
            target_model = models.target_model_shs(params).to(self.device)
            target_model.train()

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        loss_criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(target_model.parameters(), lr=1e-3)

        "set adversary"
        Adversary = self.opts['Adversary']

        if self.mode == 'wf':
            fgsm_epsilon = 0.2
            pgd_a = 0.051
        else:
            fgsm_epsilon = 0.1
            pgd_a = 0.01

        if Adversary == 'FGSM':
            print('adv training with FGSM')
            adversary = FGSM(self.mode,
                             x_box_min=self.opts['x_box_min'],
                             x_box_max=self.opts['x_box_max'],
                             pert_box=self.opts['pert_box'],
                             epsilon=fgsm_epsilon)
        elif Adversary == 'DeepFool':
            print('adv training with DeepFool')
            adversary = DeepFool(self.mode,
                                 x_box_min=self.opts['x_box_min'],
                                 x_box_max=self.opts['x_box_max'],
                                 pert_box=self.opts['pert_box'],
                                 num_classes=5)
        elif Adversary == 'PGD':
            print('adv training with PGD')
            adversary = LinfPGDAttack(self.mode,
                                      x_box_min=self.opts['x_box_min'],
                                      x_box_max=self.opts['x_box_max'],
                                      pert_box=self.opts['pert_box'],
                                      k=5,
                                      a=pgd_a,
                                      random_start=False)
        elif Adversary == 'GAN':
            generator = models.Generator(self.gen_input_nc,
                                         self.input_nc).to(self.device)
            generator.load_state_dict(
                torch.load(self.model_path + '/adv_generator.pth',
                           map_location=self.device))
            generator.eval()

        "start training process"
        steps = 0
        flag = False
        start_time = datetime.now()

        for epoch in range(self.opts['epochs']):
            print('Starting epoch %d / %d' % (epoch + 1, self.opts['epochs']))

            if flag:
                print('{} based adversarial training...'.format(Adversary))

            for x, y in train_data:
                steps += 1
                optimizer.zero_grad()
                x, y = x.to(self.device), y.to(self.device)
                outputs = target_model(x)
                loss = loss_criterion(outputs, y)

                "adversarial training"
                if epoch + 1 >= int((1 - delay) * self.opts['epochs']):
                    flag = True

                    if Adversary == 'GAN':
                        pert = generator(x)
                        "cal adv_x given different mode wf/shs"
                        adv_x = utils_gan.get_advX_gan(
                            x,
                            pert,
                            mode,
                            pert_box=self.opts['pert_box'],
                            x_box_min=self.opts['x_box_min'],
                            x_box_max=self.opts['x_box_max'],
                            alpha=self.opts['alpha'])
                    elif Adversary in ['FGSM', 'PGD', 'DeepFool']:
                        _, y_pred = torch.max(target_model(x), 1)
                        "cal adv_x given different mode wf/shs. the mode of adversary set before"
                        adv_x = self.x_adv_gen(x, y_pred, target_model,
                                               adversary)

                    else:
                        print('No Adversary found! System will exit.')
                        sys.exit()

                    loss_adv = loss_criterion(
                        target_model(adv_x.to(self.device)), y)
                    loss = (loss + loss_adv) / 2

                "print results every 100 steps"
                if steps % 100 == 0:
                    end_time = datetime.now()
                    time_diff = (end_time - start_time).seconds
                    time_usage = '{:3}m{:3}s'.format(int(time_diff / 60),
                                                     time_diff % 60)
                    msg = "Step {:5}, Loss:{:6.2f}, Time usage:{:9}."
                    print(msg.format(steps, loss, time_usage))

                loss.backward()
                optimizer.step()

            if epoch % 10 == 0:
                torch.save(
                    target_model.state_dict(), self.model_path +
                    '/adv_target_model_' + Adversary + '.pth')

        #****************************
        "test trained model"
Beispiel #7
0
    def train_model(self):

        if self.mode == 'wf':
            "load data"
            train_data = utils_wf.load_data_main(self.opts['train_data_path'],self.opts['batch_size'])
            test_data = utils_wf.load_data_main(self.opts['test_data_path'],self.opts['batch_size'])

            "load target model structure"
            # params = utils_wf.params_lstm(self.opts['num_class'],self.opts['input_size'],self.opts['batch_size'])
            target_model = models.lstm(self.params).to(self.device)
            target_model.train()

        elif self.mode == 'shs':
            "load data"
            train_data = utils_shs.load_data_main(self.opts['train_data_path'],self.opts['batch_size'])
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],self.opts['batch_size'])

            "load target model structure"
            # params = utils_shs.params_lstm(self.opts['num_class'],self.opts['input_size'],self.opts['batch_size'])
            target_model = models.lstm(self.params).to(self.device)
            target_model.train()

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        "train process"
        optimizer = torch.optim.Adam(target_model.parameters(),lr=self.opts['lr'])

        for epoch in range(self.opts['epochs']):
            loss_epoch = 0
            for i, data in enumerate(train_data, 0):
                train_x, train_y = data
                train_x, train_y = train_x.to(self.device), train_y.to(self.device)

                "batch_first = False"
                if not self.opts['batch_size']:
                    train_x = train_x.transpose(0,1)

                optimizer.zero_grad()
                logits_model = target_model(train_x)
                loss_model = F.cross_entropy(logits_model, train_y)
                loss_epoch += loss_model

                loss_model.backward(retain_graph=True)
                optimizer.step()

                if i % 100 == 0:
                    _, predicted = torch.max(logits_model, 1)
                    correct = int(sum(predicted == train_y))
                    accuracy = correct / len(train_y)
                    msg = 'Epoch {:5}, Step {:5}, Loss: {:6.2f}, Accuracy:{:8.2%}.'
                    print(msg.format(epoch, i, loss_model, accuracy))

                    "report intermediate result"
                    nni.report_intermediate_result(accuracy)


            "save model every 10 epochs"
            if epoch % 10 == 0:
                targeted_model_path = self.model_path + '/target_model.pth'
                torch.save(target_model.state_dict(), targeted_model_path)


        "test target model"
        target_model.eval()

        num_correct = 0
        total_instances = 0
        for i, data in enumerate(test_data, 0):
            test_x, test_y = data
            test_x, test_y = test_x.to(self.device), test_y.to(self.device)
            pred_lab = torch.argmax(target_model(test_x), 1)
            num_correct += torch.sum(pred_lab == test_y, 0)
            total_instances += len(test_y)

        acc = num_correct.item() / total_instances
        print('accuracy of target model against test dataset: %f\n' % (acc))

        "report final result "
        nni.report_final_result(acc)
    def test_model(self):

        "load data and target model"
        if self.mode == 'wf' or 'detect':
            "load data"
            if mode == 'wf':
                # train_data = utils_wf.load_data_main(self.opts['train_data_path'], self.opts['batch_size'])
                test_data = utils_wf.load_data_main(
                    self.opts['test_data_path'], self.opts['batch_size'])
            elif mode == 'detect':
                # train_data = utils_wf.load_NormData_main(self.opts['train_data_path'], self.opts['batch_size'])
                test_data = utils_wf.load_NormData_main(
                    self.opts['test_data_path'], self.opts['batch_size'])

            "load target model structure"
            if self.opts['classifier'] == 'lstm':
                params = utils_wf.params_lstm_detect(self.opts['num_class'],
                                                     self.opts['input_size'],
                                                     self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)
            elif self.opts['classifier'] == 'rnn':
                params = utils_wf.params_rnn(self.opts['num_class'],
                                             self.opts['input_size'],
                                             self.opts['batch_size'])
                target_model = models.rnn(params).to(self.device)
            elif self.opts['classifier'] == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn(params).to(self.device)
            elif self.opts['classifier'] == 'fcnn':
                params = utils_wf.params_fcnn(self.opts['num_class'],
                                              self.opts['input_size'])
                target_model = models.fcnn(params).to(self.device)
            target_model.eval()

        elif self.mode == 'shs':
            "load data"
            # train_data = utils_shs.load_data_main(self.opts['train_data_path'], self.opts['batch_size'])
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            params = utils_shs.params_lstm(self.opts['num_class'],
                                           self.opts['input_size'],
                                           self.opts['batch_size'])
            target_model = models.lstm(params).to(self.device)
            target_model.train()

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        model_name = self.model_path + '/target_model.pth'

        model_weights = torch.load(model_name, map_location=self.device)
        for k in model_weights:
            print(k)
        # for k in model_weights['shared_layers']: print("Shared layer", k)

        target_model.load_state_dict(
            torch.load(model_name, map_location=self.device))
        target_model.eval()

        num_correct = 0
        total_instances = 0
        y_test = []
        y_pred = []
        for i, data in enumerate(test_data, 0):
            test_x, test_y = data
            test_x, test_y = test_x.to(self.device), test_y.to(self.device)
            pred_lab = torch.argmax(target_model(test_x), 1)
            num_correct += torch.sum(pred_lab == test_y, 0)
            total_instances += len(test_y)

            "save result"
            y_test += (test_y.cpu().numpy().tolist())
            y_pred += (pred_lab.cpu().numpy().tolist())

        print('{} with {}'.format(self.opts['mode'], self.opts['classifier']))

        print(classification_report(y_test, y_pred))
        print('confusion matrix is {}'.format(confusion_matrix(y_test,
                                                               y_pred)))
        print('accuracy of target model against test dataset: %f\n' %
              (num_correct.item() / total_instances))
        print('accuracy is {}'.format(metrics.accuracy_score(y_test, y_pred)))

        # plot confusion matrix
        cm = confusion_matrix(y_test, y_pred)
        heat_map(cm, self.opts['classifier'], namorlize=True)
    def train_model(self):

        if self.mode == 'wf' or 'detect':
            "load data"
            if mode == 'wf':
                train_data = utils_wf.load_data_main(
                    self.opts['train_data_path'], self.opts['batch_size'])
                # test_data = utils_wf.load_data_main(self.opts['test_data_path'],self.opts['batch_size'])
            elif mode == 'detect':
                train_data = utils_wf.load_NormData_main(
                    self.opts['train_data_path'], self.opts['batch_size'])
                # test_data = utils_wf.load_NormData_main(self.opts['test_data_path'], self.opts['batch_size'])

            "load target model structure"
            if self.opts['classifier'] == 'lstm':
                params = utils_wf.params_lstm_detect(self.opts['num_class'],
                                                     self.opts['input_size'],
                                                     self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)
            elif self.opts['classifier'] == 'rnn':
                params = utils_wf.params_rnn(self.opts['num_class'],
                                             self.opts['input_size'],
                                             self.opts['batch_size'])
                target_model = models.rnn(params).to(self.device)
            elif self.opts['classifier'] == 'cnn':
                params = utils_wf.params_cnn_detect(self.opts['num_class'],
                                                    self.opts['input_size'])
                target_model = models.cnn(params).to(self.device)
            elif self.opts['classifier'] == 'fcnn':
                params = utils_wf.params_fcnn(self.opts['num_class'],
                                              self.opts['input_size'])
                target_model = models.fcnn(params).to(self.device)
            target_model.train()

        elif self.mode == 'shs':
            "load data"
            train_data = utils_shs.load_data_main(self.opts['train_data_path'],
                                                  self.opts['batch_size'])
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            params = utils_shs.params_lstm(self.opts['num_class'],
                                           self.opts['input_size'],
                                           self.opts['batch_size'])
            target_model = models.lstm(params).to(self.device)
            target_model.train()

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        "train process"
        optimizer = torch.optim.Adam(target_model.parameters(),
                                     lr=self.opts['lr'])

        for epoch in range(self.opts['epochs']):
            loss_epoch = 0
            for i, data in enumerate(train_data, 0):
                train_x, train_y = data
                train_x, train_y = train_x.to(self.device), train_y.to(
                    self.device)

                "batch_first = False"
                if not self.opts['batch_size']:
                    train_x = train_x.transpose(0, 1)

                optimizer.zero_grad()
                logits_model = target_model(train_x)
                loss_model = F.cross_entropy(logits_model, train_y)
                loss_epoch += loss_model

                loss_model.backward(retain_graph=True)
                optimizer.step()

                if i % 100 == 0:
                    _, predicted = torch.max(logits_model, 1)
                    correct = int(sum(predicted == train_y))
                    accuracy = correct / len(train_y)
                    msg = 'Epoch {:5}, Step {:5}, Loss: {:6.2f}, Accuracy:{:8.2%}.'
                    print(msg.format(epoch, i, loss_model, accuracy))

                "empty cache"
                del train_x, train_y
                torch.cuda.empty_cache()

            "save model every 10 epochs"
            if epoch % 10 == 0:
                targeted_model_path = self.model_path + '/target_model.pth'
                torch.save(target_model.state_dict(), targeted_model_path)

        "test target model"
Beispiel #10
0
    def ploting(self):

        "load data and model"
        if self.mode == 'wf':
            "load data"
            test_data = utils_wf.load_data_main(self.opts['test_data_path'],
                                                self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                params = utils_wf.params_lstm_eval(self.opts['num_class'],
                                                   self.opts['input_size'],
                                                   self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        elif self.mode == 'shs':
            "load data"
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            params = utils_shs.params(self.opts['num_class'],
                                      self.opts['input_size'])
            target_model = models.cnn_noNorm(params).to(self.device)

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        model_name = self.model_path + '/target_model.pth'
        target_model.load_state_dict(
            torch.load(model_name, map_location=self.device))

        "set equal_eval mode of train instead eval for lstm to aviod backward error"
        if self.classifier_type == 'lstm':
            target_model = self.model_reset(target_model)
            target_model.train()
        elif self.classifier_type == 'cnn':
            target_model.eval()

        target_model.to(self.device)

        ###############################################

        for i, data in enumerate(test_data):

            if i < self.opts['num_figs']:

                for Adversary in ['FGSM', 'PGD', 'DeepFool', 'GAN']:

                    x, y = data
                    x, y = x.to(self.device), y.to(self.device)

                    "set adversary"
                    if self.mode == 'wf':
                        fgsm_epsilon = 0.1
                        pgd_a = 0.051
                    else:
                        fgsm_epsilon = 0.1
                        pgd_a = 0.01

                    if Adversary == 'GAN':
                        pretrained_generator_path = self.model_path + '/adv_generator.pth'
                        pretrained_G = models.Generator(
                            self.gen_input_nc, self.input_nc).to(self.device)
                        pretrained_G.load_state_dict(
                            torch.load(pretrained_generator_path,
                                       map_location=self.device))
                        pretrained_G.eval()
                    elif Adversary == 'FGSM':
                        adversary = FGSM(self.mode,
                                         self.x_box_min,
                                         self.x_box_max,
                                         self.pert_box,
                                         epsilon=fgsm_epsilon)
                    elif Adversary == 'DeepFool':
                        adversary = DeepFool(self.mode,
                                             self.x_box_min,
                                             self.x_box_max,
                                             self.pert_box,
                                             num_classes=5)
                    elif Adversary == 'PGD':
                        adversary = LinfPGDAttack(self.mode,
                                                  self.x_box_min,
                                                  self.x_box_max,
                                                  self.pert_box,
                                                  k=5,
                                                  a=pgd_a,
                                                  random_start=False)

                    "generate adversarial examples"
                    if Adversary == 'GAN':
                        pert = pretrained_G(x)
                        "cal adv_x given different mode wf/shs"
                        adv_x = utils_gan.get_advX_gan(
                            x,
                            pert,
                            self.mode,
                            pert_box=self.opts['pert_box'],
                            x_box_min=self.opts['x_box_min'],
                            x_box_max=self.opts['x_box_max'],
                            alpha=self.opts['alpha'])

                    elif Adversary in ['FGSM', 'PGD', 'DeepFool']:
                        _, y_pred = torch.max(target_model(x), 1)

                        "cal adv_x given different mode wf/shs. the mode of adversary set before"
                        adversary.model = target_model
                        "use predicted label to prevent label leaking"
                        _, adv_x = adversary.perturbation(
                            x, y_pred, self.opts['alpha'])

                    else:
                        print('No Adversary found! System will exit.')
                        sys.exit()

                    if self.mode == 'shs':
                        adv_x = (adv_x.data.cpu().numpy().squeeze() *
                                 1500).round()
                    elif self.mode == 'wf':

                        "if the data use L2 normalized, then need to inverse it back"
                        # normalization = utils_wf.normalizer(x)
                        # adv_x = normalization.inverse_Normalizer(adv_x)

                        adv_x = adv_x.data.cpu().numpy().squeeze()
                        adv_x = adv_x.squeeze()
                    else:
                        print('mode should in ["wf","shs"], system will exit.')
                        sys.exit()

                    x_np = x.data.cpu().numpy().squeeze()
                    pert = adv_x - x_np

                    # traffics.append(np.squeeze(x.data.cpu().numpy().squeeze()))
                    # adv_traffics.append(adv_x)
                    # noise.append(pert)

                    "plot"
                    utils_shs.single_traffic_plot(
                        self.classifier_type + '_' + str(i), x_np, adv_x,
                        Adversary)
                    utils_shs.noise_plot(self.classifier_type + '_' + str(i),
                                         pert, Adversary)
    def generate(self):
        "generate adv_x given x, append with its original label y instead with y_pert "

        "load data and target model"
        if self.mode in ['wf', 'wf_ow', 'wf_kf', 'wf_sg']:

            "load data"
            if self.mode != 'wf_sg':
                train_data = utils_wf.load_data_main(
                    self.opts['train_data_path'], self.opts['batch_size'])
                test_data = utils_wf.load_data_main(
                    self.opts['test_data_path'], self.opts['batch_size'])
            else:
                input_data = utils_wf.load_data_main(
                    self.opts['input_data_path'], self.opts['batch_size'])
                train_data, test_data = []
            if self.mode == 'wf_ow':
                test_data_UnMon = utils_wf.load_data_main(
                    self.opts['test_data_path_UnMon'], self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                if self.mode == 'wf_ow':
                    params = utils_wf.params_lstm_ow_eval(
                        self.opts['num_class'], self.opts['input_size'],
                        self.opts['batch_size'])
                else:
                    params = utils_wf.params_lstm_eval(self.opts['num_class'],
                                                       self.opts['input_size'],
                                                       self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        elif self.mode == 'shs':
            "load data"
            train_data = utils_shs.load_data_main(self.opts['train_data_path'],
                                                  self.opts['batch_size'])
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_shs.params_cnn(self.opts['num_class'],
                                              self.opts['input_size'])
                target_model = models.cnn_noNorm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                params = utils_shs.params_lstm_eval(self.opts['num_class'],
                                                    self.opts['input_size'],
                                                    self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        else:
            print(
                'mode not in ["wf","shs","wf_ow","wf_kf"], system will exit.')
            sys.exit()

        # load trained model
        if self.mode == 'wf_kf':
            model_name = '/target_model_%d.pth' % self.opts['id']
        else:
            model_name = '/target_model.pth'
        model_path = self.model_path + model_name
        target_model.load_state_dict(
            torch.load(model_path, map_location=self.device))

        "set adversary"
        Adversary = self.opts['Adversary']

        if self.mode == 'wf' or 'wf_ow' or 'wf_kf':
            fgsm_epsilon = 0.1
            pgd_a = 0.051  # if data un-normalized
            # pgd_a = 0.01    # if data normalized
        else:
            fgsm_epsilon = 0.1
            pgd_a = 0.01

        if Adversary == 'GAN':
            if self.mode == 'wf_kf':
                g_name = '/adv_generator_%d.pth' % self.opts['id']
            else:
                g_name = '/adv_generator.pth'
            pretrained_generator_path = self.model_path + g_name
            adversary = models.Generator(self.gen_input_nc,
                                         self.input_nc).to(self.device)
            adversary.load_state_dict(
                torch.load(pretrained_generator_path,
                           map_location=self.device))
            adversary.eval()
        elif Adversary == 'FGSM':
            adversary = FGSM(self.mode,
                             self.x_box_min,
                             self.x_box_max,
                             self.pert_box,
                             epsilon=fgsm_epsilon)
        elif Adversary == 'DeepFool':
            adversary = DeepFool(self.mode,
                                 self.x_box_min,
                                 self.x_box_max,
                                 self.pert_box,
                                 num_classes=5)
        elif Adversary == 'PGD':
            adversary = LinfPGDAttack(self.mode,
                                      self.x_box_min,
                                      self.x_box_max,
                                      self.pert_box,
                                      k=5,
                                      a=pgd_a,
                                      random_start=False)

        "produce adv_x given differen input data"
        # test data Monitored
        data_type = 'test'
        self.get_adv_x(test_data, Adversary, adversary, target_model,
                       data_type)

        # test data Unmonitored in open world setting
        if self.mode == 'wf_ow':
            data_type = 'test_UnMon'
            self.get_adv_x(test_data_UnMon, Adversary, adversary, target_model,
                           data_type)

        #train data
        data_type = 'train'
        self.get_adv_x(train_data, Adversary, adversary, target_model,
                       data_type)
    def train_model(self):

        if self.mode in ['wf', 'detect', 'wf_ow', 'wf_kf']:
            "load data"
            train_data = utils_wf.load_data_main(self.opts['train_data_path'],
                                                 self.opts['batch_size'],
                                                 shuffle=True)
            test_data = utils_wf.load_data_main(self.opts['test_data_path'],
                                                self.opts['batch_size'],
                                                shuffle=True)

            "load target model structure"
            params = utils_wf.params_cnn(self.opts['num_class'],
                                         self.opts['input_size'])
            target_model = models.cnn_norm(params).to(self.device)
            target_model.train()

        elif self.mode == 'shs':
            "load data"
            train_data = utils_shs.load_data_main(self.opts['train_data_path'],
                                                  self.opts['batch_size'])
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            params = utils_shs.params_cnn(self.opts['num_class'],
                                          self.opts['input_size'])
            target_model = models.cnn_noNorm(params).to(self.device)
            target_model.train()

        else:
            print(
                'mode not in ["wf","shs","wf_ow","detect","wf-kf"], system will exit.'
            )
            sys.exit()

        "train process"
        optimizer = torch.optim.Adam(target_model.parameters(), lr=0.001)
        for epoch in range(self.opts['epochs']):
            loss_epoch = 0
            for i, data in enumerate(train_data, 0):
                train_x, train_y = data
                train_x, train_y = train_x.to(self.device), train_y.to(
                    self.device)
                optimizer.zero_grad()
                logits_model = target_model(train_x)
                loss_model = F.cross_entropy(logits_model, train_y)
                loss_epoch += loss_model

                loss_model.backward()
                optimizer.step()

                if i % 100 == 0:
                    _, predicted = torch.max(logits_model, 1)
                    correct = int(sum(predicted == train_y))
                    accuracy = correct / len(train_y)
                    msg = 'Epoch {:5}, Step {:5}, Loss: {:6.2f}, Accuracy:{:8.2%}.'
                    print(msg.format(epoch, i, loss_model, accuracy))

            "save model every 10 epochs"
            if self.mode == 'wf_kf':
                model_name = '/target_model_%d.pth' % self.opts['id']
            else:
                model_name = '/target_model.pth'

            if epoch != 0 and epoch % 10 == 0:
                targeted_model_path = self.model_path + model_name
                torch.save(target_model.state_dict(), targeted_model_path)

        "test target model"
        target_model.eval()

        num_correct = 0
        total_instances = 0
        for i, data in enumerate(test_data, 0):
            test_x, test_y = data
            test_x, test_y = test_x.to(self.device), test_y.to(self.device)
            pred_lab = torch.argmax(target_model(test_x), 1)
            num_correct += torch.sum(pred_lab == test_y, 0)
            total_instances += len(test_y)

        print('accuracy of target model against test dataset: %f\n' %
              (num_correct.item() / total_instances))
    def test_peformance(self):

        "load data and target model"
        if self.mode == 'wf' or 'detect' or 'wf_ow' or 'wf_tf':
            "load data"
            test_data = utils_wf.load_data_main(self.opts['test_data_path'],
                                                self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                if self.mode == 'wf_ow':
                    params = utils_wf.params_lstm_ow(self.opts['num_class'],
                                                     self.opts['input_size'],
                                                     self.opts['batch_size'])
                elif self.mode == 'wf_kf':
                    params = utils_wf.params_lstm(self.opts['num_class'],
                                                  self.opts['input_size'],
                                                  self.opts['batch_size'])
                else:
                    params = utils_wf.params_lstm(self.opts['num_class'],
                                                  self.opts['input_size'],
                                                  self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        elif self.mode == 'shs':
            "load data"
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            params = utils_shs.params(self.opts['num_class'],
                                      self.opts['input_size'])
            target_model = models.cnn_noNorm(params).to(self.device)

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        if self.mode == 'wf_kf':
            model_name = '/target_model_%d.pth' % self.opts['id']
        else:
            model_name = '/target_model.pth'
        model_path = self.model_path + model_name
        print(model_path)
        target_model.load_state_dict(
            torch.load(model_path, map_location=self.device))
        target_model.eval()

        "testing process..."

        num_correct = 0
        total_case = 0
        y_test = []
        y_pred = []
        for i, data in enumerate(test_data, 0):
            test_x, test_y = data
            test_x, test_y = test_x.to(self.device), test_y.to(self.device)
            "add softmax after fc of model to normalize output as positive values and sum =1"
            pred_lab = torch.argmax(torch.softmax(target_model(test_x), 1), 1)
            # pred_lab = torch.argmax(target_model(test_x), 1)

            num_correct += torch.sum(pred_lab == test_y, 0)
            total_case += len(test_y)

            "save result"
            y_test += (test_y.cpu().numpy().tolist())
            y_pred += (pred_lab.cpu().numpy().tolist())

        # print('accuracy in testing set: %f\n' % (num_correct.item() / total_case))
        # print(classification_report(y_test, y_pred))
        print('confusion matrix is {}'.format(confusion_matrix(y_test,
                                                               y_pred)))
        print('accuracy is {}'.format(metrics.accuracy_score(y_test, y_pred)))
    def test_model(self):

        "load data and model"
        if self.mode == 'wf':
            "load data"
            test_data = utils_wf.load_data_main(self.opts['test_data_path'],
                                                self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                params = utils_wf.params_lstm_eval(self.opts['num_class'],
                                                   self.opts['input_size'],
                                                   self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        elif self.mode == 'shs':
            "load data"
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            params = utils_shs.params(self.opts['num_class'],
                                      self.opts['input_size'])
            target_model = models.cnn_noNorm(params).to(self.device)

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        if self.target_model_type == 'adv_target_model':
            model_name = self.model_path + '/adv_target_model_' + self.opts[
                'Adversary'] + '.pth'
        elif self.target_model_type == 'target_model':
            model_name = self.model_path + '/target_model.pth'
        else:
            print(
                'target model type not in ["target_model","adv_target_model"], system will exit.'
            )
            sys.exit()

        target_model.load_state_dict(
            torch.load(model_name, map_location=self.device))

        "set equal_eval mode of train instead eval for lstm"
        if self.classifier_type == 'lstm':
            target_model = self.model_reset(target_model)
            target_model.train()
        elif self.classifier_type == 'cnn':
            target_model.eval()

        target_model.to(self.device)

        "set adversary"
        Adversary = self.opts['Adversary']

        if self.mode == 'wf':
            fgsm_epsilon = 0.1
            pgd_a = 0.051
        elif self.mode == 'shs':
            fgsm_epsilon = 0.1
            pgd_a = 0.01

        if Adversary == 'GAN':
            pretrained_generator_path = self.model_path + '/adv_generator.pth'
            pretrained_G = models.Generator(self.gen_input_nc,
                                            self.input_nc).to(self.device)
            pretrained_G.load_state_dict(
                torch.load(pretrained_generator_path,
                           map_location=self.device))
            pretrained_G.eval()
        elif Adversary == 'FGSM':
            adversary = FGSM(self.mode,
                             self.x_box_min,
                             self.x_box_max,
                             self.pert_box,
                             epsilon=fgsm_epsilon)
        elif Adversary == 'DeepFool':
            adversary = DeepFool(self.mode,
                                 self.x_box_min,
                                 self.x_box_max,
                                 self.pert_box,
                                 num_classes=5)
        elif Adversary == 'PGD':
            adversary = LinfPGDAttack(self.mode,
                                      self.x_box_min,
                                      self.x_box_max,
                                      self.pert_box,
                                      k=5,
                                      a=pgd_a,
                                      random_start=False)

        "test on adversarial examples"
        num_correct = 0
        correct_x = 0
        total_case = 0
        for i, data in enumerate(test_data, 0):
            test_x, test_y = data
            test_x, test_y = test_x.to(self.device), test_y.to(self.device)

            "prediction on original input x"
            pred_y = target_model(test_x)
            _, pred_y = torch.max(pred_y, 1)

            "prediction on adversarial x"
            if Adversary in ['FGSM', 'DeepFool', 'PGD']:
                adversary.model = target_model

                "use predicted label to prevent label leaking"
                adv_y, adv_x = adversary.perturbation(test_x, pred_y,
                                                      self.opts['alpha'])
                # adv_y,adv_x = adversary.perturbation(test_x,test_y,self.opts['alpha'])

            elif Adversary == 'GAN':
                pert = pretrained_G(test_x)
                adv_x = utils_gan.get_advX_gan(test_x, pert, self.mode,
                                               self.pert_box, self.x_box_min,
                                               self.x_box_max,
                                               self.opts['alpha'])
                adv_y = torch.argmax(target_model(adv_x.to(self.device)), 1)

            num_correct += torch.sum(adv_y == test_y, 0)
            correct_x += (pred_y == test_y).sum()
            total_case += len(test_y)

        acc = float(num_correct.item()) / float(total_case)

        print('*' * 30)
        print('"{}" with {} against {}.'.format(self.mode, Adversary,
                                                self.target_model_type))
        print('correct test after attack is {}'.format(num_correct.item()))
        print('total test instances is {}'.format(total_case))
        print(
            'accuracy of test after {} attack : correct/total= {:.5f}'.format(
                Adversary, acc))
        print('success rate of the attack is : {}'.format(1 - acc))
        print('accucary of the model without being attacked is {:.5f}'.format(
            float(correct_x) / float(total_case)))
        print('\n')