def fix_and_parse_tags(config, collection, size): trie_index2indices = OffsetArray.load(join(config.language_path, "trie_index2indices"), compress=True) trie_index2indices_counts = OffsetArray( np.load(join(config.language_path, "trie_index2indices_counts.npy")), trie_index2indices.offsets) if exists( join(config.language_path, "trie_index2indices_transition_values.npy")): trie_index2indices_transitions = OffsetArray( np.load( join(config.language_path, "trie_index2indices_transition_values.npy")), np.load( join(config.language_path, "trie_index2indices_transition_offsets.npy")), ) else: trie_index2indices_transitions = None anchor_trie = marisa_trie.Trie().load( join(config.language_path, "trie.marisa")) wiki_trie = marisa_trie.RecordTrie('i').load( join(config.wikidata, "wikititle2wikidata.marisa")) prefix = get_prefix(config) redirections = load_redirections(config.redirections) docs = load_wikipedia_docs(config.wiki, size) while True: try: collection.load_blacklist(join(SCRIPT_DIR, "blacklist.json")) except (ValueError, ) as e: print("issue reading blacklist, please fix.") print(str(e)) enter_or_quit() continue break print("Load first_names") with open(join(PROJECT_DIR, "data", "first_names.txt"), "rt") as fin: first_names = set(fin.read().splitlines()) all_tags = [] for doc in get_progress_bar('fixing links', item='article')(docs): tags = obtain_tags( doc, wiki_trie=wiki_trie, anchor_trie=anchor_trie, trie_index2indices=trie_index2indices, trie_index2indices_counts=trie_index2indices_counts, trie_index2indices_transitions=trie_index2indices_transitions, redirections=redirections, prefix=prefix, first_names=first_names, collection=collection, fix_destination=fix_destination, min_count=config.min_count, min_percent=config.min_percent) if any(x is not None for _, x in tags): all_tags.append(tags) collection.reset_cache() return all_tags
def main(): args = parse_args() config = load_config(args.config, [ "wiki", "language_path", "wikidata", "redirections", "classification", "path" ], defaults={ "num_names_to_load": 0, "prefix": None, "sample_size": 100, "wiki": None, "min_count": 0, "min_percent": 0.0 }, relative_to=args.relative_to) if config.wiki is None: raise ValueError("must provide path to 'wiki' in config.") prefix = get_prefix(config) print("Load type_collection") collection = TypeCollection(config.wikidata, num_names_to_load=config.num_names_to_load, prefix=prefix, verbose=True) fname = config.wiki all_tags = fix_and_parse_tags(config, collection, config.sample_size) test_tags = all_tags[:config.sample_size] train_tags = all_tags[config.sample_size:] oracles = [ load_oracle_classification(classification) for classification in config.classification ] def get_name(idx): if idx < config.num_names_to_load: if idx in collection.known_names: return collection.known_names[idx] + " (%s)" % ( collection.ids[idx], ) else: return collection.ids[idx] else: return maybe_web_get_name( collection.ids[idx]) + " (%s)" % (collection.ids[idx], ) while True: total_report, ambiguous_tags = disambiguate_batch( test_tags, train_tags, oracles) summarize_disambiguation(total_report) if args.log is not None: with open(args.log, "at") as fout: summarize_disambiguation(total_report, file=fout) if args.verbose: try: summarize_ambiguities(ambiguous_tags, oracles, get_name) except KeyboardInterrupt as e: pass if args.interactive: enter_or_quit() else: break
def main(): args = parse_args() should_export = args.export_classification is not None if should_export and len(args.export_classification) != len( args.classifiers): raise ValueError("Must have as many export filenames as classifiers.") collection = TypeCollection(args.wikidata, num_names_to_load=args.num_names_to_load, language_path=args.language_path, cache=args.use_cache) if args.interactive: alert_failure = enter_or_quit else: alert_failure = lambda: sys.exit(1) while True: try: collection.load_blacklist(join(SCRIPT_DIR, "blacklist.json")) except (ValueError, ) as e: print("Issue reading blacklist, please fix.") print(str(e)) alert_failure() continue classifications = [] for class_idx, classifier_fname in enumerate(args.classifiers): while True: try: classifier = reload_module(classifier_fname) except ALLOWED_IMPORT_ERRORS as e: print("issue reading %r, please fix." % (classifier_fname, )) print(str(e)) traceback.print_exc(file=sys.stdout) alert_failure() continue try: t0 = time.time() classification = classifier.classify(collection) classifications.append(classification) if class_idx == len(args.classifiers) - 1: collection.reset_cache() t1 = time.time() print("classification took %.3fs" % (t1 - t0, )) except ALLOWED_RUNTIME_ERRORS as e: print("issue running %r, please fix." % (classifier_fname, )) print(str(e)) traceback.print_exc(file=sys.stdout) alert_failure() continue break try: # show cardinality for each truth table: if args.interactive: mega_other_class = None for classification in classifications: for classname in sorted(classification.keys()): print( "%r: %d members" % (classname, int(classification[classname].sum()))) print("") summary = {} for classname, truth_table in classification.items(): (members, ) = np.where(truth_table) summary[classname] = [ collection.get_name(int(member)) for member in members[:20] ] print(json.dumps(summary, indent=4)) other_class = get_other_class(classification) if other_class.sum() > 0: # there are missing items: to_report = (classifier.class_report if hasattr( classifier, "class_report") else [ wprop.SUBCLASS_OF, wprop.INSTANCE_OF, wprop.OCCUPATION, wprop.CATEGORY_LINK ]) collection.class_report(to_report, other_class, name="Other") if mega_other_class is None: mega_other_class = other_class else: mega_other_class = np.logical_and( mega_other_class, other_class) if len(classifications) > 1: if mega_other_class.sum() > 0: # there are missing items: to_report = [ wprop.SUBCLASS_OF, wprop.INSTANCE_OF, wprop.OCCUPATION, wprop.CATEGORY_LINK ] collection.class_report(to_report, mega_other_class, name="Other-combined") if should_export: assert (len(classifications) == len( args.export_classification)), ( "classification outputs missing for export.") for classification, savename in zip( classifications, args.export_classification): export_classification(classification, savename) except KeyboardInterrupt as e: pass if args.interactive: enter_or_quit() else: break