示例#1
0
def inference():

    word_2_id = get_word_to_id(config.VOCAB_PATH)
    id_2_word = {k: v for (v, k) in word_2_id.items()}
    initializer = tf.random_uniform_initializer(-0.05, 0.05)
    model = Seq2seq(w2i_target=word_2_id, initializer=initializer)
    source_batch, source_lens, target_batch, target_lens = make_batches()

    print_every = 100
    batches = 10400

    with tf.Session() as sess:
        tf.summary.FileWriter('graph', sess.graph)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())

        losses = []
        total_loss = 0
        try:
            for batch in range(batches):

                feed_dict = {
                    model.seq_inputs: source_batch[batch],
                    model.seq_inputs_length: source_lens[batch],
                    model.seq_targets: target_batch[batch],
                    model.seq_targets_length: target_lens[batch]
                }

                loss, _ = sess.run([model.loss, model.train_op],
                                   feed_dict=feed_dict)
                total_loss += loss

                if batch % print_every == 0 and batch > 0:
                    print_loss = total_loss if batch == 0 else total_loss / print_every
                    losses.append(print_loss)
                    total_loss = 0
                    print("-----------------------------")
                    print("batch:", batch, "/", batches)
                    print(
                        "time:",
                        time.strftime('%Y-%m-%d %H:%M:%S',
                                      time.localtime(time.time())))
                    print("loss:", print_loss)

            print(losses)
            print(saver.save(sess, "checkpoint/model.ckpt"))
        except Exception as e:

            print(source_batch[batch], np.shape(source_batch[batch]))
            print(source_lens[batch], np.shape(source_lens[batch]))
示例#2
0

if __name__ == "__main__":

    print("(1)load data......")
    docs_source, docs_target = load_data("")
    w2i_source, i2w_source = make_vocab(docs_source)
    w2i_target, i2w_target = make_vocab(docs_target)

    print("(2) build model......")
    config = Config()
    config.source_vocab_size = len(w2i_source)
    config.target_vocab_size = len(w2i_target)
    model = Seq2seq(config=config,
                    w2i_target=w2i_target,
                    useTeacherForcing=True,
                    useAttention=True,
                    useBeamSearch=1)

    print("(3) run model......")
    batches = 3000
    print_every = 100

    with tf.Session(config=tf_config) as sess:
        tf.summary.FileWriter('graph', sess.graph)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())

        losses = []
        total_loss = 0
        for batch in range(batches):
示例#3
0
if __name__ == "__main__":

    print("(1)load data......")
    docs_source, docs_target = load_data(
        'seq2seq_dataset/train_data_question_pair')
    eval_doc_source, eval_doc_target = load_data(
        'seq2seq_dataset/dev_data_question_pair')

    print("(2) build model......")
    config = Config()
    config.source_vocab_size = len(tokenizer.vocab)
    config.target_vocab_size = 10000  #len(tokenizer.vocab)

    model = Seq2seq(config=config,
                    tokenizer=tokenizer,
                    useTeacherForcing=True,
                    useAttention=True,
                    useBeamSearch=1)

    print("(3) run model......")
    print_every = 200
    epoches = 3
    train_data_size = len(docs_target)
    batch_size = 32
    batches = train_data_size // batch_size

    acc = 0.0
    with tf.Session(config=tf_config) as sess:
        tf.summary.FileWriter('graph', sess.graph)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
示例#4
0
    percent = 0.8  # 80% 数据用来训练

    config = Config()

    print("(1)load data......")
    f = open(input_name, 'rb').read()
    data = f.decode().split('\n')
    data = data[:-1]
    sequence = []
    for str in data:
        sequence.append(list(map(float, str.split(','))))
    x_train, y_train = load_train_data(sequence, config.input_length,
                                       config.predict_length, percent)

    print("(2) build model......")
    model = Seq2seq(config=config, useTeacherForcing=True, useAttention=True)

    print("(3) run model......")
    batches = 3000
    print_every = 500

    with tf.Session(config=tf_config) as sess:
        tf.summary.FileWriter('graph', sess.graph)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())

        losses = []
        total_loss = 0
        for batch in range(batches):
            source_batch, source_lens, target_batch, target_time, target_lens = get_batch(
                x_train, y_train, config.batch_size)