Beispiel #1
0
def test_preannotate_passage(create, as_array, convert_and_back, partial, monkeypatch):
    if not partial:
        monkeypatch.setattr(textutil, "get_nlp", assert_spacy_not_loaded)
    passage = create()
    l0 = passage.layer(layer0.LAYER_ID)
    attr_values = list(range(10, 10 + len(textutil.Attr)))
    if partial:
        attr_values[textutil.Attr.ENT_TYPE.value] = ""
    if as_array:
        l0.extra["doc"] = [len(p) * [attr_values] for p in textutil.break2paragraphs(passage, return_terminals=True)]
    else:
        for terminal in l0.all:
            for attr, value in zip(textutil.Attr, attr_values):
                if value:
                    terminal.extra[attr.key] = value
    passage = (passage, convert.from_standard(convert.to_standard(passage)))[convert_and_back]
    if not partial:
        assert textutil.is_annotated(passage, as_array=as_array), "Passage %s is not pre-annotated" % passage.ID
    textutil.annotate(passage, as_array=as_array)
    assert textutil.is_annotated(passage, as_array=as_array), "Passage %s is not annotated" % passage.ID
    for terminal in l0.all:
        for i, (attr, value) in enumerate(zip(textutil.Attr, attr_values)):
            if value:
                assert (terminal.tok[i] if as_array else terminal.extra.get(attr.key)) == value, \
                    "Terminal %s has wrong %s" % (terminal, attr.name)
Beispiel #2
0
def extract_edges(passage, constructions=None, reference=None, verbose=False):
    """
    Find constructions in UCCA passage.
    :param passage: Passage object to find constructions in
    :param constructions: list of constructions to include or None for all
    :param reference: Passage object to get POS tags from (default: `passage')
    :param verbose: whether to print tagged text
    :return: dict of Construction -> list of corresponding edges
    """
    constructions = get_by_names(constructions)
    if reference is not None:
        ids1, ids2 = terminal_ids(passage), terminal_ids(reference)
        assert ids1 == ids2, "Reference passage terminals do not match: %s (%d != %d)\nDifference:\n%s" % (
            reference.ID, len(terminal_ids(passage)), len(terminal_ids(reference)),
            "\n".join(map(str, diff_terminals(passage, reference))))
    if any(not c.default for c in constructions):
        textutil.annotate(passage, verbose=verbose)
    extracted = OrderedDict((c, []) for c in constructions)
    for node in passage.layer(layer1.LAYER_ID).all:
        for edge in node:
            candidate = Candidate(edge, reference=reference)
            for construction in constructions:
                if construction.criterion(candidate):
                    extracted[construction].append(edge)
    # edges = (e for n in l1.all for e in n if e.tag)
    # for edge in edges:
    #     if args.mwe:
    #         pass
    #     if args.part_whole:
    #         pass
    #     if args.classifiers:
    #         pass
    return extracted
Beispiel #3
0
 def to_format(self,
               passage,
               metadata=True,
               wikification=True,
               verbose=False,
               use_original=True,
               default_label=None,
               **kwargs):
     self.wikification = wikification
     if use_original:
         original = passage.extra.get("original")
         if original:
             return original
     textutil.annotate(passage, as_array=True)
     if self.wikification:
         if verbose:
             print("Wikifying passage...")
         WIKIFIER.wikify_passage(passage)
     if verbose:
         print("Expanding names...")
     self._expand_names(passage.layer(layer1.LAYER_ID))
     triples = list(self._to_triples(
         passage, default_label=default_label)) or [("y", INSTANCE, "yes")]
     return (self.header(passage, **kwargs) if metadata else
             []) + (penman.encode(penman.Graph(triples)).split("\n"))
