Ejemplo n.º 1
0
def test():
    model_path = os.path.join(MODEL_PATH, FLAGS.DEMO, 'checkpoints')
    ckpt_file = tf.train.latest_checkpoint(model_path)
    logging.info("load model from {}".format(ckpt_file))

    textCNN = TextAttRNN(config=cfg(),
                         model_path=ckpt_file,
                         vocab=word2int,
                         tag2label=tag2label,
                         eopches=FLAGS.epoches)

    saver = tf.compat.v1.train.Saver()
    with tf.compat.v1.Session(config=cfg()) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)

        inps = ['卡布奇诺瑞纳冰已经胜过我爱的星爸爸',
                '非常好,店长给我耐心的推荐不同口味,skr', '比app通知的时间晚太多做好',
                '可以,还可以更好!', '巧克力有点腻', '太冰了没有加热', '商品未送到提前点击已送达,且无电话通知。']
        results = textCNN.predict(sess, inps)
        probs = textCNN.predict_prob(sess, inps)
        for inp, r, prob in zip(inps, results, probs):
            print("\n{}".format(inp))
            for idx, p in enumerate(prob):
                print("\t{} -> {}".format(int2tag[idx], p))
            print("\tTag: {}".format(int2tag[r]))
Ejemplo n.º 2
0
def demo():
    model_path = os.path.join(MODEL_PATH, FLAGS.DEMO, 'checkpoints')
    ckpt_file = tf.train.latest_checkpoint(model_path)
    logging.info("load model from {}".format(ckpt_file))

    textCNN = TextAttRNN(config=cfg(),
                         model_path=ckpt_file,
                         vocab=word2int,
                         tag2label=tag2label,
                         eopches=FLAGS.epoches)

    saver = tf.compat.v1.train.Saver()
    with tf.compat.v1.Session(config=cfg()) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while True:
            print('Please input your sentence:')
            inp = input()
            if inp == '' or inp.isspace():
                print('See you next time!')
                break
            else:
                inps = [inp.strip()]
                pred = textCNN.predict(sess, inps)[0]
                probs = textCNN.predict_prob(sess, inps)[0]

                print("\n{}".format(inps))
                for idx, prob in enumerate(probs):
                    print("\t{} -> {}".format(int2tag[idx], prob))
                print("\tTag: {}".format(int2tag[pred]))
Ejemplo n.º 3
0
def train():
    train, dev = read_corpus(filename='emergency_train.tsv')
    textAttRNN = TextAttRNN(config=cfg(),
                            model_path=os.path.join(MODEL_PATH, FLAGS.DEMO),
                            vocab=word2int,
                            tag2label=tag2label,
                            batch_size=FLAGS.batch_size,
                            embed_size=FLAGS.embed_size,
                            sequence_length=FLAGS.sequence_length,
                            eopches=FLAGS.epoches)

    with tf.compat.v1.Session(config=cfg()) as sess:
        textAttRNN.train(sess, train, dev, shuffle=True)
Ejemplo n.º 4
0
def train():
    iter = -1
    iter_size = 20000
    train, dev = read_corpus(random_state=1234, separator='\t', iter=iter, iter_size=iter_size)
    textCNN = TextAttRNN(config=cfg(),
                         model_path=os.path.join(MODEL_PATH, FLAGS.DEMO),
                         vocab=word2int,
                         tag2label=tag2label,
                         batch_size=FLAGS.batch_size,
                         embed_size=FLAGS.embed_size,
                         eopches=FLAGS.epoches)

    with tf.compat.v1.Session(config=cfg()) as sess:
        textCNN.train(sess, train, dev, shuffle=True)
Ejemplo n.º 5
0
def re_train():
    train, dev = read_corpus(filename='emergency_train.tsv', test_size=0.2)

    model_path = os.path.join(MODEL_PATH, FLAGS.DEMO, 'checkpoints')
    ckpt_file = tf.train.latest_checkpoint(model_path)

    logging.info("load pre-train model from {}".format(ckpt_file))
    textAttRNN = TextAttRNN(
        config=cfg(),
        model_path=ckpt_file,
        vocab=word2int,
        tag2label=tag2label,
        batch_size=FLAGS.batch_size,
        embed_size=FLAGS.embed_size,
        sequence_length=FLAGS.sequence_length,
        eopches=FLAGS.epoches,
    )

    saver = tf.compat.v1.train.Saver()

    with tf.compat.v1.Session(config=cfg()) as sess:
        saver.restore(sess, ckpt_file)
        textAttRNN.set_model_path(
            model_path=os.path.join(MODEL_PATH, FLAGS.DEMO))
        textAttRNN.train(sess, train, dev, shuffle=True, re_train=True)