예제 #1
0
def main(_):
    start_time = time.time()
    model_path = os.path.join('model', FLAGS.name)
    if os.path.exists(model_path) is False:
        os.makedirs(model_path)
    with open(FLAGS.input_file, 'r') as f:
        text = f.read()
    converter = TextConverter(text, FLAGS.max_vocab)
    converter.save_to_file(os.path.join(model_path, 'converter.pkl'))

    arr = converter.text_to_arr(text)
    g = batch_generator(arr, FLAGS.num_seqs, FLAGS.num_steps)
    print(converter.vocab_size)
    model = CharRNN(converter.vocab_size,
                    num_seqs=FLAGS.num_seqs,
                    num_steps=FLAGS.num_steps,
                    lstm_size=FLAGS.lstm_size,
                    num_layers=FLAGS.num_layers,
                    learning_rate=FLAGS.learning_rate,
                    train_keep_prob=FLAGS.train_keep_prob,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)
    model.train(
        g,
        FLAGS.max_steps,
        model_path,
        FLAGS.save_every_n,
        FLAGS.log_every_n,
    )
    print("Timing cost is --- %s ---second(s)" % (time.time() - start_time))
예제 #2
0
def main(_):
    converter = TextConverter(filename=FLAGS.converter_path)
    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path = \
            tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    model = BilstmNer(converter.vocab_size,
                      converter.num_classes,
                      lstm_size=FLAGS.lstm_size,
                      embedding_size=FLAGS.embedding_size)
    print("[*] Success to read {}".format(FLAGS.checkpoint_path))
    model.load(FLAGS.checkpoint_path)

    demo_sent = "京剧研究院就美日联合事件讨论"
    tag = model.demo([(converter.text_to_arr(demo_sent), [0] * len(demo_sent))
                      ])
    print(tag)