def eval_conll(self,
                dataset_type='test',
                print_results=True,
                short_report=True):
     y_true_list = list()
     y_pred_list = list()
     file = open('result' + dataset_type + '.txt', 'w')
     print('Eval on {}:'.format(dataset_type))
     for x, y_gt, token in self.corpus.batch_generator(
             batch_size=10, dataset_type=dataset_type):
         y_pred = self.predict(x)
         y_gt = self.corpus.tag_dict.batch_idxs2batch_toks(
             y_gt, filter_paddings=True)
         for tags_pred, tags_gt in zip(y_pred, y_gt):
             for tag_predicted, tag_ground_truth in zip(tags_pred, tags_gt):
                 y_true_list.append(tag_ground_truth)
                 y_pred_list.append(tag_predicted)
             y_true_list.append('O')
             y_pred_list.append('O')
         for tok, y_t, y_p in zip(token, y_gt, y_pred):
             if len(tok) != len(y_p):
                 print(tok)
                 print(y_p)
             for idx in range(len(y_p)):
                 file.write("%s ? %s %s\n" % (tok[idx], y_t[idx], y_p[idx]))
             file.write('\n')
     file.close()
     return precision_recall_f1(y_true_list, y_pred_list, print_results,
                                short_report)
Beispiel #2
0
def eval_conll(model, session, tokens, tags, short_report=True):
    """Computes NER quality measures using CONLL shared task script."""

    y_true, y_pred = [], []
    for x_batch, y_batch, lengths in batches_generator(1, tokens, tags):
        tags_batch, tokens_batch = predict_tags(model, session, x_batch,
                                                lengths)
        if len(x_batch[0]) != len(tags_batch[0]):
            raise Exception("Incorrect length of prediction for the input, "
                            "expected length: %i, got: %i" %
                            (len(x_batch[0]), len(tags_batch[0])))
        predicted_tags = []
        ground_truth_tags = []
        for gt_tag_idx, pred_tag, token in zip(y_batch[0], tags_batch[0],
                                               tokens_batch[0]):
            if token != '<PAD>':
                ground_truth_tags.append(idx2tag[gt_tag_idx])
                predicted_tags.append(pred_tag)

        # We extend every prediction and ground truth sequence with 'O' tag
        # to indicate a possible end of entity.
        y_true.extend(ground_truth_tags + ['O'])
        y_pred.extend(predicted_tags + ['O'])

    results = precision_recall_f1(y_true,
                                  y_pred,
                                  print_results=True,
                                  short_report=short_report)
    return results
Beispiel #3
0
    def _eval(self, sess, x, y):
        x, lens = create_single_batch(x, self.pad_token_index)
        predictions = self._predict_for_batch(x, lens, sess)
        y_pred = []
        for token, pred in zip(x, predictions):
            pad_indexs = np.where(token == self.pad_token_index)[0]
            if pad_indexs.size != 0:
                first_pad_index = pad_indexs[0]
                filtered_pred_indexs = pred[:first_pad_index]
            else:
                filtered_pred_indexs = pred
            filtered_pred_tag = [
                self.idx2tag[pred_index] for pred_index in filtered_pred_indexs
            ]
            y_pred.extend(filtered_pred_tag)

        assert len(y_pred) == len(y)
        precision_recall_f1(y, y_pred, short_report=True)
Beispiel #4
0
def eval_conll(model, session, tokens, tags, short_report=True):
    """Computes NER quality measures using CONLL shared task script."""

    y_true, y_pred = [], []
    for x_batch, y_batch, lengths in batches_generator(1, tokens, tags):
        tags_batch, tokens_batch = predict_tags(model, session, x_batch,
                                                lengths)
        ground_truth_tags = [idx2tag[tag_idx] for tag_idx in y_batch[0]]

        # We extend every prediction and ground truth sequence with 'O' tag
        # to indicate a possible end of entity.
        y_true.extend(ground_truth_tags + ['O'])
        y_pred.extend(tags_batch[0] + ['O'])
    results = precision_recall_f1(y_true,
                                  y_pred,
                                  print_results=True,
                                  short_report=short_report)
    return results
Beispiel #5
0
def eval_conll(model, session, tokens, tags, short_report=True):
    
    y_true, y_pred = [], []
    for x_batch, y_batch, lengths in batches_generator(1, tokens, tags):
        tags_batch, tokens_batch = predict_tags(model, session, x_batch, lengths)
        if len(x_batch[0]) != len(tags_batch[0]):
            raise Exception("Incorrect length of prediction for the input, "
                            "expected length: %i, got: %i" % (len(x_batch[0]), len(tags_batch[0])))
        predicted_tags = []
        ground_truth_tags = []
        for gt_tag_idx, pred_tag, token in zip(y_batch[0], tags_batch[0], tokens_batch[0]): 
            if token != '<PAD>':
                ground_truth_tags.append(idx2tag[gt_tag_idx])
                predicted_tags.append(pred_tag)
        y_true.extend(ground_truth_tags + ['O'])
        y_pred.extend(predicted_tags + ['O'])
        
    results = precision_recall_f1(y_true, y_pred, print_results=True, short_report=short_report)
    return results
