Esempio n. 1
0
    def create_batches(self, input_tokens, input_label):
        self.data_label_pairs = [[input_tokens, [input_label]]]

        # Split batches
        if self.batch_size == None:
            self.batch_size = len(self.data_label_pairs)
        self.num_batch = int(len(self.data_label_pairs) / self.batch_size)
        for _index in range(self.num_batch):
            item_data_label_pairs = self.data_label_pairs[_index*self.batch_size:(_index+1)*self.batch_size]
            item_sentences = [_i[0] for _i in item_data_label_pairs]
            item_labels = [_i[1] for _i in item_data_label_pairs]

            batch_encoder_input, batch_decoder_input, batch_decoder_target,             batch_encoder_length, batch_decoder_length = pad_batch_seuqences(
                item_sentences, self.id_bos, self.id_eos, self.id_unk, self.max_sequence_length, self.vocab_size,)

            src = get_cuda(torch.tensor(batch_encoder_input, dtype=torch.long))
            tgt = get_cuda(torch.tensor(batch_decoder_input, dtype=torch.long))
            tgt_y = get_cuda(torch.tensor(batch_decoder_target, dtype=torch.long))

            src_mask = (src != 0).unsqueeze(-2)
            tgt_mask = self.make_std_mask(tgt, 0)
            ntokens = (tgt_y != 0).data.sum().float()

            self.sentences_batches.append(item_sentences)
            self.labels_batches.append(get_cuda(torch.tensor(item_labels, dtype=torch.float)))
            self.src_batches.append(src)
            self.tgt_batches.append(tgt)
            self.tgt_y_batches.append(tgt_y)
            self.src_mask_batches.append(src_mask)
            self.tgt_mask_batches.append(tgt_mask)
            self.ntokens_batches.append(ntokens)

        self.pointer = 0
def eval_iters(ae_model, dis_model):
    eval_data_loader = non_pair_data_loader(
        batch_size=1,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    eval_file_list = [
        args.data_path + 'sentiment.test.0',
        args.data_path + 'sentiment.test.1',
    ]
    eval_label_list = [
        [0],
        [1],
    ]
    eval_data_loader.create_batches(eval_file_list,
                                    eval_label_list,
                                    if_shuffle=False)
    gold_ans = load_human_answer(args.data_path)
    assert len(gold_ans) == eval_data_loader.num_batch

    add_log("Start eval process.")
    ae_model.eval()
    dis_model.eval()
    for it in range(eval_data_loader.num_batch):
        batch_sentences, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch()

        print("------------%d------------" % it)
        print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
        print("origin_labels", tensor_labels)

        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_tgt_mask)
        generator_text = ae_model.greedy_decode(
            latent, max_len=args.max_sequence_length, start_id=args.id_bos)
        print(id2text_sentence(generator_text[0], args.id_to_word))

        # Define target label
        target = get_cuda(torch.tensor([[1.0]], dtype=torch.float))
        if tensor_labels[0].item() > 0.5:
            target = get_cuda(torch.tensor([[0.0]], dtype=torch.float))
        print("target_labels", target)

        modify_text = fgim_attack(dis_model, latent, target, ae_model,
                                  args.max_sequence_length, args.id_bos,
                                  id2text_sentence, args.id_to_word,
                                  gold_ans[it])
        add_output(modify_text)
        output_text = str(it) + ":\ngold: " + id2text_sentence(
            gold_ans[it], args.id_to_word) + "\nmodified: " + modify_text
        add_output(output_text)
        add_result(
            str(it) + ":\n" + str(
                calc_bleu(id2text_sentence(gold_ans[it], args.id_to_word),
                          modify_text)))
    return
Esempio n. 3
0
def get_models(args):
    ae_model = get_cuda(make_model(d_vocab=args.vocab_size,
                                   N=args.num_layers_AE,
                                   d_model=args.transformer_model_size,
                                   latent_size=args.latent_size,
                                   d_ff=args.transformer_ff_size))
    dis_model = get_cuda(Classifier(latent_size=args.latent_size, output_size=args.label_size))
    ae_model.load_state_dict(torch.load(args.current_save_path + 'ae_model_params.pkl', map_location=torch.device('cpu')))
    dis_model.load_state_dict(torch.load(args.current_save_path + 'dis_model_params.pkl', map_location=torch.device('cpu')))
    return ae_model, dis_model
def plot_tsne(ae_model, dis_model):
    epsilon = 2
    step = 1
    eval_data_loader = non_pair_data_loader(
        batch_size=1,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    eval_file_list = [
        args.data_path + 'sentiment.test.0',
        args.data_path + 'sentiment.test.1',
    ]
    eval_label_list = [
        [0],
        [1],
    ]
    eval_data_loader.create_batches(eval_file_list,
                                    eval_label_list,
                                    if_shuffle=False)
    gold_ans = load_human_answer(args.data_path)
    assert len(gold_ans) == eval_data_loader.num_batch

    ae_model.eval()
    dis_model.eval()
    latents, labels = [], []
    for it in range(eval_data_loader.num_batch):
        batch_sentences, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch()
        print("------------%d------------" % it)
        print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
        print("origin_labels", tensor_labels.item())

        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_tgt_mask)

        # Define target label
        target = get_cuda(torch.tensor([[1.0]], dtype=torch.float))
        if tensor_labels[0].item() > 0.5:
            target = get_cuda(torch.tensor([[0.0]], dtype=torch.float))

        modified_latent, modified_text = fgim_step(
            dis_model, latent, target, ae_model, args.max_sequence_length,
            args.id_bos, id2text_sentence, args.id_to_word, gold_ans[it],
            epsilon, step)
        latents.append(modified_latent)
        labels.append(tensor_labels.item())

    latents = torch.cat(latents, dim=0).detach().cpu().numpy()
    labels = numpy.array(labels)

    tsne_plot_representation(latents, labels)
    def getWeight(self, emb, style=None):
        emb_norm = torch.norm(emb, dim=-1)  #batch, seq

        if style is not None:
            style = style.unsqueeze(2)
            style = self.style_embed(style.long())
            pdb.set_trace()
            style = style.reshape(style.size(0), style.size(1), style.size(-1))
        else:
            style = get_cuda(torch.tensor([[0], [1]]).long(), self.gpu)
            style = torch.cat(emb.size(0) * [style])  #128, 2, 1
            style = style.reshape(emb.size(0), -1, 1)
            style = self.style_embed(style)  #(batch. style_num, 1, dim)
            style = style.squeeze()

        style_norm = torch.norm(style, dim=-1)  #batch, 2

        dot = torch.bmm(emb, style.transpose(1, 2))  #batch, seq, style_num

        for i in range(dot.shape[-1]):  #class num
            #norm by seq
            dot[:, :, i] = dot[:, :, i] / (emb_norm * style_norm[0, i].item())
        dot = dot.transpose(1, 2)  #batch, style_num, seq

        u = F.relu(self.conv(dot))  #batch, channel num, seq = #batch, dim, seq
        pooling = self.pool(u.transpose(1, 2)).squeeze()  #batch, seq, dim
        b_score = self.softmax(pooling)  #batch, seq, dim
        return b_score.squeeze()
