def main(_): if not os.path.exists(ckpt_path + 'checkpoint'): print('there is not saved model, please check the ckpt path') exit() print('Loading model...') W_embedding = np.load(embedding_path) config = tf.ConfigProto() config.gpu_options.allow_growth = True log_path = scores_path + settings.model_name + '/' if not os.path.exists(log_path): os.makedirs(log_path) logger = get_logger(log_path + 'predict.log') with tf.Session(config=config) as sess: model = network.Atten_TextCNN(W_embedding, settings) model.saver.restore(sess, tf.train.latest_checkpoint(ckpt_path)) # 保存二进制模型 # output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['logits'] ) # with tf.gfile.FastGFile('attention_cnn.pb', mode='wb') as f: # f.write(output_graph_def.SerializeToString()) # exit(0) print('Valid predicting...') predict_valid(sess, model, logger) print('Test predicting...') predict(sess, model, logger)
def main(_): if not os.path.exists(ckpt_path + 'checkpoint'): print('there is not saved model, please check the ckpt path') exit() print('Loading model...') W_embedding = np.load(embedding_path) config = tf.ConfigProto() config.gpu_options.allow_growth = True log_path = scores_path + settings.model_name + '/' if not os.path.exists(log_path): os.makedirs(log_path) logger = get_logger(log_path + 'predict.log') with tf.Session(config=config) as sess: model = network.Atten_TextCNN(W_embedding, settings) model.saver.restore(sess, tf.train.latest_checkpoint(ckpt_path)) print('Test predicting...') predict(sess, model, logger)
def main(_): global ckpt_path global last_score12 if not os.path.exists(ckpt_path): os.makedirs(ckpt_path) if not os.path.exists(summary_path): os.makedirs(summary_path) elif not FLAGS.is_retrain: shutil.rmtree(summary_path) os.makedirs(summary_path) if not os.path.exists(summary_path): os.makedirs(summary_path) if not os.path.exists(log_path): os.makedirs(log_path) print('1.Loading data...') W_embedding = np.load(embedding_path) print('training sample_num = %d' % n_tr_batches) print('valid sample_num = %d' % n_va_batches) logger = get_logger(log_path + FLAGS.log_file_train) print('2.Building model...') config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: model = network.Atten_TextCNN(W_embedding, settings) with tf.variable_scope('training_ops') as vs: learning_rate = tf.train.exponential_decay(FLAGS.lr, model.global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=True) with tf.variable_scope('Optimizer1'): tvars1 = tf.trainable_variables() grads1 = tf.gradients(model.loss, tvars1) optimizer1 = tf.train.AdamOptimizer( learning_rate=learning_rate) train_op1 = optimizer1.apply_gradients( zip(grads1, tvars1), global_step=model.global_step) with tf.variable_scope('Optimizer2'): tvars2 = [ tvar for tvar in tvars1 if 'embedding' not in tvar.name ] grads2 = tf.gradients(model.loss, tvars2) optimizer2 = tf.train.AdamOptimizer( learning_rate=learning_rate) train_op2 = optimizer2.apply_gradients( zip(grads2, tvars2), global_step=model.global_step) update_op = tf.group(*model.update_emas) merged = tf.summary.merge_all() # summary train_writer = tf.summary.FileWriter(summary_path + 'train', sess.graph) test_writer = tf.summary.FileWriter(summary_path + 'test') training_ops = [ v for v in tf.global_variables() if v.name.startswith(vs.name + '/') ] if os.path.exists(ckpt_path + "checkpoint"): print("Restoring Variables from Checkpoint...") model.saver.restore(sess, tf.train.latest_checkpoint(ckpt_path)) f1_micro, f1_macro, score12 = valid_epoch(data_valid_path, sess, model) print('f1_micro=%g, f1_macro=%g, score12=%g' % (f1_micro, f1_macro, score12)) sess.run(tf.variables_initializer(training_ops)) train_op2 = train_op1 else: print('Initializing Variables...') sess.run(tf.global_variables_initializer()) print('3.Begin training...') print('max_epoch=%d, max_max_epoch=%d' % (FLAGS.max_epoch, FLAGS.max_max_epoch)) logger.info('max_epoch={}, max_max_epoch={}'.format( FLAGS.max_epoch, FLAGS.max_max_epoch)) train_op = train_op2 for epoch in range(FLAGS.max_max_epoch): print('\nepoch: ', epoch) logger.info('epoch:{}'.format(epoch)) global_step = sess.run(model.global_step) print('Global step %d, lr=%g' % (global_step, sess.run(learning_rate))) if epoch == FLAGS.max_epoch: train_op = train_op1 train_fetches = [merged, model.loss, train_op, update_op] valid_fetches = [merged, model.loss] train_epoch(data_train_path, sess, model, train_fetches, valid_fetches, train_writer, test_writer, logger) # 最后再做一次验证 f1_micro, f1_macro, score12 = valid_epoch(data_valid_path, sess, model) print('END:Global_step=%d: f1_micro=%g, f1_macro=%g, score12=%g' % (sess.run(model.global_step), f1_micro, f1_macro, score12)) logger.info( 'END:Global_step={}: f1_micro={}, f1_macro={}, score12={}'.format( sess.run(model.global_step), f1_micro, f1_macro, score12)) if score12 > last_score12: saving_path = model.saver.save(sess, model_path, sess.run(model.global_step) + 1) print('saved new model to %s ' % saving_path) logger.info('saved new model to {}'.format(saving_path))