def test_span_metrics_are_computed_correctly(self):
        batch_verb_indices = [2]
        batch_sentences = [["The", "cat", "loves", "hats", "."]]
        batch_bio_predicted_tags = [["B-ARG0", "B-ARG1", "B-V", "B-ARG1", "O"]]
        batch_conll_predicted_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_predicted_tags
        ]
        batch_bio_gold_tags = [["B-ARG0", "I-ARG0", "B-V", "B-ARG1", "O"]]
        batch_conll_gold_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_gold_tags
        ]

        srl_scorer = SrlEvalScorer()
        srl_scorer(batch_verb_indices, batch_sentences,
                   batch_conll_predicted_tags, batch_conll_gold_tags)
        metrics = srl_scorer.get_metric()
        assert len(metrics) == 12
        assert_allclose(metrics['precision-ARG0'], 0.0)
        assert_allclose(metrics['recall-ARG0'], 0.0)
        assert_allclose(metrics['f1-measure-ARG0'], 0.0)
        assert_allclose(metrics['precision-ARG1'], 0.5)
        assert_allclose(metrics['recall-ARG1'], 1.0)
        assert_allclose(metrics['f1-measure-ARG1'], 2 / 3)
        assert_allclose(metrics['precision-V'], 1.0)
        assert_allclose(metrics['recall-V'], 1.0)
        assert_allclose(metrics['f1-measure-V'], 1.0)
        assert_allclose(metrics['precision-overall'], 1 / 2)
        assert_allclose(metrics['recall-overall'], 2 / 3)
        assert_allclose(metrics['f1-measure-overall'],
                        (2 * (2 / 3) * (1 / 2)) / ((2 / 3) + (1 / 2)))
Example #2
0
    def _reader(self):
        """
        line_obj:
            * metadata:             Dict
            * predicted_log_probs:  List[ float ]           #beam
            * predictions:          List[ List(ids) ]       #beam
            * predicted_tokens:     List[ List(tokens) ]    #beam
        ---------------------------
        ## metadata
            * souruce_tokens:               ['<EN-SRL>', 'Some', ..., 'i.'] 
            * verb:                         'installed'
            * src_lang:                     '<EN>'
            * tgt_lang:                     '<EN-SRL>'
            * original_BIO:                 ['B-A1', 'I-A1', ..., 'O']
            * original_predicate_senses:    []
            * predicate_senses:             [5, 'installed', '-', 'VBN']
            * original_target:              ['(#', 'Some', ..., '.']
        """
        count_different_seq_len, count_inappropriate_bracket, count_without_pred = 0, 0, 0
        sys.stdout.write('READ <- {}\n'.format(self._predictor_output))

        for instance_no, line in enumerate(open(self._predictor_output),
                                           start=1):
            line_obj = json.loads(line.strip())
            metadata = line_obj['metadata']

            sentence = metadata['source_tokens'][1:]  # '<EN-SRL>' を除く
            verb = tuple(metadata['predicate_senses'][:2])  # (v_idx, verb)

            predicted_target = line_obj['predicted_tokens'][0] if not DEBUG_MATCH else \
                               metadata['original_target']

            bio_gold = metadata['original_BIO']
            bio_pred, invalid_bracket = self.create_predicted_BIO(
                predicted_target, metadata)

            conll_formatted_gold_tag = convert_bio_tags_to_conll_format(
                bio_gold)
            conll_formatted_predicted_tag = convert_bio_tags_to_conll_format(
                bio_pred)

            ### counter ###
            if len(bio_gold) != len(bio_pred):  # pred と gold の系列長が異なる
                count_different_seq_len += 1
            if invalid_bracket:  # 括弧づけが不適切
                count_inappropriate_bracket += 1
            if verb[0] == -1:  # target 述語が存在しない
                count_without_pred += 1

            yield verb, sentence, conll_formatted_predicted_tag, conll_formatted_gold_tag, line_obj

        sys.stdout.write("COUNT:\n")
        sys.stdout.write("\tinstances: {}\n".format(instance_no))
        sys.stdout.write(
            "\tdifferent_seq_len: {}\n".format(count_different_seq_len))
        sys.stdout.write("\tinappropriate_bracket: {}\n".format(
            count_inappropriate_bracket))
        sys.stdout.write("\twithout_pred: {}\n".format(count_without_pred))