Esempio n. 6
0
    def greedy_decode(self, latent, max_len, start_id):
        '''
        latent: (batch_size, max_src_seq, d_model)
        src_mask: (batch_size, 1, max_src_len)
        '''
        batch_size = latent.size(0)

        # memory = self.latent2memory(latent)

        ys = get_cuda(torch.ones(batch_size,
                                 1).fill_(start_id).long())  # (batch_size, 1)
        for i in range(max_len - 1):
            # input("==========")
            # print("="*10, i)
            # print("ys", ys.size())  # (batch_size, i)
            # print("tgt_mask", subsequent_mask(ys.size(1)).size())  # (1, i, i)
            out = self.decode(latent.unsqueeze(1), to_var(ys),
                              to_var(subsequent_mask(ys.size(1)).long()))
            prob = self.generator(out[:, -1])
            # print("prob", prob.size())  # (batch_size, vocab_size)
            _, next_word = torch.max(prob, dim=1)
            # print("next_word", next_word.size())  # (batch_size)

            # print("next_word.unsqueeze(1)", next_word.unsqueeze(1).size())

            ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
            # print("ys", ys.size())
        return ys[:, 1:]
def plot_tsne(ae_model, dis_model, epsilon=2, step=0):
    eval_data_loader = non_pair_data_loader(
        batch_size=500,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    eval_file_list = [
        args.data_path + 'sentiment.test.0',
        args.data_path + 'sentiment.test.1',
    ]
    eval_label_list = [
        [0],
        [1],
    ]
    eval_data_loader.create_batches(eval_file_list,
                                    eval_label_list,
                                    if_shuffle=False)
    gold_ans = load_human_answer(args.data_path)

    ae_model.eval()
    dis_model.eval()
    latents, labels = [], []
    it = 0
    for _ in range(eval_data_loader.num_batch):
        batch_sentences, tensor_labels, \
        tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
        tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch()
        print("------------%d------------" % it)
        print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
        print("origin_labels", tensor_labels[0].item())

        latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask,
                                       tensor_tgt_mask)

        # Define target label
        target = get_cuda(
            torch.ones((tensor_labels.size(0), 1), dtype=torch.float))
        target = target - tensor_labels

        if step > 0:
            latent, modified_text = fgim_step(dis_model, latent, target,
                                              ae_model,
                                              args.max_sequence_length,
                                              args.id_bos, id2text_sentence,
                                              args.id_to_word, gold_ans[it],
                                              epsilon, step)

        latents.append(latent)
        labels.append(tensor_labels)

        it += tensor_labels.size(0)

    latents = torch.cat(latents, dim=0).detach().cpu().numpy()
    labels = torch.cat(labels, dim=0).squeeze().detach().cpu().numpy()

    tsne_plot_representation(latents, labels, f"tsne_step{step}_eps{epsilon}")
    def getSim(self, latent, style=None):
        #latent_norm=torch.norm(latent, dim=-1) #batch, dim
        latent_clone = get_cuda(latent.clone(), self.gpu)
        if style is not None:
            style = style.unsqueeze(2)
            style = self.style_embed(style.long())
            pdb.set_trace()
            style = style.reshape(style.size(0), style.size(1), style.size(-1))
        else:
            style = get_cuda(torch.tensor([[0], [1]]).long(), self.gpu)
            style = torch.cat(latent.size(0) * [style])  #128, 2, 1
            style = style.reshape(latent_clone.size(0), -1, 1)
            style = self.style_embed(style)  #(batch. style_num, 1, dim)
            style = style.reshape(style.size(0), style.size(1), -1)

        dot = torch.bmm(style, latent_clone.unsqueeze(2))  #batch, style_num, 1
        dot = dot.reshape(dot.size(0), dot.size(1))
        return style, dot
    def infer_encode(self, src, src_mask, style):
        style_mod = 1 - style  #style_transfer #128,1
        #style_whole=torch.cat((style, style_mod), 1)

        emb = self.src_embed(src)  #(batch, seq, dim)
        score = self.getWeight(get_cuda(emb.clone().detach(), self.gpu),
                               style_mod)  #bath, seq
        input = score * emb
        return self.encoder(input, src_mask)
    def decode(self, memory, tgt, tgt_mask):
        # memory: (batch_size, 1, d_model)=latent
        src_mask = get_cuda(torch.ones(memory.size(0), 1, 1).long(), self.gpu)

        return self.decoder(
            self.tgt_embed(tgt),
            memory,
            src_mask,
            tgt_mask,
        )