Beispiel #4
0
def test_preannotate_passage(create, as_array, convert_and_back, partial, monkeypatch):
    if not partial:
        monkeypatch.setattr(textutil, "get_nlp", assert_spacy_not_loaded)
    passage = create()
    l0 = passage.layer(layer0.LAYER_ID)
    attr_values = list(range(10, 10 + len(textutil.Attr)))
    if partial:
        attr_values[textutil.Attr.ENT_TYPE.value] = ""
    if as_array:
        l0.extra["doc"] = [len(p) * [attr_values] for p in textutil.break2paragraphs(passage, return_terminals=True)]
    else:
        for terminal in l0.all:
            for attr, value in zip(textutil.Attr, attr_values):
                if value:
                    terminal.extra[attr.key] = value
    passage = (passage, convert.from_standard(convert.to_standard(passage)))[convert_and_back]
    if not partial:
        assert textutil.is_annotated(passage, as_array=as_array, as_extra=not as_array), \
            "Passage %s is not pre-annotated" % passage.ID
    textutil.annotate(passage, as_array=as_array, as_extra=not as_array)
    assert textutil.is_annotated(passage, as_array=as_array, as_extra=not as_array), \
        "Passage %s is not annotated" % passage.ID
    for terminal in l0.all:
        for i, (attr, value) in enumerate(zip(textutil.Attr, attr_values)):
            if value:
                assert (terminal.tok[i] if as_array else terminal.extra.get(attr.key)) == value, \
                    "Terminal %s has wrong %s" % (terminal, attr.name)
Beispiel #5
0
def _test_features(config, feature_extractor_creator, filename, write_features):
    feature_extractor = feature_extractor_creator(config)
    passage = load_passage(filename, annotate=feature_extractor_creator.annotated)
    textutil.annotate(passage, as_array=True, as_extra=False, vocab=config.vocab())
    config.set_format(passage.extra.get("format") or "ucca")
    oracle = Oracle(passage)
    state = State(passage)
    actions = Actions()
    for key, param in feature_extractor.params.items():
        if not param.numeric:
            param.dropout = 0
            feature_extractor.init_param(key)
    features = [feature_extractor.init_features(state)]
    while True:
        extract_features(feature_extractor, state, features)
        action = min(oracle.get_actions(state, actions).values(), key=str)
        state.transition(action)
        if state.need_label:
            extract_features(feature_extractor, state, features)
            label, _ = oracle.get_label(state, action)
            state.label_node(label)
        if state.finished:
            break
    features = ["%s %s\n" % i for f in features if f for i in (sorted(f.items()) + [("", "")])]
    compare_file = os.path.join("test_files", "features", "-".join((basename(filename), str(feature_extractor_creator)))
                                + ".txt")
    if write_features:
        with open(compare_file, "w", encoding="utf-8") as f:
            f.writelines(features)
    with open(compare_file, encoding="utf-8") as f:
        assert f.readlines() == features, compare_file
Beispiel #6
0
 def _annotate(self, attr=None):
     passage = self.edge.parent.root
     if not passage.extra.get("annotated"):
         textutil.annotate(passage, as_array=True, verbose=self.verbose)
         passage.extra["annotated"] = True
     if attr:
         ret = self.extra.get(attr)
         if ret is None:
             ret = self.extra[attr] = {t.get_annotation(attr, as_array=True) for t in self.terminals}
         return ret
Beispiel #7
0
 def _annotate(self, attr=None):
     passage = self.edge.parent.root
     if not passage.extra.get("annotated"):
         textutil.annotate(passage, as_array=True, verbose=self.verbose)
         passage.extra["annotated"] = True
     if attr:
         ret = self.extra.get(attr)
         if ret is None:
             ret = self.extra[attr] = {t.get_annotation(attr, as_array=True) for t in self.terminals}
         return ret
Beispiel #8
0
def test_annotate_passage(create, as_array):
    passage = create()
    textutil.annotate(passage, as_array=as_array)
    for p in passage, convert.from_standard(convert.to_standard(passage)):
        assert textutil.is_annotated(p, as_array=as_array), "Passage %s is not annotated" % passage.ID
        for terminal in p.layer(layer0.LAYER_ID).all:
            if as_array:
                assert terminal.tok is not None, "Terminal %s has no annotation" % terminal
                assert len(terminal.tok) == len(textutil.Attr)
            else:
                for attr in textutil.Attr:
                    assert attr.key in terminal.extra, "Terminal %s has no %s" % (terminal, attr.name)
