def run():
    args = parser.parse_args()
    nlayer = args.nlayer
    bidirection = args.bidirection
    file_path = args.file_path  #'/content/drive/My Drive/Master_Final_Project/Genetic_attack/Code/nlp_adversarial_example_master_pytorch/glove.840B.300d.txt'#'/lustre/scratch/scratch/ucabdc3/lstm_attack'
    save_path = os.path.join(file_path, 'results')
    MAX_VOCAB_SIZE = 50000
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #    with open(os.path.join(file_path, 'dataset_%d.pkl' %MAX_VOCAB_SIZE), 'rb') as f:
    #        dataset = pickle.load(f)

    with open('aux_files/dataset_%d.pkl' % MAX_VOCAB_SIZE, 'rb') as f:
        dataset = pickle.load(f)


#    skip_list = np.load('aux_files/missed_embeddings_counter_%d.npy' %MAX_VOCAB_SIZE)
    embedding_matrix = np.load('aux_files/embeddings_glove_%d.npy' %
                               (MAX_VOCAB_SIZE))
    embedding_matrix = torch.tensor(embedding_matrix.T).to(device)
    dist = np.load(('aux_files/dist_counter_%d.npy' % (MAX_VOCAB_SIZE)))

    #    goog_lm = LM()

    # pytorch
    max_len = args.max_len
    #    padded_train_raw = pad_sequences(dataset.train_seqs2, maxlen = max_len, padding = 'post')
    padded_test_raw = pad_sequences(dataset.test_seqs2,
                                    maxlen=max_len,
                                    padding='post')
    #    # TrainSet
    #    data_set = Data_infor(padded_train_raw, dataset.train_y)
    #    num_train = len(data_set)
    #    indx = list(range(num_train))
    #    train_set = Subset(data_set, indx)

    # TestSet
    batch_size = 1
    SAMPLE_SIZE = args.sample_size
    data_set = Data_infor(padded_test_raw, dataset.test_y)
    num_test = len(data_set)
    indx = list(range(num_test))

    all_test_set = Subset(data_set, indx)
    #    indx = random.sample(indx, SAMPLE_SIZE)
    with open('seq_number.pkl', 'rb') as f:
        indx = pickle.load(f)

    test_set = Subset(data_set, indx)
    test_loader = DataLoader(test_set,
                             batch_size=batch_size,
                             shuffle=False,
                             pin_memory=True)
    all_test_loader = DataLoader(all_test_set, batch_size=128, shuffle=True)

    lstm_size = 128
    rnn_state_save = os.path.join(file_path, 'best_lstm_0.7_0.001_300')

    model = SentimentAnalysis(batch_size=batch_size,
                              embedding_matrix=embedding_matrix,
                              hidden_size=lstm_size,
                              kept_prob=0.7,
                              num_layers=nlayer,
                              bidirection=bidirection)

    model.load_state_dict(torch.load(rnn_state_save))
    model = model.to(device)

    model.eval()
    test_pred = torch.tensor([])
    test_targets = torch.tensor([])

    with torch.no_grad():
        for batch_index, (seqs, length, target) in enumerate(all_test_loader):
            seqs, target, length = seqs.to(device), target.to(
                device), length.to(device)
            seqs = seqs.type(torch.LongTensor)
            len_order = torch.argsort(length, descending=True)
            length = length[len_order]
            seqs = seqs[len_order]
            target = target[len_order]

            output, pred_out = model.pred(seqs, length, False)
            test_pred = torch.cat((test_pred, pred_out.cpu()), dim=0)
            test_targets = torch.cat(
                (test_targets, target.type(torch.float).cpu()))

        accuracy = model.evaluate_accuracy(test_pred.numpy(),
                                           test_targets.numpy())
    print('Test Accuracy:{:.4f}.'.format(accuracy))

    n1 = 8
    n2 = 4
    pop_size = 60
    max_iters = 20
    n_prefix = 5
    n_suffix = 5
    batch_model = SentimentAnalysis(batch_size=pop_size,
                                    embedding_matrix=embedding_matrix,
                                    hidden_size=lstm_size,
                                    kept_prob=0.7,
                                    num_layers=nlayer,
                                    bidirection=bidirection)

    batch_model.eval()
    batch_model.load_state_dict(torch.load(rnn_state_save))
    batch_model.to(device)

    neighbour_model = SentimentAnalysis(batch_size=n1,
                                        embedding_matrix=embedding_matrix,
                                        hidden_size=lstm_size,
                                        kept_prob=0.7,
                                        num_layers=nlayer,
                                        bidirection=bidirection)

    neighbour_model.eval()
    neighbour_model.load_state_dict(torch.load(rnn_state_save))
    neighbour_model.to(device)
    lm_model = gpt_2_get_words_probs()
    use_lm = args.use_lm
    ga_attack = GeneticAttack_pytorch(model,
                                      batch_model,
                                      neighbour_model,
                                      dist,
                                      lm_model,
                                      max_iters=max_iters,
                                      dataset=dataset,
                                      pop_size=pop_size,
                                      n1=n1,
                                      n2=n2,
                                      n_prefix=n_prefix,
                                      n_suffix=n_suffix,
                                      use_lm=use_lm,
                                      use_suffix=True)

    TEST_SIZE = args.test_size
    order_pre = 0
    n = 0
    seq_success = []
    seq_orig = []
    seq_orig_label = []
    word_varied = []
    orig_list = []
    adv_list = []
    dist_list = []
    l_list = []
    label_list = []

    # if order_pre != 0:
    #   seq_success = np.load(os.path.join(save_path,'seq_success.npy'), allow_pickle = True).tolist()
    #   seq_orig = np.load(os.path.join(save_path,'seq_orig.npy')).tolist()
    #   seq_orig_label = np.load(os.path.join(save_path,'seq_orig_label.npy')).tolist()
    #   word_varied = np.load(os.path.join(save_path,'word_varied.npy'), allow_pickle = True).tolist()
    #   n = len(seq_success)

    for order, (seq, l, target) in enumerate(test_loader):

        if order >= order_pre:

            seq_len = np.sum(np.sign(seq.numpy()))
            seq, l = seq.to(device), l.to(device)
            seq = seq.type(torch.LongTensor)
            model.eval()
            with torch.no_grad():
                preds = model.pred(seq, l, False)[1]
                orig_pred = np.argmax(preds.cpu().detach().numpy())
            if orig_pred != target.numpy()[0]:
                #print('Wrong original prediction')
                #print('----------------------')
                continue
            if seq_len > args.max_len:
                #print('Sequence is too long')
                #print('----------------------')
                continue
            print('Sequence number:{}'.format(order))
            print('Length of sentence: {}, Number of samples:{}'.format(
                l.item(), n + 1))
            print(preds)
            seq_orig.append(seq[0].numpy())
            seq_orig_label.append(target.numpy()[0])
            target = 1 - target.numpy()[0]
            # seq_success.append(ga_attack.attack(seq, target, l))

            # if None not in np.array(seq_success[n]):
            #   w_be = [dataset.inv_dict[seq_orig[n][i]] for i in list(np.where(seq_success[n] != seq_orig[n])[0])]
            #   w_to = [dataset.inv_dict[seq_success[n][i]] for i in list(np.where(seq_success[n] != seq_orig[n])[0])]
            #   for i in range(len(w_be)):
            #     print('{} ----> {}'.format(w_be[i], w_to[i]))
            #   word_varied.append([w_be]+[w_to])
            # else:
            #   print('Fail')
            # print('----------------------')
            # n += 1

            # np.save(os.path.join(save_path,'seq_success_1000.npy'), np.array(seq_success))
            # np.save(os.path.join(save_path,'seq_orig_1000.npy'), np.array(seq_orig))
            # np.save(os.path.join(save_path,'seq_orig_label_1000.npy'), np.array(seq_orig_label))
            # np.save(os.path.join(save_path,'word_varied_1000.npy'), np.array(word_varied))

            # if n>TEST_SIZE:
            #   break

            x_adv = ga_attack.attack(seq, target, l)
            orig_list.append(seq[0].numpy())
            adv_list.append(x_adv)
            label_list.append(1 - target)
            l_list.append(l.cpu().numpy()[0])
            if x_adv is None:
                print('%d failed' % (order))
                dist_list.append(100000)
            else:
                num_changes = np.sum(seq[0].numpy() != x_adv)
                print('%d - %d changed.' % (order, num_changes))
                dist_list.append(num_changes)
                # display_utils.visualize_attack(sess, model, dataset, x_orig, x_adv)
            print('--------------------------')

            n += 1
            #if n>TEST_SIZE:
            #  break
            orig_len = [np.sum(np.sign(x)) for x in orig_list]
            normalized_dist_list = [
                dist_list[i] / orig_len[i] for i in range(len(orig_list))
            ]
            SUCCESS_THRESHOLD = 0.25
            successful_attacks = [
                x < SUCCESS_THRESHOLD for x in normalized_dist_list
            ]
            print('Attack success rate : {:.2f}%'.format(
                np.mean(successful_attacks) * 100))
            SUCCESS_THRESHOLD = 0.2
            successful_attacks = [
                x < SUCCESS_THRESHOLD for x in normalized_dist_list
            ]
            print('Attack success rate : {:.2f}%'.format(
                np.mean(successful_attacks) * 100))
    output_path = 'attack_results_final_300_AL_' + str(max_len) + '.pkl'
    with open(output_path, 'wb') as f:
        pickle.dump(
            (orig_list, adv_list, l_list, label_list, normalized_dist_list), f)
