Exemplo n.º 1
0
def main(_):
    model_path = os.path.join('model', FLAGS.name)  # 拼接路径model/'name'
    if os.path.exists(model_path) is False:
        os.makedirs(model_path)
    with codecs.open(FLAGS.input_file, encoding='utf-8') as f:
        text = f.read()  # 读取文本
    converter = TextConverter(text,
                              FLAGS.max_vocab)  # 文本转换为词汇且截取FLAGS.max_vocab个词
    converter.save_to_file(os.path.join(model_path,
                                        'converter.pk1'))  # 序列化存储词汇

    data = converter.text_to_data(text)  # 将文本转化为输入(word_to_int)
    g = batch_generator(data, FLAGS.n_seqs, FLAGS.n_steps)  # 获取batch生成器
    print(converter.vocab_size)
    # 模型参数初始化
    model = CharRNN(converter.vocab_size,
                    n_seqs=FLAGS.n_seqs,
                    n_steps=FLAGS.n_steps,
                    state_size=FLAGS.state_size,
                    n_layers=FLAGS.n_layers,
                    learning_rate=FLAGS.learning_rate,
                    train_keep_prob=FLAGS.train_keep_prob,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)
    model.train(g, FLAGS.max_steps, model_path, FLAGS.save_every_n,
                FLAGS.log_every_n)
Exemplo n.º 2
0
def main(_):
    start_time = time.time()
    model_path = os.path.join('model', FLAGS.name)
    if os.path.exists(model_path) is False:
        os.makedirs(model_path)
    with open(FLAGS.input_file, 'r') as f:
        text = f.read()
    converter = TextConverter(text, FLAGS.max_vocab)
    converter.save_to_file(os.path.join(model_path, 'converter.pkl'))

    arr = converter.text_to_arr(text)
    g = batch_generator(arr, FLAGS.num_seqs, FLAGS.num_steps)
    print(converter.vocab_size)
    model = CharRNN(converter.vocab_size,
                    num_seqs=FLAGS.num_seqs,
                    num_steps=FLAGS.num_steps,
                    lstm_size=FLAGS.lstm_size,
                    num_layers=FLAGS.num_layers,
                    learning_rate=FLAGS.learning_rate,
                    train_keep_prob=FLAGS.train_keep_prob,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)
    model.train(
        g,
        FLAGS.max_steps,
        model_path,
        FLAGS.save_every_n,
        FLAGS.log_every_n,
    )
    print("Timing cost is --- %s ---second(s)" % (time.time() - start_time))
Exemplo n.º 3
0
def composePotery():
    converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path =\
            tf.train.latest_checkpoint(FLAGS.checkpoint_path)

    model = CharRNN(converter.vocab_size, sampling=True,
                    lstm_size=FLAGS.lstm_size, num_layers=FLAGS.num_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)
    model.load(FLAGS.checkpoint_path)

    start = []
    arr = model.sample(FLAGS.max_length, start, converter.vocab_size)
    rawText = converter.arr_to_text(arr)
    return(selectPoetry(rawText))
Exemplo n.º 4
0
def main(_):
    converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path = \
            tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    model = BilstmNer(converter.vocab_size,
                      converter.num_classes,
                      lstm_size=FLAGS.lstm_size,
                      embedding_size=FLAGS.embedding_size)
    print("[*] Success to read {}".format(FLAGS.checkpoint_path))
    model.load(FLAGS.checkpoint_path)

    demo_sent = "京剧研究院就美日联合事件讨论"
    tag = model.demo([(converter.text_to_arr(demo_sent), [0] * len(demo_sent))
                      ])
    print(tag)
Exemplo n.º 5
0
def main():
    '''main function'''
    parser = argparse.ArgumentParser()
    parser.add_argument('--state', required=True, help='训练还是预测, train or eval')
    parser.add_argument('--txt', required=True, help='进行训练的txt文件')
    parser.add_argument('--batch', default=128, type=int, help='训练的batch size')
    parser.add_argument('--epoch', default=5000, type=int, help='跑多少个epoch')
    parser.add_argument('--len', default=100, type=int, help='输入模型的序列长度')
    parser.add_argument('--max_vocab',
                        default=5000,
                        type=int,
                        help='最多存储的字符数目')
    parser.add_argument('--embed', default=512, type=int, help='词向量的维度')
    parser.add_argument('--hidden', default=512, type=int, help='RNN的输出维度')
    parser.add_argument('--n_layer', default=2, type=int, help='RNN的层数')
    parser.add_argument('--dropout',
                        default=0.5,
                        type=float,
                        help='RNN中drop的概率')
    parser.add_argument('--begin', default='我', help='给出生成文本的开始')
    parser.add_argument('--pred_len', default=20, type=int, help='生成文本的长度')
    parser.add_argument('--checkpoint', help='载入模型的位置')
    opt = parser.parse_args()
    print(opt)

    convert = TextConverter(opt.txt, max_vocab=opt.max_vocab)
    model = CharRNN(convert.vocab_size, opt.embed, opt.hidden, opt.n_layer,
                    opt.dropout)

    model.initialize(ctx=ctx)

    if opt.state == 'train':
        dataset = TextData(opt.txt, opt.len, convert.text_to_arr)
        dataloader = g.data.DataLoader(dataset, opt.batch, shuffle=True)
        lr_sch = mx.lr_scheduler.FactorScheduler(int(1000 * len(dataloader)),
                                                 factor=0.1)
        optimizer = g.Trainer(model.collect_params(), 'adam', {
            'learning_rate': 1e-3,
            'clip_gradient': 3,
            'lr_scheduler': lr_sch
        })
        cross_entorpy = g.loss.SoftmaxCrossEntropyLoss()
        train(opt.epoch, model, dataloader, optimizer, cross_entorpy)

    elif opt.state == 'eval':
        pred_text = sample(model, opt.checkpoint, convert.word_to_int,
                           convert.arr_to_text, opt.begin, opt.pred_len)
        print(pred_text)
        with open('./generate.txt', 'a') as f:
            f.write(pred_text)
            f.write('\n')
    else:
        print('Error state, must choose from train and eval!')
