def eval_iters(ae_model, dis_model):
    eval_data_loader = non_pair_data_loader(
        batch_size=1,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    eval_file_list = [
        args.data_path + 'sentiment.test.0',
        args.data_path + 'sentiment.test.1',
    ]
    eval_label_list = [
        [0],
        [1],
    ]
    eval_data_loader.create_batches(eval_file_list,
                                    eval_label_list,
                                    if_shuffle=False)
    gold_ans = load_human_answer(args.data_path)
    assert len(gold_ans) == eval_data_loader.num_batch

    add_log("Start eval process.")
    ae_model.eval()
    dis_model.eval()
    for it in range(eval_data_loader.num_batch):
        batch_sentences, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch()

        print("------------%d------------" % it)
        print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
        print("origin_labels", tensor_labels)

        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_tgt_mask)
        generator_text = ae_model.greedy_decode(
            latent, max_len=args.max_sequence_length, start_id=args.id_bos)
        print(id2text_sentence(generator_text[0], args.id_to_word))

        # Define target label
        target = get_cuda(torch.tensor([[1.0]], dtype=torch.float))
        if tensor_labels[0].item() > 0.5:
            target = get_cuda(torch.tensor([[0.0]], dtype=torch.float))
        print("target_labels", target)

        modify_text = fgim_attack(dis_model, latent, target, ae_model,
                                  args.max_sequence_length, args.id_bos,
                                  id2text_sentence, args.id_to_word,
                                  gold_ans[it])
        add_output(modify_text)
        output_text = str(it) + ":\ngold: " + id2text_sentence(
            gold_ans[it], args.id_to_word) + "\nmodified: " + modify_text
        add_output(output_text)
        add_result(
            str(it) + ":\n" + str(
                calc_bleu(id2text_sentence(gold_ans[it], args.id_to_word),
                          modify_text)))
    return
def fgim_step(model, origin_data, target, ae_model, max_sequence_length,
              id_bos, id2text_sentence, id_to_word, gold_ans, epsilon, step):
    """Fast Gradient Iterative Methods"""

    dis_criterion = nn.BCELoss(size_average=True)

    it = 0
    data = origin_data
    while it < step:

        data = to_var(data.clone())  # (batch_size, seq_length, latent_size)
        # Set requires_grad attribute of tensor. Important for Attack
        data.requires_grad = True
        output = model.forward(data)
        loss = dis_criterion(output, target)
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data
        data = data - epsilon * data_grad
        it += 1
        epsilon = epsilon * 0.9

        generator_id = ae_model.greedy_decode(data,
                                              max_len=max_sequence_length,
                                              start_id=id_bos)
        generator_text = id2text_sentence(generator_id[0], id_to_word)
        print("| It {:2d} | dis model pred {:5.4f} |".format(
            it, output[0].item()))
        print(generator_text)

    return data, generator_id
def plot_tsne(ae_model, dis_model, epsilon=2, step=0):
    eval_data_loader = non_pair_data_loader(
        batch_size=500,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    eval_file_list = [
        args.data_path + 'sentiment.test.0',
        args.data_path + 'sentiment.test.1',
    ]
    eval_label_list = [
        [0],
        [1],
    ]
    eval_data_loader.create_batches(eval_file_list,
                                    eval_label_list,
                                    if_shuffle=False)
    gold_ans = load_human_answer(args.data_path)

    ae_model.eval()
    dis_model.eval()
    latents, labels = [], []
    it = 0
    for _ in range(eval_data_loader.num_batch):
        batch_sentences, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch()
        print("------------%d------------" % it)
        print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
        print("origin_labels", tensor_labels[0].item())

        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_tgt_mask)

        # Define target label
        target = get_cuda(
            torch.ones((tensor_labels.size(0), 1), dtype=torch.float))
        target = target - tensor_labels

        if step > 0:
            latent, modified_text = fgim_step(dis_model, latent, target,
                                              ae_model,
                                              args.max_sequence_length,
                                              args.id_bos, id2text_sentence,
                                              args.id_to_word, gold_ans[it],
                                              epsilon, step)

        latents.append(latent)
        labels.append(tensor_labels)

        it += tensor_labels.size(0)

    latents = torch.cat(latents, dim=0).detach().cpu().numpy()
    labels = torch.cat(labels, dim=0).squeeze().detach().cpu().numpy()

    tsne_plot_representation(latents, labels, f"tsne_step{step}_eps{epsilon}")
def plot_tsne(ae_model, dis_model):
    epsilon = 2
    step = 1
    eval_data_loader = non_pair_data_loader(
        batch_size=1,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    eval_file_list = [
        args.data_path + 'sentiment.test.0',
        args.data_path + 'sentiment.test.1',
    ]
    eval_label_list = [
        [0],
        [1],
    ]
    eval_data_loader.create_batches(eval_file_list,
                                    eval_label_list,
                                    if_shuffle=False)
    gold_ans = load_human_answer(args.data_path)
    assert len(gold_ans) == eval_data_loader.num_batch

    ae_model.eval()
    dis_model.eval()
    latents, labels = [], []
    for it in range(eval_data_loader.num_batch):
        batch_sentences, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch()
        print("------------%d------------" % it)
        print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
        print("origin_labels", tensor_labels.item())

        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_tgt_mask)

        # Define target label
        target = get_cuda(torch.tensor([[1.0]], dtype=torch.float))
        if tensor_labels[0].item() > 0.5:
            target = get_cuda(torch.tensor([[0.0]], dtype=torch.float))

        modified_latent, modified_text = fgim_step(
            dis_model, latent, target, ae_model, args.max_sequence_length,
            args.id_bos, id2text_sentence, args.id_to_word, gold_ans[it],
            epsilon, step)
        latents.append(modified_latent)
        labels.append(tensor_labels.item())

    latents = torch.cat(latents, dim=0).detach().cpu().numpy()
    labels = numpy.array(labels)

    tsne_plot_representation(latents, labels)
