Exemplo n.º 1
0
def main(_):
    if not os.path.exists(ckpt_path + 'checkpoint'):
        print('there is not saved model, please check the ckpt path')
        exit()
    print('Loading model...')
    W_embedding = np.load(embedding_path)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        model = network.HAN(W_embedding, settings)
        model.saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
        print('Local predicting...')
        local_predict(sess, model)
        print('Test predicting...')
        predict(sess, model)
Exemplo n.º 2
0
def main(_):
    if not os.path.exists(ckpt_path + 'checkpoint'):
        print('there is not saved model, please check the ckpt path')
        exit()
    print('Loading model...')
    W_embedding = np.load(embedding_path)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    log_path = scores_path + settings.model_name + '/'
    logger = get_logger(log_path + 'predict.log')
    with tf.Session(config=config) as sess:
        model = network.HAN(W_embedding, settings)
        # ckpt_path: /ckpt/wd_BiGRU/checkpoint
        model.saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))

        print('Valid predicting...')
        predict_valid(sess, model, logger)

        print('Test predicting...')
        predict(sess, model, logger)
Exemplo n.º 3
0
def main(_):
    global ckpt_path
    global last_f1
    if not os.path.exists(ckpt_path):
        os.makedirs(ckpt_path)
    if not os.path.exists(summary_path):
        os.makedirs(summary_path)
    elif not FLAGS.is_retrain:  # 重新训练本模型,删除以前的 summary
        shutil.rmtree(summary_path)
        os.makedirs(summary_path)
    if not os.path.exists(summary_path):
        os.makedirs(summary_path)

    print('1.Loading data...')
    W_embedding = np.load(embedding_path)
    print('training sample_num = %d' % n_tr_batches)
    print('valid sample_num = %d' % n_va_batches)

    # Initial or restore the model
    print('2.Building model...')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        model = network.HAN(W_embedding, settings)
        with tf.variable_scope('training_ops') as vs:
            learning_rate = tf.train.exponential_decay(FLAGS.lr,
                                                       model.global_step,
                                                       FLAGS.decay_step,
                                                       FLAGS.decay_rate,
                                                       staircase=True)
            # two optimizer: op1, update embedding; op2, do not update embedding.
            with tf.variable_scope('Optimizer1'):
                tvars1 = tf.trainable_variables()
                grads1 = tf.gradients(model.loss, tvars1)
                optimizer1 = tf.train.AdamOptimizer(
                    learning_rate=learning_rate)
                train_op1 = optimizer1.apply_gradients(
                    zip(grads1, tvars1), global_step=model.global_step)
            with tf.variable_scope('Optimizer2'):
                tvars2 = [
                    tvar for tvar in tvars1 if 'embedding' not in tvar.name
                ]
                grads2 = tf.gradients(model.loss, tvars2)
                optimizer2 = tf.train.AdamOptimizer(
                    learning_rate=learning_rate)
                train_op2 = optimizer2.apply_gradients(
                    zip(grads2, tvars2), global_step=model.global_step)
            update_op = tf.group(*model.update_emas)
            merged = tf.summary.merge_all()  # summary
            train_writer = tf.summary.FileWriter(summary_path + 'train',
                                                 sess.graph)
            test_writer = tf.summary.FileWriter(summary_path + 'test')
            training_ops = [
                v for v in tf.global_variables()
                if v.name.startswith(vs.name + '/')
            ]

        # 如果已经保存过模型,导入上次的模型
        if os.path.exists(ckpt_path + "checkpoint"):
            print("Restoring Variables from Checkpoint...")
            model.saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
            last_valid_cost, precision, recall, last_f1 = valid_epoch(
                data_valid_path, sess, model)
            print(' valid cost=%g; p=%g, r=%g, f1=%g' %
                  (last_valid_cost, precision, recall, last_f1))
            sess.run(tf.variables_initializer(training_ops))
            train_op2 = train_op1
        else:
            print('Initializing Variables...')
            sess.run(tf.global_variables_initializer())

        print('3.Begin training...')
        print('max_epoch=%d, max_max_epoch=%d' %
              (FLAGS.max_epoch, FLAGS.max_max_epoch))
        train_op = train_op2
        for epoch in xrange(FLAGS.max_max_epoch):
            global_step = sess.run(model.global_step)
            print('Global step %d, lr=%g' %
                  (global_step, sess.run(learning_rate)))
            if epoch == FLAGS.max_epoch:  # update the embedding
                train_op = train_op1
            train_fetches = [merged, model.loss, train_op, update_op]
            valid_fetches = [merged, model.loss]
            train_epoch(data_train_path, sess, model, train_fetches,
                        valid_fetches, train_writer, test_writer)
        # 最后再做一次验证
        valid_cost, precision, recall, f1 = valid_epoch(
            data_valid_path, sess, model)
        print('END.Global_step=%d: valid cost=%g; p=%g, r=%g, f1=%g' %
              (sess.run(model.global_step), valid_cost, precision, recall, f1))
        if f1 > last_f1:  # save the better model
            saving_path = model.saver.save(sess, model_path,
                                           sess.run(model.global_step) + 1)
            print('saved new model to %s ' % saving_path)