Example #1
0
def train(train_path='', test_path='', save_vocab_path='', attn_model_path='',
          batch_size=64, epochs=100, maxlen=400, hidden_dim=128, use_gpu=False):
    data_reader = CGEDReader(train_path)
    input_texts, target_texts = data_reader.build_dataset(train_path)
    test_input_texts, test_target_texts = data_reader.build_dataset(test_path)

    # load or save word dict
    if os.path.exists(save_vocab_path):
        char2id = load_word_dict(save_vocab_path)
        id2char = {int(j): i for i, j in char2id.items()}
        chars = set([i for i in char2id.keys()])
    else:
        print('Training data...')
        print('input_texts:', input_texts[0])
        print('target_texts:', target_texts[0])
        max_input_texts_len = max([len(text) for text in input_texts])

        print('num of samples:', len(input_texts))
        print('max sequence length for inputs:', max_input_texts_len)

        chars = data_reader.read_vocab(input_texts + target_texts)
        id2char = {i: j for i, j in enumerate(chars)}
        char2id = {j: i for i, j in id2char.items()}
        save_word_dict(char2id, save_vocab_path)

    model = Seq2seqAttnModel(chars,
                             attn_model_path=attn_model_path,
                             hidden_dim=hidden_dim,
                             use_gpu=use_gpu).build_model()
    evaluator = Evaluate(model, attn_model_path, char2id, id2char, maxlen)
    model.fit_generator(data_generator(input_texts, target_texts, char2id, batch_size, maxlen),
                        steps_per_epoch=(len(input_texts) + batch_size - 1) // batch_size,
                        epochs=epochs,
                        validation_data=get_validation_data(test_input_texts, test_target_texts, char2id, maxlen),
                        callbacks=[evaluator])
Example #2
0
 def __init__(self,
              save_vocab_path='',
              save_pinyin_path='',
              attn_model_path='',
              maxlen=400):
     if os.path.exists(save_vocab_path):
         self.char2id = load_word_dict(save_vocab_path)
         self.pinyins = load_word_dict(save_pinyin_path)
         self.id2char = {int(j): i for i, j in self.char2id.items()}
         self.id2pinyin = {int(j): i for i, j in self.pinyins.items()}
         self.chars = set([i for i in self.char2id.keys()])
         self.pinyin2id = set([i for i in self.pinyins.keys()])
     else:
         print('not exist vocab path')
     seq2seq_attn_model = Seq2seqAttn_multiembedding(
         self.chars, self.pinyin2id, attn_model_path=attn_model_path)
     self.model = seq2seq_attn_model.build_model()
     self.maxlen = maxlen
Example #3
0
 def __init__(self, save_vocab_path='', attn_model_path='', maxlen=400):
     if os.path.exists(save_vocab_path):
         self.char2id = load_word_dict(save_vocab_path)
         self.id2char = {int(j): i for i, j in self.char2id.items()}
         self.chars = set([i for i in self.char2id.keys()])
     else:
         print('not exist vocab path')
     seq2seq_attn_model = Seq2seqAttnModel(self.chars,
                                           attn_model_path=attn_model_path,
                                           hidden_dim=config.rnn_hidden_dim,
                                           use_gpu=config.use_gpu)
     self.model = seq2seq_attn_model.build_model()
     self.maxlen = maxlen
Example #4
0
import numpy as np
from pycorrector.seq2seq_attention import config
from pycorrector.seq2seq_attention.corpus_reader import CGEDReader, str2id, padding, load_word_dict, save_word_dict


def data_generator(input_texts, target_texts, char2id, batch_size, maxlen=400):
    # 数据生成器
    while True:
        X, Y = [], []
        for i in range(len(input_texts)):
            X.append(str2id(input_texts[i], char2id, maxlen))
            # print("X", X)
            Y.append(str2id(target_texts[i], char2id, maxlen))
            # print("Y", Y)
            if len(X) == batch_size:
                X = np.array(padding(X, char2id))
                Y = np.array(padding(Y, char2id))
                # yield [X, Y], None
                X, Y = [], []