Пример #5
0
def predict(args, ae_model, dis_model, batch, epsilon):
    (batch_sentences, 
     tensor_labels, 
     tensor_src,
     tensor_src_mask,
     tensor_tgt,
     tensor_tgt_y,
     tensor_tgt_mask,
     tensor_ntokens) = batch
    
    ae_model.eval()
    dis_model.eval()
    latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_tgt_mask)
    generator_text = ae_model.greedy_decode(latent,
                                            max_len=args.max_sequence_length,
                                            start_id=args.id_bos)
    print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
    print(id2text_sentence(generator_text[0], args.id_to_word))
    target = get_cuda(torch.tensor([[1.0]], dtype=torch.float))
    if tensor_labels[0].item() > 0.5:
        target = get_cuda(torch.tensor([[0.0]], dtype=torch.float))

    dis_criterion = nn.BCELoss(size_average=True)

    data = to_var(latent.clone())  # (batch_size, seq_length, latent_size)
    data.requires_grad = True
    output = dis_model.forward(data)
    loss = dis_criterion(output, target)
    dis_model.zero_grad()
    loss.backward()
    data_grad = data.grad.data
    data = data - epsilon * data_grad

    generator_id = ae_model.greedy_decode(data,
                                          max_len=args.max_sequence_length,
                                          start_id=args.id_bos)
    return id2text_sentence(generator_id[0], args.id_to_word)
def train_iters(ae_model, dis_model):
    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    train_data_loader.create_batches(args.train_file_list,
                                     args.train_label_list,
                                     if_shuffle=True)
    add_log("Start train process.")
    ae_model.train()
    dis_model.train()

    ae_optimizer = NoamOpt(
        ae_model.src_embed[0].d_model, 1, 2000,
        torch.optim.Adam(ae_model.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1))
    dis_criterion = nn.BCELoss(size_average=True)

    for epoch in range(200):
        print('-' * 94)
        epoch_start_time = time.time()
        for it in range(train_data_loader.num_batch):
            batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = train_data_loader.next_batch()

            # For debug
            # print(batch_sentences[0])
            # print(tensor_src[0])
            # print(tensor_src_mask[0])
            # print("tensor_src_mask", tensor_src_mask.size())
            # print(tensor_tgt[0])
            # print(tensor_tgt_y[0])
            # print(tensor_tgt_mask[0])
            # print(batch_ntokens)

            # Forward pass
            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask, tensor_tgt_mask)
            # print(latent.size())  # (batch_size, max_src_seq, d_model)
            # print(out.size())  # (batch_size, max_tgt_seq, vocab_size)

            # Loss calculation
            loss_rec = ae_criterion(
                out.contiguous().view(-1, out.size(-1)),
                tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

            # loss_all = loss_rec + loss_dis

            ae_optimizer.optimizer.zero_grad()
            loss_rec.backward()
            ae_optimizer.step()

            # Classifier
            dis_lop = dis_model.forward(to_var(latent.clone()))

            loss_dis = dis_criterion(dis_lop, tensor_labels)

            dis_optimizer.zero_grad()
            loss_dis.backward()
            dis_optimizer.step()

            if it % 200 == 0:
                add_log(
                    '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |'
                    .format(epoch, it, train_data_loader.num_batch, loss_rec,
                            loss_dis))

                print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
                generator_text = ae_model.greedy_decode(
                    latent,
                    max_len=args.max_sequence_length,
                    start_id=args.id_bos)
                print(id2text_sentence(generator_text[0], args.id_to_word))

        add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))
        # Save model
        torch.save(ae_model.state_dict(),
                   args.current_save_path + 'ae_model_params.pkl')
        torch.save(dis_model.state_dict(),
                   args.current_save_path + 'dis_model_params.pkl')
    return
