def test_ann(word2vec_path): """Test ANN model.""" # Load data logger.info("✔︎ Loading data...") logger.info("Recommended padding Sequence length is: {0}".format( FLAGS.pad_seq_len)) logger.info("✔︎ Test data processing...") test_data = feed.load_data_and_labels(FLAGS.test_data_file, FLAGS.num_classes, FLAGS.embedding_dim, data_aug_flag=False, word2vec_path=word2vec_path) logger.info("✔︎ Test data padding...") x_test, y_test = feed.pad_data(test_data, FLAGS.pad_seq_len) y_test_labels = test_data.labels # Load ann model BEST_OR_LATEST = input("☛ Load Best or Latest Model?(B/L): ") while not (BEST_OR_LATEST.isalpha() and BEST_OR_LATEST.upper() in ['B', 'L']): BEST_OR_LATEST = \ input("✘ The format of your input is illegal, please re-input: ") if BEST_OR_LATEST.upper() == 'B': logger.info("✔︎ Loading best model...") checkpoint_file = checkpoints.get_best_checkpoint( FLAGS.best_checkpoint_dir, select_maximum_value=True) else: logger.info("✔︎ Loading latest model...") checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) logger.info(checkpoint_file) graph = tf.Graph() with graph.as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph( "{0}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) # Get the placeholders from the graph by name input_x = graph.get_operation_by_name("input_x").outputs[0] input_y = graph.get_operation_by_name("input_y").outputs[0] dropout_keep_prob = graph.get_operation_by_name( "dropout_keep_prob").outputs[0] is_training = graph.get_operation_by_name("is_training").outputs[0] # Tensors we want to evaluate scores = graph.get_operation_by_name("output/scores").outputs[0] loss = graph.get_operation_by_name("loss/loss").outputs[0] # Split the output nodes name by '|' if you have several output # nodes output_node_names = "output/scores" # Save the .pb model file output_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names.split("|")) tf.train.write_graph(output_graph_def, "graph", "graph-ann-{0}.pb".format(MODEL), as_text=False) # Generate batches for one epoch batches = feed.batch_iter(list(zip(x_test, y_test, y_test_labels)), FLAGS.batch_size, 1, shuffle=False) test_counter, test_loss = 0, 0.0 test_pre_tk = [0.0] * FLAGS.top_num test_rec_tk = [0.0] * FLAGS.top_num test_F_tk = [0.0] * FLAGS.top_num # Collect the predictions here true_labels = [] predicted_labels = [] predicted_scores = [] # Collect for calculating metrics true_onehot_labels = [] predicted_onehot_scores = [] predicted_onehot_labels_ts = [] predicted_onehot_labels_tk = [[] for _ in range(FLAGS.top_num)] for batch_test in batches: x_batch_test, y_batch_test, y_batch_test_labels = zip( *batch_test) print("x_batch_test", x_batch_test) print("y_batch_test", y_batch_test) feed_dict = { input_x: x_batch_test, input_y: y_batch_test, dropout_keep_prob: 1.0, is_training: False } batch_scores, cur_loss = sess.run([scores, loss], feed_dict) # Prepare for calculating metrics for i in y_batch_test: true_onehot_labels.append(i) for j in batch_scores: predicted_onehot_scores.append(j) # Get the predicted labels by threshold batch_predicted_labels_ts, batch_predicted_scores_ts = \ feed.get_label_threshold(scores=batch_scores, threshold=FLAGS.threshold) # Add results to collection for i in y_batch_test_labels: true_labels.append(i) for j in batch_predicted_labels_ts: predicted_labels.append(j) for k in batch_predicted_scores_ts: predicted_scores.append(k) # Get onehot predictions by threshold batch_predicted_onehot_labels_ts = \ feed.get_onehot_label_threshold(scores=batch_scores, threshold=FLAGS.threshold) for i in batch_predicted_onehot_labels_ts: predicted_onehot_labels_ts.append(i) # Get onehot predictions by topK for top_num in range(FLAGS.top_num): batch_predicted_onehot_labels_tk = feed.\ get_onehot_label_topk(scores=batch_scores, top_num=top_num + 1) for i in batch_predicted_onehot_labels_tk: predicted_onehot_labels_tk[top_num].append(i) test_loss = test_loss + cur_loss test_counter = test_counter + 1 # Calculate Precision & Recall & F1 (threshold & topK) test_pre_ts = precision_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') test_rec_ts = recall_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') test_F_ts = f1_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') for top_num in range(FLAGS.top_num): test_pre_tk[top_num] = precision_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[top_num]), average='micro') test_rec_tk[top_num] = recall_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[top_num]), average='micro') test_F_tk[top_num] = f1_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[top_num]), average='micro') # Calculate the average AUC test_auc = roc_auc_score(y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average='micro') # Calculate the average PR test_prc = average_precision_score( y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average="micro") test_loss = float(test_loss / test_counter) logger.info( "☛ All Test Dataset: Loss {0:g} | AUC {1:g} | AUPRC {2:g}". format(test_loss, test_auc, test_prc)) # Predict by threshold logger.info( "☛ Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}" .format(test_pre_ts, test_rec_ts, test_F_ts)) # Predict by topK logger.info("☛ Predict by topK:") for top_num in range(FLAGS.top_num): logger.info( "Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}".format( top_num + 1, test_pre_tk[top_num], test_rec_tk[top_num], test_F_tk[top_num])) # Save the prediction result if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR) feed.create_prediction_file(output_file=SAVE_DIR + "/predictions.json", data_id=test_data.testid, all_labels=true_labels, all_predict_labels=predicted_labels, all_predict_scores=predicted_scores) logger.info("✔︎ Done.")
def validation_step(_x_val, _y_val, writer=None): """Evaluates model on a validation set""" batches_validation = feed.batch_iter(list(zip(_x_val, _y_val)), FLAGS.batch_size, 1) # Predict classes by threshold or topk # ('ts': threshold; 'tk': topk) _eval_counter, _eval_loss = 0, 0.0 _eval_pre_tk = [0.0] * FLAGS.top_num _eval_rec_tk = [0.0] * FLAGS.top_num _eval_F_tk = [0.0] * FLAGS.top_num true_onehot_labels = [] predicted_onehot_scores = [] predicted_onehot_labels_ts = [] predicted_onehot_labels_tk = [[] for _ in range(FLAGS.top_num)] for batch_validation in batches_validation: x_batch_val, y_batch_val = zip(*batch_validation) feed_dict = { ann.input_x: x_batch_val, ann.input_y: y_batch_val, ann.dropout_keep_prob: 1.0, ann.is_training: False } step, summaries, scores, cur_loss = sess.run([ ann.global_step, validation_summary_op, ann.scores, ann.loss ], feed_dict) # Prepare for calculating metrics for i in y_batch_val: true_onehot_labels.append(i) for j in scores: predicted_onehot_scores.append(j) # Predict by threshold batch_predicted_onehot_labels_ts = \ feed.get_onehot_label_threshold(scores=scores, threshold=FLAGS. threshold) for k in batch_predicted_onehot_labels_ts: predicted_onehot_labels_ts.append(k) # Predict by topK for _top_num in range(FLAGS.top_num): batch_predicted_onehot_labels_tk = \ feed.get_onehot_label_topk( scores=scores, top_num=_top_num + 1) for i in batch_predicted_onehot_labels_tk: predicted_onehot_labels_tk[_top_num].append(i) _eval_loss = _eval_loss + cur_loss _eval_counter = _eval_counter + 1 if writer: writer.add_summary(summaries, step) _eval_loss = float(_eval_loss / _eval_counter) # Calculate Precision & Recall & F1 (threshold & topK) _eval_pre_ts = precision_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') _eval_rec_ts = recall_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') _eval_F_ts = f1_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') for _top_num in range(FLAGS.top_num): _eval_pre_tk[_top_num] = precision_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[_top_num]), average='micro') _eval_rec_tk[_top_num] = recall_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[_top_num]), average='micro') _eval_F_tk[_top_num] = f1_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[_top_num]), average='micro') # Calculate the average AUC _eval_auc = roc_auc_score( y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average='micro') # Calculate the average PR _eval_prc = average_precision_score( y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average='micro') return _eval_loss, _eval_auc, _eval_prc, _eval_rec_ts, \ _eval_pre_ts, _eval_F_ts, _eval_rec_tk, _eval_pre_tk,\ _eval_F_tk
def validation_step(_x_val_gov, _x_val_art, _y_val, writer=None): print("_x_val_gov: ", len(_x_val_gov)) print("_x_val_art: ", len(_x_val_art)) """Evaluates model on a validation set""" batches_validation = \ feed.batch_iter( list(zip(_x_val_gov, _x_val_art, _y_val)), FLAGS.batch_size, num_epochs=1, shuffle=False) _eval_counter, _eval_loss = 0, 0.0 _eval_pre_tk = [0.0] * FLAGS.top_num _eval_rec_tk = [0.0] * FLAGS.top_num _eval_F_tk = [0.0] * FLAGS.top_num true_onehot_labels = [] predicted_onehot_scores = [] predicted_onehot_labels_ts = [] predicted_onehot_labels_tk = [[] for _ in range(FLAGS.top_num)] valid_count_correct_one = 0 valid_count_label_one = 0 valid_count_correct_zero = 0 valid_count_label_zero = 0 valid_step_count = 0 for batch_validation in batches_validation: valid_step_count += 1 x_batch_val_gov, x_batch_val_art, y_batch_val = \ zip(*batch_validation) feed_dict = { cnn.input_x_gov: x_batch_val_gov, cnn.input_x_art: x_batch_val_art, cnn.input_y: y_batch_val, cnn.dropout_keep_prob: 1.0, cnn.is_training: False } step, \ summaries, \ scores, \ cur_loss, \ input_y = sess.run( [cnn.global_step, validation_summary_op, cnn.scores, cnn.loss, cnn.input_y], feed_dict) count_label_one, \ count_label_zero, \ count_correct_one, \ count_correct_zero = count_correct_pred(scores, input_y) valid_count_correct_one += count_correct_one valid_count_label_one += count_label_one valid_count_correct_zero += count_correct_zero valid_count_label_zero += count_label_zero print("[VALID] num_correct_answer is {} out of {}".format( count_correct_one, count_label_one)) print("[VALID] num_correct_answer is {} out of {}".format( count_correct_zero, count_label_zero)) # Prepare for calculating metrics for i in y_batch_val: true_onehot_labels.append(i) for j in scores: predicted_onehot_scores.append(j) # Predict by threshold batch_predicted_onehot_labels_ts = \ feed.get_onehot_label_threshold(scores=scores, threshold=FLAGS. threshold) for k in batch_predicted_onehot_labels_ts: predicted_onehot_labels_ts.append(k) # Predict by topK for _top_num in range(FLAGS.top_num): batch_predicted_onehot_labels_tk = feed.\ get_onehot_label_topk(scores=scores, top_num=_top_num + 1) for i in batch_predicted_onehot_labels_tk: predicted_onehot_labels_tk[_top_num].append(i) _eval_loss = _eval_loss + cur_loss _eval_counter = _eval_counter + 1 if writer: writer.add_summary(summaries, step) logger.info("[VALID_FINAL] Total Correct One Answer is {} out " "of {}".format(valid_count_correct_one, valid_count_label_one)) logger.info("[VALID_FINAL] Total Correct Zero Answer is {} " "out of {}".format(valid_count_correct_zero, valid_count_label_zero)) _eval_loss = float(_eval_loss / _eval_counter) # Calculate Precision & Recall & F1 (threshold & topK) _eval_pre_ts = precision_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') _eval_rec_ts = recall_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') _eval_F_ts = f1_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') for _top_num in range(FLAGS.top_num): _eval_pre_tk[_top_num] = precision_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[_top_num]), average='micro') _eval_rec_tk[_top_num] = recall_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[_top_num]), average='micro') _eval_F_tk[_top_num] = f1_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[_top_num]), average='micro') # Calculate the average AUC _eval_auc = roc_auc_score( y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average='micro') # Calculate the average PR _eval_prc = average_precision_score( y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average='micro') return _eval_loss, _eval_auc, _eval_prc, _eval_rec_ts, \ _eval_pre_ts, _eval_F_ts, _eval_rec_tk, _eval_pre_tk, \ _eval_F_tk
def train_ann(word2vec_path): """Training ANN model.""" # Load sentences, labels, and training parameters logger.info("✔︎ Loading data...") logger.info("✔︎ Training data processing...") train_data = feed.load_data_and_labels(FLAGS.training_data_file, FLAGS.num_classes, FLAGS.embedding_dim, data_aug_flag=False, word2vec_path=word2vec_path) logger.info("✔︎ Validation data processing...") val_data = feed.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes, FLAGS.embedding_dim, data_aug_flag=False, word2vec_path=word2vec_path) logger.info("Recommended padding Sequence length is: {0}".format( FLAGS.pad_seq_len)) logger.info("✔︎ Training data padding...") x_train, y_train = feed.pad_data(train_data, FLAGS.pad_seq_len) logger.info("✔︎ Validation data padding...") x_val, y_val = feed.pad_data(val_data, FLAGS.pad_seq_len) # Build vocabulary VOCAB_SIZE = feed.load_vocab_size(FLAGS.embedding_dim, word2vec_path=word2vec_path) # Use pretrained W2V pretrained_word2vec_matrix = feed.load_word2vec_matrix( VOCAB_SIZE, FLAGS.embedding_dim, word2vec_path=word2vec_path) # Build a graph and ann object with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): ann = TextANN(sequence_length=FLAGS.pad_seq_len, num_classes=FLAGS.num_classes, vocab_size=VOCAB_SIZE, fc_hidden_size=FLAGS.fc_hidden_size, embedding_size=FLAGS.embedding_dim, embedding_type=FLAGS.embedding_type, l2_reg_lambda=FLAGS.l2_reg_lambda, pretrained_embedding=pretrained_word2vec_matrix) # Define training procedure with tf.control_dependencies( tf.get_collection(tf.GraphKeys.UPDATE_OPS)): learning_rate = tf.train.exponential_decay( learning_rate=FLAGS.learning_rate, global_step=ann.global_step, decay_steps=FLAGS.decay_steps, decay_rate=FLAGS.decay_rate, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate) grads, variables = zip(*optimizer.compute_gradients(ann.loss)) grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio) train_op = optimizer.apply_gradients( zip(grads, variables), global_step=ann.global_step, name="train_op") # Keep track of gradient values and sparsity (optional) grad_summaries = [] for g, v in zip(grads, variables): if g is not None: grad_hist_summary = tf.summary.histogram( "{0}/grad/hist".format(v.name), g) sparsity_summary = tf.summary.scalar( "{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) grad_summaries.append(grad_hist_summary) grad_summaries.append(sparsity_summary) grad_summaries_merged = tf.summary.merge(grad_summaries) # Output directory for models and summaries if FLAGS.train_or_restore == 'R': MODEL = input( "☛ Please input the checkpoints model you want to " "restore, it should be like(1490175368): ") # The model you want to restore while not (MODEL.isdigit() and len(MODEL) == 10): MODEL = input("✘ The format of your input is illegal, " "please re-input: ") logger.info("✔︎ The format of your input is legal, " "now loading to next step...") out_dir = os.path.abspath( os.path.join(os.path.curdir, "runs", MODEL)) logger.info("✔︎ Writing to {0}\n".format(out_dir)) else: timestamp = str(int(time.time())) out_dir = os.path.abspath( os.path.join(os.path.curdir, "runs", timestamp)) logger.info("✔︎ Writing to {0}\n".format(out_dir)) checkpoint_dir = os.path.abspath( os.path.join(out_dir, "checkpoints")) best_checkpoint_dir = os.path.abspath( os.path.join(out_dir, "bestcheckpoints")) # Summaries for loss loss_summary = tf.summary.scalar("loss", ann.loss) # Train summaries train_summary_op = tf.summary.merge( [loss_summary, grad_summaries_merged]) train_summary_dir = os.path.join(out_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter( train_summary_dir, sess.graph) # Validation summaries validation_summary_op = tf.summary.merge([loss_summary]) validation_summary_dir = os.path.join(out_dir, "summaries", "validation") validation_summary_writer = tf.summary.FileWriter( validation_summary_dir, sess.graph) saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) best_saver = checkpoints.BestCheckpointSaver( save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True) if FLAGS.train_or_restore == 'R': # Load ann model logger.info("✔︎ Loading model...") checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) logger.info(checkpoint_file) # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph( "{0}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) else: if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # Embedding visualization config config = projector.ProjectorConfig() embedding_conf = config.embeddings.add() embedding_conf.tensor_name = "embedding" embedding_conf.metadata_path = FLAGS.metadata_file projector.visualize_embeddings(train_summary_writer, config) projector.visualize_embeddings(validation_summary_writer, config) # Save the embedding visualization saver.save( sess, os.path.join(out_dir, "embedding", "embedding.ckpt")) current_step = sess.run(ann.global_step) print("current_step: ", current_step) def train_step(x_batch, y_batch): """A single training step""" feed_dict = { ann.input_x: x_batch, ann.input_y: y_batch, ann.dropout_keep_prob: FLAGS.dropout_keep_prob, ann.is_training: True } _, step, summaries, loss = sess.run( [train_op, ann.global_step, train_summary_op, ann.loss], feed_dict) logger.info("step {0}: loss {1:g}".format(step, loss)) train_summary_writer.add_summary(summaries, step) def validation_step(_x_val, _y_val, writer=None): """Evaluates model on a validation set""" batches_validation = feed.batch_iter(list(zip(_x_val, _y_val)), FLAGS.batch_size, 1) # Predict classes by threshold or topk # ('ts': threshold; 'tk': topk) _eval_counter, _eval_loss = 0, 0.0 _eval_pre_tk = [0.0] * FLAGS.top_num _eval_rec_tk = [0.0] * FLAGS.top_num _eval_F_tk = [0.0] * FLAGS.top_num true_onehot_labels = [] predicted_onehot_scores = [] predicted_onehot_labels_ts = [] predicted_onehot_labels_tk = [[] for _ in range(FLAGS.top_num)] for batch_validation in batches_validation: x_batch_val, y_batch_val = zip(*batch_validation) feed_dict = { ann.input_x: x_batch_val, ann.input_y: y_batch_val, ann.dropout_keep_prob: 1.0, ann.is_training: False } step, summaries, scores, cur_loss = sess.run([ ann.global_step, validation_summary_op, ann.scores, ann.loss ], feed_dict) # Prepare for calculating metrics for i in y_batch_val: true_onehot_labels.append(i) for j in scores: predicted_onehot_scores.append(j) # Predict by threshold batch_predicted_onehot_labels_ts = \ feed.get_onehot_label_threshold(scores=scores, threshold=FLAGS. threshold) for k in batch_predicted_onehot_labels_ts: predicted_onehot_labels_ts.append(k) # Predict by topK for _top_num in range(FLAGS.top_num): batch_predicted_onehot_labels_tk = \ feed.get_onehot_label_topk( scores=scores, top_num=_top_num + 1) for i in batch_predicted_onehot_labels_tk: predicted_onehot_labels_tk[_top_num].append(i) _eval_loss = _eval_loss + cur_loss _eval_counter = _eval_counter + 1 if writer: writer.add_summary(summaries, step) _eval_loss = float(_eval_loss / _eval_counter) # Calculate Precision & Recall & F1 (threshold & topK) _eval_pre_ts = precision_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') _eval_rec_ts = recall_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') _eval_F_ts = f1_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') for _top_num in range(FLAGS.top_num): _eval_pre_tk[_top_num] = precision_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[_top_num]), average='micro') _eval_rec_tk[_top_num] = recall_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[_top_num]), average='micro') _eval_F_tk[_top_num] = f1_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[_top_num]), average='micro') # Calculate the average AUC _eval_auc = roc_auc_score( y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average='micro') # Calculate the average PR _eval_prc = average_precision_score( y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average='micro') return _eval_loss, _eval_auc, _eval_prc, _eval_rec_ts, \ _eval_pre_ts, _eval_F_ts, _eval_rec_tk, _eval_pre_tk,\ _eval_F_tk # Generate batches batches_train = feed.batch_iter(list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) num_batches_per_epoch = int( (len(x_train) - 1) / FLAGS.batch_size) + 1 # Training loop. For each batch... for batch_train in batches_train: x_batch_train, y_batch_train = zip(*batch_train) train_step(x_batch_train, y_batch_train) current_step = tf.train.global_step(sess, ann.global_step) if current_step % FLAGS.evaluate_every == 0: logger.info("\nEvaluation:") eval_loss, eval_auc, eval_prc, \ eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, \ eval_pre_tk, eval_F_tk = \ validation_step(x_val, y_val, writer=validation_summary_writer) logger.info( "All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}" .format(eval_loss, eval_auc, eval_prc)) # Predict by threshold logger.info("☛ Predict by threshold: Precision {0:g}, " "Recall {1:g}, F {2:g}".format( eval_pre_ts, eval_rec_ts, eval_F_ts)) # Predict by topK logger.info("☛ Predict by topK:") for top_num in range(FLAGS.top_num): logger.info( "Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}". format(top_num + 1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F_tk[top_num])) best_saver.handle(eval_prc, sess, current_step) if current_step % FLAGS.checkpoint_every == 0: checkpoint_prefix = os.path.join(checkpoint_dir, "model") path = saver.save(sess, checkpoint_prefix, global_step=current_step) logger.info( "✔︎ Saved model checkpoint to {0}\n".format(path)) if current_step % num_batches_per_epoch == 0: current_epoch = current_step // num_batches_per_epoch logger.info( "✔︎ Epoch {0} has finished!".format(current_epoch)) logger.info("✔︎ Done.")
def test_ann(word2vec_path, model_number): # Parameters # ============================================================================= logger = feed.logger_fn("tflog", "logs/test-{0}.log".format(time.asctime())) # MODEL = input("☛ Please input the model file you want to test, " # "it should be like(1490175368): ") MODEL = str(model_number) while not (MODEL.isdigit() and len(MODEL) == 10): MODEL = input("✘ The format of your input is illegal, " "it should be like(1490175368), please re-input: ") logger.info("✔︎ The format of your input is legal, " "now loading to next step...") TRAININGSET_DIR = 'models/citability/data/Train.json' VALIDATIONSET_DIR = 'models/citability/data/Validation.json' # TEST_DIR = 'data/Test.json' cwd = os.getcwd() TEST_DIR = os.path.join(cwd, 'web/test_data.json') cwd = os.getcwd() MODEL_DIR = os.path.join(cwd, 'web/runs/' + MODEL + '/checkpoints/') print(MODEL_DIR) BEST_MODEL_DIR = 'runs/' + MODEL + '/bestcheckpoints/' SAVE_DIR = 'results/' + MODEL # Data Parameters tf.flags.DEFINE_string("training_data_file", TRAININGSET_DIR, "Data source for the training data.") tf.flags.DEFINE_string("validation_data_file", VALIDATIONSET_DIR, "Data source for the validation data") tf.flags.DEFINE_string("test_data_file", TEST_DIR, "Data source for the test data") tf.flags.DEFINE_string("checkpoint_dir", MODEL_DIR, "Checkpoint directory from training run") tf.flags.DEFINE_string("best_checkpoint_dir", BEST_MODEL_DIR, "Best checkpoint directory from training run") # Model Hyperparameters tf.flags.DEFINE_integer( "pad_seq_len", 35842, "Recommended padding Sequence length of data " "(depends on the data)") tf.flags.DEFINE_integer( "embedding_dim", 300, "Dimensionality of character embedding " "(default: 128)") tf.flags.DEFINE_integer("embedding_type", 1, "The embedding type (default: 1)") tf.flags.DEFINE_integer( "fc_hidden_size", 1024, "Hidden size for fully connected layer " "(default: 1024)") tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)") tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)") tf.flags.DEFINE_integer("num_classes", 80, "Number of labels (depends on the task)") tf.flags.DEFINE_integer("top_num", 80, "Number of top K prediction classes (default: 5)") tf.flags.DEFINE_float("threshold", 0.5, "Threshold for prediction classes (default: 0.5)") # Test Parameters tf.flags.DEFINE_integer("batch_size", 1, "Batch Size (default: 1)") # Misc Parameters tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") tf.flags.DEFINE_boolean("gpu_options_allow_growth", True, "Allow gpu options growth") FLAGS = tf.flags.FLAGS FLAGS(sys.argv) dilim = '-' * 100 logger.info('\n'.join([ dilim, *[ '{0:>50}|{1:<50}'.format(attr.upper(), FLAGS.__getattr__(attr)) for attr in sorted(FLAGS.__dict__['__wrapped']) ], dilim ])) """Test ANN model.""" # Load data logger.info("✔︎ Loading data...") logger.info("Recommended padding Sequence length is: {0}".format( FLAGS.pad_seq_len)) logger.info("✔︎ Test data processing...") test_data = feed.load_data_and_labels(FLAGS.test_data_file, FLAGS.num_classes, FLAGS.embedding_dim, data_aug_flag=False, word2vec_path=word2vec_path) logger.info("✔︎ Test data padding...") x_test, y_test = feed.pad_data(test_data, FLAGS.pad_seq_len) y_test_labels = test_data.labels # Load ann model # BEST_OR_LATEST = input("☛ Load Best or Latest Model?(B/L): ") BEST_OR_LATEST = 'L' while not (BEST_OR_LATEST.isalpha() and BEST_OR_LATEST.upper() in ['B', 'L']): BEST_OR_LATEST = \ input("✘ The format of your input is illegal, please re-input: ") if BEST_OR_LATEST.upper() == 'B': logger.info("✔︎ Loading best model...") checkpoint_file = checkpoints.get_best_checkpoint( FLAGS.best_checkpoint_dir, select_maximum_value=True) else: logger.info("✔︎ Loading latest model...") checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) logger.info(checkpoint_file) graph = tf.Graph() with graph.as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph( "{0}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) # Get the placeholders from the graph by name input_x = graph.get_operation_by_name("input_x").outputs[0] input_y = graph.get_operation_by_name("input_y").outputs[0] dropout_keep_prob = graph.get_operation_by_name( "dropout_keep_prob").outputs[0] is_training = graph.get_operation_by_name("is_training").outputs[0] # Tensors we want to evaluate scores = graph.get_operation_by_name("output/scores").outputs[0] loss = graph.get_operation_by_name("loss/loss").outputs[0] # Split the output nodes name by '|' if you have several output # nodes output_node_names = "output/scores" # Save the .pb model file output_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names.split("|")) tf.train.write_graph(output_graph_def, "graph", "graph-ann-{0}.pb".format(MODEL), as_text=False) # Generate batches for one epoch batches = feed.batch_iter(list(zip(x_test, y_test, y_test_labels)), FLAGS.batch_size, 1, shuffle=False) test_counter, test_loss = 0, 0.0 test_pre_tk = [0.0] * FLAGS.top_num test_rec_tk = [0.0] * FLAGS.top_num test_F_tk = [0.0] * FLAGS.top_num # Collect the predictions here true_labels = [] predicted_labels = [] predicted_scores = [] # Collect for calculating metrics true_onehot_labels = [] predicted_onehot_scores = [] predicted_onehot_labels_ts = [] predicted_onehot_labels_tk = [[] for _ in range(FLAGS.top_num)] for batch_test in batches: x_batch_test, y_batch_test, y_batch_test_labels = zip( *batch_test) print("x_batch_test", x_batch_test) print("y_batch_test", y_batch_test) feed_dict = { input_x: x_batch_test, input_y: y_batch_test, dropout_keep_prob: 1.0, is_training: False } batch_scores, cur_loss = sess.run([scores, loss], feed_dict) # Prepare for calculating metrics for i in y_batch_test: true_onehot_labels.append(i) for j in batch_scores: predicted_onehot_scores.append(j) # Get the predicted labels by threshold batch_predicted_labels_ts, batch_predicted_scores_ts = \ feed.get_label_threshold(scores=batch_scores, threshold=FLAGS.threshold) # Add results to collection for i in y_batch_test_labels: true_labels.append(i) for j in batch_predicted_labels_ts: predicted_labels.append(j) for k in batch_predicted_scores_ts: predicted_scores.append(k) # Get onehot predictions by threshold batch_predicted_onehot_labels_ts = \ feed.get_onehot_label_threshold(scores=batch_scores, threshold=FLAGS.threshold) for i in batch_predicted_onehot_labels_ts: predicted_onehot_labels_ts.append(i) # Get onehot predictions by topK for top_num in range(FLAGS.top_num): batch_predicted_onehot_labels_tk = feed.\ get_onehot_label_topk(scores=batch_scores, top_num=top_num + 1) for i in batch_predicted_onehot_labels_tk: predicted_onehot_labels_tk[top_num].append(i) test_loss = test_loss + cur_loss test_counter = test_counter + 1 # Calculate Precision & Recall & F1 (threshold & topK) test_pre_ts = precision_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') test_rec_ts = recall_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') test_F_ts = f1_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') for top_num in range(FLAGS.top_num): test_pre_tk[top_num] = precision_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[top_num]), average='micro') test_rec_tk[top_num] = recall_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[top_num]), average='micro') test_F_tk[top_num] = f1_score( y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[top_num]), average='micro') # Calculate the average AUC test_auc = roc_auc_score(y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average='micro') # Calculate the average PR test_prc = average_precision_score( y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average="micro") test_loss = float(test_loss / test_counter) logger.info( "☛ All Test Dataset: Loss {0:g} | AUC {1:g} | AUPRC {2:g}". format(test_loss, test_auc, test_prc)) # Predict by threshold logger.info( "☛ Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}" .format(test_pre_ts, test_rec_ts, test_F_ts)) # Predict by topK logger.info("☛ Predict by topK:") for top_num in range(FLAGS.top_num): logger.info( "Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}".format( top_num + 1, test_pre_tk[top_num], test_rec_tk[top_num], test_F_tk[top_num])) # Save the prediction result if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR) feed.create_prediction_file(output_file=SAVE_DIR + "/predictions.json", data_id=test_data.testid, all_labels=true_labels, all_predict_labels=predicted_labels, all_predict_scores=predicted_scores) logger.info("✔︎ Done.")