예제 #1
0
x_1 = tf.placeholder(tf.int32, [1, 20], name='input')
x_2 = tf.placeholder(tf.int32, (), name='input')
x_3 = tf.placeholder(tf.string, (), name='inptut')

y_1 = tf.placeholder(tf.int32, [1, 1], name='output')
y_2 = tf.placeholder(tf.int32, [1, 19], name='output')
y_3 = tf.placeholder(tf.int32, (), name='output')
y_4 = tf.placeholder(tf.string, (), name='output')

x_input = (x_1, x_2, x_3)
y_input = (y_1, y_2, y_3, y_4)

sess = tf.Session()
m = Transformer(hp)
y_hat = m.infer(x_input, y_input)

new_saver = tf.train.Saver()
new_saver.restore(sess, tf.train.latest_checkpoint('./model'))


def generate_input(query):
    query_id = []
    for word in jieba.cut(query):
        query_id.append(word2idx.get(word, 1))
    query_id.append(word2idx.get('<S>'))
    if len(query_id) >= hp.maxlen:
        query_id = query_id[:20]
    else:
        query_id = pad(query_id, hp.maxlen, vocab_path)
    query_input = [query_id]
from model import Transformer
from hparams import Hparams
import logging

logging.basicConfig(level=logging.INFO)
logging.info("# hparams")
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()
# load_hparams(hp, hp.ckpt)

input_tokens = tf.placeholder(tf.int32, shape=(1, None))
xs = (input_tokens, None, None)
logging.info("# Load model")
m = Transformer(hp)
y_hat = m.infer(xs)

logging.info("# Session")
sess = tf.Session()
# ckpt_ = tf.train.latest_checkpoint(hp.ckpt)
# ckpt = hp.ckpt if ckpt_ is None else ckpt_ # None: ckpt is a file. otherwise dir.
saver = tf.train.Saver()
saver.restore(sess, './data/translation_model.ckpt')


def divide_long(text):
    return [i + '。' for i in text.split('。')]


def trans(text):
    # x = encode(text, "x", m.token2idx)