Пример #7
0
def fgim(data_loader, args, ae_model, c_theta, gold_ans = None) :
    """
    Input: 
        Original latent representation z : (n_batch, batch_size, seq_length, latent_size)
        Well-trained attribute classifier C_θ
        Target attribute y
        A set of weights w = {w_i}
        Decay coefficient λ 
        Threshold t
        
    Output: An optimal modified latent representation z'
    """
    w = args.w
    lambda_ = args.lambda_
    t = args.threshold
    max_iter_per_epsilon = args.max_iter_per_epsilon
    max_sequence_length = args.max_sequence_length
    id_bos = args.id_bos
    id_to_word = args.id_to_word
    limit_batches = args.limit_batches
        
    text_z_prime = {}
    text_z_prime = {"source" : [], "origin_labels" : [], "before" : [], "after" : [], "change" : [], "pred_label" : []}
    if gold_ans is not None :
        text_z_prime["gold_ans"] = []
    z_prime = []
    
    dis_criterion = nn.BCELoss(size_average=True)
    n_batches = 0
    for it in tqdm.tqdm(list(range(data_loader.num_batch)), desc="FGIM"):
        if gold_ans is not None :
            text_z_prime["gold_ans"].append(gold_ans[it])
        
        _, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, _ = data_loader.next_batch()
        # only on negative example
        negative_examples = ~(tensor_labels.squeeze()==args.positive_label)
        tensor_labels = tensor_labels[negative_examples].squeeze(0) # .view(1, -1)
        tensor_src = tensor_src[negative_examples].squeeze(0) 
        tensor_src_attn_mask = tensor_src_attn_mask[negative_examples].squeeze(0)
        tensor_src_mask = tensor_src_mask[negative_examples].squeeze(0)  
        tensor_tgt_y = tensor_tgt_y[negative_examples].squeeze(0) 
        tensor_tgt = tensor_tgt[negative_examples].squeeze(0) 
        tensor_tgt_mask = tensor_tgt_mask[negative_examples].squeeze(0) 
        #if gold_ans is not None :
        #    text_z_prime["gold_ans"][-1] = text_z_prime["gold_ans"][-1][negative_examples]

        #print("------------%d------------" % it)
        if negative_examples.any():
            text_z_prime["source"].append([id2text_sentence(t, args.id_to_word) for t in tensor_tgt_y])
            text_z_prime["origin_labels"].append(tensor_labels.cpu().numpy())

            origin_data, _ = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_src_attn_mask, tensor_tgt_mask)

            # Define target label
            y_prime = 1.0 - (tensor_labels > 0.5).float()

            ############################### FGIM ######################################################
            generator_id = ae_model.greedy_decode(origin_data, max_len=max_sequence_length, start_id=id_bos)
            generator_text = [id2text_sentence(gid, id_to_word) for gid in generator_id]
            text_z_prime["before"].append(generator_text)
            
            flag = False
            for w_i in w:
                #print("---------- w_i:", w_i)
                data = to_var(origin_data.clone())  # (batch_size, seq_length, latent_size)
                b = True
                if b :
                    data.requires_grad = True
                    output = c_theta.forward(data)
                    loss = dis_criterion(output, y_prime)
                    c_theta.zero_grad()
                    loss.backward()
                    data = data - w_i * data.grad.data
                else :
                    data = origin_data
                    output = c_theta.forward(data)
                    
                it = 0 
                while True:
                    #if torch.cdist(output, y_prime) < t :
                    #if torch.sum((output - y_prime)**2, dim=1).sqrt().mean() < t :
                    if torch.sum((output - y_prime).abs(), dim=1).mean() < t :
                        flag = True
                        break
            
                    data = to_var(data.clone())  # (batch_size, seq_length, latent_size)
                    # Set requires_grad attribute of tensor. Important for Attack
                    data.requires_grad = True
                    output = c_theta.forward(data)
                    # Calculate gradients of model in backward pass
                    loss = dis_criterion(output, y_prime)
                    c_theta.zero_grad()
                    # dis_optimizer.zero_grad()
                    loss.backward()
                    data = data - w_i * data.grad.data
                    it += 1
                    # data = perturbed_data
                    w_i = lambda_ * w_i
                    if False :
                        if text_gen_params is not None :
                            generator_id = ae_model.greedy_decode(data,
                                                                    max_len=max_sequence_length,
                                                                    start_id=id_bos)
                            generator_text = id2text_sentence(generator_id[0], id_to_word)
                            print("| It {:2d} | dis model pred {:5.4f} |".format(it, output[0].item()))
                            print(generator_text)
                    if it > max_iter_per_epsilon:
                        break
                
                if flag :    
                    z_prime.append(data)
                    generator_id = ae_model.greedy_decode(data, max_len=max_sequence_length, start_id=id_bos)
                    generator_text = [id2text_sentence(gid, id_to_word) for gid in generator_id]
                    text_z_prime["after"].append(generator_text)
                    text_z_prime["change"].append([True]*len(output))
                    text_z_prime["pred_label"].append([o.item() for o in output])
                    break
            
            if not flag : # cannot debiaising
                z_prime.append(origin_data)
                text_z_prime["after"].append(text_z_prime["before"][-1])
                text_z_prime["change"].append([False]*len(y_prime))
                text_z_prime["pred_label"].append([o.item() for o in y_prime])
            
            n_batches += 1
            if n_batches > limit_batches:
                break        
    return z_prime, text_z_prime
Пример #8
0
def fgim_attack(model, origin_data, target, ae_model, max_sequence_length, id_bos,
                id2text_sentence, id_to_word, gold_ans = None):
    """Fast Gradient Iterative Methods"""

    dis_criterion = nn.BCELoss(size_average=True)
    t = 0.001 # Threshold
    lambda_ = 0.9 # Decay coefficient
    max_iter_per_epsilon=20
    
    if gold_ans is not None :
        gold_text = id2text_sentence(gold_ans, id_to_word)
        print("gold:", gold_text)
        
    flag = False
    for epsilon in [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]:
        print("---------- epsilon:", epsilon)
        generator_id = ae_model.greedy_decode(origin_data, max_len=max_sequence_length, start_id=id_bos)
        generator_text = id2text_sentence(generator_id[0], id_to_word)
        print("z:", generator_text)
        
        data = to_var(origin_data.clone())  # (batch_size, seq_length, latent_size)
        b = True
        if b :
            data.requires_grad = True
            output = model.forward(data)
            loss = dis_criterion(output, target)
            model.zero_grad()
            loss.backward()
            data_grad = data.grad.data
            data = data - epsilon * data_grad
        else :
            data = origin_data
            
        it = 0 
        while True:
            if torch.cdist(output, target) < t :
                flag = True
                break
    
            data = to_var(data.clone())  # (batch_size, seq_length, latent_size)
            # Set requires_grad attribute of tensor. Important for Attack
            data.requires_grad = True
            output = model.forward(data)
            # Calculate gradients of model in backward pass
            loss = dis_criterion(output, target)
            model.zero_grad()
            # dis_optimizer.zero_grad()
            loss.backward()
            data_grad = data.grad.data
            data = data - epsilon * data_grad
            it += 1
            # data = perturbed_data
            epsilon = epsilon * lambda_
            if False :
                generator_id = ae_model.greedy_decode(data,
                                                        max_len=max_sequence_length,
                                                        start_id=id_bos)
                generator_text = id2text_sentence(generator_id[0], id_to_word)
                print("| It {:2d} | dis model pred {:5.4f} |".format(it, output[0].item()))
                print(generator_text)
            if it > max_iter_per_epsilon:
                break
        
        generator_id = ae_model.greedy_decode(data, max_len=max_sequence_length, start_id=id_bos)
        generator_text = id2text_sentence(generator_id[0], id_to_word)
        print("|dis model pred {:5.4f} |".format(output[0].item()))
        print("z*", generator_text)
        print()
        if flag :
            return generator_text