Exemplo n.º 6
0
def main(_):
    FLAGS.start_string = FLAGS.start_string
    converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path = \
            tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    model = CharRNN(converter.vocab_size,
                    sampling=True,
                    state_size=FLAGS.state_size,
                    n_layers=FLAGS.n_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)
    model.load(FLAGS.checkpoint_path)

    start = converter.text_to_data(FLAGS.start_string)
    data = model.sample(FLAGS.max_length, start, converter.vocab_size)
    # for c in converter.data_to_text(data):
    #     for d in c:
    #         # print(d,end="")
    #         time.sleep(0.5)
    print(converter.data_to_text(data))
Exemplo n.º 7
0
def main(_):
    model_path = os.path.join('model', FLAGS.name)
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    data = read_corpus(FLAGS.input_file)
    converter = TextConverter(data, FLAGS.max_vocab)
    converter.save_to_file(os.path.join(model_path, 'converter.pkl'))

    g = batch_generator(data, FLAGS.batch_size)
    print(converter.vocab_size)
    model = BilstmNer(converter.vocab_size,
                      converter.num_classes,
                      lstm_size=FLAGS.lstm_size,
                      learning_rate=FLAGS.learning_rate,
                      train_keep_prob=FLAGS.train_keep_prob,
                      embedding_size=FLAGS.embedding_size)
    model.train(
        g,
        FLAGS.max_steps,
        model_path,
        FLAGS.save_every_n,
        FLAGS.log_every_n,
    )
Exemplo n.º 8
0
def main():
    '''main function'''
    parser = argparse.ArgumentParser()
    parser.add_argument('--state', required=True, help='训练还是预测, train or eval')
    parser.add_argument('--txt', required=True, help='进行训练的txt文件')
    parser.add_argument('--batch', default=128, type=int, help='训练的batch size')
    parser.add_argument('--epoch', default=5000, type=int, help='跑多少个epoch')
    parser.add_argument('--len', default=100, type=int, help='输入模型的序列长度')
    parser.add_argument('--max_vocab',
                        default=5000,
                        type=int,
                        help='最多存储的字符数目')
    parser.add_argument('--embed', default=512, type=int, help='词向量的维度')
    parser.add_argument('--hidden', default=512, type=int, help='RNN的输出维度')
    parser.add_argument('--n_layer', default=2, type=int, help='RNN的层数')
    parser.add_argument('--dropout', default=0.5, help='RNN中drop的概率')
    parser.add_argument('--begin', default='我', help='给出生成文本的开始')
    parser.add_argument('--pred_len', default=20, type=int, help='生成文本的长度')
    parser.add_argument('--checkpoint', help='载入模型的位置')
    opt = parser.parse_args()
    print(opt)

    convert = TextConverter(opt.txt, max_vocab=opt.max_vocab)
    model = CharRNN(convert.vocab_size, opt.embed, opt.hidden, opt.n_layer,
                    opt.dropout)
    model.load_state_dict(torch.load('./poetry_checkpoints/model_300.pth'))
    if use_gpu:
        model = model.cuda()

    if opt.state == 'train':
        model.train()
        dataset = TextData(opt.txt, opt.len, convert.text_to_arr)
        dataloader = data.DataLoader(dataset,
                                     opt.batch,
                                     shuffle=True,
                                     num_workers=4)
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        criterion = nn.CrossEntropyLoss(size_average=False)
        train(opt.epoch, model, dataloader, optimizer, criterion)

    elif opt.state == 'eval':
        pred_text = sample(model, opt.checkpoint, convert.word_to_int,
                           convert.arr_to_text, opt.begin, opt.pred_len)
        print(pred_text)
        with open('./generate.txt', 'a') as f:
            f.write(pred_text)
            f.write('\n')
    else:
        print('Error state, must choose from train and eval!')