x_1 = tf.placeholder(tf.int32, [1, 20], name='input') x_2 = tf.placeholder(tf.int32, (), name='input') x_3 = tf.placeholder(tf.string, (), name='inptut') y_1 = tf.placeholder(tf.int32, [1, 1], name='output') y_2 = tf.placeholder(tf.int32, [1, 19], name='output') y_3 = tf.placeholder(tf.int32, (), name='output') y_4 = tf.placeholder(tf.string, (), name='output') x_input = (x_1, x_2, x_3) y_input = (y_1, y_2, y_3, y_4) sess = tf.Session() m = Transformer(hp) y_hat = m.infer(x_input, y_input) new_saver = tf.train.Saver() new_saver.restore(sess, tf.train.latest_checkpoint('./model')) def generate_input(query): query_id = [] for word in jieba.cut(query): query_id.append(word2idx.get(word, 1)) query_id.append(word2idx.get('<S>')) if len(query_id) >= hp.maxlen: query_id = query_id[:20] else: query_id = pad(query_id, hp.maxlen, vocab_path) query_input = [query_id]
from model import Transformer from hparams import Hparams import logging logging.basicConfig(level=logging.INFO) logging.info("# hparams") hparams = Hparams() parser = hparams.parser hp = parser.parse_args() # load_hparams(hp, hp.ckpt) input_tokens = tf.placeholder(tf.int32, shape=(1, None)) xs = (input_tokens, None, None) logging.info("# Load model") m = Transformer(hp) y_hat = m.infer(xs) logging.info("# Session") sess = tf.Session() # ckpt_ = tf.train.latest_checkpoint(hp.ckpt) # ckpt = hp.ckpt if ckpt_ is None else ckpt_ # None: ckpt is a file. otherwise dir. saver = tf.train.Saver() saver.restore(sess, './data/translation_model.ckpt') def divide_long(text): return [i + '。' for i in text.split('。')] def trans(text): # x = encode(text, "x", m.token2idx)