Esempio n. 11
0
def predict(args, ae_model, dis_model, batch, epsilon):
    (batch_sentences, 
     tensor_labels, 
     tensor_src,
     tensor_src_mask,
     tensor_tgt,
     tensor_tgt_y,
     tensor_tgt_mask,
     tensor_ntokens) = batch
    
    ae_model.eval()
    dis_model.eval()
    latent, out = ae_model.forward(tensor_src, tensor_tgt, tensor_src_mask, tensor_tgt_mask)
    generator_text = ae_model.greedy_decode(latent,
                                            max_len=args.max_sequence_length,
                                            start_id=args.id_bos)
    print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
    print(id2text_sentence(generator_text[0], args.id_to_word))
    target = get_cuda(torch.tensor([[1.0]], dtype=torch.float))
    if tensor_labels[0].item() > 0.5:
        target = get_cuda(torch.tensor([[0.0]], dtype=torch.float))

    dis_criterion = nn.BCELoss(size_average=True)

    data = to_var(latent.clone())  # (batch_size, seq_length, latent_size)
    data.requires_grad = True
    output = dis_model.forward(data)
    loss = dis_criterion(output, target)
    dis_model.zero_grad()
    loss.backward()
    data_grad = data.grad.data
    data = data - epsilon * data_grad

    generator_id = ae_model.greedy_decode(data,
                                          max_len=args.max_sequence_length,
                                          start_id=args.id_bos)
    return id2text_sentence(generator_id[0], args.id_to_word)
Esempio n. 12
0
def generation(ae_model, sm, test_sentence, label, args):
    for it in range(len(test_sentence)):
        ####################
        #####load data######
        ####################
        batch_encoder_input, batch_decoder_input, batch_decoder_target, batch_encoder_len, \
        batch_decoder_len=pad_batch_sequences(test_sentence, args.id_bos, args.id_eos, \
        args.id_unk, args.max_sequence_length, args.vocab_size)

        tensor_src = get_cuda(
            torch.tensor(batch_encoder_input, dtype=torch.long), args.gpu)
        tensor_tgt_y = get_cuda(
            torch.tensor(batch_decoder_target, dtype=torch.long), args.gpu)
        tensor_src_mask = (tensor_src != 0).unsqueeze(-2)
        tensor_labels = get_cuda(torch.tensor(label, dtype=torch.long),
                                 args.gpu)

        latent = ae_model.getLatent(tensor_src, tensor_src_mask)
        style, similarity = ae_model.getSim(latent)
        sign = 2 * (tensor_labels.long()) - 1
        t_sign = 2 * (1 - tensor_labels.long()) - 1

        trans_emb = style.clone()[torch.arange(style.size(0)),
                                  (1 - tensor_labels).long().item()]
        own_emb = style.clone()[torch.arange(style.size(0)),
                                tensor_labels.long().item()]
        #batch, dim = 1,256
        w = args.weight
        out_1 = ae_model.beam_decode(latent + sign * w * (trans_emb + own_emb),
                                     args.beam_size, args.max_sequence_length,
                                     args.id_bos)

        print("-------------------------------")
        print('original:', sm.DecodeIds(tensor_tgt_y.tolist()[0]))
        print('transferred:', piece2text(out_1[1].tolist(), sm))
        print("-------------------------------")
Esempio n. 13
0
def main(args):
    preparation(args)

    ae_model = get_cuda(
        make_model(d_vocab=args.vocab_size,
                   N=args.num_layers_AE,
                   d_model=args.transformer_model_size,
                   latent_size=args.latent_size,
                   d_ff=args.transformer_ff_size,
                   h=args.n_heads,
                   dropout=args.attention_dropout), args)
    dis_model = get_cuda(
        Classifier(latent_size=args.latent_size, output_size=args.label_size),
        args)

    if args.task == "debias":
        load_db_from_ae_model = False
        if load_db_from_ae_model:
            deb_model = copy.deepcopy(ae_model.encoder)
        else:
            deb_model = get_cuda(
                make_deb(N=args.num_layers_AE,
                         d_model=args.transformer_model_size,
                         d_ff=args.transformer_ff_size,
                         h=args.n_heads,
                         dropout=args.attention_dropout), args)

    if os.path.exists(args.load_from_checkpoint):
        # Load models' params from checkpoint
        add_log(
            args, "Load pretrained weigths, pretrain : ae, dis %s ..." %
            args.load_from_checkpoint)
        try:
            ae_model.load_state_dict(
                torch.load(
                    os.path.join(args.load_from_checkpoint,
                                 'ae_model_params.pkl')))
            dis_model.load_state_dict(
                torch.load(
                    os.path.join(args.load_from_checkpoint,
                                 'dis_model_params.pkl')))
        except FileNotFoundError:
            assert args.task == "debias"
        f1 = os.path.join(args.load_from_checkpoint, 'ae_model_params_deb.pkl')
        f2 = os.path.join(args.load_from_checkpoint,
                          'dis_model_params_deb.pkl')
        if os.path.exists(f1):
            add_log(
                args, "Load pretrained weigths, debias : ae, dis %s ..." %
                args.current_save_path)
            ae_model.load_state_dict(torch.load(f1))
            dis_model.load_state_dict(torch.load(f2))
        f3 = os.path.join(args.load_from_checkpoint,
                          'deb_model_params_deb.pkl')
        if args.task == "debias" and os.path.exists(f1):
            add_log(
                args, "Load pretrained weigths, debias : deb %s ..." %
                args.current_save_path)
            deb_model.load_state_dict(torch.load(f3))

    if not args.eval_only:
        if args.task == "pretrain":
            stats, s_test = pretrain(args, ae_model, dis_model)
        if args.task == "debias":
            sedat_train(args, ae_model, f=dis_model, deb=deb_model)

    if os.path.exists(args.test_data_file):
        if args.task == "pretrain":
            fgim_algorithm(args, ae_model, dis_model)
        if args.task == "debias":
            sedat_eval(args, ae_model, f=dis_model, deb=deb_model)

    print("Done!")
