def eval_iters(ae_model, dis_model):
    eval_data_loader = non_pair_data_loader(
        batch_size=1,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    eval_file_list = [
        args.data_path + 'sentiment.test.0',
        args.data_path + 'sentiment.test.1',
    ]
    eval_label_list = [
        [0],
        [1],
    ]
    eval_data_loader.create_batches(eval_file_list,
                                    eval_label_list,
                                    if_shuffle=False)
    gold_ans = load_human_answer(args.data_path)
    assert len(gold_ans) == eval_data_loader.num_batch

    add_log("Start eval process.")
    ae_model.eval()
    dis_model.eval()
    for it in range(eval_data_loader.num_batch):
        batch_sentences, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch()

        print("------------%d------------" % it)
        print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
        print("origin_labels", tensor_labels)

        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_tgt_mask)
        generator_text = ae_model.greedy_decode(
            latent, max_len=args.max_sequence_length, start_id=args.id_bos)
        print(id2text_sentence(generator_text[0], args.id_to_word))

        # Define target label
        target = get_cuda(torch.tensor([[1.0]], dtype=torch.float))
        if tensor_labels[0].item() > 0.5:
            target = get_cuda(torch.tensor([[0.0]], dtype=torch.float))
        print("target_labels", target)

        modify_text = fgim_attack(dis_model, latent, target, ae_model,
                                  args.max_sequence_length, args.id_bos,
                                  id2text_sentence, args.id_to_word,
                                  gold_ans[it])
        add_output(modify_text)
        output_text = str(it) + ":\ngold: " + id2text_sentence(
            gold_ans[it], args.id_to_word) + "\nmodified: " + modify_text
        add_output(output_text)
        add_result(
            str(it) + ":\n" + str(
                calc_bleu(id2text_sentence(gold_ans[it], args.id_to_word),
                          modify_text)))
    return
def eval_iters(ae_model, dis_model):
    # tokenizer = BertTokenizer.from_pretrained(args.PRETRAINED_MODEL_NAME, do_lower_case=True)
    if args.use_albert:
        tokenizer = BertTokenizer.from_pretrained("clue/albert_chinese_tiny",
                                                  do_lower_case=True)
    elif args.use_tiny_bert:
        tokenizer = AutoTokenizer.from_pretrained(
            "google/bert_uncased_L-2_H-256_A-4", do_lower_case=True)
    elif args.use_distil_bert:
        tokenizer = DistilBertTokenizer.from_pretrained(
            'distilbert-base-uncased', do_lower_case=True)
    tokenizer.add_tokens('[EOS]')
    bos_id = tokenizer.convert_tokens_to_ids(['[CLS]'])[0]
    ae_model.bert_encoder.resize_token_embeddings(len(tokenizer))

    print("[CLS] ID: ", bos_id)

    # if args.task == 'news_china_taiwan':
    eval_file_list = [
        args.data_path + 'test.0',
        args.data_path + 'test.1',
    ]
    eval_label_list = [
        [0],
        [1],
    ]

    if args.eval_positive:
        eval_file_list = eval_file_list[::-1]
        eval_label_list = eval_label_list[::-1]

    print("Load testData...")

    testData = TextDataset(batch_size=args.batch_size,
                           id_bos='[CLS]',
                           id_eos='[EOS]',
                           id_unk='[UNK]',
                           max_sequence_length=args.max_sequence_length,
                           vocab_size=0,
                           file_list=eval_file_list,
                           label_list=eval_label_list,
                           tokenizer=tokenizer)

    dataset = testData
    eval_data_loader = DataLoader(dataset,
                                  batch_size=1,
                                  shuffle=False,
                                  collate_fn=dataset.collate_fn,
                                  num_workers=4)

    num_batch = len(eval_data_loader)
    trange = tqdm(enumerate(eval_data_loader),
                  total=num_batch,
                  desc='Training',
                  file=sys.stdout,
                  position=0,
                  leave=True)

    gold_ans = [''] * num_batch

    add_log("Start eval process.")
    ae_model.to(device)
    dis_model.to(device)
    ae_model.eval()
    dis_model.eval()

    total_latent_lst = []

    for it, data in trange:
        batch_sentences, tensor_labels, tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, tensor_tgt_mask, tensor_ntokens = data

        tensor_labels = tensor_labels.to(device)
        tensor_src = tensor_src.to(device)
        tensor_tgt = tensor_tgt.to(device)
        tensor_tgt_y = tensor_tgt_y.to(device)
        tensor_src_mask = tensor_src_mask.to(device)
        tensor_tgt_mask = tensor_tgt_mask.to(device)

        print("------------%d------------" % it)
        print(id2text_sentence(tensor_tgt_y[0], tokenizer, args.task))
        print("origin_labels", tensor_labels.cpu().detach().numpy()[0])

        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_tgt_mask)
        generator_text = ae_model.greedy_decode(
            latent, max_len=args.max_sequence_length, start_id=bos_id)
        print(id2text_sentence(generator_text[0], tokenizer, args.task))

        # Define target label
        target = torch.FloatTensor([[1.0]]).to(device)
        if tensor_labels[0].item() > 0.5:
            target = torch.FloatTensor([[0.0]]).to(device)
        print("target_labels", target)

        modify_text, latent_lst = fgim_attack(dis_model,
                                              latent,
                                              target,
                                              ae_model,
                                              args.max_sequence_length,
                                              bos_id,
                                              id2text_sentence,
                                              None,
                                              gold_ans[it],
                                              tokenizer,
                                              device,
                                              task=args.task,
                                              save_latent=args.save_latent)
        if args.save_latent != -1:
            total_latent_lst.append(latent_lst)

        add_output(modify_text)

        if it >= args.save_latent_num:
            break

    print("Save log in ", args.output_file)

    if args.save_latent == -1:
        return

    folder = './latent_{}/'.format(args.task)
    if not os.path.exists(folder):
        os.mkdir(folder)

    if args.save_latent == 0:  # full
        prefix = 'full'
    elif args.save_latent == 1:  # first 6 layer
        prefix = 'first_6'
    elif args.save_latent == 2:  # last 6 layer
        prefix = 'last_6'
    elif args.save_latent == 3:  # get second layer
        prefix = 'distill_2'

    total_latent_lst = np.asarray(total_latent_lst)
    if args.eval_negative:
        save_label = 0
    else:
        save_label = 1
    with open(folder + '{}_{}.pkl'.format(prefix, save_label), 'wb') as f:
        pickle.dump(total_latent_lst, f)

    print("Save laten in ", folder + '{}_{}.pkl'.format(prefix, save_label))