def eval_iters(ae_model, dis_model):
    # tokenizer = BertTokenizer.from_pretrained(args.PRETRAINED_MODEL_NAME, do_lower_case=True)
    if args.use_albert:
        tokenizer = BertTokenizer.from_pretrained("clue/albert_chinese_tiny",
                                                  do_lower_case=True)
    elif args.use_tiny_bert:
        tokenizer = AutoTokenizer.from_pretrained(
            "google/bert_uncased_L-2_H-256_A-4", do_lower_case=True)
    elif args.use_distil_bert:
        tokenizer = DistilBertTokenizer.from_pretrained(
            'distilbert-base-uncased', do_lower_case=True)
    tokenizer.add_tokens('[EOS]')
    bos_id = tokenizer.convert_tokens_to_ids(['[CLS]'])[0]
    ae_model.bert_encoder.resize_token_embeddings(len(tokenizer))

    print("[CLS] ID: ", bos_id)

    # if args.task == 'news_china_taiwan':
    eval_file_list = [
        args.data_path + 'test.0',
        args.data_path + 'test.1',
    ]
    eval_label_list = [
        [0],
        [1],
    ]

    if args.eval_positive:
        eval_file_list = eval_file_list[::-1]
        eval_label_list = eval_label_list[::-1]

    print("Load testData...")

    testData = TextDataset(batch_size=args.batch_size,
                           id_bos='[CLS]',
                           id_eos='[EOS]',
                           id_unk='[UNK]',
                           max_sequence_length=args.max_sequence_length,
                           vocab_size=0,
                           file_list=eval_file_list,
                           label_list=eval_label_list,
                           tokenizer=tokenizer)

    dataset = testData
    eval_data_loader = DataLoader(dataset,
                                  batch_size=1,
                                  shuffle=False,
                                  collate_fn=dataset.collate_fn,
                                  num_workers=4)

    num_batch = len(eval_data_loader)
    trange = tqdm(enumerate(eval_data_loader),
                  total=num_batch,
                  desc='Training',
                  file=sys.stdout,
                  position=0,
                  leave=True)

    gold_ans = [''] * num_batch

    add_log("Start eval process.")
    ae_model.to(device)
    dis_model.to(device)
    ae_model.eval()
    dis_model.eval()

    total_latent_lst = []

    for it, data in trange:
        batch_sentences, tensor_labels, tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, tensor_tgt_mask, tensor_ntokens = data

        tensor_labels = tensor_labels.to(device)
        tensor_src = tensor_src.to(device)
        tensor_tgt = tensor_tgt.to(device)
        tensor_tgt_y = tensor_tgt_y.to(device)
        tensor_src_mask = tensor_src_mask.to(device)
        tensor_tgt_mask = tensor_tgt_mask.to(device)

        print("------------%d------------" % it)
        print(id2text_sentence(tensor_tgt_y[0], tokenizer, args.task))
        print("origin_labels", tensor_labels.cpu().detach().numpy()[0])

        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_tgt_mask)
        generator_text = ae_model.greedy_decode(
            latent, max_len=args.max_sequence_length, start_id=bos_id)
        print(id2text_sentence(generator_text[0], tokenizer, args.task))

        # Define target label
        target = torch.FloatTensor([[1.0]]).to(device)
        if tensor_labels[0].item() > 0.5:
            target = torch.FloatTensor([[0.0]]).to(device)
        print("target_labels", target)

        modify_text, latent_lst = fgim_attack(dis_model,
                                              latent,
                                              target,
                                              ae_model,
                                              args.max_sequence_length,
                                              bos_id,
                                              id2text_sentence,
                                              None,
                                              gold_ans[it],
                                              tokenizer,
                                              device,
                                              task=args.task,
                                              save_latent=args.save_latent)
        if args.save_latent != -1:
            total_latent_lst.append(latent_lst)

        add_output(modify_text)

        if it >= args.save_latent_num:
            break

    print("Save log in ", args.output_file)

    if args.save_latent == -1:
        return

    folder = './latent_{}/'.format(args.task)
    if not os.path.exists(folder):
        os.mkdir(folder)

    if args.save_latent == 0:  # full
        prefix = 'full'
    elif args.save_latent == 1:  # first 6 layer
        prefix = 'first_6'
    elif args.save_latent == 2:  # last 6 layer
        prefix = 'last_6'
    elif args.save_latent == 3:  # get second layer
        prefix = 'distill_2'

    total_latent_lst = np.asarray(total_latent_lst)
    if args.eval_negative:
        save_label = 0
    else:
        save_label = 1
    with open(folder + '{}_{}.pkl'.format(prefix, save_label), 'wb') as f:
        pickle.dump(total_latent_lst, f)

    print("Save laten in ", folder + '{}_{}.pkl'.format(prefix, save_label))