Esempio n. 14
0
def fgim_algorithm(args, ae_model, dis_model):
    batch_size = 1
    test_data_loader = non_pair_data_loader(
        batch_size=batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    file_list = [args.test_data_file]
    test_data_loader.create_batches(args,
                                    file_list,
                                    if_shuffle=False,
                                    n_samples=args.test_n_samples)
    if args.references_files:
        gold_ans = load_human_answer(args.references_files, args.text_column)
        assert len(gold_ans) == test_data_loader.num_batch
    else:
        gold_ans = [[None] * batch_size] * test_data_loader.num_batch

    add_log(args, "Start eval process.")
    ae_model.eval()
    dis_model.eval()

    fgim_our = True
    if fgim_our:
        # for FGIM
        z_prime, text_z_prime = fgim(test_data_loader,
                                     args,
                                     ae_model,
                                     dis_model,
                                     gold_ans=gold_ans)
        write_text_z_in_file(args, text_z_prime)
        add_log(
            args,
            "Saving model modify embedding %s ..." % args.current_save_path)
        torch.save(z_prime,
                   os.path.join(args.current_save_path, 'z_prime_fgim.pkl'))
    else:
        for it in range(test_data_loader.num_batch):
            batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = test_data_loader.next_batch()

            print("------------%d------------" % it)
            print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
            print("origin_labels", tensor_labels)

            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask,
                                           tensor_src_attn_mask,
                                           tensor_tgt_mask)
            generator_text = ae_model.greedy_decode(
                latent, max_len=args.max_sequence_length, start_id=args.id_bos)
            print(id2text_sentence(generator_text[0], args.id_to_word))

            # Define target label
            target = get_cuda(torch.tensor([[1.0]], dtype=torch.float), args)
            if tensor_labels[0].item() > 0.5:
                target = get_cuda(torch.tensor([[0.0]], dtype=torch.float),
                                  args)
            add_log(args, "target_labels : %s" % target)

            modify_text = fgim_attack(dis_model, latent, target, ae_model,
                                      args.max_sequence_length, args.id_bos,
                                      id2text_sentence, args.id_to_word,
                                      gold_ans[it])

            add_output(args, modify_text)
Esempio n. 15
0
 def decode(self, memory, tgt, tgt_mask):
     # memory: (batch_size, 1, d_model)
     src_mask = get_cuda(torch.ones(memory.size(0), 1, 1).long())
     # print("src_mask here", src_mask)
     # print("src_mask", src_mask.size())
     return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
        latents.append(modified_latent)
        labels.append(tensor_labels.item())

    latents = torch.cat(latents, dim=0).detach().cpu().numpy()
    labels = numpy.array(labels)

    tsne_plot_representation(latents, labels)


if __name__ == '__main__':
    preparation()

    ae_model = get_cuda(
        make_model(
            d_vocab=args.vocab_size,
            N=args.num_layers_AE,
            d_model=args.transformer_model_size,
            latent_size=args.latent_size,
            d_ff=args.transformer_ff_size,
        ))
    dis_model = get_cuda(
        Classifier(latent_size=args.latent_size, output_size=args.label_size))

    if args.if_load_from_checkpoint:
        # Load models' params from checkpoint
        ae_model.load_state_dict(
            torch.load(args.current_save_path + 'ae_model_params.pkl'))
        dis_model.load_state_dict(
            torch.load(args.current_save_path + 'dis_model_params.pkl'))
    else:
        train_iters(ae_model, dis_model)