示例#3
0
def fgim_algorithm(args, ae_model, dis_model):
    batch_size = 1
    test_data_loader = non_pair_data_loader(
        batch_size=batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    file_list = [args.test_data_file]
    test_data_loader.create_batches(args,
                                    file_list,
                                    if_shuffle=False,
                                    n_samples=args.test_n_samples)
    if args.references_files:
        gold_ans = load_human_answer(args.references_files, args.text_column)
        assert len(gold_ans) == test_data_loader.num_batch
    else:
        gold_ans = [[None] * batch_size] * test_data_loader.num_batch

    add_log(args, "Start eval process.")
    ae_model.eval()
    dis_model.eval()

    fgim_our = True
    if fgim_our:
        # for FGIM
        z_prime, text_z_prime = fgim(test_data_loader,
                                     args,
                                     ae_model,
                                     dis_model,
                                     gold_ans=gold_ans)
        write_text_z_in_file(args, text_z_prime)
        add_log(
            args,
            "Saving model modify embedding %s ..." % args.current_save_path)
        torch.save(z_prime,
                   os.path.join(args.current_save_path, 'z_prime_fgim.pkl'))
    else:
        for it in range(test_data_loader.num_batch):
            batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = test_data_loader.next_batch()

            print("------------%d------------" % it)
            print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
            print("origin_labels", tensor_labels)

            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask,
                                           tensor_src_attn_mask,
                                           tensor_tgt_mask)
            generator_text = ae_model.greedy_decode(
                latent, max_len=args.max_sequence_length, start_id=args.id_bos)
            print(id2text_sentence(generator_text[0], args.id_to_word))

            # Define target label
            target = get_cuda(torch.tensor([[1.0]], dtype=torch.float), args)
            if tensor_labels[0].item() > 0.5:
                target = get_cuda(torch.tensor([[0.0]], dtype=torch.float),
                                  args)
            add_log(args, "target_labels : %s" % target)

            modify_text = fgim_attack(dis_model, latent, target, ae_model,
                                      args.max_sequence_length, args.id_bos,
                                      id2text_sentence, args.id_to_word,
                                      gold_ans[it])

            add_output(args, modify_text)