def train_iters(ae_model, dis_model):
    if args.use_albert:
        tokenizer = BertTokenizer.from_pretrained("clue/albert_chinese_tiny",
                                                  do_lower_case=True)
    elif args.use_tiny_bert:
        tokenizer = AutoTokenizer.from_pretrained(
            "google/bert_uncased_L-2_H-256_A-4", do_lower_case=True)
    elif args.use_distil_bert:
        tokenizer = DistilBertTokenizer.from_pretrained(
            'distilbert-base-uncased', do_lower_case=True)
    # tokenizer = BertTokenizer.from_pretrained(args.PRETRAINED_MODEL_NAME, do_lower_case=True)
    tokenizer.add_tokens('[EOS]')
    bos_id = tokenizer.convert_tokens_to_ids(['[CLS]'])[0]

    ae_model.bert_encoder.resize_token_embeddings(len(tokenizer))
    #print("[CLS] ID: ", bos_id)

    print("Load trainData...")
    if args.load_trainData and os.path.exists('./{}_trainData.pkl'.format(
            args.task)):
        with open('./{}_trainData.pkl'.format(args.task), 'rb') as f:
            trainData = pickle.load(f)
    else:
        trainData = TextDataset(batch_size=args.batch_size,
                                id_bos='[CLS]',
                                id_eos='[EOS]',
                                id_unk='[UNK]',
                                max_sequence_length=args.max_sequence_length,
                                vocab_size=0,
                                file_list=args.train_file_list,
                                label_list=args.train_label_list,
                                tokenizer=tokenizer)
        with open('./{}_trainData.pkl'.format(args.task), 'wb') as f:
            pickle.dump(trainData, f)

    add_log("Start train process.")

    ae_model.train()
    dis_model.train()
    ae_model.to(device)
    dis_model.to(device)
    '''
    Fixing or distilling BERT encoder
    '''
    if args.fix_first_6:
        print("Try fixing first 6 bertlayers")
        for layer in range(6):
            for param in ae_model.bert_encoder.encoder.layer[layer].parameters(
            ):
                param.requires_grad = False
    elif args.fix_last_6:
        print("Try fixing last 6 bertlayers")
        for layer in range(6, 12):
            for param in ae_model.bert_encoder.encoder.layer[layer].parameters(
            ):
                param.requires_grad = False

    if args.distill_2:
        print("Get result from layer 2")
        for layer in range(2, 12):
            for param in ae_model.bert_encoder.encoder.layer[layer].parameters(
            ):
                param.requires_grad = False

    ae_optimizer = NoamOpt(
        ae_model.d_model, 1, 2000,
        torch.optim.Adam(ae_model.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001)

    #ae_criterion = get_cuda(LabelSmoothing(size=args.vocab_size, padding_idx=args.id_pad, smoothing=0.1))
    ae_criterion = LabelSmoothing(size=ae_model.bert_encoder.config.vocab_size,
                                  padding_idx=0,
                                  smoothing=0.1).to(device)
    dis_criterion = nn.BCELoss(reduction='mean')

    history = {'train': []}

    for epoch in range(args.epochs):
        print('-' * 94)
        epoch_start_time = time.time()
        total_rec_loss = 0
        total_dis_loss = 0

        train_data_loader = DataLoader(trainData,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       collate_fn=trainData.collate_fn,
                                       num_workers=4)
        num_batch = len(train_data_loader)
        trange = tqdm(enumerate(train_data_loader),
                      total=num_batch,
                      desc='Training',
                      file=sys.stdout,
                      position=0,
                      leave=True)

        for it, data in trange:
            batch_sentences, tensor_labels, tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, tensor_tgt_mask, tensor_ntokens = data

            tensor_labels = tensor_labels.to(device)
            tensor_src = tensor_src.to(device)
            tensor_tgt = tensor_tgt.to(device)
            tensor_tgt_y = tensor_tgt_y.to(device)
            tensor_src_mask = tensor_src_mask.to(device)
            tensor_tgt_mask = tensor_tgt_mask.to(device)

            # Forward pass
            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask, tensor_tgt_mask)

            # Loss calculation
            loss_rec = ae_criterion(
                out.contiguous().view(-1, out.size(-1)),
                tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

            ae_optimizer.optimizer.zero_grad()
            loss_rec.backward()
            ae_optimizer.step()

            latent = latent.detach()
            next_latent = latent.to(device)

            # Classifier
            dis_lop = dis_model.forward(next_latent)
            loss_dis = dis_criterion(dis_lop, tensor_labels)

            dis_optimizer.zero_grad()
            loss_dis.backward()
            dis_optimizer.step()

            total_rec_loss += loss_rec.item()
            total_dis_loss += loss_dis.item()

            trange.set_postfix(total_rec_loss=total_rec_loss / (it + 1),
                               total_dis_loss=total_dis_loss / (it + 1))

            if it % 100 == 0:
                add_log(
                    '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |'
                    .format(epoch, it, num_batch, loss_rec, loss_dis))

                print(id2text_sentence(tensor_tgt_y[0], tokenizer, args.task))
                generator_text = ae_model.greedy_decode(
                    latent, max_len=args.max_sequence_length, start_id=bos_id)
                print(id2text_sentence(generator_text[0], tokenizer,
                                       args.task))

                # Save model
                #torch.save(ae_model.state_dict(), args.current_save_path / 'ae_model_params.pkl')
                #torch.save(dis_model.state_dict(), args.current_save_path / 'dis_model_params.pkl')

        history['train'].append({
            'epoch': epoch,
            'total_rec_loss': total_rec_loss / len(trange),
            'total_dis_loss': total_dis_loss / len(trange)
        })

        add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))
        # Save model
        torch.save(ae_model.state_dict(),
                   args.current_save_path / 'ae_model_params.pkl')
        torch.save(dis_model.state_dict(),
                   args.current_save_path / 'dis_model_params.pkl')

    print("Save in ", args.current_save_path)
    return
Пример #11
0
def sedat_eval(args, ae_model, f, deb):
    """
    Input: 
        Original latent representation z : (n_batch, batch_size, seq_length, latent_size)
    Output: 
        An optimal modified latent representation z'
    """
    max_sequence_length = args.max_sequence_length
    id_bos = args.id_bos
    id_to_word = args.id_to_word
    limit_batches = args.limit_batches

    eval_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    file_list = [args.test_data_file]
    eval_data_loader.create_batches(args,
                                    file_list,
                                    if_shuffle=False,
                                    n_samples=args.test_n_samples)
    if args.references_files:
        gold_ans = load_human_answer(args.references_files, args.text_column)
        assert len(gold_ans) == eval_data_loader.num_batch
    else:
        gold_ans = None

    add_log(args, "Start eval process.")
    ae_model.eval()
    f.eval()
    deb.eval()

    text_z_prime = {}
    text_z_prime = {
        "source": [],
        "origin_labels": [],
        "before": [],
        "after": [],
        "change": [],
        "pred_label": []
    }
    if gold_ans is not None:
        text_z_prime["gold_ans"] = []
    z_prime = []
    n_batches = 0
    for it in tqdm.tqdm(list(range(eval_data_loader.num_batch)), desc="SEDAT"):

        _, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, _ = eval_data_loader.next_batch()
        # only on negative example
        negative_examples = ~(tensor_labels.squeeze() == args.positive_label)
        tensor_labels = tensor_labels[negative_examples].squeeze(
            0)  # .view(1, -1)
        tensor_src = tensor_src[negative_examples].squeeze(0)
        tensor_src_mask = tensor_src_mask[negative_examples].squeeze(0)
        tensor_src_attn_mask = tensor_src_attn_mask[negative_examples].squeeze(
            0)
        tensor_tgt_y = tensor_tgt_y[negative_examples].squeeze(0)
        tensor_tgt = tensor_tgt[negative_examples].squeeze(0)
        tensor_tgt_mask = tensor_tgt_mask[negative_examples].squeeze(0)
        if negative_examples.any():
            if gold_ans is not None:
                text_z_prime["gold_ans"].append(gold_ans[it])

            text_z_prime["source"].append(
                [id2text_sentence(t, args.id_to_word) for t in tensor_tgt_y])
            text_z_prime["origin_labels"].append(tensor_labels.cpu().numpy())

            origin_data, _ = ae_model.forward(tensor_src, tensor_tgt,
                                              tensor_src_mask,
                                              tensor_src_attn_mask,
                                              tensor_tgt_mask)

            generator_id = ae_model.greedy_decode(origin_data,
                                                  max_len=max_sequence_length,
                                                  start_id=id_bos)
            generator_text = [
                id2text_sentence(gid, id_to_word) for gid in generator_id
            ]
            text_z_prime["before"].append(generator_text)

            data = deb(origin_data, mask=None)
            data = torch.sum(ae_model.sigmoid(data),
                             dim=1)  # (batch_size, d_model)
            #logit = ae_model.decode(data.unsqueeze(1), tensor_tgt, tensor_tgt_mask)  # (batch_size, max_tgt_seq, d_model)
            #output = ae_model.generator(logit)  # (batch_size, max_seq, vocab_size)
            y_hat = f.forward(data)
            y_hat = y_hat.round().int()
            z_prime.append(data)
            generator_id = ae_model.greedy_decode(data,
                                                  max_len=max_sequence_length,
                                                  start_id=id_bos)
            generator_text = [
                id2text_sentence(gid, id_to_word) for gid in generator_id
            ]
            text_z_prime["after"].append(generator_text)
            text_z_prime["change"].append([True] * len(y_hat))
            text_z_prime["pred_label"].append([y_.item() for y_ in y_hat])

            n_batches += 1
            if n_batches > limit_batches:
                break
    write_text_z_in_file(args, text_z_prime)
    add_log(args, "")
    add_log(args,
            "Saving model modify embedding %s ..." % args.current_save_path)
    torch.save(z_prime,
               os.path.join(args.current_save_path, 'z_prime_sedat.pkl'))
    return z_prime, text_z_prime