def train_iters(ae_model, dis_model):
    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    train_data_loader.create_batches(args.train_file_list,
                                     args.train_label_list,
                                     if_shuffle=True)
    add_log("Start train process.")
    ae_model.train()
    dis_model.train()

    ae_optimizer = NoamOpt(
        ae_model.src_embed[0].d_model, 1, 2000,
        torch.optim.Adam(ae_model.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    dis_optimizer = torch.optim.Adam(dis_model.parameters(), lr=0.0001)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1))
    dis_criterion = nn.BCELoss(size_average=True)

    for epoch in range(200):
        print('-' * 94)
        epoch_start_time = time.time()
        for it in range(train_data_loader.num_batch):
            batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = train_data_loader.next_batch()

            # For debug
            # print(batch_sentences[0])
            # print(tensor_src[0])
            # print(tensor_src_mask[0])
            # print("tensor_src_mask", tensor_src_mask.size())
            # print(tensor_tgt[0])
            # print(tensor_tgt_y[0])
            # print(tensor_tgt_mask[0])
            # print(batch_ntokens)

            # Forward pass
            latent, out = ae_model.forward(tensor_src, tensor_tgt,
                                           tensor_src_mask, tensor_tgt_mask)
            # print(latent.size())  # (batch_size, max_src_seq, d_model)
            # print(out.size())  # (batch_size, max_tgt_seq, vocab_size)

            # Loss calculation
            loss_rec = ae_criterion(
                out.contiguous().view(-1, out.size(-1)),
                tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

            # loss_all = loss_rec + loss_dis

            ae_optimizer.optimizer.zero_grad()
            loss_rec.backward()
            ae_optimizer.step()

            # Classifier
            dis_lop = dis_model.forward(to_var(latent.clone()))

            loss_dis = dis_criterion(dis_lop, tensor_labels)

            dis_optimizer.zero_grad()
            loss_dis.backward()
            dis_optimizer.step()

            if it % 200 == 0:
                add_log(
                    '| epoch {:3d} | {:5d}/{:5d} batches | rec loss {:5.4f} | dis loss {:5.4f} |'
                    .format(epoch, it, train_data_loader.num_batch, loss_rec,
                            loss_dis))

                print(id2text_sentence(tensor_tgt_y[0], args.id_to_word))
                generator_text = ae_model.greedy_decode(
                    latent,
                    max_len=args.max_sequence_length,
                    start_id=args.id_bos)
                print(id2text_sentence(generator_text[0], args.id_to_word))

        add_log('| end of epoch {:3d} | time: {:5.2f}s |'.format(
            epoch, (time.time() - epoch_start_time)))
        # Save model
        torch.save(ae_model.state_dict(),
                   args.current_save_path + 'ae_model_params.pkl')
        torch.save(dis_model.state_dict(),
                   args.current_save_path + 'dis_model_params.pkl')
    return
    def beam_decode(self, latent, beam_size, max_len, start_id):
        '''
        latent: (batch_size, max_src_seq, d_model)
        src_mask: (batch_size, 1, max_src_len)
        '''
        memory_beam = latent.detach().repeat(beam_size, 1, 1)
        beam = Beam(beam_size=beam_size,
                    min_length=0,
                    n_top=beam_size,
                    ranker=None)
        batch_size = latent.size(0)
        candidate = get_cuda(torch.zeros(beam_size, batch_size, max_len),
                             self.gpu)
        global_scores = get_cuda(torch.zeros(beam_size), self.gpu)

        tmp_cand = get_cuda(torch.zeros(beam_size * beam_size), self.gpu)
        tmp_scores = get_cuda(torch.zeros(beam_size * beam_size), self.gpu)

        ys = get_cuda(
            torch.ones(batch_size, 1).fill_(start_id).long(),
            self.gpu)  # (batch_size, 1)
        candidate[:, :, 0] = ys.clone()
        #first
        out = self.decode(latent.unsqueeze(1), to_var(ys, self.gpu),
                          to_var(subsequent_mask(ys.size(1)).long(), self.gpu))
        prob = self.generator(out[:, -1])
        scores, ids = prob.topk(k=beam_size, dim=1)  #shape:1,baem_size
        global_scores = scores.view(-1)
        candidate[:, :, 1] = ids.transpose(0, 1)
        for i in range(1, max_len - 1):
            for j in range(beam_size):
                #                candidate[j,:,:i+1] = torch.cat([candidate[j,:,:i], ids[j]], dim=-1)
                tmp = candidate[j, :, :i + 1].view(1, -1)
                #tmp_cand:3
                tp, tc = self.recursive_beam(
                    beam_size, latent.unsqueeze(1),
                    to_var(tmp.long(), self.gpu),
                    to_var(subsequent_mask(tmp.size(1)).long(), self.gpu))
                tmp_cand[beam_size * j:beam_size * (j + 1)] = tc.view(-1)
                tmp_scores[beam_size * j:beam_size *
                           (j + 1)] = tp.view(-1) + global_scores[j]
            beam_head_scores, beam_head_ids = tmp_scores.topk(k=beam_size,
                                                              dim=0)
            global_scores = beam_head_scores
            can_list = []
            for bb in range(beam_size):
                can_list.append(
                    torch.cat([
                        candidate[int(beam_head_ids[bb].item() /
                                      beam_size), :, :i + 1].long(),
                        tmp_cand[beam_head_ids[bb]].long().unsqueeze(
                            0).unsqueeze(0)
                    ],
                              dim=1))
#            c2=torch.cat([candidate[int(beam_head_ids[1].item()/beam_size),:,:i+1].long(), tmp_cand[beam_head_ids[1]].long().unsqueeze(0).unsqueeze(0)], dim=1)
#            c3=torch.cat([candidate[int(beam_head_ids[2].item()/3),:,:i+1].long(), tmp_cand[beam_head_ids[2]].long().unsqueeze(0).unsqueeze(0)], dim=1)
            for bb in range(beam_size):
                candidate[bb, :, :i + 2] = can_list[bb]
#            candidate[0,:,:i+2]=c1
#            candidate[1,:,:i+2]=c2
#            candidate[2,:,:i+2]=c3
        top_s, top_i = global_scores.sort()
        candidate = candidate.view(beam_size, -1)
        candidate = candidate[:, 1:]
        sorted_candidate = candidate.clone()
        for bb in range(beam_size):
            sorted_candidate[bb] = candidate[top_i[bb]]
        return sorted_candidate.long().view(beam_size, -1)
Esempio n. 19
0
if __name__ == '__main__':
    preparation()

    train_data_loader = non_pair_data_loader(args.batch_size)
    train_data_loader.create_batches(args.train_file_list,
                                     args.train_label_list,
                                     if_shuffle=True)

    # create models
    ae_model = get_cuda(
        EncoderDecoder(
            vocab_size=args.vocab_size,
            embedding_size=args.embedding_size,
            hidden_size=args.hidden_size,
            num_layers=args.num_layers_AE,
            word_dropout=args.word_dropout,
            embedding_dropout=args.embedding_dropout,
            sos_idx=args.id_bos,
            eos_idx=args.id_eos,
            pad_idx=args.id_pad,
            unk_idx=args.id_unk,
            max_sequence_length=args.max_sequence_length,
            rnn_type=args.rnn_type,
            bidirectional=True,
        ))

    train_iters(ae_model, train_data_loader)

    print("Done!")
Esempio n. 20
0
def sedat_train(args, ae_model, f, deb):
    """
    Input: 
        Original latent representation z : (n_batch, batch_size, seq_length, latent_size)
    Output: 
        An optimal modified latent representation z'
    """
    # TODO : fin a metric to control the evelotuion of training, mainly for deb model
    lambda_ = args.sedat_threshold
    alpha, beta = [float(coef) for coef in args.sedat_alpha_beta.split(",")]
    # only on negative example
    only_on_negative_example = args.sedat_only_on_negative_example
    penalty = args.penalty
    type_penalty = args.type_penalty

    assert penalty in ["lasso", "ridge"]
    assert type_penalty in ["last", "group"]

    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    file_list = [args.train_data_file]
    if os.path.exists(args.val_data_file):
        file_list.append(args.val_data_file)
    train_data_loader.create_batches(args,
                                     file_list,
                                     if_shuffle=True,
                                     n_samples=args.train_n_samples)

    add_log(args, "Start train process.")

    #add_log("Start train process.")
    ae_model.train()
    f.train()
    deb.train()

    ae_optimizer = get_optimizer(parameters=ae_model.parameters(),
                                 s=args.ae_optimizer,
                                 noamopt=args.ae_noamopt)
    dis_optimizer = get_optimizer(parameters=f.parameters(),
                                  s=args.dis_optimizer)
    deb_optimizer = get_optimizer(parameters=deb.parameters(),
                                  s=args.dis_optimizer)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1), args)
    dis_criterion = nn.BCELoss(size_average=True)
    deb_criterion = LossSedat(penalty=penalty)

    stats = []
    for epoch in range(args.max_epochs):
        print('-' * 94)
        epoch_start_time = time.time()

        loss_ae, n_words_ae, xe_loss_ae, n_valid_ae = 0, 0, 0, 0
        loss_clf, total_clf, n_valid_clf = 0, 0, 0
        for it in range(train_data_loader.num_batch):
            _, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_src_attn_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, _ = train_data_loader.next_batch()
            flag = True
            # only on negative example
            if only_on_negative_example:
                negative_examples = ~(tensor_labels.squeeze()
                                      == args.positive_label)
                tensor_labels = tensor_labels[negative_examples].squeeze(
                    0)  # .view(1, -1)
                tensor_src = tensor_src[negative_examples].squeeze(0)
                tensor_src_mask = tensor_src_mask[negative_examples].squeeze(0)
                tensor_src_attn_mask = tensor_src_attn_mask[
                    negative_examples].squeeze(0)
                tensor_tgt_y = tensor_tgt_y[negative_examples].squeeze(0)
                tensor_tgt = tensor_tgt[negative_examples].squeeze(0)
                tensor_tgt_mask = tensor_tgt_mask[negative_examples].squeeze(0)
                flag = negative_examples.any()
            if flag:
                # forward
                z, out, z_list = ae_model.forward(tensor_src,
                                                  tensor_tgt,
                                                  tensor_src_mask,
                                                  tensor_src_attn_mask,
                                                  tensor_tgt_mask,
                                                  return_intermediate=True)
                #y_hat = f.forward(to_var(z.clone()))
                y_hat = f.forward(z)

                loss_dis = dis_criterion(y_hat, tensor_labels)
                dis_optimizer.zero_grad()
                loss_dis.backward(retain_graph=True)
                dis_optimizer.step()

                dis_lop = f.forward(z)
                t_c = tensor_labels.view(-1).size(0)
                n_v = (dis_lop.round().int() == tensor_labels).sum().item()
                loss_clf += loss_dis.item()
                total_clf += t_c
                n_valid_clf += n_v
                clf_acc = 100. * n_v / (t_c + eps)
                avg_clf_acc = 100. * n_valid_clf / (total_clf + eps)
                avg_clf_loss = loss_clf / (it + 1)

                mask_deb = y_hat.squeeze(
                ) >= lambda_ if args.positive_label == 0 else y_hat.squeeze(
                ) < lambda_
                # if f(z) > lambda :
                if mask_deb.any():
                    y_hat_deb = y_hat[mask_deb]
                    if type_penalty == "last":
                        z_deb = z[mask_deb].squeeze(
                            0) if args.batch_size == 1 else z[mask_deb]
                    elif type_penalty == "group":
                        # TODO : unit test for bach_size = 1
                        z_deb = z_list[-1][mask_deb]
                    z_prime, z_prime_list = deb(z_deb,
                                                mask=None,
                                                return_intermediate=True)
                    if type_penalty == "last":
                        z_prime = torch.sum(ae_model.sigmoid(z_prime), dim=1)
                        loss_deb = alpha * deb_criterion(
                            z_deb, z_prime,
                            is_list=False) + beta * y_hat_deb.sum()
                    elif type_penalty == "group":
                        z_deb_list = [z_[mask_deb] for z_ in z_list]
                        #assert len(z_deb_list) == len(z_prime_list)
                        loss_deb = alpha * deb_criterion(
                            z_deb_list, z_prime_list,
                            is_list=True) + beta * y_hat_deb.sum()

                    deb_optimizer.zero_grad()
                    loss_deb.backward(retain_graph=True)
                    deb_optimizer.step()
                else:
                    loss_deb = torch.tensor(float("nan"))

                # else :
                if (~mask_deb).any():
                    out_ = out[~mask_deb]
                    tensor_tgt_y_ = tensor_tgt_y[~mask_deb]
                    tensor_ntokens = (tensor_tgt_y_ != 0).data.sum().float()
                    loss_rec = ae_criterion(
                        out_.contiguous().view(-1, out_.size(-1)),
                        tensor_tgt_y_.contiguous().view(-1)) / (
                            tensor_ntokens.data + eps)
                else:
                    loss_rec = torch.tensor(float("nan"))

                ae_optimizer.zero_grad()
                (loss_dis + loss_deb + loss_rec).backward()
                ae_optimizer.step()

                if True:
                    n_v, n_w = get_n_v_w(tensor_tgt_y, out)
                else:
                    n_w = float("nan")
                    n_v = float("nan")

                x_e = loss_rec.item() * n_w
                loss_ae += loss_rec.item()
                n_words_ae += n_w
                xe_loss_ae += x_e
                n_valid_ae += n_v
                ae_acc = 100. * n_v / (n_w + eps)
                avg_ae_acc = 100. * n_valid_ae / (n_words_ae + eps)
                avg_ae_loss = loss_ae / (it + 1)
                ae_ppl = np.exp(x_e / (n_w + eps))
                avg_ae_ppl = np.exp(xe_loss_ae / (n_words_ae + eps))

                x_e = loss_rec.item() * n_w
                loss_ae += loss_rec.item()
                n_words_ae += n_w
                xe_loss_ae += x_e
                n_valid_ae += n_v

                if it % args.log_interval == 0:
                    add_log(args, "")
                    add_log(
                        args, 'epoch {:3d} | {:5d}/{:5d} batches |'.format(
                            epoch, it, train_data_loader.num_batch))
                    add_log(
                        args,
                        'Train : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
                        .format(ae_acc, loss_rec.item(), ae_ppl, clf_acc,
                                loss_dis.item()))
                    add_log(
                        args,
                        'Train : avg : rec acc {:5.4f} | rec loss {:5.4f} | ppl {:5.4f} |  dis acc {:5.4f} | diss loss {:5.4f} |'
                        .format(avg_ae_acc, avg_ae_loss, avg_ae_ppl,
                                avg_clf_acc, avg_clf_loss))

                    add_log(
                        args, "input : %s" %
                        id2text_sentence(tensor_tgt_y[0], args.id_to_word))
                    generator_text = ae_model.greedy_decode(
                        z,
                        max_len=args.max_sequence_length,
                        start_id=args.id_bos)
                    # batch_sentences
                    add_log(
                        args, "gen : %s" %
                        id2text_sentence(generator_text[0], args.id_to_word))
                    if mask_deb.any():
                        generator_text_prime = ae_model.greedy_decode(
                            z_prime,
                            max_len=args.max_sequence_length,
                            start_id=args.id_bos)

                        add_log(
                            args, "deb : %s" % id2text_sentence(
                                generator_text_prime[0], args.id_to_word))

        s = {}
        L = train_data_loader.num_batch + eps
        s["train_ae_loss"] = loss_ae / L
        s["train_ae_acc"] = 100. * n_valid_ae / (n_words_ae + eps)
        s["train_ae_ppl"] = np.exp(xe_loss_ae / (n_words_ae + eps))
        s["train_clf_loss"] = loss_clf / L
        s["train_clf_acc"] = 100. * n_valid_clf / (total_clf + eps)
        stats.append(s)

        add_log(args, "")
        add_log(
            args, '| end of epoch {:3d} | time: {:5.2f}s |'.format(
                epoch, (time.time() - epoch_start_time)))

        add_log(
            args,
            '| rec acc {:5.4f} | rec loss {:5.4f} | rec ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
            .format(s["train_ae_acc"], s["train_ae_loss"], s["train_ae_ppl"],
                    s["train_clf_acc"], s["train_clf_loss"]))

        # Save model
        torch.save(
            ae_model.state_dict(),
            os.path.join(args.current_save_path, 'ae_model_params_deb.pkl'))
        torch.save(
            f.state_dict(),
            os.path.join(args.current_save_path, 'dis_model_params_deb.pkl'))
        torch.save(
            deb.state_dict(),
            os.path.join(args.current_save_path, 'deb_model_params_deb.pkl'))

    add_log(args, "Saving training statistics %s ..." % args.current_save_path)
    torch.save(stats,
               os.path.join(args.current_save_path, 'stats_train_deb.pkl'))