def run():
    args = parser.parse_args()
    data = args.data
    nlayer = args.nlayer
    file_path = args.file_path  #'/content/drive/My Drive/Master_Final_Project/Genetic_attack/Code/nlp_adversarial_example_master_pytorch/glove.840B.300d.txt'#'/lustre/scratch/scratch/ucabdc3/lstm_attack'
    save_path = os.path.join(file_path, 'model_params')
    MAX_VOCAB_SIZE = 50000

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # with open(os.path.join(file_path, 'dataset_%d.pkl' %MAX_VOCAB_SIZE), 'rb') as f:
    #     dataset = pickle.load(f)
    with open('aux_files/dataset_%d.pkl' % MAX_VOCAB_SIZE, 'rb') as f:
        dataset = pickle.load(f)

    #    skip_list = np.load('aux_files/missed_embeddings_counter_%d.npy' %MAX_VOCAB_SIZE)
    embedding_matrix = np.load('aux_files/embeddings_glove_%d.npy' %
                               (MAX_VOCAB_SIZE))
    embedding_matrix = torch.tensor(embedding_matrix.T).to(device)
    # dist = np.load(('aux_files/dist_counter_%d.npy' %(MAX_VOCAB_SIZE)))
    # dist[0,:] = 100000
    # dist[:,0] = 100000
    #    goog_lm = LM()

    # pytorch
    max_len = 100
    #    padded_train_raw = pad_sequences(dataset.train_seqs2, maxlen = max_len, padding = 'post')
    #    padded_test_raw = pad_sequences(dataset.test_seqs2, maxlen = max_len, padding = 'post')
    #    # TrainSet
    #    data_set = Data_infor(padded_train_raw, dataset.train_y)
    #    num_train = len(data_set)
    #    indx = list(range(num_train))
    #    train_set = Subset(data_set, indx)
    if data.lower() == 'imdb':
        data_path = 'aclImdb'

    bert = BertModel.from_pretrained('bert-base-uncased')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    data_processed = pre_processing(data_path, MAX_VOCAB_SIZE, max_len)
    tokenizer_select = args.tokenizer
    tokenizer_selection = tokenizer_select
    if tokenizer_selection.lower() != 'bert':
        data_processed.processing()
        train_sequences, test_sequences = data_processed.bert_indx(tokenizer)
        print('Self preprocessing')
    else:
        data_processed.bert_tokenize(tokenizer)
        train_sequences, test_sequences = data_processed.bert_indx(tokenizer)
        print('BERT tokenizer')
    train_text_init, test_text_init = data_processed.numerical(
        tokenizer, train_sequences, test_sequences)

    #    train_text = pad_sequences(train_text_init, maxlen = max_len, padding = 'post')
    test_text = pad_sequences(test_text_init, maxlen=max_len, padding='post')
    # orig_test_text = pad_sequences(dataset.test_seqs2, maxlen = max_len, padding = 'post')
    #    train_target = data_processed.all_train_labels
    test_target = data_processed.all_test_labels
    SAMPLE_SIZE = args.sample_size
    test_data, all_test_data = data_loading(test_text, test_target,
                                            SAMPLE_SIZE)

    # TestSet
    batch_size = 1

    #    data_set = Data_infor(padded_test_raw, dataset.test_y)
    #    num_test = len(data_set)
    #    indx = list(range(num_test))
    #
    ##    all_test_set  = Subset(data_set, indx)
    #    indx = random.sample(indx, SAMPLE_SIZE)
    #    test_set = Subset(data_set, indx)
    #
    #    test_loader = DataLoader(test_set, batch_size = batch_size, shuffle = False)

    test_loader_bert = DataLoader(test_data,
                                  batch_size=batch_size,
                                  shuffle=False)
    all_test_loader_bert = DataLoader(all_test_data,
                                      batch_size=128,
                                      shuffle=True)

    lstm_size = 128
    rnn_state_save = os.path.join(save_path, 'best_bert_0.7_0.001_bert_200')
    model = bert_lstm(
        bert, 2, False, nlayer, lstm_size, True, 0.7
    )  # batch_size=batch_size, embedding_matrix = embedding_matrix, hidden_size = lstm_size, kept_prob = 0.73, num_layers=2, bidirection=True)
    model.eval()
    model.load_state_dict(torch.load(rnn_state_save))
    model = model.to(device)

    model.eval()
    test_pred = torch.tensor([])
    test_targets = torch.tensor([])

    with torch.no_grad():
        for batch_index, (seqs, length,
                          target) in enumerate(all_test_loader_bert):
            seqs = seqs.type(torch.LongTensor)
            len_order = torch.argsort(length, descending=True)
            length = length[len_order]
            seqs = seqs[len_order]
            target = target[len_order]
            seqs, target, length = seqs.to(device), target.to(
                device), length.to(device)

            output, pred_out = model.pred(seqs, length, False)
            test_pred = torch.cat((test_pred, pred_out.cpu()), dim=0)
            test_targets = torch.cat(
                (test_targets, target.type(torch.float).cpu()))

        accuracy = model.evaluate_accuracy(test_pred.numpy(),
                                           test_targets.numpy())
    print('Test Accuracy:{:.4f}.'.format(accuracy))
    # np.save(os.path.join(save_path,'accuracy.npy'), np.array(accuracy))
    print('\n')
    n1 = 8
    n2 = 4
    pop_size = 60
    max_iters = 20
    n_prefix = 6
    n_suffix = 6
    batch_model = bert_lstm(
        bert, 2, False, nlayer, lstm_size, True, 0.7
    )  #SentimentAnalysis(batch_size=pop_size, embedding_matrix = embedding_matrix, hidden_size = lstm_size, kept_prob = 0.73, num_layers=2, bidirection=True)

    batch_model.eval()
    batch_model.load_state_dict(torch.load(rnn_state_save))
    batch_model.to(device)

    neighbour_model = bert_lstm(
        bert, 2, False, nlayer, lstm_size, True, 0.7
    )  #SentimentAnalysis(batch_size=batch_size, embedding_matrix = embedding_matrix, hidden_size = lstm_size, kept_prob = 0.73, num_layers=2, bidirection=True)

    neighbour_model.eval()
    neighbour_model.load_state_dict(torch.load(rnn_state_save))
    neighbour_model.to(device)
    lm_model = gpt_2_get_words_probs()
    ga_attack = GeneticAttack_pytorch(model,
                                      batch_model,
                                      neighbour_model,
                                      compute_dis,
                                      lm_model,
                                      tokenizer=tokenizer,
                                      max_iters=max_iters,
                                      dataset=dataset,
                                      pop_size=pop_size,
                                      n1=n1,
                                      n2=n2,
                                      n_prefix=n_prefix,
                                      n_suffix=n_suffix,
                                      use_lm=True,
                                      use_suffix=True)

    #     TEST_SIZE = args.test_size
    #     order_pre = 0
    #     n = 0
    #     seq_success = []
    #     seq_orig = []
    #     seq_orig_label = []
    #     word_varied = []

    # #    seq_success_path = os.path.join(save_path,'seq_success_perplexity_bert.npy')
    # #    seq_orig_path = os.path.join(save_path,'seq_orig_perplexity_bert.npy')
    # #    seq_orig_label_path = os.path.join(save_path,'seq_orig_label_perplexity_bert.npy')
    # #    word_varied_path = os.path.join(save_path,'word_varied_perplexity_bert.npy')

    # #    if order_pre != 0:
    # #      seq_success = np.load(seq_success_path, allow_pickle = True).tolist()
    # #      seq_orig = np.load(seq_orig_path).tolist()
    # #      seq_orig_label = np.load(seq_orig_label_path).tolist()
    # #      word_varied = np.load(word_varied_path, allow_pickle = True).tolist()
    # #      n = len(seq_success)
    #     for order, (seq, l, target) in enumerate(test_loader_bert):

    #       if order>=order_pre:

    #         seq_len = np.sum(np.sign(seq.numpy()))
    #         seq = seq.type(torch.LongTensor)
    #         seq, l = seq.to(device), l.to(device)
    #         model.eval()
    #         with torch.no_grad():
    #           orig_pred = np.argmax(model.pred(seq, l).cpu().detach().numpy())
    #         if orig_pred != target.numpy()[0]:
    # #          print('Wrong original prediction')
    # #          print('----------------------')
    #           continue
    #         if seq_len > 100:
    # #          print('Sequence is too long')
    # #          print('----------------------')
    #           continue
    #         print('Sequence number:{}'.format(order))
    #         print('Length of sentence: {}, Number of samples:{}'.format(l.item(), n+1))
    #         seq_orig.append(seq[0].cpu().detach().numpy())
    #         seq_orig_label.append(target.numpy()[0])
    #         target = int(1-target.numpy()[0])
    #         seq_success.append(ga_attack.attack(seq, target, l.type(torch.LongTensor)))

    #         if None not in np.array(seq_success[n]):
    #           w_be = [dataset.inv_dict[seq_orig[n][i]] for i in list(np.where(seq_success[n] != seq_orig[n])[0])]
    #           w_to = [dataset.inv_dict[seq_success[n][i]] for i in list(np.where(seq_success[n] != seq_orig[n])[0])]
    #           for i in range(len(w_be)):
    #             print('{} ----> {}'.format(w_be[i], w_to[i]))
    #           word_varied.append([w_be]+[w_to])
    #         else:
    #           print('Fail')
    #         print('----------------------')
    #         n += 1

    #         np.save(seq_success_path, np.array(seq_success))
    #         np.save(seq_orig_path, np.array(seq_orig))
    #         np.save(seq_orig_label_path, np.array(seq_orig_label))
    #         np.save(word_varied_path, np.array(word_varied, dtype=object))

    #         if n>TEST_SIZE:
    #           break
    TEST_SIZE = args.test_size
    order_pre = 0
    n = 0
    seq_success = []
    seq_orig = []
    seq_orig_label = []
    word_varied = []
    orig_list = []
    adv_list = []
    dist_list = []

    # seq_success_path = os.path.join(save_path,'seq_success_perplexity_bert.npy')
    # seq_orig_path = os.path.join(save_path,'seq_orig_perplexity_bert.npy')
    # seq_orig_label_path = os.path.join(save_path,'seq_orig_label_perplexity_bert.npy')
    # word_varied_path = os.path.join(save_path,'word_varied_perplexity_bert.npy')

    # if order_pre != 0:
    #   seq_success = np.load(seq_success_path, allow_pickle = True).tolist()
    #   seq_orig = np.load(seq_orig_path).tolist()
    #   seq_orig_label = np.load(seq_orig_label_path).tolist()
    #   word_varied = np.load(word_varied_path, allow_pickle = True).tolist()
    #   n = len(seq_success)

    for order, (seq, l, target) in enumerate(test_loader_bert):

        if order >= order_pre:
            seq_len = np.sum(np.sign(seq.numpy()))
            seq = seq.type(torch.LongTensor)
            seq, l = seq.to(device), l.to(device)
            model.eval()
            with torch.no_grad():
                prediction = model.pred(seq, l,
                                        False)[1].cpu().detach().numpy()
                orig_pred = np.argmax(prediction)
            if orig_pred != target:
                # print('Wrong original prediction')
                # print('----------------------')
                continue
            if seq_len > 100:
                # print('Sequence is too long')
                # print('----------------------')
                continue

            print('Sequence number:{}'.format(order))
            print('Predicted value:{}'.format(prediction))
            print('Length of sentence: {}, Number of samples:{}'.format(
                l.item(), n + 1))
            # seq_orig.append(seq[0].cpu().detach().numpy())
            # seq_orig_label.append(target.numpy()[0])
            target = int(1 - target)
            # seq_success.append(ga_attack.attack(seq, target, l.type(torch.LongTensor)))

            # if None not in np.array(seq_success[n]):
            #   w_be = [dataset.inv_dict[seq_orig[n][i]] for i in list(np.where(seq_success[n] != seq_orig[n])[0])]
            #   w_to = [dataset.inv_dict[seq_success[n][i]] for i in list(np.where(seq_success[n] != seq_orig[n])[0])]
            #   for i in range(len(w_be)):
            #     print('{} ----> {}'.format(w_be[i], w_to[i]))
            #   word_varied.append([w_be]+[w_to])
            # else:
            #   print('Fail')
            # print('----------------------')
            # n += 1

            # np.save(seq_success_path, np.array(seq_success))
            # np.save(seq_orig_path, np.array(seq_orig))
            # np.save(seq_orig_label_path, np.array(seq_orig_label))
            # np.save(word_varied_path, np.array(word_varied, dtype=object))

            # if n>TEST_SIZE:
            #   break
            # orig_list.append(seq[0].cpu().detach().numpy())

            x_adv, seq_out = ga_attack.attack(seq, target,
                                              l.type(torch.LongTensor))
            orig_list.append(seq)
            adv_list.append(x_adv)
            if x_adv is None:
                print('%d failed' % (order))
                dist_list.append(100000)
            else:
                num_changes = np.sum(np.array(seq_out) != np.array(x_adv))
                print('%d - %d changed.' % (order, num_changes))
                dist_list.append(num_changes)
                # display_utils.visualize_attack(sess, model, dataset, x_orig, x_adv)
                w_be = [
                    seq_out[i] for i in list(
                        np.where(np.array(seq_out) != np.array(x_adv))[0])
                ]
                w_to = [
                    x_adv[i] for i in list(
                        np.where(np.array(seq_out) != np.array(x_adv))[0])
                ]
                for i in range(len(w_be)):
                    print('{} ----> {}'.format(w_be[i], w_to[i]))
            print('--------------------------')

            n += 1
            if n > TEST_SIZE:
                break
            orig_len = [x.shape[1] for x in orig_list]
            normalized_dist_list = [
                dist_list[i] / orig_len[i] for i in range(len(orig_list))
            ]
            SUCCESS_THRESHOLD = 0.25
            successful_attacks = [
                x <= SUCCESS_THRESHOLD for x in normalized_dist_list
            ]
            print('Attack success rate : {:.2f}%'.format(
                np.mean(successful_attacks) * 100))
            SUCCESS_THRESHOLD = 0.2
            successful_attacks = [
                x <= SUCCESS_THRESHOLD for x in normalized_dist_list
            ]
            print('Attack success rate : {:.2f}%'.format(
                np.mean(successful_attacks) * 100))
