def validation_step(x_validation, y_validation, y_validation_bind, writer=None): """Evaluates model on a validation set""" batches_validation = data_helpers.batch_iter( list(zip(x_validation, y_validation, y_validation_bind)), FLAGS.batch_size, FLAGS.num_epochs) eval_loss, eval_rec, eval_acc, eval_counter = 0.0, 0.0, 0.0, 0 for batch_validation in batches_validation: x_batch_validation, y_batch_validation, y_batch_validation_bind = zip( *batch_validation) feed_dict = { fasttext.input_x: x_batch_validation, fasttext.input_y: y_batch_validation, fasttext.dropout_keep_prob: 1.0, fasttext.is_training: False } step, summaries, logits, cur_loss = sess.run([ fasttext.global_step, validation_summary_op, fasttext.logits, fasttext.loss ], feed_dict) if FLAGS.use_classbind_or_not == 'Y': predicted_labels = data_helpers.get_label_using_logits_and_classbind( logits, y_batch_validation_bind, top_number=FLAGS.top_num) if FLAGS.use_classbind_or_not == 'N': predicted_labels = data_helpers.get_label_using_logits( logits, top_number=FLAGS.top_num) cur_rec, cur_acc = 0.0, 0.0 for index, predicted_label in enumerate(predicted_labels): rec_inc, acc_inc = data_helpers.cal_rec_and_acc( predicted_label, y_batch_validation[index]) cur_rec, cur_acc = cur_rec + rec_inc, cur_acc + acc_inc cur_rec = cur_rec / len(y_batch_validation) cur_acc = cur_acc / len(y_batch_validation) eval_loss, eval_rec, eval_acc, eval_counter = eval_loss + cur_loss, eval_rec + cur_rec, \ eval_acc + cur_acc, eval_counter + 1 logger.info("✔︎ validation batch {} finished.".format( eval_counter)) if writer: writer.add_summary(summaries, step) eval_loss = float(eval_loss / eval_counter) eval_rec = float(eval_rec / eval_counter) eval_acc = float(eval_acc / eval_counter) return eval_loss, eval_rec, eval_acc
def test_fasttext(): """Test FASTTEXT model.""" # Load data logger.info("✔ Loading data...") logger.info('Recommand padding Sequence length is: {}'.format( FLAGS.pad_seq_len)) logger.info('✔︎ Test data processing...') test_data = data_helpers.load_data_and_labels(FLAGS.test_data_file, FLAGS.num_classes, FLAGS.embedding_dim) logger.info('✔︎ Test data padding...') x_test, y_test = data_helpers.pad_data(test_data, FLAGS.pad_seq_len) y_test_bind = test_data.labels_bind # Build vocabulary VOCAB_SIZE = data_helpers.load_vocab_size(FLAGS.embedding_dim) pretrained_word2vec_matrix = data_helpers.load_word2vec_matrix( VOCAB_SIZE, FLAGS.embedding_dim) # Load fasttext model logger.info("✔ Loading model...") checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) logger.info(checkpoint_file) graph = tf.Graph() with graph.as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph( "{}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) # Get the placeholders from the graph by name input_x = graph.get_operation_by_name("input_x").outputs[0] # input_y = graph.get_operation_by_name("input_y").outputs[0] dropout_keep_prob = graph.get_operation_by_name( "dropout_keep_prob").outputs[0] # pre-trained_word2vec pretrained_embedding = graph.get_operation_by_name( "embedding/embedding").outputs[0] # Tensors we want to evaluate logits = graph.get_operation_by_name("output/logits").outputs[0] # Generate batches for one epoch batches = data_helpers.batch_iter(list( zip(x_test, y_test, y_test_bind)), FLAGS.batch_size, 1, shuffle=False) # Collect the predictions here all_predicitons = [] eval_loss, eval_rec, eval_acc, eval_counter = 0.0, 0.0, 0.0, 0 for batch_test in batches: x_batch_test, y_batch_test, y_batch_test_bind = zip( *batch_test) feed_dict = {input_x: x_batch_test, dropout_keep_prob: 1.0} batch_logits = sess.run(logits, feed_dict) if FLAGS.use_classbind_or_not == 'Y': predicted_labels = data_helpers.get_label_using_logits_and_classbind( batch_logits, y_batch_test_bind, top_number=FLAGS.top_num) if FLAGS.use_classbind_or_not == 'N': predicted_labels = data_helpers.get_label_using_logits( batch_logits, top_number=FLAGS.top_num) all_predicitons = np.append(all_predicitons, predicted_labels) cur_rec, cur_acc = 0.0, 0.0 for index, predicted_label in enumerate(predicted_labels): rec_inc, acc_inc = data_helpers.cal_rec_and_acc( predicted_label, y_batch_test[index]) cur_rec, cur_acc = cur_rec + rec_inc, cur_acc + acc_inc cur_rec = cur_rec / len(y_batch_test) cur_acc = cur_acc / len(y_batch_test) eval_rec, eval_acc, eval_counter = eval_rec + cur_rec, eval_acc + cur_acc, eval_counter + 1 logger.info( "✔︎ validation batch {} finished.".format(eval_counter)) eval_rec = float(eval_rec / eval_counter) eval_acc = float(eval_acc / eval_counter) logger.info("☛ Recall {:g}, Accuracy {:g}".format( eval_rec, eval_acc)) np.savetxt(SAVE_FILE, list(zip(all_predicitons)), fmt='%s') logger.info("✔ Done.")