コード例 #1
0
def main(args):
    """Main train driver."""
    print("Start Training")

    parameters, copy_player = setupconfig(args)

    # For Training
    train_data, train_lang = loaddata(file_loc,
                                      'train',
                                      copy_player=copy_player)
    if MAX_TRAIN_NUM is not None:
        train_data = train_data[:MAX_TRAIN_NUM]

    train_data = data2index(train_data,
                            train_lang,
                            max_sentences=parameters['max_sentence'])

    encoder, decoder = train(train_data, train_lang, **parameters)

    # For evaluation
    valid_data, _ = loaddata(file_loc, 'valid', copy_player=copy_player)

    valid_data = data2index(valid_data, train_lang)
    evaluate(encoder, decoder, valid_data, train_lang['summary'],
             parameters['embedding_size'])
コード例 #2
0
def main():
    # Display Configuration
    showconfig()

    # Default parameter
    embedding_size = EMBEDDING_SIZE
    learning_rate = LR
    train_iter_time = ITER_TIME
    batch_size = BATCH_SIZE

    # For Training
    train_data, train_lang = loaddata(file_loc, 'train')
    train_data = data2index(train_data, train_lang)
    encoder, decoder = train(train_data,
                             train_lang,
                             embedding_size=embedding_size,
                             learning_rate=learning_rate,
                             iter_time=train_iter_time,
                             batch_size=batch_size)

    # For evaluation
    valid_data, _ = loaddata(file_loc, 'valid')
    valid_data = data2index(valid_data, train_lang)
    evaluate(encoder, decoder, valid_data, train_lang['summary'],
             embedding_size)
コード例 #3
0
def main():
    print("Start Training")
    # Display Configuration
    showconfig()
    # Default parameter
    embedding_size = EMBEDDING_SIZE
    learning_rate = LR
    train_iter_time = ITER_TIME
    batch_size = BATCH_SIZE

    # For Training
    train_data, train_lang = loaddata(file_loc, 'train')
    if MAX_TRAIN_NUM is not None:
        train_data = train_data[:MAX_TRAIN_NUM]
    train_data = data2index(train_data, train_lang)
    encoder, decoder = train(train_data,
                             train_lang,
                             embedding_size=embedding_size,
                             learning_rate=learning_rate,
                             iter_time=train_iter_time,
                             batch_size=batch_size,
                             use_model=USE_MODEL)

    # For evaluation
    valid_data, _ = loaddata(file_loc, 'valid')
    valid_data = data2index(valid_data, train_lang)
    evaluate(encoder, decoder, valid_data, train_lang['summary'],
             embedding_size)
コード例 #4
0
def main(_):
    model_path = os.path.join('model', FLAGS.name)
    if os.path.exists(model_path) is False:
        os.makedirs(model_path)

    train_data, train_lang = loaddata(path, FLAGS.num_steps)
    vocab_size = train_lang.vocab_size

    converter = TextConverter(lang=train_lang, max_vocab=FLAGS.max_vocab)
    converter.save_lang(filename=FLAGS.name + '_converter.pkl')

    g = batch_generator(train_data, FLAGS.batch_size, FLAGS.max_steps)

    model = Seq2Seq('train',
                    vocab_size,
                    batch_size=FLAGS.batch_size,
                    num_steps=FLAGS.num_steps,
                    max_steps=FLAGS.max_steps,
                    lstm_size=FLAGS.lstm_size,
                    num_layers=FLAGS.num_layers,
                    learning_rate=FLAGS.learning_rate,
                    train_keep_prob=FLAGS.train_keep_prob,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size,
                    max_iters=FLAGS.max_iters,
                    bidirectional=FLAGS.bidirectional,
                    beam_search=False)
    model.train(
        g,
        converter,
        FLAGS.max_steps,
        model_path,
        FLAGS.save_every_n,
        FLAGS.log_every_n,
    )
コード例 #5
0
ファイル: evaluate.py プロジェクト: rchanda/Data2Doc
def generate_text(model, data_file, output):
    encoder_src = model['encoder_path']
    decoder_src = model['decoder_path']
    encoder_style = None

    # Choose model architecture
    if 'RNN' in encoder_src:
        encoder = EncoderRNN(embedding_size, emb)
        encoder_style = 'RNN'
    elif 'LSTM' in encoder_src:
        encoder = EncoderBiLSTM(embedding_size, emb)
        encoder_style = 'LSTM'
    else:
        encoder = EncoderLIN(embedding_size, emb)
        encoder_style = 'LIN'

    decoder = AttnDecoderRNN(embedding_size, langs['summary'].n_words)
    encoder = load_model(encoder, encoder_src)
    decoder = load_model(decoder, decoder_src)
    data_path = os.path.join(data_file['data_dir'], data_file['data_name'] + '.json')
    with open(data_path) as f:
        valuation_data = json.load(f)
    assert valuation_data is not None

    valid_data, _ = loaddata(data_file['data_dir'], data_file['data_name'])
    data_length = len(valid_data)
    valid_data = data2index(valid_data, train_lang)
    text_generator = evaluate(encoder, decoder, valid_data,
                              train_lang['summary'], embedding_size,
                              encoder_style=encoder_style, iter_time=data_length,
                              beam_size=1, verbose=False)
    print('The text generation begin\n', flush=True)
    with open(output, 'w') as f:
        for idx, line in enumerate(text_generator):
            print('Summery generated, No{}'.format(idx + 1))
            f.write(line + '\n')
コード例 #6
0
ファイル: evaluate.py プロジェクト: rchanda/Data2Doc
"""Evaluate the model."""
from dataprepare import loaddata, data2index
from train import evaluate
from model import AttnDecoderRNN, EncoderBiLSTM, EncoderRNN, EncoderLIN, docEmbedding
from settings import file_loc
from util import load_model

import json
import os
import configparser
import argparse

config = configparser.ConfigParser()

train_data, train_lang = loaddata(file_loc, 'train')

embedding_size = 600
langs = train_lang
emb = docEmbedding(langs['rt'].n_words, langs['re'].n_words,
                   langs['rm'].n_words, embedding_size)
emb.init_weights()

encoder = EncoderLIN(embedding_size, emb)


def generate_text(model, data_file, output):
    encoder_src = model['encoder_path']
    decoder_src = model['decoder_path']
    encoder_style = None

    # Choose model architecture
コード例 #7
0
"""Evaluate the model."""
from dataprepare import loaddata, data2index
from train import evaluate
from model import AttnDecoderRNN, EncoderRNN, EncoderLIN, docEmbedding, EncoderBiLSTM
from settings import file_loc, ENCODER_STYLE
from util import load_model

train_data, train_lang = loaddata(file_loc, 'train')

embedding_size = 600
langs = train_lang
emb = docEmbedding(langs['rt'].n_words, langs['re'].n_words,
                   langs['rm'].n_words, embedding_size)
emb.init_weights()

encoder_src = './models/long4_encoder_2120'
decoder_src = './models/long4_decoder_2120'

encoder_style = None

if 'RNN' == ENCODER_STYLE:
    encoder = EncoderRNN(embedding_size, emb)
    encoder_style = 'RNN'
elif 'LSTM' == ENCODER_STYLE:
    encoder = EncoderBiLSTM(embedding_size, emb)
    encoder_style = 'BiLSTM'
else:
    encoder = EncoderLIN(embedding_size, emb)
    encoder_style = 'LIN'

decoder = AttnDecoderRNN(embedding_size, langs['summary'].n_words)