def eval_iters(ae_model, dis_model):
    eval_data_loader = non_pair_data_loader(
        batch_size=1,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    eval_file_list = [
        args.data_path + 'sentiment.test.0',
        args.data_path + 'sentiment.test.1',
    ]
    eval_label_list = [
        [0],
        [1],
    ]
    eval_data_loader.create_batches(eval_file_list,
                                    eval_label_list,
                                    if_shuffle=False)
    gold_ans = load_human_answer(args.data_path)
    assert len(gold_ans) == eval_data_loader.num_batch

    add_log("Start eval process.")
    ae_model.eval()
    dis_model.eval()
    for it in range(eval_data_loader.num_batch):
        batch_sentences, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch()

        print("------------%d------------" % it)
        print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
        print("origin_labels", tensor_labels)

        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_tgt_mask)
        generator_text = ae_model.greedy_decode(
            latent, max_len=args.max_sequence_length, start_id=args.id_bos)
        print(id2text_sentence(generator_text[0], args.id_to_word))

        # Define target label
        target = get_cuda(torch.tensor([[1.0]], dtype=torch.float))
        if tensor_labels[0].item() > 0.5:
            target = get_cuda(torch.tensor([[0.0]], dtype=torch.float))
        print("target_labels", target)

        modify_text = fgim_attack(dis_model, latent, target, ae_model,
                                  args.max_sequence_length, args.id_bos,
                                  id2text_sentence, args.id_to_word,
                                  gold_ans[it])
        add_output(modify_text)
    return
def train_iters(ae_model, dis_model):
    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    train_data_loader.create_batches(args.train_file_list,
                                     args.train_label_list,
                                     if_shuffle=True)
    add_log("Start train process.")
    ae_model.train()
    dis_model.train()

    ae_optimizer = NoamOpt(
        ae_model.src_embed[0].d_model, 1, 2000,
        torch.optim.Adam(ae_model.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1))
    dis_criterion = nn.BCELoss(size_average=True)

    for epoch in range(200):
        print('-' * 94)
        epoch_start_time = time.time()
        for it in range(train_data_loader.num_batch):
            batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = train_data_loader.next_batch()

            # For debug
            # print(batch_sentences[0])
            # print(tensor_src[0])
            # print(tensor_src_mask[0])
            # print("tensor_src_mask", tensor_src_mask.size())
            # print(tensor_tgt[0])
            # print(tensor_tgt_y[0])
            # print(tensor_tgt_mask[0])
            # print(batch_ntokens)

            # Forward pass
            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask, tensor_tgt_mask)
            # print(latent.size())  # (batch_size, max_src_seq, d_model)
            # print(out.size())  # (batch_size, max_tgt_seq, vocab_size)

            # Loss calculation
            loss_rec = ae_criterion(
                out.contiguous().view(-1, out.size(-1)),
                tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

            # loss_all = loss_rec + loss_dis

            ae_optimizer.optimizer.zero_grad()
            loss_rec.backward()
            ae_optimizer.step()

            # Classifier
            dis_lop = dis_model.forward(to_var(latent.clone()))

            loss_dis = dis_criterion(dis_lop, tensor_labels)

            dis_optimizer.zero_grad()
            loss_dis.backward()
            dis_optimizer.step()

            if it % 200 == 0:
                add_log(
                    '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |'
                    .format(epoch, it, train_data_loader.num_batch, loss_rec,
                            loss_dis))

                print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
                generator_text = ae_model.greedy_decode(
                    latent,
                    max_len=args.max_sequence_length,
                    start_id=args.id_bos)
                print(id2text_sentence(generator_text[0], args.id_to_word))

        add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))
        # Save model
        torch.save(ae_model.state_dict(),
                   args.current_save_path + 'ae_model_params.pkl')
        torch.save(dis_model.state_dict(),
                   args.current_save_path + 'dis_model_params.pkl')
    return
