Example #1
0
def test(params, infos):
    from seq_to_seq import Seq2Seq
    from data_utils import batch_flow

    x_data, _ = pickle.load(open('chatbot.pkl', 'rb'))
    ws = pickle.load(open('wx.pkl', 'rb'))

    for x in x_data[:5]:
        print(' '.join(x))

    config = tf.ConfigProto(device_count={
        'CPU': 1,
        'GPU': 0
    },
                            allow_soft_placement=True,
                            log_device_placement=False)
    ##使用抗语言模型进行训练,train_anti.py
    save_path = './model3/s2ss_chatbot_anti.ckpt'

    tf.reset_default_graph()
    model_pred = Seq2Seq(input_vocab_size=len(ws),
                         target_vocab_size=len(ws),
                         batch_size=1,
                         mode='decode',
                         beam_width=0,
                         **params)

    init = tf.global_variables_initializer()

    with tf.Session(config=config) as sess:
        sess.run(init)
        model_pred.load(sess, save_path)

        while True:
            #user_text = input('请输入你的句子:')
            #if user_text in ('exit', 'quit'):
            #    exit(0)
            x_test = [list(infos.lower())]
            bar = batch_flow([x_test], ws, 1)
            x, xl = next(bar)
            x = np.flip(x, axis=1)
            #print(x, xl)

            pred = model_pred.predict(sess, np.array(x), np.array(xl))
            #print(pred)
            #print(ws.inverse_transform(x[0]))

            total_word = ""
            for p in pred:
                ans = ws.inverse_transform(p)
                if ans[0] == '</S>':
                    #print(ans)
                    total_word += ""
                else:
                    total_word += ans[0]
                    print(ans)
            return total_word
Example #2
0
def test(params):
    from seq_to_seq import Seq2Seq
    from data_utils import batch_flow_bucket as batch_flow
    from thread_generator import ThreadedGenerator

    x_data, y_data = pickle.load(open('chatbot.pkl', 'rb'))
    ws = pickle.load(open('wx.pkl', 'rb'))

    # 训练
    # n_epoch是轮数(越大越容易过拟合)
    n_epoch = 1
    batch_size = 128

    steps = int(len(x_data) / batch_size) + 1

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)

    save_path = './model/s2ss_chatbot.ckpt'

    tf.reset_default_graph()

    with tf.Graph().as_default():
        random.seed(0)
        np.random.seed(0)
        tf.set_random_seed(0)

        with tf.Session(config=config) as sess:
            # 定义模型
            model = Seq2Seq(input_vocab_size=len(ws),
                            target_vocab_size=len(ws),
                            batch_size=batch_size,
                            **params)

            init = tf.global_variables_initializer()
            sess.run(init)

            flow = ThreadedGenerator(batch_flow([x_data, y_data],
                                                ws,
                                                batch_size,
                                                add_end=[False, True]),
                                     queue_maxsize=30)

            for epoch in range(1, n_epoch + 1):
                costs = []
                bar = tqdm(range(steps),
                           total=steps,
                           desc='epoch {}, loss=0.000000'.format(epoch))
                for _ in bar:
                    x, xl, y, yl = next(flow)
                    # [[1,2],[3,4]]
                    # [[3,4],[1,2]]
                    # 按axis=1反转
                    x = np.flip(x, axis=1)
                    cost, lr = model.train(sess, x, xl, y, yl, return_lr=True)
                    costs.append(cost)
                    bar.set_description(
                        'epoch {} loss={:.6f} lr={:.6f}'.format(
                            epoch, np.mean(costs), lr))

            model.save(sess, save_path=save_path)

    # Testing
    tf.reset_default_graph()
    model_pred = Seq2Seq(input_vocab_size=len(ws),
                         target_vocab_size=len(ws),
                         batch_size=1,
                         mode='decode',
                         beam_width=12,
                         parallel_iterations=1,
                         **params)
    init = tf.global_variables_initializer()
    with tf.Session(config=config) as sess:
        sess.run(init)
        model_pred.load(sess, save_path)

        bar = batch_flow([x_data, y_data], ws, 1, add_end=False)
        t = 0
        for x, xl, y, yl in bar:
            x = np.flip(x, axis=1)
            pred = model_pred.predict(sess, np.array(x), np.array(xl))
            print(ws.inverse_transform(x[0]))
            print(ws.inverse_transform(y[0]))
            print(ws.inverse_transform(pred[0]))
            t += 1
            if t >= 3:
                break