def adv_train_process(self,delay=0.5):
        "Adversarial training, returns pertubed mini batch"
        "delay: parameter to decide how many epochs should be used as adv training"

        if self.opts['Adversary'] == 'WT':
            adv_train_data_path = '../data/WalkieTalkie/defended_csv/adv_train_WT.csv'
        else:
            adv_train_data_path = self.data_path + self.classifier_type + '/' + self.opts['adv_train_data_path']
        print('adv train data path: ',adv_train_data_path)

        if self.mode in ['wf','wf_ow','wf_kf']:
            "load data"
            if self.mode != 'wf_kf':
                train_path = self.data_path + self.opts['train_data_path']
                train_data = utils_wf.load_data_main(train_path,self.opts['batch_size'],shuffle=True)
                # test_data = utils_wf.load_data_main(self.data_path + self.opts['test_data_path'],self.opts['batch_size'])
            else:
                # benign_path = '../data/wf/cross_val/'
                # train_path =  benign_path + self.opts['train_data_path']
                train_path = '../data/wf/train_NoDef_burst.csv'
                train_data = utils_wf.load_data_main(train_path,self.opts['batch_size'], shuffle=True)
                # test_data = utils_wf.load_data_main(benign_path + self.opts['test_data_path'],self.opts['batch_size'])
            print('train data path: ',train_path)

            adv_train_data = utils_wf.load_data_main(adv_train_data_path,self.opts['batch_size'],shuffle=True)

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'], self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)
                target_model.train()

            elif self.classifier_type == 'lstm':
                if self.mode == 'wf_ow':
                    params = utils_wf.params_lstm_ow(self.opts['num_class'], self.opts['input_size'],self.opts['batch_size'])
                else:
                    params = utils_wf.params_lstm(self.opts['num_class'], self.opts['input_size'],self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)
                target_model.train()


        elif self.mode == 'shs':
            "load data"
            train_data = utils_shs.load_data_main(self.data_path + self.opts['train_data_path'],self.opts['batch_size'])
            test_data = utils_shs.load_data_main(self.data_path + self.opts['test_data_path'],self.opts['batch_size'])
            adv_train_data = utils_shs.load_data_main(adv_train_data_path, self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_shs.params_cnn(self.opts['num_class'], self.opts['input_size'])
                target_model = models.cnn_noNorm(params).to(self.device)
                target_model.train()

            elif self.classifier_type == 'lstm':
                params = utils_shs.params_lstm(self.opts['num_class'], self.opts['input_size'], self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)
                target_model.train()

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        loss_criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(target_model.parameters(), lr=1e-3)


        "start training"
        steps = 0
        flag = False
        start_time = datetime.now()
        # print('adversary working dir {}'.format(adv_train_data_path))
        for epoch in range(self.opts['epochs']):
            print('Starting epoch %d / %d' % (epoch + 1, self.opts['epochs']))

            if flag:
                print('{} based adversarial training...'.format(self.opts['Adversary']))

            for (x,y),(x_adv,y_adv) in zip(train_data,adv_train_data):
                steps += 1
                optimizer.zero_grad()

                x,y = x.to(self.device), y.to(self.device)
                outputs = target_model(x)
                loss = loss_criterion(outputs,y)

                "adversarial training"
                if epoch + 1 >= int((1-delay)*self.opts['epochs']):
                    x_adv, y_adv = x_adv.to(self.device), y_adv.to(self.device)
                    flag = True
                    loss_adv = loss_criterion(target_model(x_adv),y_adv)
                    loss = (loss + loss_adv) / 2

                "print results every 100 steps"
                if steps % 100 == 0:

                    end_time = datetime.now()
                    time_diff = (end_time - start_time).seconds
                    time_usage = '{:3}m{:3}s'.format(int(time_diff / 60), time_diff % 60)
                    msg = "Step {:5}, Loss:{:6.2f}, Time usage:{:9}."
                    print(msg.format(steps, loss, time_usage))

                loss.backward()
                optimizer.step()

            if epoch!=0 and epoch % 10 == 0 or epoch == self.opts['epochs']:
                if self.mode != 'wf_kf':
                    output_name = '/adv_target_model_%s.pth' % self.opts['Adversary']
                else:
                    output_name = '/adv_target_model_%s_%d.pth' % (self.opts['Adversary'], self.opts['id'])
                output_path = self.model_path + output_name
                torch.save(target_model.state_dict(), output_path)


        "test trianed model"
    def test_peformance(self, threshold_val):

        "load data and target model"
        if self.mode == 'wf_ow':
            "load data"
            if self.opts['test_data_type'] == 'NoDef':
                test_data_Mon = utils_wf.load_data_main(
                    self.opts['test_data_Mon_path'], self.opts['batch_size'])
                test_data_UnMon = utils_wf.load_data_main(
                    self.opts['test_data_UnMon_path'], self.opts['batch_size'])
            elif self.opts['test_data_type'] == 'Def':
                test_data_Mon = utils_wf.load_data_main(
                    self.opts['adv_test_data_Mon_path'],
                    self.opts['batch_size'])
                test_data_UnMon = utils_wf.load_data_main(
                    self.opts['adv_test_data_UnMon_path'],
                    self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                params = utils_wf.params_lstm_ow_eval(self.opts['num_class'],
                                                      self.opts['input_size'],
                                                      self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        else:
            print('mode not in ["wf_ow"], system will exit.')
            sys.exit()

        if self.model_type == 'adv_target_model':
            model_name = self.model_path + '/adv_target_model_' + self.opts[
                'Adversary'] + '.pth'
        elif self.model_type == 'target_model':
            model_name = self.model_path + '/target_model.pth'
        else:
            print(
                'target model type not in ["target_model","adv_target_model"], system will exit.'
            )
            sys.exit()
        target_model.load_state_dict(
            torch.load(model_name, map_location=self.device))
        target_model.eval()

        print('target model %s' % model_name)
        print('test data path unmonitored %s ' %
              self.opts['adv_test_data_UnMon_path'])
        print('test data path monitored %s ' %
              self.opts['adv_test_data_Mon_path'])

        # ==============================================================
        "testing process..."
        TP = 0
        FP = 0
        TN = 0
        FN = 0

        "obtain monitored and Unmonitored lables"
        _, y_test_Mon = utils_wf.load_csv_data(self.opts['test_data_Mon_path'])
        _, y_test_UnMon = utils_wf.load_csv_data(
            self.opts['test_data_UnMon_path'])
        monitored_labels = torch.Tensor(y_test_Mon).long().to(self.device)
        unmonitored_labels = torch.Tensor(y_test_UnMon).long().to(self.device)

        # ==============================================================
        # Test with Monitored testing instances

        for i, data in enumerate(test_data_Mon, 0):
            test_x, test_y = data
            test_x, test_y = test_x.to(self.device), test_y.to(self.device)
            max_probs, pred_labels = torch.max(
                torch.softmax(target_model(test_x), 1), 1)

            for j, pred_label in enumerate(pred_labels):
                if pred_label in monitored_labels:  # predited as Monitored
                    if max_probs[
                            j] >= threshold_val:  # probability greater than the threshold
                        TP += 1
                    else:  # predicted as unmonitored and true lable is Monitored
                        FN += 1
                elif pred_label in unmonitored_labels:  # predicted as unmonitored and true lable is monitored
                    FN += 1

        # ==============================================================
        # Test with Unmonitored testing instances
        for i, data in enumerate(test_data_UnMon, 0):
            test_x, test_y = data
            test_x, test_y = test_x.to(self.device), test_y.to(self.device)
            max_probs, pred_labels = torch.max(
                torch.softmax(target_model(test_x), 1), 1)

            for j, pred_label in enumerate(pred_labels):
                if pred_label in unmonitored_labels:  # predited as unmonitored and true label is unmonitored
                    TN += 1
                elif pred_label in monitored_labels:  # predicted as Monitored
                    if max_probs[
                            j] >= threshold_val:  # predicted as monitored and true label is unmonitored
                        FP += 1
                    else:
                        TN += 1

        "print result"
        print("TP : ", TP)
        print("FP : ", FP)
        print("TN : ", TN)
        print("FN : ", FN)
        print("Total  : ", TP + FP + TN + FN)
        TPR = float(TP) / (TP + FN)
        print("TPR : ", TPR)
        FPR = float(FP) / (FP + TN)
        print("FPR : ", FPR)
        Precision = float(TP) / (TP + FP)
        print("Precision : ", Precision)
        Recall = float(TP) / (TP + FN)
        print("Recall : ", Recall)
        print("\n")

        self.log_file.writelines(
            "%.6f,%d,%d,%d,%d,%.6f,%.6f,%.6f,%.6f\n" %
            (threshold_val, TP, FP, TN, FN, TPR, FPR, Precision, Recall))
    def train(self,train_dataloader):

        "load target model"
        if self.mode == 'wf' or 'wf_ow' or 'wf_kf':

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'], self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                "eval: lstm can't work on eval model, so set dropout as 0 to make train model as eval model"
                if self.mode == 'wf_ow':
                    params = utils_wf.params_lstm_ow_eval(self.opts['num_class'], self.opts['input_size'],self.opts['batch_size'])
                else:
                    params = utils_wf.params_lstm_eval(self.opts['num_class'], self.opts['input_size'], self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)


        elif self.mode == 'shs':

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_shs.params_cnn(self.opts['num_class'], self.opts['input_size'])
                target_model = models.cnn_noNorm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                "eval: lstm can't work on eval model, so set dropout as 0 to make train model as eval model"
                params = utils_shs.params_lstm_eval(self.opts['num_class'], self.opts['input_size'], self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        else:
            print('mode not in ["wf","shs","wf_ow","wf_kf"], system will exit.')
            sys.exit()

        if self.mode == 'wf_kf':
            model_name = '/target_model_%d.pth' % self.opts['id']
        else:
            model_name = '/target_model.pth'
        model_path = self.model_path + model_name
        target_model.load_state_dict(torch.load(model_path, map_location=self.device))

        "set equal_eval mode of train instead eval for lstm"
        if self.classifier_type == 'lstm':
            target_model = self.model_reset(target_model)
            target_model.train()
        elif self.classifier_type == 'cnn':
            target_model.eval()

        target_model.to(self.device)

        for epoch in range(1, self.opts['epochs'] + 1):

            if epoch == 50:
                self.optimizer_G = torch.optim.Adam(self.generator.parameters(),
                                                    lr=0.0001)
                self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(),
                                                    lr=0.0001)
            if epoch == 80:
                self.optimizer_G = torch.optim.Adam(self.generator.parameters(),
                                                    lr=0.00001)
                self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(),
                                                    lr=0.00001)
            loss_D_sum = 0
            loss_G_fake_sum = 0
            loss_perturb_sum = 0
            loss_adv_sum = 0
            for i, data in enumerate(train_dataloader, start=0):
                train_x, train_y = data
                train_x, train_y = train_x.to(self.device), train_y.to(self.device)

                loss_D_batch, loss_G_fake_batch, loss_perturb_batch, loss_adv_batch = \
                    self.train_batch(train_x, train_y,target_model)
                loss_D_sum += loss_D_batch
                loss_G_fake_sum += loss_G_fake_batch
                loss_perturb_sum += loss_perturb_batch
                loss_adv_sum += loss_adv_batch

            "print statistics"
            num_batch = len(train_dataloader)
            print("epoch %d:\nloss_D: %.3f, loss_G_fake: %.3f,\
             \nloss_perturb: %.3f, loss_adv: %.3f, \n" %
                  (epoch, loss_D_sum / num_batch, loss_G_fake_sum / num_batch,
                   loss_perturb_sum / num_batch, loss_adv_sum / num_batch))

            "save generator"
            if epoch != 0 and epoch % 10 == 0:
                if self.mode == 'wf_kf':
                    out_model_name = '/adv_generator_%d.pth' % self.opts['id']
                else:
                    out_model_name = '/adv_generator.pth'
                torch.save(self.generator.state_dict(), self.model_path + out_model_name)
Ejemplo n.º 4
0
    def test_model(self):

        "define test data type/path"
        # classifier cross validation:
        if self.opts['cross_validation']:
            if self.classifier_type == 'lstm':
                adv_test_data_path = '../data/' + self.mode + '/cnn/' + self.opts[
                    'adv_test_data_path']
            elif self.classifier_type == 'cnn':
                adv_test_data_path = '../data/' + self.mode + '/lstm/' + self.opts[
                    'adv_test_data_path']
            else:
                print(
                    'classifier type should in [lstm,cnn]. System will exit!')
                sys.exit()
        # test on Walkie-Talkie defended data
        elif self.opts['Adversary'] == 'WT':
            adv_test_data_path = '../data/WalkieTalkie/defended_csv/adv_test_WT.csv'
        else:
            adv_test_data_path = self.data_path + '/' + self.opts[
                'adv_test_data_path']

        print('adv test data path:  ', adv_test_data_path)

        "load data and model"
        if self.mode == 'wf' or 'wf_ow' or 'wf_kf':
            "load data"
            if self.mode == 'wf_kf':
                test_path = '../data/wf/cross_val/' + self.opts[
                    'test_data_path']
                test_data = utils_wf.load_data_main(test_path,
                                                    self.opts['batch_size'])
            else:
                test_path = '../data/wf/' + self.opts['test_data_path']
                test_data = utils_wf.load_data_main(test_path,
                                                    self.opts['batch_size'])
            adv_test_data = utils_wf.load_data_main(adv_test_data_path,
                                                    self.opts['batch_size'])
            print('test data path: ', test_path)

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                if self.mode == 'wf_ow':
                    params = utils_wf.params_lstm_ow_eval(
                        self.opts['num_class'], self.opts['input_size'],
                        self.opts['batch_size'])
                else:
                    params = utils_wf.params_lstm_eval(self.opts['num_class'],
                                                       self.opts['input_size'],
                                                       self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        elif self.mode == 'shs':
            "load data"
            test_data = utils_shs.load_data_main(
                '../data/' + self.mode + '/' + self.opts['test_data_path'],
                self.opts['batch_size'])
            adv_test_data = utils_shs.load_data_main(adv_test_data_path,
                                                     self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_shs.params_cnn(self.opts['num_class'],
                                              self.opts['input_size'])
                target_model = models.cnn_noNorm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                params = utils_shs.params_lstm_eval(self.opts['num_class'],
                                                    self.opts['input_size'],
                                                    self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        if self.model_type == 'adv_target_model':
            if self.mode == 'wf_kf':
                model_name = self.model_path + '/adv_target_model_%s_%d.pth' % (
                    self.opts['Adversary'], self.opts['id'])
            else:
                model_name = self.model_path + '/adv_target_model_' + self.opts[
                    'Adversary'] + '.pth'
        elif self.model_type == 'target_model':
            if self.mode == 'wf_kf':
                model_name = self.model_path + '/target_model_%d.pth' % self.opts[
                    'id']
            else:
                model_name = self.model_path + '/target_model.pth'
        else:
            print(
                'target model type not in ["target_model","adv_target_model"], system will exit.'
            )
            sys.exit()
        print('model path: ', model_name)

        target_model.load_state_dict(
            torch.load(model_name, map_location=self.device))
        target_model.eval()

        "test on adversarial examples"
        correct_adv_x = 0
        correct_x = 0
        total_case = 0
        for (x, y), (adv_x, adv_y) in zip(test_data, adv_test_data):

            x, y = x.to(self.device), y.to(self.device)
            adv_x, adv_y = adv_x.to(self.device), adv_y.to(self.device)

            "prediction on original input x"
            pred_x = target_model(x)
            _, pred_x = torch.max(pred_x, 1)
            correct_x += (pred_x == y).sum()

            "predition on adv_x"
            pred_adv_x = target_model(adv_x)
            _, pred_adv_x = torch.max(pred_adv_x, 1)
            correct_adv_x += (pred_adv_x == adv_y).sum()

            total_case += len(y)

        acc_x = float(correct_x.item()) / float(total_case)
        acc_adv = float(correct_adv_x.item()) / float(total_case)

        print('*' * 30)
        print('"{}" with {} against {}.'.format(self.mode,
                                                self.opts['Adversary'],
                                                self.model_type))
        print('correct test after attack is {}'.format(correct_adv_x.item()))
        print('total test instances is {}'.format(total_case))
        print(
            'accuracy of test after {} attack : correct/total= {:.5f}'.format(
                self.opts['Adversary'], acc_adv))
        print('success rate of the attack is : {}'.format(1 - acc_adv))
        print('accucary of the model without being attacked is {:.5f}'.format(
            acc_x))
        print('\n')
Ejemplo n.º 5
0
    def ploting(self):

        "load data and model"
        if self.mode == 'wf':
            "load data"
            test_data = utils_wf.load_data_main(self.opts['test_data_path'],
                                                self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                params = utils_wf.params_lstm_eval(self.opts['num_class'],
                                                   self.opts['input_size'],
                                                   self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        elif self.mode == 'shs':
            "load data"
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            params = utils_shs.params(self.opts['num_class'],
                                      self.opts['input_size'])
            target_model = models.cnn_noNorm(params).to(self.device)

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        model_name = self.model_path + '/target_model.pth'
        target_model.load_state_dict(
            torch.load(model_name, map_location=self.device))

        "set equal_eval mode of train instead eval for lstm to aviod backward error"
        if self.classifier_type == 'lstm':
            target_model = self.model_reset(target_model)
            target_model.train()
        elif self.classifier_type == 'cnn':
            target_model.eval()

        target_model.to(self.device)

        ###############################################

        for i, data in enumerate(test_data):

            if i < self.opts['num_figs']:

                for Adversary in ['FGSM', 'PGD', 'DeepFool', 'GAN']:

                    x, y = data
                    x, y = x.to(self.device), y.to(self.device)

                    "set adversary"
                    if self.mode == 'wf':
                        fgsm_epsilon = 0.1
                        pgd_a = 0.051
                    else:
                        fgsm_epsilon = 0.1
                        pgd_a = 0.01

                    if Adversary == 'GAN':
                        pretrained_generator_path = self.model_path + '/adv_generator.pth'
                        pretrained_G = models.Generator(
                            self.gen_input_nc, self.input_nc).to(self.device)
                        pretrained_G.load_state_dict(
                            torch.load(pretrained_generator_path,
                                       map_location=self.device))
                        pretrained_G.eval()
                    elif Adversary == 'FGSM':
                        adversary = FGSM(self.mode,
                                         self.x_box_min,
                                         self.x_box_max,
                                         self.pert_box,
                                         epsilon=fgsm_epsilon)
                    elif Adversary == 'DeepFool':
                        adversary = DeepFool(self.mode,
                                             self.x_box_min,
                                             self.x_box_max,
                                             self.pert_box,
                                             num_classes=5)
                    elif Adversary == 'PGD':
                        adversary = LinfPGDAttack(self.mode,
                                                  self.x_box_min,
                                                  self.x_box_max,
                                                  self.pert_box,
                                                  k=5,
                                                  a=pgd_a,
                                                  random_start=False)

                    "generate adversarial examples"
                    if Adversary == 'GAN':
                        pert = pretrained_G(x)
                        "cal adv_x given different mode wf/shs"
                        adv_x = utils_gan.get_advX_gan(
                            x,
                            pert,
                            self.mode,
                            pert_box=self.opts['pert_box'],
                            x_box_min=self.opts['x_box_min'],
                            x_box_max=self.opts['x_box_max'],
                            alpha=self.opts['alpha'])

                    elif Adversary in ['FGSM', 'PGD', 'DeepFool']:
                        _, y_pred = torch.max(target_model(x), 1)

                        "cal adv_x given different mode wf/shs. the mode of adversary set before"
                        adversary.model = target_model
                        "use predicted label to prevent label leaking"
                        _, adv_x = adversary.perturbation(
                            x, y_pred, self.opts['alpha'])

                    else:
                        print('No Adversary found! System will exit.')
                        sys.exit()

                    if self.mode == 'shs':
                        adv_x = (adv_x.data.cpu().numpy().squeeze() *
                                 1500).round()
                    elif self.mode == 'wf':

                        "if the data use L2 normalized, then need to inverse it back"
                        # normalization = utils_wf.normalizer(x)
                        # adv_x = normalization.inverse_Normalizer(adv_x)

                        adv_x = adv_x.data.cpu().numpy().squeeze()
                        adv_x = adv_x.squeeze()
                    else:
                        print('mode should in ["wf","shs"], system will exit.')
                        sys.exit()

                    x_np = x.data.cpu().numpy().squeeze()
                    pert = adv_x - x_np

                    # traffics.append(np.squeeze(x.data.cpu().numpy().squeeze()))
                    # adv_traffics.append(adv_x)
                    # noise.append(pert)

                    "plot"
                    utils_shs.single_traffic_plot(
                        self.classifier_type + '_' + str(i), x_np, adv_x,
                        Adversary)
                    utils_shs.noise_plot(self.classifier_type + '_' + str(i),
                                         pert, Adversary)
    def generate(self):
        "generate adv_x given x, append with its original label y instead with y_pert "

        "load data and target model"
        if self.mode in ['wf', 'wf_ow', 'wf_kf', 'wf_sg']:

            "load data"
            if self.mode != 'wf_sg':
                train_data = utils_wf.load_data_main(
                    self.opts['train_data_path'], self.opts['batch_size'])
                test_data = utils_wf.load_data_main(
                    self.opts['test_data_path'], self.opts['batch_size'])
            else:
                input_data = utils_wf.load_data_main(
                    self.opts['input_data_path'], self.opts['batch_size'])
                train_data, test_data = []
            if self.mode == 'wf_ow':
                test_data_UnMon = utils_wf.load_data_main(
                    self.opts['test_data_path_UnMon'], self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                if self.mode == 'wf_ow':
                    params = utils_wf.params_lstm_ow_eval(
                        self.opts['num_class'], self.opts['input_size'],
                        self.opts['batch_size'])
                else:
                    params = utils_wf.params_lstm_eval(self.opts['num_class'],
                                                       self.opts['input_size'],
                                                       self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        elif self.mode == 'shs':
            "load data"
            train_data = utils_shs.load_data_main(self.opts['train_data_path'],
                                                  self.opts['batch_size'])
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_shs.params_cnn(self.opts['num_class'],
                                              self.opts['input_size'])
                target_model = models.cnn_noNorm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                params = utils_shs.params_lstm_eval(self.opts['num_class'],
                                                    self.opts['input_size'],
                                                    self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        else:
            print(
                'mode not in ["wf","shs","wf_ow","wf_kf"], system will exit.')
            sys.exit()

        # load trained model
        if self.mode == 'wf_kf':
            model_name = '/target_model_%d.pth' % self.opts['id']
        else:
            model_name = '/target_model.pth'
        model_path = self.model_path + model_name
        target_model.load_state_dict(
            torch.load(model_path, map_location=self.device))

        "set adversary"
        Adversary = self.opts['Adversary']

        if self.mode == 'wf' or 'wf_ow' or 'wf_kf':
            fgsm_epsilon = 0.1
            pgd_a = 0.051  # if data un-normalized
            # pgd_a = 0.01    # if data normalized
        else:
            fgsm_epsilon = 0.1
            pgd_a = 0.01

        if Adversary == 'GAN':
            if self.mode == 'wf_kf':
                g_name = '/adv_generator_%d.pth' % self.opts['id']
            else:
                g_name = '/adv_generator.pth'
            pretrained_generator_path = self.model_path + g_name
            adversary = models.Generator(self.gen_input_nc,
                                         self.input_nc).to(self.device)
            adversary.load_state_dict(
                torch.load(pretrained_generator_path,
                           map_location=self.device))
            adversary.eval()
        elif Adversary == 'FGSM':
            adversary = FGSM(self.mode,
                             self.x_box_min,
                             self.x_box_max,
                             self.pert_box,
                             epsilon=fgsm_epsilon)
        elif Adversary == 'DeepFool':
            adversary = DeepFool(self.mode,
                                 self.x_box_min,
                                 self.x_box_max,
                                 self.pert_box,
                                 num_classes=5)
        elif Adversary == 'PGD':
            adversary = LinfPGDAttack(self.mode,
                                      self.x_box_min,
                                      self.x_box_max,
                                      self.pert_box,
                                      k=5,
                                      a=pgd_a,
                                      random_start=False)

        "produce adv_x given differen input data"
        # test data Monitored
        data_type = 'test'
        self.get_adv_x(test_data, Adversary, adversary, target_model,
                       data_type)

        # test data Unmonitored in open world setting
        if self.mode == 'wf_ow':
            data_type = 'test_UnMon'
            self.get_adv_x(test_data_UnMon, Adversary, adversary, target_model,
                           data_type)

        #train data
        data_type = 'train'
        self.get_adv_x(train_data, Adversary, adversary, target_model,
                       data_type)
    def train_model(self):

        if self.mode in ['wf', 'detect', 'wf_ow', 'wf_kf']:
            "load data"
            train_data = utils_wf.load_data_main(self.opts['train_data_path'],
                                                 self.opts['batch_size'],
                                                 shuffle=True)
            test_data = utils_wf.load_data_main(self.opts['test_data_path'],
                                                self.opts['batch_size'],
                                                shuffle=True)

            "load target model structure"
            params = utils_wf.params_cnn(self.opts['num_class'],
                                         self.opts['input_size'])
            target_model = models.cnn_norm(params).to(self.device)
            target_model.train()

        elif self.mode == 'shs':
            "load data"
            train_data = utils_shs.load_data_main(self.opts['train_data_path'],
                                                  self.opts['batch_size'])
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            params = utils_shs.params_cnn(self.opts['num_class'],
                                          self.opts['input_size'])
            target_model = models.cnn_noNorm(params).to(self.device)
            target_model.train()

        else:
            print(
                'mode not in ["wf","shs","wf_ow","detect","wf-kf"], system will exit.'
            )
            sys.exit()

        "train process"
        optimizer = torch.optim.Adam(target_model.parameters(), lr=0.001)
        for epoch in range(self.opts['epochs']):
            loss_epoch = 0
            for i, data in enumerate(train_data, 0):
                train_x, train_y = data
                train_x, train_y = train_x.to(self.device), train_y.to(
                    self.device)
                optimizer.zero_grad()
                logits_model = target_model(train_x)
                loss_model = F.cross_entropy(logits_model, train_y)
                loss_epoch += loss_model

                loss_model.backward()
                optimizer.step()

                if i % 100 == 0:
                    _, predicted = torch.max(logits_model, 1)
                    correct = int(sum(predicted == train_y))
                    accuracy = correct / len(train_y)
                    msg = 'Epoch {:5}, Step {:5}, Loss: {:6.2f}, Accuracy:{:8.2%}.'
                    print(msg.format(epoch, i, loss_model, accuracy))

            "save model every 10 epochs"
            if self.mode == 'wf_kf':
                model_name = '/target_model_%d.pth' % self.opts['id']
            else:
                model_name = '/target_model.pth'

            if epoch != 0 and epoch % 10 == 0:
                targeted_model_path = self.model_path + model_name
                torch.save(target_model.state_dict(), targeted_model_path)

        "test target model"
        target_model.eval()

        num_correct = 0
        total_instances = 0
        for i, data in enumerate(test_data, 0):
            test_x, test_y = data
            test_x, test_y = test_x.to(self.device), test_y.to(self.device)
            pred_lab = torch.argmax(target_model(test_x), 1)
            num_correct += torch.sum(pred_lab == test_y, 0)
            total_instances += len(test_y)

        print('accuracy of target model against test dataset: %f\n' %
              (num_correct.item() / total_instances))
    def test_peformance(self):

        "load data and target model"
        if self.mode == 'wf' or 'detect' or 'wf_ow' or 'wf_tf':
            "load data"
            test_data = utils_wf.load_data_main(self.opts['test_data_path'],
                                                self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                if self.mode == 'wf_ow':
                    params = utils_wf.params_lstm_ow(self.opts['num_class'],
                                                     self.opts['input_size'],
                                                     self.opts['batch_size'])
                elif self.mode == 'wf_kf':
                    params = utils_wf.params_lstm(self.opts['num_class'],
                                                  self.opts['input_size'],
                                                  self.opts['batch_size'])
                else:
                    params = utils_wf.params_lstm(self.opts['num_class'],
                                                  self.opts['input_size'],
                                                  self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        elif self.mode == 'shs':
            "load data"
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            params = utils_shs.params(self.opts['num_class'],
                                      self.opts['input_size'])
            target_model = models.cnn_noNorm(params).to(self.device)

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        if self.mode == 'wf_kf':
            model_name = '/target_model_%d.pth' % self.opts['id']
        else:
            model_name = '/target_model.pth'
        model_path = self.model_path + model_name
        print(model_path)
        target_model.load_state_dict(
            torch.load(model_path, map_location=self.device))
        target_model.eval()

        "testing process..."

        num_correct = 0
        total_case = 0
        y_test = []
        y_pred = []
        for i, data in enumerate(test_data, 0):
            test_x, test_y = data
            test_x, test_y = test_x.to(self.device), test_y.to(self.device)
            "add softmax after fc of model to normalize output as positive values and sum =1"
            pred_lab = torch.argmax(torch.softmax(target_model(test_x), 1), 1)
            # pred_lab = torch.argmax(target_model(test_x), 1)

            num_correct += torch.sum(pred_lab == test_y, 0)
            total_case += len(test_y)

            "save result"
            y_test += (test_y.cpu().numpy().tolist())
            y_pred += (pred_lab.cpu().numpy().tolist())

        # print('accuracy in testing set: %f\n' % (num_correct.item() / total_case))
        # print(classification_report(y_test, y_pred))
        print('confusion matrix is {}'.format(confusion_matrix(y_test,
                                                               y_pred)))
        print('accuracy is {}'.format(metrics.accuracy_score(y_test, y_pred)))
    def test_model(self):

        "load data and model"
        if self.mode == 'wf':
            "load data"
            test_data = utils_wf.load_data_main(self.opts['test_data_path'],
                                                self.opts['batch_size'])

            "load target model structure"
            if self.classifier_type == 'cnn':
                params = utils_wf.params_cnn(self.opts['num_class'],
                                             self.opts['input_size'])
                target_model = models.cnn_norm(params).to(self.device)

            elif self.classifier_type == 'lstm':
                params = utils_wf.params_lstm_eval(self.opts['num_class'],
                                                   self.opts['input_size'],
                                                   self.opts['batch_size'])
                target_model = models.lstm(params).to(self.device)

        elif self.mode == 'shs':
            "load data"
            test_data = utils_shs.load_data_main(self.opts['test_data_path'],
                                                 self.opts['batch_size'])

            "load target model structure"
            params = utils_shs.params(self.opts['num_class'],
                                      self.opts['input_size'])
            target_model = models.cnn_noNorm(params).to(self.device)

        else:
            print('mode not in ["wf","shs"], system will exit.')
            sys.exit()

        if self.target_model_type == 'adv_target_model':
            model_name = self.model_path + '/adv_target_model_' + self.opts[
                'Adversary'] + '.pth'
        elif self.target_model_type == 'target_model':
            model_name = self.model_path + '/target_model.pth'
        else:
            print(
                'target model type not in ["target_model","adv_target_model"], system will exit.'
            )
            sys.exit()

        target_model.load_state_dict(
            torch.load(model_name, map_location=self.device))

        "set equal_eval mode of train instead eval for lstm"
        if self.classifier_type == 'lstm':
            target_model = self.model_reset(target_model)
            target_model.train()
        elif self.classifier_type == 'cnn':
            target_model.eval()

        target_model.to(self.device)

        "set adversary"
        Adversary = self.opts['Adversary']

        if self.mode == 'wf':
            fgsm_epsilon = 0.1
            pgd_a = 0.051
        elif self.mode == 'shs':
            fgsm_epsilon = 0.1
            pgd_a = 0.01

        if Adversary == 'GAN':
            pretrained_generator_path = self.model_path + '/adv_generator.pth'
            pretrained_G = models.Generator(self.gen_input_nc,
                                            self.input_nc).to(self.device)
            pretrained_G.load_state_dict(
                torch.load(pretrained_generator_path,
                           map_location=self.device))
            pretrained_G.eval()
        elif Adversary == 'FGSM':
            adversary = FGSM(self.mode,
                             self.x_box_min,
                             self.x_box_max,
                             self.pert_box,
                             epsilon=fgsm_epsilon)
        elif Adversary == 'DeepFool':
            adversary = DeepFool(self.mode,
                                 self.x_box_min,
                                 self.x_box_max,
                                 self.pert_box,
                                 num_classes=5)
        elif Adversary == 'PGD':
            adversary = LinfPGDAttack(self.mode,
                                      self.x_box_min,
                                      self.x_box_max,
                                      self.pert_box,
                                      k=5,
                                      a=pgd_a,
                                      random_start=False)

        "test on adversarial examples"
        num_correct = 0
        correct_x = 0
        total_case = 0
        for i, data in enumerate(test_data, 0):
            test_x, test_y = data
            test_x, test_y = test_x.to(self.device), test_y.to(self.device)

            "prediction on original input x"
            pred_y = target_model(test_x)
            _, pred_y = torch.max(pred_y, 1)

            "prediction on adversarial x"
            if Adversary in ['FGSM', 'DeepFool', 'PGD']:
                adversary.model = target_model

                "use predicted label to prevent label leaking"
                adv_y, adv_x = adversary.perturbation(test_x, pred_y,
                                                      self.opts['alpha'])
                # adv_y,adv_x = adversary.perturbation(test_x,test_y,self.opts['alpha'])

            elif Adversary == 'GAN':
                pert = pretrained_G(test_x)
                adv_x = utils_gan.get_advX_gan(test_x, pert, self.mode,
                                               self.pert_box, self.x_box_min,
                                               self.x_box_max,
                                               self.opts['alpha'])
                adv_y = torch.argmax(target_model(adv_x.to(self.device)), 1)

            num_correct += torch.sum(adv_y == test_y, 0)
            correct_x += (pred_y == test_y).sum()
            total_case += len(test_y)

        acc = float(num_correct.item()) / float(total_case)

        print('*' * 30)
        print('"{}" with {} against {}.'.format(self.mode, Adversary,
                                                self.target_model_type))
        print('correct test after attack is {}'.format(num_correct.item()))
        print('total test instances is {}'.format(total_case))
        print(
            'accuracy of test after {} attack : correct/total= {:.5f}'.format(
                Adversary, acc))
        print('success rate of the attack is : {}'.format(1 - acc))
        print('accucary of the model without being attacked is {:.5f}'.format(
            float(correct_x) / float(total_case)))
        print('\n')