if __name__ == '__main__':
    train_path = config.test_path
    data_reader = CGEDReader(train_path)
    input_texts, target_texts = data_reader.build_dataset(
        train_path)  # 原始文本,修改后文本
    print("input_text:", input_texts)
    save_vocab_path = config.save_vocab_path
    char2id = load_word_dict(save_vocab_path)
    batch_size = config.batch_size
    data_generator(input_texts, target_texts, char2id, batch_size)
Example #5
0
def train(train_path='',
          test_path='',
          save_vocab_path='',
          save_pinyin_path='',
          attn_model_path='',
          batch_size=64,
          epochs=100,
          maxlen=400,
          hidden_dim=128,
          dropout=0.2,
          use_gpu=False):
    data_reader = CGEDReader(train_path)
    input_texts, target_texts = data_reader.build_dataset(
        train_path)  # 原始文本,修改后文本
    # todo 获取相应拼音文本
    input_pinyins = to_pinyin(input_texts)
    output_pinyins = to_pinyin(target_texts)

    test_input_texts, test_target_texts = data_reader.build_dataset(test_path)
    test_input_pinyins = to_pinyin(test_input_texts)
    test_output_pinyins = to_pinyin(test_target_texts)

    # load or save word dict
    if os.path.exists(save_vocab_path):
        char2id = load_word_dict(save_vocab_path)
        id2char = {int(j): i
                   for i, j in char2id.items()
                   }  # {0: '"', 1: '%', 2: '(', 3: ')', 4: '+'}
        chars = set([i for i in char2id.keys()])  # {'+', '%', ')', '(', '"'}
    else:
        print('Training data...')
        print('input_texts:', input_texts[0])
        print('target_texts:', target_texts[0])
        max_input_texts_len = max([len(text) for text in input_texts])

        print('num of samples:', len(input_texts))
        print('max sequence length for inputs:', max_input_texts_len)
        # chars:保存所有字符,存放到列表中
        chars = data_reader.read_vocab(input_texts + target_texts)
        id2char = {i: j for i, j in enumerate(chars)}
        char2id = {j: i for i, j in id2char.items()}
        save_word_dict(char2id, save_vocab_path)
    print("chars:", chars)

    # load pinyin vocab dict
    if os.path.exists(save_pinyin_path):
        pingyin2id = load_word_dict(save_pinyin_path)
        id2pinyin = {int(j): i for i, j in pingyin2id.items()}
        pinyins = set([i for i in pingyin2id.keys()])
    else:
        pinyins = data_reader.read_vocab(input_pinyins + output_pinyins)
        id2pinyin = {i: j for i, j in enumerate(pinyins)}
        pingyin2id = {j: i for i, j in id2pinyin.items()}
        save_word_dict(pingyin2id, save_pinyin_path)
    print("pinyins:", pinyins)

    model = Seq2seqAttn_multiembedding(chars,
                                       pinyins,
                                       attn_model_path=attn_model_path,
                                       hidden_dim=hidden_dim,
                                       use_gpu=use_gpu,
                                       dropout=dropout).build_model()
    # evaluator = Evaluate(model, attn_model_path, char2id, id2char, maxlen)
    # model.fit_generator(data_generator(input_texts, target_texts, char2id, input_pinyins, output_pinyins, pingyin2id,batch_size, maxlen),
    #                     steps_per_epoch=(len(input_texts) + batch_size - 1) // batch_size,
    #                     epochs=epochs,
    #                     validation_data=get_validation_data(test_input_texts, test_target_texts, char2id,test_input_pinyins, test_output_pinyins, pingyin2id, maxlen),
    #                     callbacks=[evaluator])
    model.fit_generator(
        data_generator(input_texts, target_texts, char2id, input_pinyins,
                       output_pinyins, pingyin2id, batch_size, maxlen),
        steps_per_epoch=(len(input_texts) + batch_size - 1) // batch_size,
        epochs=epochs,
        validation_data=get_validation_data(test_input_texts,
                                            test_target_texts, char2id,
                                            test_input_pinyins,
                                            test_output_pinyins, pingyin2id,
                                            maxlen))