Esempio n. 21
0
    def forward(self, input_sequence, seq_lengths):
        # input_sequence:(batch_size, xx)
        # seq_lengths: (batch_size)
        batch_size = input_sequence.size(0)
        sorted_lengths, sorted_idx = torch.sort(seq_lengths, descending=True)
        input_sequence = input_sequence[sorted_idx]

        # print(input_sequence.shape)
        # Encoder
        input_embedding = self.embedding(input_sequence)  # (batch_size, length, embedding_size)

        # for debug
        # print(input_embedding.shape)
        # print(sorted_lengths.shape)
        # print(input_embedding.detach().numpy())
        # print(sorted_lengths.cpu().detach().numpy())

        packed_input = rnn_utils.pack_padded_sequence(input_embedding,
                                                      sorted_lengths.data.tolist(), batch_first=True)
        _, hidden = self.encoder_rnn(packed_input)

        if self.bidirectional or self.num_layers > 1:
            # flatten hidden state
            latent = hidden.view(batch_size, self.hidden_size * self.hidden_factor)
        else:
            latent = hidden.squeeze()

        # z = self.hidden2latent(hidden)
        # hidden = self.latent2hidden(z)

        if self.bidirectional or self.num_layers > 1:
            # unflatten hidden state
            hidden = latent.view(self.hidden_factor, batch_size, self.hidden_size)
        else:
            hidden = latent.unsqueeze(0)

        # decoder input
        if self.word_dropout_rate > 0:
            # randomly replace decoder input with <unk>
            dropout_prob_mask = get_cuda(torch.rand(input_sequence.size()))
            # Don't replace the place that has sos or pad.
            dropout_prob_mask[(input_sequence.data - self.sos_idx) * (input_sequence.data - self.pad_idx) == 0] = 1
            decoder_input_sequence = input_sequence.clone()
            decoder_input_sequence[dropout_prob_mask < self.word_dropout_rate] = self.unk_idx
            input_embedding = self.embedding(decoder_input_sequence)
        input_embedding = self.embedding_dropout(input_embedding)
        packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)

        # decoder forward pass
        outputs, _ = self.decoder_rnn(packed_input, hidden)

        # process outputs
        padded_outputs = rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0]
        padded_outputs = padded_outputs.contiguous()
        _, reversed_idx = torch.sort(sorted_idx)
        padded_outputs = padded_outputs[reversed_idx]
        b, s, _ = padded_outputs.size()

        # project outputs to vocab
        logp = nn.functional.log_softmax(self.outputs2vocab(padded_outputs.view(-1, padded_outputs.size(2))), dim=-1)
        logp = logp.view(-1, self.vocab_size)  # [b*len, vocab_size]

        # Restore original order
        latent = latent[reversed_idx]

        return logp, latent