Beispiel #6
0
 def eval_conll(self,
                dataset_type='test',
                print_results=True,
                short_report=True):
     y_true_list = list()
     y_pred_list = list()
     print('Eval on {}:'.format(dataset_type))
     for x, y_gt in self.corpus.batch_generator(batch_size=32,
                                                dataset_type=dataset_type):
         y_pred = self.predict(x)
         y_gt = self.corpus.tag_dict.batch_idxs2batch_toks(
             y_gt, filter_paddings=True)
         for tags_pred, tags_gt in zip(y_pred, y_gt):
             for tag_predicted, tag_ground_truth in zip(tags_pred, tags_gt):
                 y_true_list.append(tag_ground_truth)
                 y_pred_list.append(tag_predicted)
             y_true_list.append('O')
             y_pred_list.append('O')
     return precision_recall_f1(y_true_list, y_pred_list, print_results,
                                short_report)
  def test_precision_recall_f1_integration_test(self):
    df = pd.DataFrame(
        dict(
            sequence_name=["seq1", "seq2"],
            predicted_label=[{"CLASS_0", "CLASS_1"}, {"CLASS_1"}],
            true_label=[{"CLASS_1"}, {"CLASS_1"}]))

    normalizing_dict = {
        "CLASS_0": list(),  # Not included in output.
        "CLASS_1": ["CLASS_1"]
    }

    actual_precision, actual_recall, actual_f1 = evaluation.precision_recall_f1(
        df, normalizing_dict)

    expected_precision = 1.  # CLASS_0 is normalized out.
    expected_recall = 1.  # Both sequences have their true label recalled.
    expected_f1 = 1.

    self.assertEqual(actual_precision, expected_precision)
    self.assertEqual(actual_recall, expected_recall)
    self.assertEqual(actual_f1, expected_f1)
Beispiel #8
0
    def evaluate(self):
        self.checkpoint_file = tf.train.latest_checkpoint(
            self.config.checkpoint_path)

        sess_config = tf.ConfigProto(log_device_placement=True,
                                     allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        with tf.Graph().as_default() as graph:
            with tf.Session(config=sess_config) as sess:
                saver = tf.train.import_meta_graph('{}.meta'.format(
                    self.checkpoint_file))
                saver.restore(sess, self.checkpoint_file)

                word_batch = graph.get_operation_by_name(
                    'model/placeholders/word_batch').outputs[0]
                cap_feat_batch = graph.get_operation_by_name(
                    'model/placeholders/cap_feat_batch').outputs[0]
                lengths_batch = graph.get_operation_by_name(
                    'model/placeholders/lengths').outputs[0]

                predictions = graph.get_operation_by_name(
                    'model/output/predictions').outputs[0]

                y_true, y_pred = [], []
                for x_batch, c_batch, y_batch, l_batch in self.data_helper.batches_generator(
                        1, self.test_token, self.test_tags):
                    tags_idxs_batch = sess.run(predictions,
                                               feed_dict={
                                                   word_batch: x_batch,
                                                   cap_feat_batch: c_batch,
                                                   lengths_batch: l_batch
                                               })

                    tags_batch, tokens_batch = [], []
                    for tag_idxs, token_idxs in zip(tags_idxs_batch, x_batch):
                        tags, tokens = [], []
                        for tag_idx, token_idx in zip(tag_idxs, token_idxs):
                            tags.append(self.data_helper.idx2tag[tag_idx])
                            tokens.append(
                                self.data_helper.idx2token[token_idx])

                        tags_batch.append(tags)
                        tokens_batch.append(tokens)

                    predicted_tags = []
                    ground_truth_tags = []

                    for gt_tag_idx, pred_tag, token in zip(
                            y_batch[0], tags_batch[0], tokens_batch[0]):
                        if token != '<PAD>':
                            ground_truth_tags.append(
                                self.data_helper.idx2tag[gt_tag_idx])
                            predicted_tags.append(pred_tag)

                    # We extend every prediction and ground truth sequence with 'O' tag
                    # to indicate a possible end of entity.
                    y_true.extend(ground_truth_tags + ['O'])
                    y_pred.extend(predicted_tags + ['O'])

                results = precision_recall_f1(y_true,
                                              y_pred,
                                              print_results=True,
                                              short_report=False)