def run():
    args = parser.parse_args()
    file_path = args.file_path#'/content/drive/My Drive/Master_Final_Project/Genetic_attack/Code/nlp_adversarial_example_master_pytorch/glove.840B.300d.txt'#'/lustre/scratch/scratch/ucabdc3/lstm_attack'
    save_path = os.path.join(file_path, 'results')
    bidirection = args.bidirection
    nlayer = args.nlayer
    MAX_VOCAB_SIZE = 50000
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#    with open(os.path.join(file_path, 'dataset_%d.pkl' %MAX_VOCAB_SIZE), 'rb') as f:
#        dataset = pickle.load(f)
        
    with open('aux_files/dataset_%d.pkl' %MAX_VOCAB_SIZE, 'rb') as f:
        dataset = pickle.load(f)

    
#    skip_list = np.load('aux_files/missed_embeddings_counter_%d.npy' %MAX_VOCAB_SIZE)
    embedding_matrix = np.load('aux_files/embeddings_glove_%d.npy' %(MAX_VOCAB_SIZE))
    embedding_matrix = torch.tensor(embedding_matrix.T).to(device)
    dist = np.load(('aux_files/dist_counter_%d.npy' %(MAX_VOCAB_SIZE)))
    dist[0,:] = 100000
    dist[:,0] = 100000