def val(ae_model, dis_model, eval_data_loader, epoch, args):

    print("Transformer Validation process....")
    ae_model.eval()

    print('-' * 94)
    epoch_start = time.time()

    loss_ae = list()
    loss_dis = list()

    acc = list()

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1), args.gpu)
    dis_criterion = torch.nn.BCELoss(size_average=True)
    for it in range(eval_data_loader.num_batch):
        ####################
        #####load data######
        ####################
        batch_sentences, tensor_labels, \
            tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
            tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch()

        latent = ae_model.getLatent(tensor_src, tensor_src_mask)  #(128, 256)
        style, similarity = ae_model.getSim(
            latent)  #style (128, 2, 256), sim(128, 2)
        dis_out = dis_model.forward(similarity)
        one = get_cuda(torch.tensor(1), args.gpu)
        zero = get_cuda(torch.tensor(0), args.gpu)
        style_pred = torch.where(dis_out > 0.5, one, zero)
        style_pred = style_pred.reshape(style_pred.size(0))
        style_emb = get_cuda(style.clone()[torch.arange(style.size(0)),
                                           tensor_labels.squeeze().long()],
                             args.gpu)  #(128, 256)

        add_latent = latent + style_emb  #batch, dim
        out = ae_model.getOutput(add_latent, tensor_tgt, tensor_tgt_mask)
        loss_rec = ae_criterion(
            out.contiguous().view(-1, out.size(-1)),
            tensor_tgt_y.contiguous().view(-1)) / tensor_ntokens.data

        loss_style = dis_criterion(dis_out, tensor_labels)

        pred = style_pred.to('cpu').detach().tolist()
        true = tensor_labels.squeeze().to('cpu').tolist()

        dis_acc = accuracy_score(pred, true)
        acc.append(dis_acc)

        loss_ae.append(loss_rec.item())
        loss_dis.append(loss_style.item())

        if it % 200 == 0:
            print(
                '| epoch {:3d} | {:5d}/{:5d} batches |\n| rec loss {:5.4f} | dis loss {:5.4f} |\n'
                .format(epoch, it, eval_data_loader.num_batch, loss_rec.item(),
                        loss_style.item()))

    print('| end of epoch {:3d} | time: {:5.2f}s |'.format(
        epoch, (time.time() - epoch_start)))

    return np.mean(loss_ae), np.mean(loss_dis), np.mean(acc)