Beispiel #9
0
def test_annotate_passage(create, as_array):
    passage = create()
    textutil.annotate(passage, as_array=as_array)
    for p in passage, convert.from_standard(convert.to_standard(passage)):
        assert textutil.is_annotated(p, as_array=as_array), "Passage %s is not annotated" % passage.ID
        for terminal in p.layer(layer0.LAYER_ID).all:
            if as_array:
                assert terminal.tok is not None, "Terminal %s has no annotation" % terminal
                assert len(terminal.tok) == len(textutil.Attr)
            else:
                for attr in textutil.Attr:
                    assert attr.key in terminal.extra, "Terminal %s has no %s" % (terminal, attr.name)
Beispiel #10
0
 def to_format(self, passage, metadata=True, wikification=True, verbose=False, use_original=True):
     if use_original:
         original = passage.extra.get("original")
         if original:
             return original
     textutil.annotate(passage, as_array=True)
     lines = self.header(passage) if metadata else []
     if wikification:
         if verbose:
             print("Wikifying passage...")
         WIKIFIER.wikify_passage(passage)
     if verbose:
         print("Expanding names...")
     self._expand_names(passage.layer(layer1.LAYER_ID))
     return lines + (penman.encode(penman.Graph(list(self._to_triples(passage)))).split("\n") or ["(y / yes)"])
Beispiel #11
0
 def test_annotate_passage(self):
     passage = convert.from_standard(
         TestUtil.load_xml("test_files/standard3.xml"))
     textutil.annotate(passage)
     textutil.annotate(passage, as_array=True)
     for p in passage, convert.from_standard(convert.to_standard(passage)):
         self.assertTrue(is_annotated(p, as_array=True),
                         "Passage %s is not annotated" % passage.ID)
         self.assertTrue(is_annotated(p, as_array=False),
                         "Passage %s is not annotated" % passage.ID)
         for terminal in p.layer(layer0.LAYER_ID).all:
             for attr in textutil.Attr:
                 self.assertIn(
                     attr.key, terminal.extra,
                     "Terminal %s has no %s" % (terminal, attr.name))
             self.assertIsNotNone(
                 terminal.tok, "Terminal %s has no annotation" % terminal)
             self.assertEqual(len(terminal.tok), len(textutil.Attr))
Beispiel #12
0
def main():
    argparser = argparse.ArgumentParser(description=desc)
    argparser.add_argument("filenames",
                           nargs="+",
                           help="passage file names to annotate")
    argparser.add_argument("-v",
                           "--verbose",
                           action="store_true",
                           help="print tagged text for each passage")
    args = argparser.parse_args()

    for pattern in args.filenames:
        filenames = glob.glob(pattern)
        if not filenames:
            raise IOError("Not found: " + pattern)
        for filename in filenames:
            passage = file2passage(filename)
            annotate(passage, verbose=args.verbose, replace=True)
            sys.stderr.write("Writing '%s'...\n" % filename)
            passage2file(passage,
                         filename,
                         binary=not filename.endswith("xml"))

    sys.exit(0)
