def re_train():
    train, dev = read_corpus(filename='emergency_train.tsv', test_size=0.2)

    model_path = os.path.join(MODEL_PATH, FLAGS.DEMO, 'checkpoints')
    ckpt_file = tf.train.latest_checkpoint(model_path)

    logging.info("load pre-train model from {}".format(ckpt_file))
    textAttRNN = TextAttRNN(
        config=cfg(),
        model_path=ckpt_file,
        vocab=word2int,
        tag2label=tag2label,
        batch_size=FLAGS.batch_size,
        embed_size=FLAGS.embed_size,
        sequence_length=FLAGS.sequence_length,
        eopches=FLAGS.epoches,
    )

    saver = tf.compat.v1.train.Saver()

    with tf.compat.v1.Session(config=cfg()) as sess:
        saver.restore(sess, ckpt_file)
        textAttRNN.set_model_path(
            model_path=os.path.join(MODEL_PATH, FLAGS.DEMO))
        textAttRNN.train(sess, train, dev, shuffle=True, re_train=True)
def train():
    train, dev = read_corpus(filename='emergency_train.tsv')
    textAttRNN = TextAttRNN(config=cfg(),
                            model_path=os.path.join(MODEL_PATH, FLAGS.DEMO),
                            vocab=word2int,
                            tag2label=tag2label,
                            batch_size=FLAGS.batch_size,
                            embed_size=FLAGS.embed_size,
                            sequence_length=FLAGS.sequence_length,
                            eopches=FLAGS.epoches)

    with tf.compat.v1.Session(config=cfg()) as sess:
        textAttRNN.train(sess, train, dev, shuffle=True)
def train():
    iter = -1
    iter_size = 20000
    train, dev = read_corpus(random_state=1234, separator='\t', iter=iter, iter_size=iter_size)
    textCNN = TextAttRNN(config=cfg(),
                         model_path=os.path.join(MODEL_PATH, FLAGS.DEMO),
                         vocab=word2int,
                         tag2label=tag2label,
                         batch_size=FLAGS.batch_size,
                         embed_size=FLAGS.embed_size,
                         eopches=FLAGS.epoches)

    with tf.compat.v1.Session(config=cfg()) as sess:
        textCNN.train(sess, train, dev, shuffle=True)