Esempio n. 23
0
        print("-------------------------------")
        print('original:', sm.DecodeIds(tensor_tgt_y.tolist()[0]))
        print('transferred:', piece2text(out_1[1].tolist(), sm))
        print("-------------------------------")


if __name__ == '__main__':
    if not os.path.exists('./generation/{}'.format(args.name)):
        os.makedirs('./generation/{}'.format(args.name))

    preparation(args)

    ae_model = get_cuda(
        make_model(d_vocab=args.vocab_size,
                   N=args.num_layers_AE,
                   d_model=args.transformer_model_size,
                   latent_size=args.latent_size,
                   gpu=args.gpu,
                   d_ff=args.transformer_ff_size), args.gpu)

    ae_model.load_state_dict(torch.load('./model.pkl', map_location=device))
    sm = spm.SentencePieceProcessor()
    sm.Load(args.sm_path + '%s.model' % args.name)
    print('Type the style of original sentence')
    print('Negative:0 Positive:1')
    label = int(input())
    print('Type the sentence wanted to transferred')
    input_sentence = str(input())
    test_sentence = text2piece(input_sentence, sm)
    generation(ae_model, sm, test_sentence, label, args)
Esempio n. 24
0
def pretrain(args, ae_model, dis_model):
    train_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    train_data_loader.create_batches(args, [args.train_data_file],
                                     if_shuffle=True,
                                     n_samples=args.train_n_samples)

    val_data_loader = non_pair_data_loader(
        batch_size=args.batch_size,
        id_bos=args.id_bos,
        id_eos=args.id_eos,
        id_unk=args.id_unk,
        max_sequence_length=args.max_sequence_length,
        vocab_size=args.vocab_size)
    val_data_loader.create_batches(args, [args.val_data_file],
                                   if_shuffle=True,
                                   n_samples=args.valid_n_samples)

    ae_model.train()
    dis_model.train()

    ae_optimizer = get_optimizer(parameters=ae_model.parameters(),
                                 s=args.ae_optimizer,
                                 noamopt=args.ae_noamopt)
    dis_optimizer = get_optimizer(parameters=dis_model.parameters(),
                                  s=args.dis_optimizer)

    ae_criterion = get_cuda(
        LabelSmoothing(size=args.vocab_size,
                       padding_idx=args.id_pad,
                       smoothing=0.1), args)
    dis_criterion = nn.BCELoss(size_average=True)

    possib = [
        "%s_%s" % (i, j) for i, j in itertools.product(
            ["train", "eval"],
            ["ae_loss", "ae_acc", "ae_ppl", "clf_loss", "clf_acc"])
    ]
    stopping_criterion, best_criterion, decrease_counts, decrease_counts_max = settings(
        args, possib)
    metric, biggest = stopping_criterion
    factor = 1 if biggest else -1

    stats = []

    add_log(args, "Start train process.")
    for epoch in range(args.max_epochs):
        print('-' * 94)
        add_log(args, "")
        s_train = train_step(args, train_data_loader, ae_model, dis_model,
                             ae_optimizer, dis_optimizer, ae_criterion,
                             dis_criterion, epoch)
        add_log(args, "")
        s_eval = eval_step(args, val_data_loader, ae_model, dis_model,
                           ae_criterion, dis_criterion)
        scores = {**s_train, **s_eval}
        stats.append(scores)
        add_log(args, "")
        if factor * scores[metric] > factor * best_criterion:
            best_criterion = scores[metric]
            add_log(args, "New best validation score: %f" % best_criterion)
            decrease_counts = 0
            # Save model
            add_log(args, "Saving model to %s ..." % args.current_save_path)
            torch.save(
                ae_model.state_dict(),
                os.path.join(args.current_save_path, 'ae_model_params.pkl'))
            torch.save(
                dis_model.state_dict(),
                os.path.join(args.current_save_path, 'dis_model_params.pkl'))
        else:
            add_log(
                args, "Not a better validation score (%i / %i)." %
                (decrease_counts, decrease_counts_max))
            decrease_counts += 1
        if decrease_counts > decrease_counts_max:
            add_log(
                args,
                "Stopping criterion has been below its best value for more "
                "than %i epochs. Ending the experiment..." %
                decrease_counts_max)
            #exit()
            break

    s_test = None
    if os.path.exists(args.test_data_file):
        add_log(args, "")
        test_data_loader = non_pair_data_loader(
            batch_size=args.batch_size,
            id_bos=args.id_bos,
            id_eos=args.id_eos,
            id_unk=args.id_unk,
            max_sequence_length=args.max_sequence_length,
            vocab_size=args.vocab_size)
        test_data_loader.create_batches(args, [args.test_data_file],
                                        if_shuffle=True,
                                        n_samples=args.test_n_samples)
        s = eval_step(args, test_data_loader, ae_model, dis_model,
                      ae_criterion, dis_criterion)
        add_log(
            args,
            'Test | rec acc {:5.4f} | rec loss {:5.4f} | rec ppl {:5.4f} | dis acc {:5.4f} | dis loss {:5.4f} |'
            .format(s["eval_ae_acc"], s["eval_ae_loss"], s["eval_ae_ppl"],
                    s["eval_clf_acc"], s["eval_clf_loss"]))
        s_test = s
    add_log(args, "")
    add_log(args, "Saving training statistics %s ..." % args.current_save_path)
    torch.save(stats,
               os.path.join(args.current_save_path, 'stats_train_eval.pkl'))
    if s_test is not None:
        torch.save(s_test, os.path.join(args.current_save_path,
                                        'stat_test.pkl'))
    return stats, s_test