Пример #12
0
def train_step(args, data_loader, ae_model, dis_model, ae_optimizer,
               dis_optimizer, ae_criterion, dis_criterion, epoch):
    ae_model.train()
    dis_model.train()

    loss_ae, n_words_ae, xe_loss_ae, n_valid_ae = 0, 0, 0, 0
    loss_clf, total_clf, n_valid_clf = 0, 0, 0
    epoch_start_time = time.time()
    for it in range(data_loader.num_batch):
        flag_rec = True
        batch_sentences, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, tensor_ntokens = data_loader.next_batch()

        # Forward pass
        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_src_attn_mask, tensor_tgt_mask)

        # Loss calculation
        if not args.sedat:
            loss_rec = ae_criterion(out.contiguous().view(-1, out.size(-1)),
                                    tensor_tgt_y.contiguous().view(-1)) / (
                                        tensor_ntokens.data + eps)

        else:
            # only on positive example
            positive_examples = tensor_labels.squeeze() == args.positive_label
            out = out[positive_examples]  # or out[positive_examples,:,:]
            tensor_tgt_y = tensor_tgt_y[
                positive_examples]  # or tensor_tgt_y[positive_examples,:]
            tensor_ntokens = (tensor_tgt_y != 0).data.sum().float()
            loss_rec = ae_criterion(out.contiguous().view(-1, out.size(-1)),
                                    tensor_tgt_y.contiguous().view(-1)) / (
                                        tensor_ntokens.data + eps)
            flag_rec = positive_examples.any()
            out = out.squeeze(0)
            tensor_tgt_y = tensor_tgt_y.squeeze(0)

        if flag_rec:
            n_v, n_w = get_n_v_w(tensor_tgt_y, out)
        else:
            n_w = float("nan")
            n_v = float("nan")

        x_e = loss_rec.item() * n_w
        loss_ae += loss_rec.item()
        n_words_ae += n_w
        xe_loss_ae += x_e
        n_valid_ae += n_v
        ae_acc = 100. * n_v / (n_w + eps)
        avg_ae_acc = 100. * n_valid_ae / (n_words_ae + eps)
        avg_ae_loss = loss_ae / (it + 1)
        ae_ppl = np.exp(x_e / (n_w + eps))
        avg_ae_ppl = np.exp(xe_loss_ae / (n_words_ae + eps))

        ae_optimizer.zero_grad()
        loss_rec.backward(retain_graph=not args.detach_classif)
        ae_optimizer.step()

        # Classifier
        if args.detach_classif:
            dis_lop = dis_model.forward(to_var(latent.clone()))
        else:
            dis_lop = dis_model.forward(latent)

        loss_dis = dis_criterion(dis_lop, tensor_labels)

        dis_optimizer.zero_grad()
        loss_dis.backward()
        dis_optimizer.step()

        t_c = tensor_labels.view(-1).size(0)
        n_v = (dis_lop.round().int() == tensor_labels).sum().item()
        loss_clf += loss_dis.item()
        total_clf += t_c
        n_valid_clf += n_v
        clf_acc = 100. * n_v / (t_c + eps)
        avg_clf_acc = 100. * n_valid_clf / (total_clf + eps)
        avg_clf_loss = loss_clf / (it + 1)
        if it % args.log_interval == 0:
            add_log(
                args, 'epoch {:3d} | {:5d}/{:5d} batches |'.format(
                    epoch, it, data_loader.num_batch))
            add_log(
                args,
                'Train : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
                .format(ae_acc, loss_rec.item(), ae_ppl, clf_acc,
                        loss_dis.item()))
            add_log(
                args,
                'Train, avg : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} |  dis acc {:5.4f} | dis loss {:5.4f} |'
                .format(avg_ae_acc, avg_ae_loss, avg_ae_ppl, avg_clf_acc,
                        avg_clf_loss))
            if flag_rec:
                i = random.randint(0, len(tensor_tgt_y) - 1)
                reference = id2text_sentence(tensor_tgt_y[i], args.id_to_word)
                add_log(args, "input : %s" % reference)
                generator_text = ae_model.greedy_decode(
                    latent,
                    max_len=args.max_sequence_length,
                    start_id=args.id_bos)
                # batch_sentences
                hypothesis = id2text_sentence(generator_text[i],
                                              args.id_to_word)
                add_log(args, "gen : %s" % hypothesis)
                add_log(
                    args, "bleu : %s" %
                    calc_bleu(reference.split(" "), hypothesis.split(" ")))

    s = {}
    L = data_loader.num_batch + eps
    s["train_ae_loss"] = loss_ae / L
    s["train_ae_acc"] = 100. * n_valid_ae / (n_words_ae + eps)
    s["train_ae_ppl"] = np.exp(xe_loss_ae / (n_words_ae + eps))
    s["train_clf_loss"] = loss_clf / L
    s["train_clf_acc"] = 100. * n_valid_clf / (total_clf + eps)

    add_log(
        args, '| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))

    add_log(
        args,
        '| rec acc {:5.4f} | rec loss {:5.4f} | rec ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
        .format(s["train_ae_acc"], s["train_ae_loss"], s["train_ae_ppl"],
                s["train_clf_acc"], s["train_clf_loss"]))
    return s
Пример #13
0
def sedat_train(args, ae_model, f, deb):
    """
    Input: 
        Original latent representation z : (n_batch, batch_size, seq_length, latent_size)
    Output: 
        An optimal modified latent representation z'
    """
    # TODO : fin a metric to control the evelotuion of training, mainly for deb model
    lambda_ = args.sedat_threshold
    alpha, beta = [float(coef) for coef in args.sedat_alpha_beta.split(",")]
    # only on negative example
    only_on_negative_example = args.sedat_only_on_negative_example
    penalty = args.penalty
    type_penalty = args.type_penalty

    assert penalty in ["lasso", "ridge"]
    assert type_penalty in ["last", "group"]

    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    file_list = [args.train_data_file]
    if os.path.exists(args.val_data_file):
        file_list.append(args.val_data_file)
    train_data_loader.create_batches(args,
                                     file_list,
                                     if_shuffle=True,
                                     n_samples=args.train_n_samples)

    add_log(args, "Start train process.")

    #add_log("Start train process.")
    ae_model.train()
    f.train()
    deb.train()

    ae_optimizer = get_optimizer(parameters=ae_model.parameters(),
                                 s=args.ae_optimizer,
                                 noamopt=args.ae_noamopt)
    dis_optimizer = get_optimizer(parameters=f.parameters(),
                                  s=args.dis_optimizer)
    deb_optimizer = get_optimizer(parameters=deb.parameters(),
                                  s=args.dis_optimizer)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1), args)
    dis_criterion = nn.BCELoss(size_average=True)
    deb_criterion = LossSedat(penalty=penalty)

    stats = []
    for epoch in range(args.max_epochs):
        print('-' * 94)
        epoch_start_time = time.time()

        loss_ae, n_words_ae, xe_loss_ae, n_valid_ae = 0, 0, 0, 0
        loss_clf, total_clf, n_valid_clf = 0, 0, 0
        for it in range(train_data_loader.num_batch):
            _, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, _ = train_data_loader.next_batch()
            flag = True
            # only on negative example
            if only_on_negative_example:
                negative_examples = ~(tensor_labels.squeeze()
                                      == args.positive_label)
                tensor_labels = tensor_labels[negative_examples].squeeze(
                    0)  # .view(1, -1)
                tensor_src = tensor_src[negative_examples].squeeze(0)
                tensor_src_mask = tensor_src_mask[negative_examples].squeeze(0)
                tensor_src_attn_mask = tensor_src_attn_mask[
                    negative_examples].squeeze(0)
                tensor_tgt_y = tensor_tgt_y[negative_examples].squeeze(0)
                tensor_tgt = tensor_tgt[negative_examples].squeeze(0)
                tensor_tgt_mask = tensor_tgt_mask[negative_examples].squeeze(0)
                flag = negative_examples.any()
            if flag:
                # forward
                z, out, z_list = ae_model.forward(tensor_src,
                                                  tensor_tgt,
                                                  tensor_src_mask,
                                                  tensor_src_attn_mask,
                                                  tensor_tgt_mask,
                                                  return_intermediate=True)
                #y_hat = f.forward(to_var(z.clone()))
                y_hat = f.forward(z)

                loss_dis = dis_criterion(y_hat, tensor_labels)
                dis_optimizer.zero_grad()
                loss_dis.backward(retain_graph=True)
                dis_optimizer.step()

                dis_lop = f.forward(z)
                t_c = tensor_labels.view(-1).size(0)
                n_v = (dis_lop.round().int() == tensor_labels).sum().item()
                loss_clf += loss_dis.item()
                total_clf += t_c
                n_valid_clf += n_v
                clf_acc = 100. * n_v / (t_c + eps)
                avg_clf_acc = 100. * n_valid_clf / (total_clf + eps)
                avg_clf_loss = loss_clf / (it + 1)

                mask_deb = y_hat.squeeze(
                ) >= lambda_ if args.positive_label == 0 else y_hat.squeeze(
                ) < lambda_
                # if f(z) > lambda :
                if mask_deb.any():
                    y_hat_deb = y_hat[mask_deb]
                    if type_penalty == "last":
                        z_deb = z[mask_deb].squeeze(
                            0) if args.batch_size == 1 else z[mask_deb]
                    elif type_penalty == "group":
                        # TODO : unit test for bach_size = 1
                        z_deb = z_list[-1][mask_deb]
                    z_prime, z_prime_list = deb(z_deb,
                                                mask=None,
                                                return_intermediate=True)
                    if type_penalty == "last":
                        z_prime = torch.sum(ae_model.sigmoid(z_prime), dim=1)
                        loss_deb = alpha * deb_criterion(
                            z_deb, z_prime,
                            is_list=False) + beta * y_hat_deb.sum()
                    elif type_penalty == "group":
                        z_deb_list = [z_[mask_deb] for z_ in z_list]
                        #assert len(z_deb_list) == len(z_prime_list)
                        loss_deb = alpha * deb_criterion(
                            z_deb_list, z_prime_list,
                            is_list=True) + beta * y_hat_deb.sum()

                    deb_optimizer.zero_grad()
                    loss_deb.backward(retain_graph=True)
                    deb_optimizer.step()
                else:
                    loss_deb = torch.tensor(float("nan"))

                # else :
                if (~mask_deb).any():
                    out_ = out[~mask_deb]
                    tensor_tgt_y_ = tensor_tgt_y[~mask_deb]
                    tensor_ntokens = (tensor_tgt_y_ != 0).data.sum().float()
                    loss_rec = ae_criterion(
                        out_.contiguous().view(-1, out_.size(-1)),
                        tensor_tgt_y_.contiguous().view(-1)) / (
                            tensor_ntokens.data + eps)
                else:
                    loss_rec = torch.tensor(float("nan"))

                ae_optimizer.zero_grad()
                (loss_dis + loss_deb + loss_rec).backward()
                ae_optimizer.step()

                if True:
                    n_v, n_w = get_n_v_w(tensor_tgt_y, out)
                else:
                    n_w = float("nan")
                    n_v = float("nan")

                x_e = loss_rec.item() * n_w
                loss_ae += loss_rec.item()
                n_words_ae += n_w
                xe_loss_ae += x_e
                n_valid_ae += n_v
                ae_acc = 100. * n_v / (n_w + eps)
                avg_ae_acc = 100. * n_valid_ae / (n_words_ae + eps)
                avg_ae_loss = loss_ae / (it + 1)
                ae_ppl = np.exp(x_e / (n_w + eps))
                avg_ae_ppl = np.exp(xe_loss_ae / (n_words_ae + eps))

                x_e = loss_rec.item() * n_w
                loss_ae += loss_rec.item()
                n_words_ae += n_w
                xe_loss_ae += x_e
                n_valid_ae += n_v

                if it % args.log_interval == 0:
                    add_log(args, "")
                    add_log(
                        args, 'epoch {:3d} | {:5d}/{:5d} batches |'.format(
                            epoch, it, train_data_loader.num_batch))
                    add_log(
                        args,
                        'Train : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
                        .format(ae_acc, loss_rec.item(), ae_ppl, clf_acc,
                                loss_dis.item()))
                    add_log(
                        args,
                        'Train : avg : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} |  dis acc {:5.4f} | diss loss {:5.4f} |'
                        .format(avg_ae_acc, avg_ae_loss, avg_ae_ppl,
                                avg_clf_acc, avg_clf_loss))

                    add_log(
                        args, "input : %s" %
                        id2text_sentence(tensor_tgt_y[0], args.id_to_word))
                    generator_text = ae_model.greedy_decode(
                        z,
                        max_len=args.max_sequence_length,
                        start_id=args.id_bos)
                    # batch_sentences
                    add_log(
                        args, "gen : %s" %
                        id2text_sentence(generator_text[0], args.id_to_word))
                    if mask_deb.any():
                        generator_text_prime = ae_model.greedy_decode(
                            z_prime,
                            max_len=args.max_sequence_length,
                            start_id=args.id_bos)

                        add_log(
                            args, "deb : %s" % id2text_sentence(
                                generator_text_prime[0], args.id_to_word))

        s = {}
        L = train_data_loader.num_batch + eps
        s["train_ae_loss"] = loss_ae / L
        s["train_ae_acc"] = 100. * n_valid_ae / (n_words_ae + eps)
        s["train_ae_ppl"] = np.exp(xe_loss_ae / (n_words_ae + eps))
        s["train_clf_loss"] = loss_clf / L
        s["train_clf_acc"] = 100. * n_valid_clf / (total_clf + eps)
        stats.append(s)

        add_log(args, "")
        add_log(
            args, '| end of epoch {:3d} | time: {:5.2f}s |'.format(
                epoch, (time.time() - epoch_start_time)))

        add_log(
            args,
            '| rec acc {:5.4f} | rec loss {:5.4f} | rec ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
            .format(s["train_ae_acc"], s["train_ae_loss"], s["train_ae_ppl"],
                    s["train_clf_acc"], s["train_clf_loss"]))

        # Save model
        torch.save(
            ae_model.state_dict(),
            os.path.join(args.current_save_path, 'ae_model_params_deb.pkl'))
        torch.save(
            f.state_dict(),
            os.path.join(args.current_save_path, 'dis_model_params_deb.pkl'))
        torch.save(
            deb.state_dict(),
            os.path.join(args.current_save_path, 'deb_model_params_deb.pkl'))

    add_log(args, "Saving training statistics %s ..." % args.current_save_path)
    torch.save(stats,
               os.path.join(args.current_save_path, 'stats_train_deb.pkl'))
