def test(): model_path = os.path.join(MODEL_PATH, FLAGS.DEMO, 'checkpoints') ckpt_file = tf.train.latest_checkpoint(model_path) logging.info("load model from {}".format(ckpt_file)) textCNN = TextAttRNN(config=cfg(), model_path=ckpt_file, vocab=word2int, tag2label=tag2label, eopches=FLAGS.epoches) saver = tf.compat.v1.train.Saver() with tf.compat.v1.Session(config=cfg()) as sess: print('============= demo =============') saver.restore(sess, ckpt_file) inps = ['卡布奇诺瑞纳冰已经胜过我爱的星爸爸', '非常好,店长给我耐心的推荐不同口味,skr', '比app通知的时间晚太多做好', '可以,还可以更好!', '巧克力有点腻', '太冰了没有加热', '商品未送到提前点击已送达,且无电话通知。'] results = textCNN.predict(sess, inps) probs = textCNN.predict_prob(sess, inps) for inp, r, prob in zip(inps, results, probs): print("\n{}".format(inp)) for idx, p in enumerate(prob): print("\t{} -> {}".format(int2tag[idx], p)) print("\tTag: {}".format(int2tag[r]))
def demo(): model_path = os.path.join(MODEL_PATH, FLAGS.DEMO, 'checkpoints') ckpt_file = tf.train.latest_checkpoint(model_path) logging.info("load model from {}".format(ckpt_file)) textCNN = TextAttRNN(config=cfg(), model_path=ckpt_file, vocab=word2int, tag2label=tag2label, eopches=FLAGS.epoches) saver = tf.compat.v1.train.Saver() with tf.compat.v1.Session(config=cfg()) as sess: print('============= demo =============') saver.restore(sess, ckpt_file) while True: print('Please input your sentence:') inp = input() if inp == '' or inp.isspace(): print('See you next time!') break else: inps = [inp.strip()] pred = textCNN.predict(sess, inps)[0] probs = textCNN.predict_prob(sess, inps)[0] print("\n{}".format(inps)) for idx, prob in enumerate(probs): print("\t{} -> {}".format(int2tag[idx], prob)) print("\tTag: {}".format(int2tag[pred]))
def train(): train, dev = read_corpus(filename='emergency_train.tsv') textAttRNN = TextAttRNN(config=cfg(), model_path=os.path.join(MODEL_PATH, FLAGS.DEMO), vocab=word2int, tag2label=tag2label, batch_size=FLAGS.batch_size, embed_size=FLAGS.embed_size, sequence_length=FLAGS.sequence_length, eopches=FLAGS.epoches) with tf.compat.v1.Session(config=cfg()) as sess: textAttRNN.train(sess, train, dev, shuffle=True)
def train(): iter = -1 iter_size = 20000 train, dev = read_corpus(random_state=1234, separator='\t', iter=iter, iter_size=iter_size) textCNN = TextAttRNN(config=cfg(), model_path=os.path.join(MODEL_PATH, FLAGS.DEMO), vocab=word2int, tag2label=tag2label, batch_size=FLAGS.batch_size, embed_size=FLAGS.embed_size, eopches=FLAGS.epoches) with tf.compat.v1.Session(config=cfg()) as sess: textCNN.train(sess, train, dev, shuffle=True)
def re_train(): train, dev = read_corpus(filename='emergency_train.tsv', test_size=0.2) model_path = os.path.join(MODEL_PATH, FLAGS.DEMO, 'checkpoints') ckpt_file = tf.train.latest_checkpoint(model_path) logging.info("load pre-train model from {}".format(ckpt_file)) textAttRNN = TextAttRNN( config=cfg(), model_path=ckpt_file, vocab=word2int, tag2label=tag2label, batch_size=FLAGS.batch_size, embed_size=FLAGS.embed_size, sequence_length=FLAGS.sequence_length, eopches=FLAGS.epoches, ) saver = tf.compat.v1.train.Saver() with tf.compat.v1.Session(config=cfg()) as sess: saver.restore(sess, ckpt_file) textAttRNN.set_model_path( model_path=os.path.join(MODEL_PATH, FLAGS.DEMO)) textAttRNN.train(sess, train, dev, shuffle=True, re_train=True)