示例#1
0
文件: train.py 项目: Choitsugun/HSATA
                self.lr = noam_scheme(hp.lr, self.global_step, hp.warmup_steps)
                self.optimizer = tf.train.AdamOptimizer(self.lr)
                self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step)
            else:
                # inference
                self.prob_c = tf.nn.softmax(self.logits_c)  # (N, T_q, vocab_size)
                self.prob_t = tf.nn.softmax(self.logits_t)  # (N, T_q, tw_vocab_size)
                self.prob_t = tf.einsum('nlt,tv->nlv', self.prob_t, self.tw_vocab_overlap)  # (N, T_q, vocab_size)
                self.prob = self.prob_c + self.prob_t * hp.penalty # (N, T_q, vocab_size)
                self.preds = tf.to_int32(tf.argmax(self.prob, axis=-1))  # (N, T_q)



if __name__ == '__main__':
    # Load vocabulary
    token2idx, idx2token = load_de_en_vocab()
    tw2idx, idx2tw = load_tw_vocab()
    token2idx_len = len(token2idx)
    tw2idx_len = len(tw2idx)

    X, X_length, Y, YTWD, Y_DI, TW, num_batch = get_batch_data()

    # Construct graph
    g = Graph(True, token2idx_len, tw2idx_len, None)
    print("Graph loaded")

    # Start session
    sv = tf.train.Supervisor(graph=g.graph, 
                             logdir=hp.logdir,
                             save_model_secs=0)