Пример #14
0
def fgim_algorithm(args, ae_model, dis_model):
    batch_size = 1
    test_data_loader = non_pair_data_loader(
        batch_size=batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    file_list = [args.test_data_file]
    test_data_loader.create_batches(args,
                                    file_list,
                                    if_shuffle=False,
                                    n_samples=args.test_n_samples)
    if args.references_files:
        gold_ans = load_human_answer(args.references_files, args.text_column)
        assert len(gold_ans) == test_data_loader.num_batch
    else:
        gold_ans = [[None] * batch_size] * test_data_loader.num_batch

    add_log(args, "Start eval process.")
    ae_model.eval()
    dis_model.eval()

    fgim_our = True
    if fgim_our:
        # for FGIM
        z_prime, text_z_prime = fgim(test_data_loader,
                                     args,
                                     ae_model,
                                     dis_model,
                                     gold_ans=gold_ans)
        write_text_z_in_file(args, text_z_prime)
        add_log(
            args,
            "Saving model modify embedding %s ..." % args.current_save_path)
        torch.save(z_prime,
                   os.path.join(args.current_save_path, 'z_prime_fgim.pkl'))
    else:
        for it in range(test_data_loader.num_batch):
            batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = test_data_loader.next_batch()

            print("------------%d------------" % it)
            print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
            print("origin_labels", tensor_labels)

            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask,
                                           tensor_src_attn_mask,
                                           tensor_tgt_mask)
            generator_text = ae_model.greedy_decode(
                latent, max_len=args.max_sequence_length, start_id=args.id_bos)
            print(id2text_sentence(generator_text[0], args.id_to_word))

            # Define target label
            target = get_cuda(torch.tensor([[1.0]], dtype=torch.float), args)
            if tensor_labels[0].item() > 0.5:
                target = get_cuda(torch.tensor([[0.0]], dtype=torch.float),
                                  args)
            add_log(args, "target_labels : %s" % target)

            modify_text = fgim_attack(dis_model, latent, target, ae_model,
                                      args.max_sequence_length, args.id_bos,
                                      id2text_sentence, args.id_to_word,
                                      gold_ans[it])

            add_output(args, modify_text)