Beispiel #13
0
    def parse(self, passages, mode=ParseMode.test, evaluate=False):
        """
        Parse given passages
        :param passages: iterable of passages to parse
        :param mode: ParseMode value.
                     If train, use oracle to train on given passages.
                     Otherwise, just parse with classifier.
        :param evaluate: whether to evaluate parsed passages with respect to given ones.
                         Only possible when given passages are annotated.
        :return: generator of parsed passages (or in train mode, the original ones),
                 or, if evaluate=True, of pairs of (Passage, Scores).
        """
        assert mode in ParseMode, "Invalid parse mode: %s" % mode
        train = (mode is ParseMode.train)
        if not train and not self.trained:
            list(self.train())
        passage_word = "sentence" if self.args.sentences else \
                       "paragraph" if self.args.paragraphs else \
                       "passage"
        self.total_actions = 0
        self.total_correct_actions = 0
        total_duration = 0
        total_tokens = 0
        passage_index = 0
        if not hasattr(passages, "__iter__"):  # Single passage given
            passages = (passages, )
        for passage_index, passage in enumerate(passages):
            labeled = any(n.outgoing or n.attrib.get(LABEL_ATTRIB)
                          for n in passage.layer(layer1.LAYER_ID).all)
            assert not train or labeled, "Cannot train on unannotated passage: %s" % passage.ID
            assert not evaluate or labeled, "Cannot evaluate on unannotated passage: %s" % passage.ID
            print("%s %-7s" % (passage_word, passage.ID),
                  end=Config().line_end,
                  flush=True)
            started = time.time()
            self.action_count = self.correct_action_count = self.label_count = self.correct_label_count = 0
            textutil.annotate(passage, verbose=self.args.verbose >
                              1)  # tag POS and parse dependencies
            self.state = State(passage)
            self.state_hash_history = set()
            self.oracle = Oracle(passage) if train or (
                self.args.verbose or Config().args.use_gold_node_labels
            ) and labeled or self.args.verify else None
            failed = False
            if ClassifierProperty.require_init_features in self.model.model.get_classifier_properties(
            ):
                self.model.init_features(self.state, train)
            try:
                self.parse_passage(
                    train)  # This is where the actual parsing takes place
            except ParserException as e:
                if train:
                    raise
                Config().log("%s %s: %s" % (passage_word, passage.ID, e))
                failed = True
            guessed = self.state.create_passage(
                verify=self.args.verify
            ) if not train or self.args.verify else passage
            duration = time.time() - started
            total_duration += duration
            num_tokens = len(
                set(self.state.terminals).difference(self.state.buffer))
            total_tokens += num_tokens
            if self.oracle:  # We have an oracle to verify by
                if not failed and self.args.verify:
                    self.verify_passage(guessed, passage, train)
                if self.action_count:
                    accuracy_str = "%d%% (%d/%d)" % (
                        100 * self.correct_action_count / self.action_count,
                        self.correct_action_count, self.action_count)
                    if self.label_count:
                        accuracy_str += " %d%% (%d/%d)" % (
                            100 * self.correct_label_count / self.label_count,
                            self.correct_label_count, self.label_count)
                    print("%-30s" % accuracy_str, end=Config().line_end)
            print("%0.3fs" % duration, end="")
            print("%-15s" % (" (failed)" if failed else " (%d tokens/s)" %
                             (num_tokens / duration)),
                  end="")
            print(Config().line_end, end="")
            if self.oracle:
                print(Config().line_end, flush=True)
            self.model.model.finished_item(train)
            self.total_correct_actions += self.correct_action_count
            self.total_actions += self.action_count
            self.total_correct_labels += self.correct_label_count
            self.total_labels += self.label_count
            if train and self.args.save_every and (
                    passage_index + 1) % self.args.save_every == 0:
                self.eval_and_save()
                self.eval_index += 1
            yield (guessed, self.evaluate_passage(
                guessed, passage)) if evaluate else guessed

        if passages:
            print("Parsed %d %ss" % (passage_index + 1, passage_word))
            if self.oracle and self.total_actions:
                accuracy_str = "%d%% correct actions (%d/%d)" % (
                    100 * self.total_correct_actions / self.total_actions,
                    self.total_correct_actions, self.total_actions)
                if self.total_labels:
                    accuracy_str += ", %d%% correct labels (%d/%d)" % (
                        100 * self.total_correct_labels / self.total_labels,
                        self.total_correct_labels, self.total_labels)
                print("Overall %s on %s" % (accuracy_str, mode.name))
            if total_duration:
                print(
                    "Total time: %.3fs (average time/%s: %.3fs, average tokens/s: %d)"
                    % (total_duration, passage_word, total_duration /
                       (passage_index + 1), total_tokens / total_duration),
                    flush=True)