def test_mann(): """Test MANN model.""" # Load data logger.info("✔ Loading data...") logger.info('Recommended padding Sequence length is: {0}'.format( FLAGS.pad_seq_len)) logger.info('✔︎ Test data processing...') test_data = dh.load_data_and_labels(FLAGS.test_data_file, FLAGS.num_classes, FLAGS.embedding_dim) logger.info('✔︎ Test data padding...') x_test, y_test = dh.pad_data(test_data, FLAGS.pad_seq_len) y_test_labels = test_data.labels # Load mann model logger.info("✔ Loading model...") checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) logger.info(checkpoint_file) graph = tf.Graph() with graph.as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph( "{0}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) # Get the placeholders from the graph by name input_x = graph.get_operation_by_name("input_x").outputs[0] input_y = graph.get_operation_by_name("input_y").outputs[0] dropout_keep_prob = graph.get_operation_by_name( "dropout_keep_prob").outputs[0] is_training = graph.get_operation_by_name("is_training").outputs[0] # Tensors we want to evaluate scores = graph.get_operation_by_name("output/scores").outputs[0] loss = graph.get_operation_by_name("loss/loss").outputs[0] # Split the output nodes name by '|' if you have several output nodes output_node_names = 'output/logits|output/scores' # Save the .pb model file output_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names.split("|")) tf.train.write_graph(output_graph_def, 'graph', 'graph-mann-{0}.pb'.format(MODEL), as_text=False) # Generate batches for one epoch batches = dh.batch_iter(list(zip(x_test, y_test, y_test_labels)), FLAGS.batch_size, 1, shuffle=False) # Collect the predictions here all_labels = [] all_predicted_labels = [] all_predicted_values = [] # Calculate the metric test_counter, test_loss, test_rec, test_pre, test_F = 0, 0.0, 0.0, 0.0, 0.0 for batch_test in batches: x_batch_test, y_batch_test, y_batch_test_labels = zip( *batch_test) feed_dict = { input_x: x_batch_test, input_y: y_batch_test, dropout_keep_prob: 1.0, is_training: False } batch_scores, cur_loss = sess.run([scores, loss], feed_dict) # Predict by threshold predicted_labels_threshold, predicted_values_threshold = \ dh.get_label_using_scores_by_threshold(scores=batch_scores, threshold=FLAGS.threshold) cur_rec, cur_pre, cur_F = 0.0, 0.0, 0.0 for index, predicted_label_threshold in enumerate( predicted_labels_threshold): rec_inc, pre_inc = dh.cal_metric(predicted_label_threshold, y_batch_test[index]) cur_rec, cur_pre = cur_rec + rec_inc, cur_pre + pre_inc cur_rec = cur_rec / len(y_batch_test) cur_pre = cur_pre / len(y_batch_test) test_rec, test_pre = test_rec + cur_rec, test_pre + cur_pre # Add results to collection for item in y_batch_test_labels: all_labels.append(item) for item in predicted_labels_threshold: all_predicted_labels.append(item) for item in predicted_values_threshold: all_predicted_values.append(item) test_loss = test_loss + cur_loss test_counter = test_counter + 1 test_loss = float(test_loss / test_counter) test_rec = float(test_rec / test_counter) test_pre = float(test_pre / test_counter) test_F = dh.cal_F(test_rec, test_pre) logger.info("☛ All Test Dataset: Loss {0:g}".format(test_loss)) # Predict by threshold logger.info( "︎☛ Predict by threshold: Recall {0:g}, Precision {1:g}, F {2:g}" .format(test_rec, test_pre, test_F)) # Save the prediction result if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR) dh.create_prediction_file(output_file=SAVE_DIR + '/predictions.json', data_id=test_data.testid, all_labels=all_labels, all_predict_labels=all_predicted_labels, all_predict_values=all_predicted_values) logger.info("✔ Done.")
def validation_step(x_validation, y_validation, writer=None): """Evaluates model on a validation set""" batches_validation = dh.batch_iter( list(zip(x_validation, y_validation)), FLAGS.batch_size, 1) # Predict classes by threshold or topk ('ts': threshold; 'tk': topk) eval_counter, eval_loss, eval_rec_ts, eval_pre_ts, eval_F_ts = 0, 0.0, 0.0, 0.0, 0.0 eval_rec_tk = [0.0] * FLAGS.top_num eval_pre_tk = [0.0] * FLAGS.top_num eval_F_tk = [0.0] * FLAGS.top_num for batch_validation in batches_validation: x_batch_validation, y_batch_validation = zip( *batch_validation) feed_dict = { fasttext.input_x: x_batch_validation, fasttext.input_y: y_batch_validation, fasttext.dropout_keep_prob: 1.0, fasttext.is_training: False } step, summaries, scores, cur_loss = sess.run([ fasttext.global_step, validation_summary_op, fasttext.scores, fasttext.loss ], feed_dict) # Predict by threshold predicted_labels_threshold, predicted_values_threshold = \ dh.get_label_using_scores_by_threshold(scores=scores, threshold=FLAGS.threshold) cur_rec_ts, cur_pre_ts, cur_F_ts = 0.0, 0.0, 0.0 for index, predicted_label_threshold in enumerate( predicted_labels_threshold): rec_inc_ts, pre_inc_ts = dh.cal_metric( predicted_label_threshold, y_batch_validation[index]) cur_rec_ts, cur_pre_ts = cur_rec_ts + rec_inc_ts, cur_pre_ts + pre_inc_ts cur_rec_ts = cur_rec_ts / len(y_batch_validation) cur_pre_ts = cur_pre_ts / len(y_batch_validation) cur_F_ts = dh.cal_F(cur_rec_ts, cur_pre_ts) eval_rec_ts, eval_pre_ts = eval_rec_ts + cur_rec_ts, eval_pre_ts + cur_pre_ts # Predict by topK topK_predicted_labels = [] for top_num in range(FLAGS.top_num): predicted_labels_topk, predicted_values_topk = \ dh.get_label_using_scores_by_topk(scores=scores, top_num=top_num+1) topK_predicted_labels.append(predicted_labels_topk) cur_rec_tk = [0.0] * FLAGS.top_num cur_pre_tk = [0.0] * FLAGS.top_num cur_F_tk = [0.0] * FLAGS.top_num for top_num, predicted_labels_topK in enumerate( topK_predicted_labels): for index, predicted_label_topK in enumerate( predicted_labels_topK): rec_inc_tk, pre_inc_tk = dh.cal_metric( predicted_label_topK, y_batch_validation[index]) cur_rec_tk[top_num], cur_pre_tk[top_num] = \ cur_rec_tk[top_num] + rec_inc_tk, cur_pre_tk[top_num] + pre_inc_tk cur_rec_tk[top_num] = cur_rec_tk[top_num] / len( y_batch_validation) cur_pre_tk[top_num] = cur_pre_tk[top_num] / len( y_batch_validation) cur_F_tk[top_num] = dh.cal_F(cur_rec_tk[top_num], cur_pre_tk[top_num]) eval_rec_tk[top_num], eval_pre_tk[top_num] = \ eval_rec_tk[top_num] + cur_rec_tk[top_num], eval_pre_tk[top_num] + cur_pre_tk[top_num] eval_loss = eval_loss + cur_loss eval_counter = eval_counter + 1 if writer: writer.add_summary(summaries, step) eval_loss = float(eval_loss / eval_counter) eval_rec_ts = float(eval_rec_ts / eval_counter) eval_pre_ts = float(eval_pre_ts / eval_counter) eval_F_ts = dh.cal_F(eval_rec_ts, eval_pre_ts) for top_num in range(FLAGS.top_num): eval_rec_tk[top_num] = float(eval_rec_tk[top_num] / eval_counter) eval_pre_tk[top_num] = float(eval_pre_tk[top_num] / eval_counter) eval_F_tk[top_num] = dh.cal_F(eval_rec_tk[top_num], eval_pre_tk[top_num]) return eval_loss, eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk
def test_lmlp(): """Test LMLP model.""" # Load data logger.info("✔︎ Loading data...") logger.info("Recommended padding Sequence length is: {0}".format( FLAGS.pad_seq_len)) logger.info("✔︎ Test data processing...") test_data = dh.load_data_and_labels(FLAGS.test_data_file, FLAGS.num_classes_list, FLAGS.embedding_dim, data_aug_flag=False) logger.info("✔︎ Test data padding...") x_test, y_test = dh.pad_data(test_data, FLAGS.pad_seq_len) y_test_labels = test_data.labels # Load LMLP model BEST_OR_LATEST = input("☛ Load Best or Latest Model?(B/L): ") while not (BEST_OR_LATEST.isalpha() and BEST_OR_LATEST.upper() in ['B', 'L']): BEST_OR_LATEST = input( "✘ The format of your input is illegal, please re-input: ") if BEST_OR_LATEST == 'B': logger.info("✔︎ Loading best model...") checkpoint_file = cm.get_best_checkpoint(FLAGS.best_checkpoint_dir, select_maximum_value=True) else: logger.info("✔︎ Loading latest model...") checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) logger.info(checkpoint_file) graph = tf.Graph() with graph.as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph( "{0}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) # Get the placeholders from the graph by name input_x = graph.get_operation_by_name("input_x").outputs[0] input_y_first = graph.get_operation_by_name( "input_y_first").outputs[0] input_y_second = graph.get_operation_by_name( "input_y_second").outputs[0] input_y = graph.get_operation_by_name("input_y").outputs[0] dropout_keep_prob = graph.get_operation_by_name( "dropout_keep_prob").outputs[0] is_training = graph.get_operation_by_name("is_training").outputs[0] # Tensors we want to evaluate scores = graph.get_operation_by_name("output/scores").outputs[0] loss = graph.get_operation_by_name("loss/loss").outputs[0] # Split the output nodes name by '|' if you have several output nodes output_node_names = "output/logits|output/scores" # Save the .pb model file output_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names.split("|")) tf.train.write_graph(output_graph_def, "graph", "graph-cnn-{0}.pb".format(MODEL), as_text=False) # Generate batches for one epoch batches = dh.batch_iter(list(zip(x_test, y_test, y_test_labels)), FLAGS.batch_size, 1, shuffle=False) # Collect the predictions here all_labels = [] all_predicted_labels = [] all_predicted_values = [] # Calculate the metric test_counter, test_loss, test_rec, test_pre, test_F = 0, 0.0, 0.0, 0.0, 0.0 for batch_test in batches: x_batch_test, y_batch_test, y_batch_test_labels = zip( *batch_test) y_batch_test_first = [i[0] for i in y_batch_test] y_batch_test_second = [j[1] for j in y_batch_test] y_batch_test_third = [k[2] for k in y_batch_test] feed_dict = { input_x: x_batch_test, input_y_first: y_batch_test_first, input_y_second: y_batch_test_second, input_y: y_batch_test_third, dropout_keep_prob: 1.0, is_training: False } batch_scores, cur_loss = sess.run([scores, loss], feed_dict) # Predict by threshold predicted_labels_threshold, predicted_values_threshold = \ dh.get_label_using_scores_by_threshold(scores=batch_scores, threshold=FLAGS.threshold) cur_rec, cur_pre, cur_F = 0.0, 0.0, 0.0 for index, predicted_label_threshold in enumerate( predicted_labels_threshold): rec_inc, pre_inc = dh.cal_metric(predicted_label_threshold, y_batch_test_third[index]) cur_rec, cur_pre = cur_rec + rec_inc, cur_pre + pre_inc cur_rec = cur_rec / len(y_batch_test_third) cur_pre = cur_pre / len(y_batch_test_third) test_rec, test_pre = test_rec + cur_rec, test_pre + cur_pre # Add results to collection for item in y_batch_test_labels: all_labels.append(item) for item in predicted_labels_threshold: all_predicted_labels.append(item) for item in predicted_values_threshold: all_predicted_values.append(item) test_loss = test_loss + cur_loss test_counter = test_counter + 1 test_loss = float(test_loss / test_counter) test_rec = float(test_rec / test_counter) test_pre = float(test_pre / test_counter) test_F = dh.cal_F(test_rec, test_pre) logger.info("☛ All Test Dataset: Loss {0:g}".format(test_loss)) # Predict by threshold logger.info( "☛ Predict by threshold: Recall {0:g}, Precision {1:g}, F {2:g}" .format(test_rec, test_pre, test_F)) # Save the prediction result if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR) dh.create_prediction_file(output_file=SAVE_DIR + "/predictions.json", data_id=test_data.patent_id, all_labels=all_labels, all_predict_labels=all_predicted_labels, all_predict_values=all_predicted_values) logger.info("✔︎ Done.")