示例#2
0
def eval():
    # Load vocabulary
    token2idx, idx2token = load_de_en_vocab()
    tw2idx, idx2tw = load_tw_vocab()
    token2idx_len = len(token2idx)
    tw2idx_len = len(tw2idx)

    # Load vocab_overlap
    token_idx_list = []
    con_list = np.zeros([4, token2idx_len],dtype='float32')
    for i in range(4, tw2idx_len):
        tw = idx2tw[i]
        token_idx_list.append(token2idx[tw])

    vocab_overlap = np.append(con_list, np.eye(token2idx_len, dtype='float32')[token_idx_list], axis=0)

    # Load graph
    g = Graph(False, token2idx_len, tw2idx_len, vocab_overlap)
    print("Graph loaded")
    
    # Load data
    X, X_length, Y, TW, Sources, Targets = load_test_data()
     
    # Start session         
    with g.graph.as_default():    
        sv = tf.train.Supervisor()
        with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            ## Restore parameters
            sv.saver.restore(sess, tf.train.latest_checkpoint(hp.logdir))
            print("Restored!")
              
            ## Get model name
            mname = open(hp.logdir + '/checkpoint', 'r').read().split('"')[1] # model name
            #fftmp=open("tmp.txt","w")

            ## Inference
            if not os.path.exists('results'): os.mkdir('results')

            with codecs.open("results/" + mname, "w", "utf-8") as fout:
                list_of_refs, hypotheses = [], []
                for i in range(len(X) // hp.batch_size):
                    ### Get mini-batches
                    x =              X[i * hp.batch_size: (i + 1) * hp.batch_size]
                    x_length= X_length[i * hp.batch_size: (i + 1) * hp.batch_size]
                    y =              Y[i * hp.batch_size: (i + 1) * hp.batch_size]
                    y_tw =          TW[i * hp.batch_size: (i + 1) * hp.batch_size]
                    sources =  Sources[i * hp.batch_size: (i + 1) * hp.batch_size]
                    targets =  Targets[i * hp.batch_size: (i + 1) * hp.batch_size]
                    #fftmp.write("%s\n"%(" ".join(str(w) for w in x[0][0]).encode("utf-8")))
                    #fftmp.write("%s\n"%(sources[0].encode("utf-8")))
                    #fftmp.write("%s\n"%(' '.join(str(w) for w in x_length)))
                    #print (sources)
                    #print (targets) 
                    ### Autoregressive inference
                    preds = np.zeros((hp.batch_size, hp.maxlen), np.int32)
                    ppls = np.zeros((hp.batch_size, hp.maxlen), np.int32)
                    for j in range(hp.maxlen):
                        _preds, ppl_step, att_w, att_u, att_v = sess.run([g.preds, g.ppl_step, g.att_w, g.att_u, g.att_v],
                                                    {g.x:x, g.x_length:x_length, g.y:y, g.y_tw:y_tw, g.y_decoder_input:preds})
                        preds[:, j] = _preds[:, j]
                        ppls[:, j] = ppl_step[:, j]
                     
                    ### Write to file
                    for source, target, pred, ppl in zip(sources, targets, preds, ppls): # sentence-wise
                        got = " ".join(idx2token[idx] for idx in pred).split("</S>")[0].strip()
                        ppl_score = " ".join('%s' %score for score in ppl).strip()
                        fout.write("- source: " + source +"\n")
                        fout.write("- expected: " + target + "\n")
                        fout.write("- got: " + got + "\n\n")
                        fout.write("- ppl_score: " + ppl_score + "\n\n")
                        fout.flush()
                          
                        # bleu score
                        ref = target.split()
                        hypothesis = got.split()
                        if len(ref) > 3 and len(hypothesis) > 3:
                            list_of_refs.append([ref])
                            hypotheses.append(hypothesis)

                    ## Calculate attention
                    #fout.write("- att_w: " + str(att_w) + "\n")
                    #fout.write("- att_u: " + str(att_u) + "\n")
                    #fout.write("- att_v: " + str(att_v) + "\n")
                    #fout.flush()

                ## Calculate bleu score
                score = corpus_bleu(list_of_refs, hypotheses)
                fout.write("Bleu Score = " + str(100*score))
示例#3
0
文件: eval.py 项目: Choitsugun/HSATA
def eval():
    # Load vocabulary
    token2idx, idx2token = load_de_en_vocab()
    tw2idx, idx2tw = load_tw_vocab()
    token2idx_len = len(token2idx)
    tw2idx_len = len(tw2idx)

    # Load vocab_overlap
    token_idx_list = []
    con_list = np.zeros([4, token2idx_len], dtype='float32')
    for i in range(4, tw2idx_len):
        tw = idx2tw[i]
        token_idx_list.append(token2idx[tw])

    vocab_overlap = np.append(con_list,
                              np.eye(token2idx_len,
                                     dtype='float32')[token_idx_list],
                              axis=0)

    # Load graph
    g = Graph(False, token2idx_len, tw2idx_len, vocab_overlap)
    print("Graph loaded")

    # Load data
    X, X_length, Y, TW, Sources, Targets = load_test_data()

    # Start session
    with g.graph.as_default():
        sv = tf.train.Supervisor()
        with sv.managed_session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            ## Restore parameters
            sv.saver.restore(sess, tf.train.latest_checkpoint(hp.logdir))
            print("Restored!")

            ## Get model name
            mname = open(hp.logdir + '/checkpoint',
                         'r').read().split('"')[1]  # model name

            ## Inference
            if not os.path.exists('results'): os.mkdir('results')

            with codecs.open("results/" + mname, "w", "utf-8") as fout:
                args = parse_args()
                if args.a:
                    att_f = codecs.open("results/" + "attention_vis", "w",
                                        "utf-8")

                for i in range(len(X) // hp.batch_size):
                    ### Get mini-batches
                    x = X[i * hp.batch_size:(i + 1) * hp.batch_size]
                    x_length = X_length[i * hp.batch_size:(i + 1) *
                                        hp.batch_size]
                    y = Y[i * hp.batch_size:(i + 1) * hp.batch_size]
                    y_tw = TW[i * hp.batch_size:(i + 1) * hp.batch_size]
                    sources = Sources[i * hp.batch_size:(i + 1) *
                                      hp.batch_size]
                    targets = Targets[i * hp.batch_size:(i + 1) *
                                      hp.batch_size]

                    preds = np.zeros((hp.batch_size, hp.maxlen), np.int32)
                    ppls = np.zeros((hp.batch_size, hp.maxlen), np.int32)
                    for j in range(hp.maxlen):
                        _preds, ppl_step, att_ws, att_us, att_v = sess.run(
                            [g.preds, g.ppl, g.att_w, g.att_u, g.att_v], {
                                g.x: x,
                                g.x_length: x_length,
                                g.y: y,
                                g.y_tw: y_tw,
                                g.y_decoder_input: preds
                            })
                        preds[:, j] = _preds[:, j]
                        ppls[:, j] = ppl_step[:, j]

                    if args.a:
                        att_ws = np.mean(np.split(att_ws, hp.num_heads,
                                                  axis=0),
                                         axis=0)  # (N, L, L)
                        att_us = np.mean(np.split(att_us, hp.num_heads,
                                                  axis=0),
                                         axis=0)  # (N, T, T)
                        att_ws = np.reshape(
                            att_ws,
                            [hp.batch_size, hp.max_turn, hp.maxlen, hp.maxlen])
                        att_ws = np.mean(att_ws, axis=2)  # N, T, L
                        att_ws = np.reshape(
                            att_ws, [hp.batch_size, hp.max_turn * hp.maxlen])
                        att_us = np.sum(att_us, axis=1)  # N, T

                    ### Write to file
                    for source, target, pred, ppl, att_w, att_u in zip(
                            sources, targets, preds, ppls, att_ws,
                            att_us):  # sentence-wise
                        got = " ".join(
                            idx2token[idx]
                            for idx in pred).split("</S>")[0].strip()
                        if len(got.split()) > hp.gener_maxlen:
                            pred = pred.tolist()
                            pred_final = list(set(pred))
                            pred_final.sort(key=pred.index)
                        else:
                            pred_final = pred

                        got = " ".join(
                            idx2token[idx]
                            for idx in pred_final).split("</S>")[0].strip()
                        fout.write("- source: " + source + "\n")
                        fout.write("- expected: " + target + "\n")
                        fout.write("- got: " + got + "\n\n")
                        fout.write("- ppl_score: " +
                                   " ".join('%s' % np.mean(ppl)) + "\n\n")
                        if args.a:
                            att_f.write("- att_w: " + str(att_w) + "\n")
                            att_f.write("- att_u: " + str(att_u) +
                                        "\n\n\n\n\n")
                        fout.flush()
                        if args.a:
                            att_f.flush()
示例#4
0
import os
import time
import codecs
import argparse
import numpy as np
from tqdm import tqdm
from modules import *
import tensorflow as tf
from hyperparams import Hyperparams as hp
from data_load import load_de_en_vocab, load_train_data, load_test_data

parser = argparse.ArgumentParser()
parser.add_argument("mode", help="train or eval")
args = parser.parse_args()

en2idx, idx2en = load_de_en_vocab('processed-data/en.vocab.tsv')
de2idx, idx2de = load_de_en_vocab('processed-data/zh.vocab.tsv')
print("读取en,zh字典")

# load train data
en_npy_path = "./processed-data/train_en.npy"
zh_npy_path = "./processed-data/train_zh.npy"
if os.path.exists(en_npy_path) and os.path.exists(zh_npy_path):
    print("load training data")
    X = np.load(en_npy_path)
    Y = np.load(zh_npy_path)
else:
    X, Y = load_train_data(de2idx, en2idx)
    np.save(en_npy_path, X)
    np.save(zh_npy_path, Y)