def test_token_err(self): ''' test tooken error rate ''' config = utils.load_config(self.token_conf_file) metrics1 = metrics.get_metrics( config, y_true=self.token_true_label, y_pred=self.token_pred1) self.assertEqual(0.0, metrics1['TokenErrCal']) metrics2 = metrics.get_metrics( config, y_true=self.token_true_label, y_pred=self.token_pred2) self.assertEqual(0.75, metrics2['TokenErrCal'])
def test_crf_metrics(self): ''' test crf metrics ''' config = utils.load_config(self.config_file_crf) metrics3 = metrics.get_metrics( config, y_true=[self.true_label], y_pred=[self.pred1]) # metrics3: one string. Text summary of the precision, recall, F1 score for each class. # res3 = metrics3['CrfCal'] # print(res3) # for i, s in enumerate(res3): # print(i, s) self.assertEqual('1.0000', metrics3['CrfCal'][67:73]) metrics4 = metrics.get_metrics( config, y_true=[self.true_label], y_pred=[self.pred2]) self.assertEqual('0.0000', metrics4['CrfCal'][67:73])
def test_metric(self): ''' test get_metrics function ''' config = utils.load_config(self.conf_file) metrics1 = metrics.get_metrics( config, y_true=self.true_label, y_pred=self.pred1) self.assertEqual(1.0, metrics1['AccuracyCal']) self.assertEqual(1.0, metrics1['PrecisionCal']) self.assertEqual(1.0, metrics1['RecallCal']) self.assertEqual(1.0, metrics1['F1ScoreCal']) metrics2 = metrics.get_metrics( config, y_true=self.true_label, y_pred=self.pred2) self.assertEqual(0.0, metrics2['AccuracyCal']) self.assertEqual(0.0, metrics2['PrecisionCal']) self.assertEqual(0.0, metrics2['RecallCal']) self.assertEqual(0.0, metrics2['F1ScoreCal'])
def run_metrics(config, y_preds, y_ground_truth, mode): """Run metrics for one output""" metcs = metrics.get_metrics( config=config, y_pred=y_preds, y_true=y_ground_truth) logging.info("Evaluation on %s:" % mode) if isinstance(metcs, list): for one_metcs in metcs: for key in sorted(one_metcs.keys()): logging.info(key + ":" + str(one_metcs[key])) else: for key in sorted(metcs.keys()): logging.info(key + ":" + str(metcs[key]))
def eval_or_infer_core(self, model, mode): # pylint: disable=too-many-locals, too-many-branches """The core part of evaluation.""" model_path = self.get_model_path(mode) if model_path is None: logging.warning("model_path is None!") return with model.sess.graph.as_default(): model.saver.restore(model.sess, save_path=model_path) if self.first_eval: model.sess.run(tf.tables_initializer()) self.first_eval = False model.sess.run(model.iterator.initializer) # Evaluating loop. total_loss = 0.0 data_size = self.config["data"]['{}_data_size'.format(mode)] num_batch_every_epoch = int(math.ceil(data_size / self.batch_size)) y_ground_truth = [] y_preds = [] for i in range(num_batch_every_epoch): if mode == utils.EVAL: loss_val, \ batch_preds, \ batch_y_ground_truth = model.sess.run( [model.loss, model.preds, model.y_ground_truth]) elif not self.infer_no_label: batch_preds, \ batch_y_ground_truth = model.sess.run( [model.preds, model.y_ground_truth]) else: batch_preds = model.sess.run([model.preds]) batch_preds = batch_preds[0] if mode == utils.EVAL: total_loss += loss_val y_preds.append([preds for preds in batch_preds]) else: end_id = (i + 1) * self.batch_size if data_size < end_id: act_end_id = self.batch_size - end_id + data_size batch_preds = batch_preds[:act_end_id] if not self.infer_no_label: batch_y_ground_truth = batch_y_ground_truth[:act_end_id] y_preds.extend([preds for preds in batch_preds]) if not self.infer_no_label: y_ground_truth.extend( [ground_truth for ground_truth in batch_y_ground_truth]) if i % 10 == 0 or i == num_batch_every_epoch - 1: logging.info("Evaluation rate of " "progress: [ {:.2%} ]".format( i / (num_batch_every_epoch - 1))) if mode == utils.EVAL: logging.info("Evaluation Average Loss: {:.6}".format(total_loss / len(y_preds))) else: predictions = {"preds": y_preds} self.postproc_fn()(predictions, log_verbose=False) if not self.infer_no_label: metcs = metrics.get_metrics( config=self.config, y_pred=y_preds, y_true=y_ground_truth) logging.info("Evaluation on %s:" % mode) # add sort function to make sequence of metrics identical. for key in sorted(metcs.keys()): logging.info(key + ":" + str(metcs[key]))