Example #3
0
 def test_bio_tags_correctly_convert_to_conll_format(self):
     bio_tags = [u"B-ARG-1", u"I-ARG-1", u"O", u"B-V", u"B-ARGM-ADJ", u"O"]
     conll_tags = convert_bio_tags_to_conll_format(bio_tags)
     assert conll_tags == [u"(ARG-1*", u"*)", u"*", u"(V*)", u"(ARGM-ADJ*)", u"*"]
 def test_bio_tags_correctly_convert_to_conll_format(self):
     bio_tags = ["B-ARG-1", "I-ARG-1", "O", "B-V", "B-ARGM-ADJ", "O"]
     conll_tags = convert_bio_tags_to_conll_format(bio_tags)
     assert conll_tags == ["(ARG-1*", "*)", "*", "(V*)", "(ARGM-ADJ*)", "*"]
    def test_srl_eval_correctly_scores_identical_tags(self):
        batch_verb_indices = [3, 8, 2]
        batch_sentences = [[
            "Mali", "government", "officials", "say", "the", "woman", "'s",
            "confession", "was", "forced", "."
        ],
                           [
                               "Mali", "government", "officials", "say", "the",
                               "woman", "'s", "confession", "was", "forced",
                               "."
                           ],
                           [
                               'The', 'prosecution', 'rested', 'its', 'case',
                               'last', 'month', 'after', 'four', 'months',
                               'of', 'hearings', '.'
                           ]]
        batch_bio_predicted_tags = [[
            'B-ARG0', 'I-ARG0', 'I-ARG0', 'B-V', 'B-ARG1', 'I-ARG1', 'I-ARG1',
            'I-ARG1', 'I-ARG1', 'I-ARG1', 'O'
        ],
                                    [
                                        'O', 'O', 'O', 'O', 'B-ARG1', 'I-ARG1',
                                        'I-ARG1', 'I-ARG1', 'B-V', 'B-ARG2',
                                        'O'
                                    ],
                                    [
                                        'B-ARG0', 'I-ARG0', 'B-V', 'B-ARG1',
                                        'I-ARG1', 'B-ARGM-TMP', 'I-ARGM-TMP',
                                        'B-ARGM-TMP', 'I-ARGM-TMP',
                                        'I-ARGM-TMP', 'I-ARGM-TMP',
                                        'I-ARGM-TMP', 'O'
                                    ]]
        batch_conll_predicted_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_predicted_tags
        ]
        batch_bio_gold_tags = [[
            'B-ARG0', 'I-ARG0', 'I-ARG0', 'B-V', 'B-ARG1', 'I-ARG1', 'I-ARG1',
            'I-ARG1', 'I-ARG1', 'I-ARG1', 'O'
        ],
                               [
                                   'O', 'O', 'O', 'O', 'B-ARG1', 'I-ARG1',
                                   'I-ARG1', 'I-ARG1', 'B-V', 'B-ARG2', 'O'
                               ],
                               [
                                   'B-ARG0', 'I-ARG0', 'B-V', 'B-ARG1',
                                   'I-ARG1', 'B-ARGM-TMP', 'I-ARGM-TMP',
                                   'B-ARGM-TMP', 'I-ARGM-TMP', 'I-ARGM-TMP',
                                   'I-ARGM-TMP', 'I-ARGM-TMP', 'O'
                               ]]
        batch_conll_gold_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_gold_tags
        ]

        srl_scorer = SrlEvalScorer()
        srl_scorer(batch_verb_indices, batch_sentences,
                   batch_conll_predicted_tags, batch_conll_gold_tags)
        metrics = srl_scorer.get_metric()
        assert len(metrics) == 18
        assert_allclose(metrics['precision-ARG0'], 1.0)
        assert_allclose(metrics['recall-ARG0'], 1.0)
        assert_allclose(metrics['f1-measure-ARG0'], 1.0)
        assert_allclose(metrics['precision-ARG1'], 1.0)
        assert_allclose(metrics['recall-ARG1'], 1.0)
        assert_allclose(metrics['f1-measure-ARG1'], 1.0)
        assert_allclose(metrics['precision-ARG2'], 1.0)
        assert_allclose(metrics['recall-ARG2'], 1.0)
        assert_allclose(metrics['f1-measure-ARG2'], 1.0)
        assert_allclose(metrics['precision-V'], 1.0)
        assert_allclose(metrics['recall-V'], 1.0)
        assert_allclose(metrics['f1-measure-V'], 1.0)
        assert_allclose(metrics['precision-ARGM-TMP'], 1.0)
        assert_allclose(metrics['recall-ARGM-TMP'], 1.0)
        assert_allclose(metrics['f1-measure-ARGM-TMP'], 1.0)
        assert_allclose(metrics['precision-overall'], 1.0)
        assert_allclose(metrics['recall-overall'], 1.0)
        assert_allclose(metrics['f1-measure-overall'], 1.0)