#    goog_lm = LM()
    
    # pytorch
    max_len = 250
#    padded_train_raw = pad_sequences(dataset.train_seqs2, maxlen = max_len, padding = 'post')
    padded_test_raw = pad_sequences(dataset.test_seqs2, maxlen = max_len, padding = 'post')
#    # TrainSet
#    data_set = Data_infor(padded_train_raw, dataset.train_y)
#    num_train = len(data_set)
#    indx = list(range(num_train))
#    train_set = Subset(data_set, indx)
    
    # TestSet
    batch_size = 1
    SAMPLE_SIZE = args.sample_size
    data_set = Data_infor(padded_test_raw, dataset.test_y)
    num_test = len(data_set)
    indx = list(range(num_test))

    all_test_set  = Subset(data_set, indx)
    indx = random.sample(indx, SAMPLE_SIZE)
    test_set = Subset(data_set, indx)
    test_loader = DataLoader(test_set, batch_size = batch_size, shuffle = False)
    all_test_loader  = DataLoader(all_test_set, batch_size = 128, shuffle = True)
    
    lstm_size = 128
    rnn_state_save = os.path.join(file_path,'best_lstm_0.73_0.0005')
    model = SentimentAnalysis(batch_size=batch_size, embedding_matrix = embedding_matrix, hidden_size = lstm_size, kept_prob = 0.73, num_layers=nlayer, bidirection=bidirection)
    model.eval()
    model.load_state_dict(torch.load(rnn_state_save))
    model = model.to(device)

    model.eval()
    test_pred = torch.tensor([])
    test_targets = torch.tensor([])

    with torch.no_grad():
      for batch_index, (seqs, length, target) in enumerate(all_test_loader):
        seqs, target, length = seqs.to(device), target.to(device), length.to(device)
        seqs = seqs.type(torch.LongTensor)
        len_order = torch.argsort(length, descending = True)
        length = length[len_order]
        seqs = seqs[len_order]
        target = target[len_order]

        output, pred_out = model.pred(seqs, length)
        test_pred = torch.cat((test_pred, pred_out.cpu()), dim = 0)
        test_targets = torch.cat((test_targets, target.type(torch.float).cpu()))

      accuracy = model.evaluate_accuracy(test_pred.numpy(), test_targets.numpy())
    print('Test Accuracy:{:.4f}.'.format(accuracy))
    np.save(os.path.join(save_path,'accuracy.npy'), np.array(accuracy))
    print('\n')
    n1 =12
    n2 = 6
    pop_size = 80
    max_iters = 20
    n_prefix = 6
    n_suffix = 6
    batch_model = SentimentAnalysis(batch_size=pop_size, embedding_matrix = embedding_matrix, hidden_size = lstm_size, kept_prob = 0.73, num_layers=nlayer, bidirection=bidirection)
    
    batch_model.eval()
    batch_model.load_state_dict(torch.load(rnn_state_save))
    batch_model.to(device)
    
    neighbour_model = SentimentAnalysis(batch_size=batch_size, embedding_matrix = embedding_matrix, hidden_size = lstm_size, kept_prob = 0.73, num_layers=nlayer, bidirection=bidirection)
    
    neighbour_model.eval()
    neighbour_model.load_state_dict(torch.load(rnn_state_save))
    neighbour_model.to(device)
    lm_model = gpt_2_get_words_probs()
    ga_attack = GeneticAttack_pytorch(model, batch_model, neighbour_model, dist,
               lm_model, max_iters = max_iters, dataset = dataset,
               pop_size = pop_size, n1 = n1, n2 = n2,n_prefix = n_prefix, 
               n_suffix = n_suffix, use_lm = True, use_suffix = True)
    
    
    TEST_SIZE = args.test_size
    order_pre = 0
    n = 0
    seq_success = []
    seq_orig = []
    seq_orig_label = []
    word_varied = []

    seq_success_path = os.path.join(save_path,'seq_success_perplexity_g.npy')
    seq_orig_path = os.path.join(save_path,'seq_orig_perplexity_g.npy')
    seq_orig_label_path = os.path.join(save_path,'seq_orig_label_perplexity_g.npy')
    word_varied_path = os.path.join(save_path,'word_varied_perplexity_g.npy')
    
    if order_pre != 0:
      seq_success = np.load(seq_success_path, allow_pickle = True).tolist()
      seq_orig = np.load(seq_orig_path).tolist()
      seq_orig_label = np.load(seq_orig_label_path).tolist()
      word_varied = np.load(word_varied_path, allow_pickle = True).tolist()
      n = len(seq_success)
    
    for order, (seq, l, target) in enumerate(test_loader):
    
      if order>=order_pre:
        print('Sequence number:{}'.format(order))
        seq_len = np.sum(np.sign(seq.numpy()))
        seq, l = seq.to(device), l.to(device)
        seq = seq.type(torch.LongTensor)
        model.eval()
        with torch.no_grad():
          orig_pred = np.argmax(model.pred(seq, l)[1].cpu().detach().numpy())
        if orig_pred != target.numpy()[0]:
          print('Wrong original prediction')
          print('----------------------')
          continue
        if seq_len > 100:
          print('Sequence is too long')
          print('----------------------')
          continue
    
        print('Length of sentence: {}, Number of samples:{}'.format(l.item(), n+1))
        seq_orig.append(seq[0].numpy())
        seq_orig_label.append(target.numpy()[0])
        target = 1-target.numpy()[0]
        seq_success.append(ga_attack.attack(seq, target, l.type(torch.LongTensor)))
        
        if None not in np.array(seq_success[n]):
          w_be = [dataset.inv_dict[seq_orig[n][i]] for i in list(np.where(seq_success[n] != seq_orig[n])[0])]
          w_to = [dataset.inv_dict[seq_success[n][i]] for i in list(np.where(seq_success[n] != seq_orig[n])[0])]
          for i in range(len(w_be)):
            print('{} ----> {}'.format(w_be[i], w_to[i]))
          word_varied.append([w_be]+[w_to])
        else:
          print('Fail')
        print('----------------------')
        n += 1
        
        np.save(seq_success_path, np.array(seq_success))
        np.save(seq_orig_path, np.array(seq_orig))
        np.save(seq_orig_label_path, np.array(seq_orig_label))
        np.save(word_varied_path, np.array(word_varied, dtype=object))
        
        if n>TEST_SIZE:
          break