Ejemplo n.º 1
0
def train(train_path='',
          test_path='',
          save_vocab_path='',
          attn_model_path='',
          batch_size=64,
          epochs=100,
          maxlen=400,
          hidden_dim=128,
          dropout=0.2,
          vocab_max_size=50000,
          vocab_min_count=5,
          gpu_id=0):
    source_texts, target_texts = build_dataset(train_path)
    test_input_texts, test_target_texts = build_dataset(test_path)

    # load or save word dict
    if os.path.exists(save_vocab_path):
        vocab2id = load_word_dict(save_vocab_path)
    else:
        print('Training data...')
        vocab2id = read_vocab(source_texts + target_texts,
                              max_size=vocab_max_size,
                              min_count=vocab_min_count)
        num_encoder_tokens = len(vocab2id)
        max_input_texts_len = max([len(text) for text in source_texts])

        print('input_texts:', source_texts[0])
        print('target_texts:', target_texts[0])
        print('num of samples:', len(source_texts))
        print('num of unique input tokens:', num_encoder_tokens)
        print('max sequence length for inputs:', max_input_texts_len)
        save_word_dict(vocab2id, save_vocab_path)

    id2vocab = {int(j): i for i, j in vocab2id.items()}
    print('The vocabulary file:%s, size: %s' %
          (save_vocab_path, len(vocab2id)))
    model = Seq2seqAttnModel(len(vocab2id),
                             attn_model_path=attn_model_path,
                             hidden_dim=hidden_dim,
                             dropout=dropout,
                             gpu_id=gpu_id).build_model()
    evaluator = Evaluate(model, attn_model_path, vocab2id, id2vocab, maxlen)
    earlystop = EarlyStopping(monitor='val_loss',
                              patience=3,
                              verbose=1,
                              mode='auto')
    model.fit_generator(
        data_generator(source_texts, target_texts, vocab2id, batch_size,
                       maxlen),
        steps_per_epoch=(len(source_texts) + batch_size - 1) // batch_size,
        epochs=epochs,
        validation_data=get_validation_data(test_input_texts,
                                            test_target_texts, vocab2id,
                                            maxlen),
        callbacks=[evaluator, earlystop])
Ejemplo n.º 2
0
 def __init__(self,
              save_vocab_path='',
              attn_model_path='',
              maxlen=400,
              gpu_id=0):
     if os.path.exists(save_vocab_path):
         self.vocab2id = load_word_dict(save_vocab_path)
         self.id2vocab = {int(j): i for i, j in self.vocab2id.items()}
     else:
         print('not exist vocab path')
     self.model = Seq2seqAttnModel(len(self.vocab2id),
                                   attn_model_path=attn_model_path,
                                   hidden_dim=128,
                                   dropout=0.0,
                                   gpu_id=gpu_id).build_model()
     self.maxlen = maxlen
Ejemplo n.º 3
0
                           result.split(' '), attention_image_path)
        except Exception as e:
            print(e)
            pass


if __name__ == "__main__":
    if config.gpu_id > -1:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id)
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
    inputs = [
        '以 前 , 包 括 中 国 , 我 国 也 是 。', '我 现 在 好 得 多 了 。', '这几年前时间,',
        '歌曲使人的感到快乐,', '会能够大幅减少互相抱怨的情况。'
    ]
    source_word2id = load_word_dict(config.save_src_vocab_path)
    target_word2id = load_word_dict(config.save_trg_vocab_path)
    model = Seq2SeqModel(source_word2id,
                         target_word2id,
                         embedding_dim=config.embedding_dim,
                         hidden_dim=config.hidden_dim,
                         batch_size=config.batch_size,
                         maxlen=config.maxlen,
                         checkpoint_path=config.model_dir,
                         gpu_id=config.gpu_id)
    for id, i in enumerate(inputs):
        img_path = os.path.join(config.output_dir, str(id) + ".png")
        infer(model, i, img_path)

# result:
# input:由我起开始做。