Beispiel #1
0
def fix_and_parse_tags(config, collection, size):
    trie_index2indices = OffsetArray.load(join(config.language_path,
                                               "trie_index2indices"),
                                          compress=True)
    trie_index2indices_counts = OffsetArray(
        np.load(join(config.language_path, "trie_index2indices_counts.npy")),
        trie_index2indices.offsets)
    if exists(
            join(config.language_path,
                 "trie_index2indices_transition_values.npy")):
        trie_index2indices_transitions = OffsetArray(
            np.load(
                join(config.language_path,
                     "trie_index2indices_transition_values.npy")),
            np.load(
                join(config.language_path,
                     "trie_index2indices_transition_offsets.npy")),
        )
    else:
        trie_index2indices_transitions = None

    anchor_trie = marisa_trie.Trie().load(
        join(config.language_path, "trie.marisa"))
    wiki_trie = marisa_trie.RecordTrie('i').load(
        join(config.wikidata, "wikititle2wikidata.marisa"))
    prefix = get_prefix(config)
    redirections = load_redirections(config.redirections)
    docs = load_wikipedia_docs(config.wiki, size)

    while True:
        try:
            collection.load_blacklist(join(SCRIPT_DIR, "blacklist.json"))
        except (ValueError, ) as e:
            print("issue reading blacklist, please fix.")
            print(str(e))
            enter_or_quit()
            continue
        break

    print("Load first_names")
    with open(join(PROJECT_DIR, "data", "first_names.txt"), "rt") as fin:
        first_names = set(fin.read().splitlines())

    all_tags = []
    for doc in get_progress_bar('fixing links', item='article')(docs):
        tags = obtain_tags(
            doc,
            wiki_trie=wiki_trie,
            anchor_trie=anchor_trie,
            trie_index2indices=trie_index2indices,
            trie_index2indices_counts=trie_index2indices_counts,
            trie_index2indices_transitions=trie_index2indices_transitions,
            redirections=redirections,
            prefix=prefix,
            first_names=first_names,
            collection=collection,
            fix_destination=fix_destination,
            min_count=config.min_count,
            min_percent=config.min_percent)
        if any(x is not None for _, x in tags):
            all_tags.append(tags)
    collection.reset_cache()
    return all_tags
Beispiel #2
0
def main():
    args = parse_args()
    config = load_config(args.config, [
        "wiki", "language_path", "wikidata", "redirections", "classification",
        "path"
    ],
                         defaults={
                             "num_names_to_load": 0,
                             "prefix": None,
                             "sample_size": 100,
                             "wiki": None,
                             "min_count": 0,
                             "min_percent": 0.0
                         },
                         relative_to=args.relative_to)
    if config.wiki is None:
        raise ValueError("must provide path to 'wiki' in config.")
    prefix = get_prefix(config)

    print("Load type_collection")
    collection = TypeCollection(config.wikidata,
                                num_names_to_load=config.num_names_to_load,
                                prefix=prefix,
                                verbose=True)

    fname = config.wiki
    all_tags = fix_and_parse_tags(config, collection, config.sample_size)
    test_tags = all_tags[:config.sample_size]
    train_tags = all_tags[config.sample_size:]

    oracles = [
        load_oracle_classification(classification)
        for classification in config.classification
    ]

    def get_name(idx):
        if idx < config.num_names_to_load:
            if idx in collection.known_names:
                return collection.known_names[idx] + " (%s)" % (
                    collection.ids[idx], )
            else:
                return collection.ids[idx]
        else:
            return maybe_web_get_name(
                collection.ids[idx]) + " (%s)" % (collection.ids[idx], )

    while True:
        total_report, ambiguous_tags = disambiguate_batch(
            test_tags, train_tags, oracles)
        summarize_disambiguation(total_report)
        if args.log is not None:
            with open(args.log, "at") as fout:
                summarize_disambiguation(total_report, file=fout)
        if args.verbose:
            try:
                summarize_ambiguities(ambiguous_tags, oracles, get_name)
            except KeyboardInterrupt as e:
                pass
        if args.interactive:
            enter_or_quit()
        else:
            break
Beispiel #3
0
def main():
    args = parse_args()
    should_export = args.export_classification is not None
    if should_export and len(args.export_classification) != len(
            args.classifiers):
        raise ValueError("Must have as many export filenames as classifiers.")
    collection = TypeCollection(args.wikidata,
                                num_names_to_load=args.num_names_to_load,
                                language_path=args.language_path,
                                cache=args.use_cache)
    if args.interactive:
        alert_failure = enter_or_quit
    else:
        alert_failure = lambda: sys.exit(1)

    while True:
        try:
            collection.load_blacklist(join(SCRIPT_DIR, "blacklist.json"))
        except (ValueError, ) as e:
            print("Issue reading blacklist, please fix.")
            print(str(e))
            alert_failure()
            continue

        classifications = []
        for class_idx, classifier_fname in enumerate(args.classifiers):
            while True:
                try:
                    classifier = reload_module(classifier_fname)
                except ALLOWED_IMPORT_ERRORS as e:
                    print("issue reading %r, please fix." %
                          (classifier_fname, ))
                    print(str(e))
                    traceback.print_exc(file=sys.stdout)
                    alert_failure()
                    continue

                try:
                    t0 = time.time()
                    classification = classifier.classify(collection)
                    classifications.append(classification)
                    if class_idx == len(args.classifiers) - 1:
                        collection.reset_cache()
                    t1 = time.time()
                    print("classification took %.3fs" % (t1 - t0, ))
                except ALLOWED_RUNTIME_ERRORS as e:
                    print("issue running %r, please fix." %
                          (classifier_fname, ))
                    print(str(e))
                    traceback.print_exc(file=sys.stdout)
                    alert_failure()
                    continue
                break
        try:
            # show cardinality for each truth table:
            if args.interactive:
                mega_other_class = None
                for classification in classifications:
                    for classname in sorted(classification.keys()):
                        print(
                            "%r: %d members" %
                            (classname, int(classification[classname].sum())))
                    print("")
                    summary = {}
                    for classname, truth_table in classification.items():
                        (members, ) = np.where(truth_table)
                        summary[classname] = [
                            collection.get_name(int(member))
                            for member in members[:20]
                        ]
                    print(json.dumps(summary, indent=4))

                    other_class = get_other_class(classification)
                    if other_class.sum() > 0:
                        # there are missing items:
                        to_report = (classifier.class_report if hasattr(
                            classifier, "class_report") else [
                                wprop.SUBCLASS_OF, wprop.INSTANCE_OF,
                                wprop.OCCUPATION, wprop.CATEGORY_LINK
                            ])
                        collection.class_report(to_report,
                                                other_class,
                                                name="Other")
                        if mega_other_class is None:
                            mega_other_class = other_class
                        else:
                            mega_other_class = np.logical_and(
                                mega_other_class, other_class)
                if len(classifications) > 1:
                    if mega_other_class.sum() > 0:
                        # there are missing items:
                        to_report = [
                            wprop.SUBCLASS_OF, wprop.INSTANCE_OF,
                            wprop.OCCUPATION, wprop.CATEGORY_LINK
                        ]
                        collection.class_report(to_report,
                                                mega_other_class,
                                                name="Other-combined")
            if should_export:
                assert (len(classifications) == len(
                    args.export_classification)), (
                        "classification outputs missing for export.")
                for classification, savename in zip(
                        classifications, args.export_classification):
                    export_classification(classification, savename)
        except KeyboardInterrupt as e:
            pass

        if args.interactive:
            enter_or_quit()
        else:
            break