Esempio n. 3
0
            # backward + optimization
            ae_optimizer.zero_grad()
            Loss_rec.backward()
            ae_optimizer.step()

        print('| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))

    return


if __name__ == '__main__':
    preparation()

    train_data_loader = non_pair_data_loader(args.batch_size)
    train_data_loader.create_batches(args.train_file_list,
                                     args.train_label_list,
                                     if_shuffle=True)

    # create models
    ae_model = get_cuda(
        EncoderDecoder(
            vocab_size=args.vocab_size,
            embedding_size=args.embedding_size,
            hidden_size=args.hidden_size,
            num_layers=args.num_layers_AE,
            word_dropout=args.word_dropout,
            embedding_dropout=args.embedding_dropout,
            sos_idx=args.id_bos,
            eos_idx=args.id_eos,
    dis_model=get_cuda(Classifier(1, args),args.gpu)
    ae_optimizer = NoamOpt(ae_model.src_embed[0].d_model, 1, 2000,
                           torch.optim.Adam(ae_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
    dis_optimizer=torch.optim.Adam(dis_model.parameters(), lr=0.0001)
    if args.load_model:
        # Load models' params from checkpoint
        ae_model.load_state_dict(torch.load(args.current_save_path + '/{}_ae_model_params.pkl'.format(args.load_iter), map_location=device))
        dis_model.load_state_dict(torch.load(args.current_save_path + '/{}_dis_model_params.pkl'.format(args.load_iter), map_location=device))
       
        start=args.load_iter+1 
    else:
        start=0        
    
    train_data_loader=non_pair_data_loader(
        batch_size=args.batch_size, id_bos=args.id_bos,
        id_eos=args.id_eos, id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length, vocab_size=args.vocab_size,
        gpu=args.gpu
    )

    train_data_loader.create_batches(args.train_file_list, args.train_label_list, if_shuffle=True)
    
    eval_data_loader=non_pair_data_loader(
#        batch_size=args.batch_size, id_bos=args.id_bos,
        batch_size=200, id_bos=args.id_bos,
        id_eos=args.id_eos, id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length, vocab_size=args.vocab_size,
        gpu=args.gpu
    )

    eval_file_list=[
        args.data_path+'sentiment.dev.0',
Esempio n. 5
0
def generation(ae_model, sm, args):
    eval_data_loader = non_pair_data_loader(
        batch_size=1,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size,
        gpu=args.gpu)

    eval_file_list = [
        args.data_path + 'sentiment.test.0',
        args.data_path + 'sentiment.test.1'
    ]

    eval_label_list = [[0], [1]]
    eval_data_loader.create_batches(eval_file_list,
                                    eval_label_list,
                                    if_shuffle=False)
    ae_model.eval()

    sent_dic = {0: 'chatbot', 1: 'bible'}

    trans_pos = list()
    dif_pos = list()
    trans_neg = list()
    dif_neg = list()

    pred = list()
    true = list()
    for it in range(eval_data_loader.num_batch):
        ####################
        #####load data######
        ####################
        batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch()

        latent = ae_model.getLatent(tensor_src, tensor_src_mask)
        style, similarity = ae_model.getSim(latent)
        sign = 2 * (tensor_labels.long()) - 1
        t_sign = 2 * (1 - tensor_labels.long()) - 1
        #style_prob=ae_model.sigmoid(similarity)

        trans_emb = style.clone()[torch.arange(style.size(0)),
                                  (1 - tensor_labels).long().item()]
        own_emb = style.clone()[torch.arange(style.size(0)),
                                tensor_labels.long().item()]
        w = args.weight

        out_1 = ae_model.greedy_decode(
            latent + sign * w * (trans_emb + own_emb),
            args.max_sequence_length, args.id_bos)
        style_1 = sm.DecodeIds(id2list(out_1[0].tolist()))
        add_output(
            style_1,
            './generation/{}/sign_{}_add.txt'.format(args.name, args.weight))
        out_3 = ae_model.greedy_decode(latent + sign * w * (own_emb),
                                       args.max_sequence_length, args.id_bos)
        style_3 = sm.DecodeIds(id2list(out_3[0].tolist()))
        add_output(
            style_3,
            './generation/{}/sign_{}_own.txt'.format(args.name, args.weight))

        sent = sent_dic[tensor_labels.item()]
        trans = sent_dic[1 - tensor_labels.item()]
        print("------------%d------------" % it)
        print('original {}:'.format(sent),
              sm.DecodeIds(tensor_tgt_y[0].tolist()))
        print('s:{} w:{} {}:'.format(t_sign.item(), args.weight, trans),
              style_3)
Esempio n. 6
0
def sedat_eval(args, ae_model, f, deb):
    """
    Input: 
        Original latent representation z : (n_batch, batch_size, seq_length, latent_size)
    Output: 
        An optimal modified latent representation z'
    """
    max_sequence_length = args.max_sequence_length
    id_bos = args.id_bos
    id_to_word = args.id_to_word
    limit_batches = args.limit_batches

    eval_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    file_list = [args.test_data_file]
    eval_data_loader.create_batches(args,
                                    file_list,
                                    if_shuffle=False,
                                    n_samples=args.test_n_samples)
    if args.references_files:
        gold_ans = load_human_answer(args.references_files, args.text_column)
        assert len(gold_ans) == eval_data_loader.num_batch
    else:
        gold_ans = None

    add_log(args, "Start eval process.")
    ae_model.eval()
    f.eval()
    deb.eval()

    text_z_prime = {}
    text_z_prime = {
        "source": [],
        "origin_labels": [],
        "before": [],
        "after": [],
        "change": [],
        "pred_label": []
    }
    if gold_ans is not None:
        text_z_prime["gold_ans"] = []
    z_prime = []
    n_batches = 0
    for it in tqdm.tqdm(list(range(eval_data_loader.num_batch)), desc="SEDAT"):

        _, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, _ = eval_data_loader.next_batch()
        # only on negative example
        negative_examples = ~(tensor_labels.squeeze() == args.positive_label)
        tensor_labels = tensor_labels[negative_examples].squeeze(
            0)  # .view(1, -1)
        tensor_src = tensor_src[negative_examples].squeeze(0)
        tensor_src_mask = tensor_src_mask[negative_examples].squeeze(0)
        tensor_src_attn_mask = tensor_src_attn_mask[negative_examples].squeeze(
            0)
        tensor_tgt_y = tensor_tgt_y[negative_examples].squeeze(0)
        tensor_tgt = tensor_tgt[negative_examples].squeeze(0)
        tensor_tgt_mask = tensor_tgt_mask[negative_examples].squeeze(0)
        if negative_examples.any():
            if gold_ans is not None:
                text_z_prime["gold_ans"].append(gold_ans[it])

            text_z_prime["source"].append(
                [id2text_sentence(t, args.id_to_word) for t in tensor_tgt_y])
            text_z_prime["origin_labels"].append(tensor_labels.cpu().numpy())

            origin_data, _ = ae_model.forward(tensor_src, tensor_tgt,
                                              tensor_src_mask,
                                              tensor_src_attn_mask,
                                              tensor_tgt_mask)

            generator_id = ae_model.greedy_decode(origin_data,
                                                  max_len=max_sequence_length,
                                                  start_id=id_bos)
            generator_text = [
                id2text_sentence(gid, id_to_word) for gid in generator_id
            ]
            text_z_prime["before"].append(generator_text)

            data = deb(origin_data, mask=None)
            data = torch.sum(ae_model.sigmoid(data),
                             dim=1)  # (batch_size, d_model)
            #logit = ae_model.decode(data.unsqueeze(1), tensor_tgt, tensor_tgt_mask)  # (batch_size, max_tgt_seq, d_model)
            #output = ae_model.generator(logit)  # (batch_size, max_seq, vocab_size)
            y_hat = f.forward(data)
            y_hat = y_hat.round().int()
            z_prime.append(data)
            generator_id = ae_model.greedy_decode(data,
                                                  max_len=max_sequence_length,
                                                  start_id=id_bos)
            generator_text = [
                id2text_sentence(gid, id_to_word) for gid in generator_id
            ]
            text_z_prime["after"].append(generator_text)
            text_z_prime["change"].append([True] * len(y_hat))
            text_z_prime["pred_label"].append([y_.item() for y_ in y_hat])

            n_batches += 1
            if n_batches > limit_batches:
                break
    write_text_z_in_file(args, text_z_prime)
    add_log(args, "")
    add_log(args,
            "Saving model modify embedding %s ..." % args.current_save_path)
    torch.save(z_prime,
               os.path.join(args.current_save_path, 'z_prime_sedat.pkl'))
    return z_prime, text_z_prime
Esempio n. 7
0
def sedat_train(args, ae_model, f, deb):
    """
    Input: 
        Original latent representation z : (n_batch, batch_size, seq_length, latent_size)
    Output: 
        An optimal modified latent representation z'
    """
    # TODO : fin a metric to control the evelotuion of training, mainly for deb model
    lambda_ = args.sedat_threshold
    alpha, beta = [float(coef) for coef in args.sedat_alpha_beta.split(",")]
    # only on negative example
    only_on_negative_example = args.sedat_only_on_negative_example
    penalty = args.penalty
    type_penalty = args.type_penalty

    assert penalty in ["lasso", "ridge"]
    assert type_penalty in ["last", "group"]

    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    file_list = [args.train_data_file]
    if os.path.exists(args.val_data_file):
        file_list.append(args.val_data_file)
    train_data_loader.create_batches(args,
                                     file_list,
                                     if_shuffle=True,
                                     n_samples=args.train_n_samples)

    add_log(args, "Start train process.")

    #add_log("Start train process.")
    ae_model.train()
    f.train()
    deb.train()

    ae_optimizer = get_optimizer(parameters=ae_model.parameters(),
                                 s=args.ae_optimizer,
                                 noamopt=args.ae_noamopt)
    dis_optimizer = get_optimizer(parameters=f.parameters(),
                                  s=args.dis_optimizer)
    deb_optimizer = get_optimizer(parameters=deb.parameters(),
                                  s=args.dis_optimizer)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1), args)
    dis_criterion = nn.BCELoss(size_average=True)
    deb_criterion = LossSedat(penalty=penalty)

    stats = []
    for epoch in range(args.max_epochs):
        print('-' * 94)
        epoch_start_time = time.time()

        loss_ae, n_words_ae, xe_loss_ae, n_valid_ae = 0, 0, 0, 0
        loss_clf, total_clf, n_valid_clf = 0, 0, 0
        for it in range(train_data_loader.num_batch):
            _, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, _ = train_data_loader.next_batch()
            flag = True
            # only on negative example
            if only_on_negative_example:
                negative_examples = ~(tensor_labels.squeeze()
                                      == args.positive_label)
                tensor_labels = tensor_labels[negative_examples].squeeze(
                    0)  # .view(1, -1)
                tensor_src = tensor_src[negative_examples].squeeze(0)
                tensor_src_mask = tensor_src_mask[negative_examples].squeeze(0)
                tensor_src_attn_mask = tensor_src_attn_mask[
                    negative_examples].squeeze(0)
                tensor_tgt_y = tensor_tgt_y[negative_examples].squeeze(0)
                tensor_tgt = tensor_tgt[negative_examples].squeeze(0)
                tensor_tgt_mask = tensor_tgt_mask[negative_examples].squeeze(0)
                flag = negative_examples.any()
            if flag:
                # forward
                z, out, z_list = ae_model.forward(tensor_src,
                                                  tensor_tgt,
                                                  tensor_src_mask,
                                                  tensor_src_attn_mask,
                                                  tensor_tgt_mask,
                                                  return_intermediate=True)
                #y_hat = f.forward(to_var(z.clone()))
                y_hat = f.forward(z)

                loss_dis = dis_criterion(y_hat, tensor_labels)
                dis_optimizer.zero_grad()
                loss_dis.backward(retain_graph=True)
                dis_optimizer.step()

                dis_lop = f.forward(z)
                t_c = tensor_labels.view(-1).size(0)
                n_v = (dis_lop.round().int() == tensor_labels).sum().item()
                loss_clf += loss_dis.item()
                total_clf += t_c
                n_valid_clf += n_v
                clf_acc = 100. * n_v / (t_c + eps)
                avg_clf_acc = 100. * n_valid_clf / (total_clf + eps)
                avg_clf_loss = loss_clf / (it + 1)

                mask_deb = y_hat.squeeze(
                ) >= lambda_ if args.positive_label == 0 else y_hat.squeeze(
                ) < lambda_
                # if f(z) > lambda :
                if mask_deb.any():
                    y_hat_deb = y_hat[mask_deb]
                    if type_penalty == "last":
                        z_deb = z[mask_deb].squeeze(
                            0) if args.batch_size == 1 else z[mask_deb]
                    elif type_penalty == "group":
                        # TODO : unit test for bach_size = 1
                        z_deb = z_list[-1][mask_deb]
                    z_prime, z_prime_list = deb(z_deb,
                                                mask=None,
                                                return_intermediate=True)
                    if type_penalty == "last":
                        z_prime = torch.sum(ae_model.sigmoid(z_prime), dim=1)
                        loss_deb = alpha * deb_criterion(
                            z_deb, z_prime,
                            is_list=False) + beta * y_hat_deb.sum()
                    elif type_penalty == "group":
                        z_deb_list = [z_[mask_deb] for z_ in z_list]
                        #assert len(z_deb_list) == len(z_prime_list)
                        loss_deb = alpha * deb_criterion(
                            z_deb_list, z_prime_list,
                            is_list=True) + beta * y_hat_deb.sum()

                    deb_optimizer.zero_grad()
                    loss_deb.backward(retain_graph=True)
                    deb_optimizer.step()
                else:
                    loss_deb = torch.tensor(float("nan"))

                # else :
                if (~mask_deb).any():
                    out_ = out[~mask_deb]
                    tensor_tgt_y_ = tensor_tgt_y[~mask_deb]
                    tensor_ntokens = (tensor_tgt_y_ != 0).data.sum().float()
                    loss_rec = ae_criterion(
                        out_.contiguous().view(-1, out_.size(-1)),
                        tensor_tgt_y_.contiguous().view(-1)) / (
                            tensor_ntokens.data + eps)
                else:
                    loss_rec = torch.tensor(float("nan"))

                ae_optimizer.zero_grad()
                (loss_dis + loss_deb + loss_rec).backward()
                ae_optimizer.step()

                if True:
                    n_v, n_w = get_n_v_w(tensor_tgt_y, out)
                else:
                    n_w = float("nan")
                    n_v = float("nan")

                x_e = loss_rec.item() * n_w
                loss_ae += loss_rec.item()
                n_words_ae += n_w
                xe_loss_ae += x_e
                n_valid_ae += n_v
                ae_acc = 100. * n_v / (n_w + eps)
                avg_ae_acc = 100. * n_valid_ae / (n_words_ae + eps)
                avg_ae_loss = loss_ae / (it + 1)
                ae_ppl = np.exp(x_e / (n_w + eps))
                avg_ae_ppl = np.exp(xe_loss_ae / (n_words_ae + eps))

                x_e = loss_rec.item() * n_w
                loss_ae += loss_rec.item()
                n_words_ae += n_w
                xe_loss_ae += x_e
                n_valid_ae += n_v

                if it % args.log_interval == 0:
                    add_log(args, "")
                    add_log(
                        args, 'epoch {:3d} | {:5d}/{:5d} batches |'.format(
                            epoch, it, train_data_loader.num_batch))
                    add_log(
                        args,
                        'Train : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
                        .format(ae_acc, loss_rec.item(), ae_ppl, clf_acc,
                                loss_dis.item()))
                    add_log(
                        args,
                        'Train : avg : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} |  dis acc {:5.4f} | diss loss {:5.4f} |'
                        .format(avg_ae_acc, avg_ae_loss, avg_ae_ppl,
                                avg_clf_acc, avg_clf_loss))

                    add_log(
                        args, "input : %s" %
                        id2text_sentence(tensor_tgt_y[0], args.id_to_word))
                    generator_text = ae_model.greedy_decode(
                        z,
                        max_len=args.max_sequence_length,
                        start_id=args.id_bos)
                    # batch_sentences
                    add_log(
                        args, "gen : %s" %
                        id2text_sentence(generator_text[0], args.id_to_word))
                    if mask_deb.any():
                        generator_text_prime = ae_model.greedy_decode(
                            z_prime,
                            max_len=args.max_sequence_length,
                            start_id=args.id_bos)

                        add_log(
                            args, "deb : %s" % id2text_sentence(
                                generator_text_prime[0], args.id_to_word))

        s = {}
        L = train_data_loader.num_batch + eps
        s["train_ae_loss"] = loss_ae / L
        s["train_ae_acc"] = 100. * n_valid_ae / (n_words_ae + eps)
        s["train_ae_ppl"] = np.exp(xe_loss_ae / (n_words_ae + eps))
        s["train_clf_loss"] = loss_clf / L
        s["train_clf_acc"] = 100. * n_valid_clf / (total_clf + eps)
        stats.append(s)

        add_log(args, "")
        add_log(
            args, '| end of epoch {:3d} | time: {:5.2f}s |'.format(
                epoch, (time.time() - epoch_start_time)))

        add_log(
            args,
            '| rec acc {:5.4f} | rec loss {:5.4f} | rec ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
            .format(s["train_ae_acc"], s["train_ae_loss"], s["train_ae_ppl"],
                    s["train_clf_acc"], s["train_clf_loss"]))

        # Save model
        torch.save(
            ae_model.state_dict(),
            os.path.join(args.current_save_path, 'ae_model_params_deb.pkl'))
        torch.save(
            f.state_dict(),
            os.path.join(args.current_save_path, 'dis_model_params_deb.pkl'))
        torch.save(
            deb.state_dict(),
            os.path.join(args.current_save_path, 'deb_model_params_deb.pkl'))

    add_log(args, "Saving training statistics %s ..." % args.current_save_path)
    torch.save(stats,
               os.path.join(args.current_save_path, 'stats_train_deb.pkl'))
Esempio n. 8
0
def fgim_algorithm(args, ae_model, dis_model):
    batch_size = 1
    test_data_loader = non_pair_data_loader(
        batch_size=batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    file_list = [args.test_data_file]
    test_data_loader.create_batches(args,
                                    file_list,
                                    if_shuffle=False,
                                    n_samples=args.test_n_samples)
    if args.references_files:
        gold_ans = load_human_answer(args.references_files, args.text_column)
        assert len(gold_ans) == test_data_loader.num_batch
    else:
        gold_ans = [[None] * batch_size] * test_data_loader.num_batch

    add_log(args, "Start eval process.")
    ae_model.eval()
    dis_model.eval()

    fgim_our = True
    if fgim_our:
        # for FGIM
        z_prime, text_z_prime = fgim(test_data_loader,
                                     args,
                                     ae_model,
                                     dis_model,
                                     gold_ans=gold_ans)
        write_text_z_in_file(args, text_z_prime)
        add_log(
            args,
            "Saving model modify embedding %s ..." % args.current_save_path)
        torch.save(z_prime,
                   os.path.join(args.current_save_path, 'z_prime_fgim.pkl'))
    else:
        for it in range(test_data_loader.num_batch):
            batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = test_data_loader.next_batch()

            print("------------%d------------" % it)
            print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
            print("origin_labels", tensor_labels)

            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask,
                                           tensor_src_attn_mask,
                                           tensor_tgt_mask)
            generator_text = ae_model.greedy_decode(
                latent, max_len=args.max_sequence_length, start_id=args.id_bos)
            print(id2text_sentence(generator_text[0], args.id_to_word))

            # Define target label
            target = get_cuda(torch.tensor([[1.0]], dtype=torch.float), args)
            if tensor_labels[0].item() > 0.5:
                target = get_cuda(torch.tensor([[0.0]], dtype=torch.float),
                                  args)
            add_log(args, "target_labels : %s" % target)

            modify_text = fgim_attack(dis_model, latent, target, ae_model,
                                      args.max_sequence_length, args.id_bos,
                                      id2text_sentence, args.id_to_word,
                                      gold_ans[it])

            add_output(args, modify_text)
Esempio n. 9
0
def pretrain(args, ae_model, dis_model):
    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    train_data_loader.create_batches(args, [args.train_data_file],
                                     if_shuffle=True,
                                     n_samples=args.train_n_samples)

    val_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    val_data_loader.create_batches(args, [args.val_data_file],
                                   if_shuffle=True,
                                   n_samples=args.valid_n_samples)

    ae_model.train()
    dis_model.train()

    ae_optimizer = get_optimizer(parameters=ae_model.parameters(),
                                 s=args.ae_optimizer,
                                 noamopt=args.ae_noamopt)
    dis_optimizer = get_optimizer(parameters=dis_model.parameters(),
                                  s=args.dis_optimizer)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1), args)
    dis_criterion = nn.BCELoss(size_average=True)

    possib = [
        "%s_%s" % (i, j) for i, j in itertools.product(
            ["train", "eval"],
            ["ae_loss", "ae_acc", "ae_ppl", "clf_loss", "clf_acc"])
    ]
    stopping_criterion, best_criterion, decrease_counts, decrease_counts_max = settings(
        args, possib)
    metric, biggest = stopping_criterion
    factor = 1 if biggest else -1

    stats = []

    add_log(args, "Start train process.")
    for epoch in range(args.max_epochs):
        print('-' * 94)
        add_log(args, "")
        s_train = train_step(args, train_data_loader, ae_model, dis_model,
                             ae_optimizer, dis_optimizer, ae_criterion,
                             dis_criterion, epoch)
        add_log(args, "")
        s_eval = eval_step(args, val_data_loader, ae_model, dis_model,
                           ae_criterion, dis_criterion)
        scores = {**s_train, **s_eval}
        stats.append(scores)
        add_log(args, "")
        if factor * scores[metric] > factor * best_criterion:
            best_criterion = scores[metric]
            add_log(args, "New best validation score: %f" % best_criterion)
            decrease_counts = 0
            # Save model
            add_log(args, "Saving model to %s ..." % args.current_save_path)
            torch.save(
                ae_model.state_dict(),
                os.path.join(args.current_save_path, 'ae_model_params.pkl'))
            torch.save(
                dis_model.state_dict(),
                os.path.join(args.current_save_path, 'dis_model_params.pkl'))
        else:
            add_log(
                args, "Not a better validation score (%i / %i)." %
                (decrease_counts, decrease_counts_max))
            decrease_counts += 1
        if decrease_counts > decrease_counts_max:
            add_log(
                args,
                "Stopping criterion has been below its best value for more "
                "than %i epochs. Ending the experiment..." %
                decrease_counts_max)
            #exit()
            break

    s_test = None
    if os.path.exists(args.test_data_file):
        add_log(args, "")
        test_data_loader = non_pair_data_loader(
            batch_size=args.batch_size,
            id_bos=args.id_bos,
            id_eos=args.id_eos,
            id_unk=args.id_unk,
            max_sequence_length=args.max_sequence_length,
            vocab_size=args.vocab_size)
        test_data_loader.create_batches(args, [args.test_data_file],
                                        if_shuffle=True,
                                        n_samples=args.test_n_samples)
        s = eval_step(args, test_data_loader, ae_model, dis_model,
                      ae_criterion, dis_criterion)
        add_log(
            args,
            'Test | rec acc {:5.4f} | rec loss {:5.4f} | rec ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
            .format(s["eval_ae_acc"], s["eval_ae_loss"], s["eval_ae_ppl"],
                    s["eval_clf_acc"], s["eval_clf_loss"]))
        s_test = s
    add_log(args, "")
    add_log(args, "Saving training statistics %s ..." % args.current_save_path)
    torch.save(stats,
               os.path.join(args.current_save_path, 'stats_train_eval.pkl'))
    if s_test is not None:
        torch.save(s_test, os.path.join(args.current_save_path,
                                        'stat_test.pkl'))
    return stats, s_test