Exemplo n.º 1
0
def main(_):
    if not os.path.exists(ckpt_path + 'checkpoint'):
        print('there is not saved model, please check the ckpt path')
        exit()
    print('Loading model...')
    W_embedding = np.load(embedding_path)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    log_path = scores_path + settings.model_name + '/'
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    logger = get_logger(log_path + 'predict.log')
    with tf.Session(config=config) as sess:
        model = network.Atten_TextCNN(W_embedding, settings)
        model.saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
        # 保存二进制模型
        # output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['logits'] )
        # with tf.gfile.FastGFile('attention_cnn.pb', mode='wb') as f:
        #     f.write(output_graph_def.SerializeToString())
        # exit(0)
        print('Valid predicting...')
        predict_valid(sess, model, logger)

        print('Test predicting...')
        predict(sess, model, logger)
Exemplo n.º 2
0
def main(_):
    if not os.path.exists(ckpt_path + 'checkpoint'):
        print('there is not saved model, please check the ckpt path')
        exit()
    print('Loading model...')
    W_embedding = np.load(embedding_path)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    log_path = scores_path + settings.model_name + '/'
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    logger = get_logger(log_path + 'predict.log')
    with tf.Session(config=config) as sess:
        model = network.Atten_TextCNN(W_embedding, settings)
        model.saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))

        print('Test predicting...')
        predict(sess, model, logger)
Exemplo n.º 3
0
def main(_):
    global ckpt_path
    global last_score12
    if not os.path.exists(ckpt_path):
        os.makedirs(ckpt_path)
    if not os.path.exists(summary_path):
        os.makedirs(summary_path)
    elif not FLAGS.is_retrain:
        shutil.rmtree(summary_path)
        os.makedirs(summary_path)
    if not os.path.exists(summary_path):
        os.makedirs(summary_path)
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    print('1.Loading data...')
    W_embedding = np.load(embedding_path)
    print('training sample_num = %d' % n_tr_batches)
    print('valid sample_num = %d' % n_va_batches)
    logger = get_logger(log_path + FLAGS.log_file_train)

    print('2.Building model...')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        model = network.Atten_TextCNN(W_embedding, settings)
        with tf.variable_scope('training_ops') as vs:
            learning_rate = tf.train.exponential_decay(FLAGS.lr,
                                                       model.global_step,
                                                       FLAGS.decay_step,
                                                       FLAGS.decay_rate,
                                                       staircase=True)
            with tf.variable_scope('Optimizer1'):
                tvars1 = tf.trainable_variables()
                grads1 = tf.gradients(model.loss, tvars1)
                optimizer1 = tf.train.AdamOptimizer(
                    learning_rate=learning_rate)
                train_op1 = optimizer1.apply_gradients(
                    zip(grads1, tvars1), global_step=model.global_step)
            with tf.variable_scope('Optimizer2'):
                tvars2 = [
                    tvar for tvar in tvars1 if 'embedding' not in tvar.name
                ]
                grads2 = tf.gradients(model.loss, tvars2)
                optimizer2 = tf.train.AdamOptimizer(
                    learning_rate=learning_rate)
                train_op2 = optimizer2.apply_gradients(
                    zip(grads2, tvars2), global_step=model.global_step)
            update_op = tf.group(*model.update_emas)
            merged = tf.summary.merge_all()  # summary
            train_writer = tf.summary.FileWriter(summary_path + 'train',
                                                 sess.graph)
            test_writer = tf.summary.FileWriter(summary_path + 'test')
            training_ops = [
                v for v in tf.global_variables()
                if v.name.startswith(vs.name + '/')
            ]

        if os.path.exists(ckpt_path + "checkpoint"):
            print("Restoring Variables from Checkpoint...")
            model.saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
            f1_micro, f1_macro, score12 = valid_epoch(data_valid_path, sess,
                                                      model)
            print('f1_micro=%g, f1_macro=%g, score12=%g' %
                  (f1_micro, f1_macro, score12))
            sess.run(tf.variables_initializer(training_ops))
            train_op2 = train_op1
        else:
            print('Initializing Variables...')
            sess.run(tf.global_variables_initializer())

        print('3.Begin training...')
        print('max_epoch=%d, max_max_epoch=%d' %
              (FLAGS.max_epoch, FLAGS.max_max_epoch))
        logger.info('max_epoch={}, max_max_epoch={}'.format(
            FLAGS.max_epoch, FLAGS.max_max_epoch))
        train_op = train_op2
        for epoch in range(FLAGS.max_max_epoch):
            print('\nepoch: ', epoch)
            logger.info('epoch:{}'.format(epoch))
            global_step = sess.run(model.global_step)
            print('Global step %d, lr=%g' %
                  (global_step, sess.run(learning_rate)))
            if epoch == FLAGS.max_epoch:
                train_op = train_op1

            train_fetches = [merged, model.loss, train_op, update_op]
            valid_fetches = [merged, model.loss]
            train_epoch(data_train_path, sess, model, train_fetches,
                        valid_fetches, train_writer, test_writer, logger)
        # 最后再做一次验证
        f1_micro, f1_macro, score12 = valid_epoch(data_valid_path, sess, model)
        print('END:Global_step=%d: f1_micro=%g, f1_macro=%g, score12=%g' %
              (sess.run(model.global_step), f1_micro, f1_macro, score12))
        logger.info(
            'END:Global_step={}: f1_micro={}, f1_macro={}, score12={}'.format(
                sess.run(model.global_step), f1_micro, f1_macro, score12))
        if score12 > last_score12:
            saving_path = model.saver.save(sess, model_path,
                                           sess.run(model.global_step) + 1)
            print('saved new model to %s ' % saving_path)
            logger.info('saved new model to {}'.format(saving_path))