Exemplo n.º 1
0
def main():
    args = parse_args()
    collection = TypeCollection(args.wikidata, num_names_to_load=0)
    collection.load_blacklist(
        join(dirname(SCRIPT_DIR), "extraction", "blacklist.json"))
    lines_arr, mask = generate_training_data(collection, args.dataset)
    article_ids = np.array(list(set(lines_arr[:, 1])), dtype=np.int32)
    proposal_sets = get_proposal_sets(collection, article_ids, args.seed)
    report = []
    total = sum(len(topfields) for topfields, _ in proposal_sets)
    seen = 0
    t0 = time.time()
    data_source = generate_truth_tables(collection, lines_arr, proposal_sets,
                                        args.simultaneous_fields)

    for topfields, relation_name, truth_tables, qids, id2pos in prefetch_generator(
            data_source):
        # for each of these properties and given relation
        # construct the truth table for each item and discover
        # their 'learnability':
        seen += len(topfields)
        field_auc_scores = learnability(collection,
                                        lines_arr,
                                        mask,
                                        qids=qids,
                                        truth_tables=truth_tables,
                                        id2pos=id2pos,
                                        batch_size=args.batch_size,
                                        epochs=args.max_epochs,
                                        input_size=args.input_size,
                                        window_size=args.window_size,
                                        max_vocab_size=args.max_vocab_size,
                                        verbose=True)
        for qid, auc, average_precision_score, correct, pos, neg in field_auc_scores:
            report.append({
                "qid": collection.ids[qid],
                "auc": auc,
                "average_precision_score": average_precision_score,
                "correct": correct,
                "relation": relation_name,
                "positive": pos,
                "negative": neg
            })
        with open(args.out, "wt") as fout:
            json.dump(report, fout)
        t1 = time.time()
        speed = seen / (t1 - t0)
        print("AUC obtained for %d / %d items (%.3f items/s)" %
              (seen, total, speed),
              flush=True)
Exemplo n.º 2
0
def main():
    args = parse_args()
    config = load_config(args.config, [
        "wiki", "language_path", "wikidata", "redirections", "classification",
        "path"
    ],
                         defaults={
                             "num_names_to_load": 0,
                             "prefix": None,
                             "sample_size": 100,
                             "wiki": None,
                             "min_count": 0,
                             "min_percent": 0.0
                         },
                         relative_to=args.relative_to)
    if config.wiki is None:
        raise ValueError("must provide path to 'wiki' in config.")
    prefix = get_prefix(config)

    print("Load type_collection")
    collection = TypeCollection(config.wikidata,
                                num_names_to_load=config.num_names_to_load,
                                prefix=prefix,
                                verbose=True)

    fname = config.wiki
    all_tags = fix_and_parse_tags(config, collection, config.sample_size)
    test_tags = all_tags[:config.sample_size]
    train_tags = all_tags[config.sample_size:]

    oracles = [
        load_oracle_classification(classification)
        for classification in config.classification
    ]

    def get_name(idx):
        if idx < config.num_names_to_load:
            if idx in collection.known_names:
                return collection.known_names[idx] + " (%s)" % (
                    collection.ids[idx], )
            else:
                return collection.ids[idx]
        else:
            return maybe_web_get_name(
                collection.ids[idx]) + " (%s)" % (collection.ids[idx], )

    while True:
        total_report, ambiguous_tags = disambiguate_batch(
            test_tags, train_tags, oracles)
        summarize_disambiguation(total_report)
        if args.log is not None:
            with open(args.log, "at") as fout:
                summarize_disambiguation(total_report, file=fout)
        if args.verbose:
            try:
                summarize_ambiguities(ambiguous_tags, oracles, get_name)
            except KeyboardInterrupt as e:
                pass
        if args.interactive:
            enter_or_quit()
        else:
            break
def main():
    args = parse_args()
    config = load_config(args.config,
                         ["wiki", "language_path", "wikidata", "redirections"],
                         defaults={
                             "num_names_to_load": 0,
                             "prefix": None,
                             "sample_size": 100
                         },
                         relative_to=args.relative_to)
    prefix = config.prefix or induce_wikipedia_prefix(config.wiki)

    collection = TypeCollection(config.wikidata, num_names_to_load=0)
    collection.load_blacklist(join(SCRIPT_DIR, "blacklist.json"))

    trie_index2indices = OffsetArray.load(join(config.language_path,
                                               "trie_index2indices"),
                                          compress=True)
    trie_index2indices_counts = OffsetArray(
        np.load(join(config.language_path, "trie_index2indices_counts.npy")),
        trie_index2indices.offsets)
    if exists(
            join(config.language_path,
                 "trie_index2indices_transition_values.npy")):
        trie_index2indices_transitions = OffsetArray(
            np.load(
                join(config.language_path,
                     "trie_index2indices_transition_values.npy")),
            np.load(
                join(config.language_path,
                     "trie_index2indices_transition_offsets.npy")),
        )
    else:
        trie_index2indices_transitions = None

    anchor_trie = marisa_trie.Trie().load(
        join(config.language_path, "trie.marisa"))
    wiki_trie = marisa_trie.RecordTrie('i').load(
        join(config.wikidata, "wikititle2wikidata.marisa"))
    redirections = load_redirections(config.redirections)

    seen = 0
    with open(args.out, "wt") as fout:
        try:
            for i, (article_name,
                    article) in tqdm(enumerate(iterate_articles(config.wiki))):
                if i == 5409:
                    continue
                fixed_article, article_qid = convert(
                    article_name,
                    article,
                    collection=collection,
                    anchor_trie=anchor_trie,
                    wiki_trie=wiki_trie,
                    trie_index2indices=trie_index2indices,
                    trie_index2indices_counts=trie_index2indices_counts,
                    trie_index2indices_transitions=
                    trie_index2indices_transitions,
                    redirections=redirections,
                    prefix=prefix)
                if fixed_article is False:
                    continue
                for paragraph in fixed_article:
                    for word, qids in paragraph:
                        if len(qids) > 0:
                            fout.write(word.rstrip() + "\t" +
                                       "\t".join(qids + [article_qid]) + "\n")
                        else:
                            fout.write(word.rstrip() + "\n")
                    fout.write("\n")
                seen += 1
                if seen >= config.sample_size:
                    break
        finally:
            fout.flush()
            fout.close()
Exemplo n.º 4
0
def main():
    args = parse_args()
    config = load_config(args.config, [
        "wiki", "language_path", "wikidata", "redirections", "classification"
    ],
                         defaults={
                             "num_names_to_load": 0,
                             "prefix": None,
                             "sample_size": 100,
                             "wiki": None,
                             "fix_links": False,
                             "min_count": 0,
                             "min_percent": 0.0
                         },
                         relative_to=args.relative_to)
    if config.wiki is None:
        raise ValueError("must provide path to 'wiki' in config.")
    prefix = get_prefix(config)
    collection = TypeCollection(config.wikidata,
                                num_names_to_load=config.num_names_to_load,
                                prefix=prefix,
                                verbose=True)
    collection.load_blacklist(join(SCRIPT_DIR, "blacklist.json"))

    fname = config.wiki
    test_tags = fix_and_parse_tags(config, collection, config.sample_size)
    aucs = load_aucs()
    ids = sorted(
        set([
            idx for doc_tags in test_tags for _, tag in doc_tags
            if tag is not None for idx in tag[2] if len(tag[2]) > 1
        ]))
    id2pos = {idx: k for k, idx in enumerate(ids)}
    # use reduced identity system:
    remapped_tags = []
    for doc_tags in test_tags:
        for text, tag in doc_tags:
            if tag is not None:
                remapped_tags.append(
                    (id2pos[tag[1]] if len(tag[2]) > 1 else tag[1],
                     np.array([id2pos[idx] for idx in tag[2]])
                     if len(tag[2]) > 1 else tag[2], tag[3]))
    test_tags = remapped_tags

    aucs = {key: value for key, value in aucs.items() if value > 0.5}
    print("%d relations to pick from with %d ids." % (len(aucs), len(ids)),
          flush=True)
    cached_satisfy = get_cached_satisfy(collection,
                                        aucs,
                                        ids,
                                        mmap=args.method == "greedy")
    del collection
    key2row = {key: k for k, key in enumerate(sorted(aucs.keys()))}

    if args.method == "greedy":
        picks, _ = beam_project(cached_satisfy,
                                key2row,
                                remapped_tags,
                                aucs,
                                ids,
                                beam_width=1,
                                penalty=args.penalty,
                                log=args.log)
    elif args.method == "beam":
        picks, _ = beam_project(cached_satisfy,
                                key2row,
                                remapped_tags,
                                aucs,
                                ids,
                                beam_width=args.beam_width,
                                penalty=args.penalty,
                                log=args.log)
    elif args.method == "cem":
        picks, _ = cem_project(cached_satisfy,
                               key2row,
                               remapped_tags,
                               aucs,
                               ids,
                               n_samples=args.samples,
                               penalty=args.penalty,
                               log=args.log)
    elif args.method == "ga":
        picks, _ = ga_project(cached_satisfy,
                              key2row,
                              remapped_tags,
                              aucs,
                              ids,
                              ngen=args.ngen,
                              n_samples=args.samples,
                              penalty=args.penalty,
                              log=args.log)
    else:
        raise ValueError("unknown method %r." % (args.method, ))
    with open(args.out, "wt") as fout:
        json.dump(picks, fout)
Exemplo n.º 5
0
def main():
    args = parse_args()
    if args.new_language_path == args.language_path:
        raise ValueError("new_language_path and language_path must be "
                         "different: cannot generate a fixed trie in "
                         "the same directory as the original trie.")

    c = TypeCollection(args.wikidata, num_names_to_load=0)
    c.load_blacklist(join(SCRIPT_DIR, "blacklist.json"))
    original_values = np.load(
        join(args.language_path, "trie_index2indices_values.npy"))
    original_offsets = np.load(
        join(args.language_path, "trie_index2indices_offsets.npy"))
    original_counts = np.load(
        join(args.language_path, "trie_index2indices_counts.npy"))
    original_trie_path = join(args.language_path, 'trie.marisa')
    trie = marisa_trie.Trie().load(original_trie_path)
    initialize_globals(c)
    t0 = time.time()

    old_location_shift = None
    values, offsets, counts = original_values, original_offsets, original_counts
    for step in range(args.steps):
        anchor_length = get_trie_properties(trie, offsets, values)
        (offsets, values,
         counts), location_shift = fix(collection=c,
                                       offsets=offsets,
                                       values=values,
                                       counts=counts,
                                       anchor_length=anchor_length,
                                       num_category_link=8)
        if old_location_shift is not None:
            # see where newly shifted values are now pointing
            # to (extra indirection level):
            location_shift = location_shift[old_location_shift]
            location_shift[old_location_shift == -1] = -1
        old_location_shift = location_shift
        pre_reduced_values = values[location_shift]
        pre_reduced_values[location_shift == -1] = -1
        num_changes = int((pre_reduced_values != original_values).sum())
        change_volume = int(
            (original_counts[pre_reduced_values != original_values].sum()))
        print("step %d with %d changes, %d total links" %
              (step, num_changes, change_volume))
    pre_reduced_values = values[location_shift]
    pre_reduced_values[location_shift == -1] = -1
    t1 = time.time()
    num_changes = int((pre_reduced_values != original_values).sum())
    print("Done with link fixing in %.3fs, with %d changes." %
          (t1 - t0, num_changes))

    # show some remappings:
    np.random.seed(1234)
    num_samples = 10
    samples = np.random.choice(np.where(
        np.logical_and(
            np.logical_and(pre_reduced_values != original_values,
                           pre_reduced_values != -1),
            original_values != -1))[0],
                               size=num_samples,
                               replace=False)
    print("Sample fixes:")
    for index in samples:
        print("   %r (%d) -> %r (%d)" %
              (c.get_name(int(
                  original_values[index])), int(original_values[index]),
               c.get_name(int(pre_reduced_values[index])),
               int(pre_reduced_values[index])))
    print("")

    samples = np.random.choice(
        np.where(OffsetArray(values, offsets).edges() == 0)[0],
        size=num_samples,
        replace=False)
    print("Sample deletions:")
    for index in samples:
        print("   %r" % (trie.restore_key(int(index))))

    # prune out anchors where there are no more linked items:
    print("Removing empty anchors from trie...")
    t0 = time.time()
    non_empty_offsets = np.where(OffsetArray(values, offsets).edges() != 0)[0]
    fixed_trie = filter_trie(trie, non_empty_offsets)

    contexts_found = true_exists(
        join(args.language_path, "trie_index2contexts_values.npy"))
    if contexts_found:
        contexts_values = np.load(
            join(args.language_path, "trie_index2contexts_values.npy"))
        contexts_offsets = np.load(
            join(args.language_path, "trie_index2contexts_offsets.npy"))
        contexts_counts = np.load(
            join(args.language_path, "trie_index2contexts_counts.npy"))

    to_port = [(offsets, values, counts),
               (original_offsets, pre_reduced_values, original_values)]
    if contexts_found:
        to_port.append((contexts_offsets, contexts_values, contexts_counts))

    ported = remap_trie_offset_array(trie, fixed_trie, to_port)
    offsets, values, counts = ported[0]
    original_offsets, pre_reduced_values, original_values = ported[1]
    t1 = time.time()
    print("Removed %d empty anchors from trie in %.3fs" % (
        len(trie) - len(fixed_trie),
        t1 - t0,
    ))

    print("Saving...")
    makedirs(args.new_language_path, exist_ok=True)

    np.save(join(args.new_language_path, "trie_index2indices_values.npy"),
            values)
    np.save(join(args.new_language_path, "trie_index2indices_offsets.npy"),
            offsets)
    np.save(join(args.new_language_path, "trie_index2indices_counts.npy"),
            counts)
    if contexts_found:
        contexts_offsets, contexts_values, contexts_counts = ported[2]
        np.save(join(args.new_language_path, "trie_index2contexts_values.npy"),
                contexts_values)
        np.save(
            join(args.new_language_path, "trie_index2contexts_offsets.npy"),
            contexts_offsets)
        np.save(join(args.new_language_path, "trie_index2contexts_counts.npy"),
                contexts_counts)
    new_trie_path = join(args.new_language_path, 'trie.marisa')
    fixed_trie.save(new_trie_path)

    transition = np.vstack([original_values, pre_reduced_values]).T
    np.save(
        join(args.new_language_path,
             "trie_index2indices_transition_values.npy"), transition)
    np.save(
        join(args.new_language_path,
             "trie_index2indices_transition_offsets.npy"), original_offsets)
    print("Done.")
Exemplo n.º 6
0
def main():
    args = parse_args()
    should_export = args.export_classification is not None
    if should_export and len(args.export_classification) != len(
            args.classifiers):
        raise ValueError("Must have as many export filenames as classifiers.")
    collection = TypeCollection(args.wikidata,
                                num_names_to_load=args.num_names_to_load,
                                language_path=args.language_path,
                                cache=args.use_cache)
    if args.interactive:
        alert_failure = enter_or_quit
    else:
        alert_failure = lambda: sys.exit(1)

    while True:
        try:
            collection.load_blacklist(join(SCRIPT_DIR, "blacklist.json"))
        except (ValueError, ) as e:
            print("Issue reading blacklist, please fix.")
            print(str(e))
            alert_failure()
            continue

        classifications = []
        for class_idx, classifier_fname in enumerate(args.classifiers):
            while True:
                try:
                    classifier = reload_module(classifier_fname)
                except ALLOWED_IMPORT_ERRORS as e:
                    print("issue reading %r, please fix." %
                          (classifier_fname, ))
                    print(str(e))
                    traceback.print_exc(file=sys.stdout)
                    alert_failure()
                    continue

                try:
                    t0 = time.time()
                    classification = classifier.classify(collection)
                    classifications.append(classification)
                    if class_idx == len(args.classifiers) - 1:
                        collection.reset_cache()
                    t1 = time.time()
                    print("classification took %.3fs" % (t1 - t0, ))
                except ALLOWED_RUNTIME_ERRORS as e:
                    print("issue running %r, please fix." %
                          (classifier_fname, ))
                    print(str(e))
                    traceback.print_exc(file=sys.stdout)
                    alert_failure()
                    continue
                break
        try:
            # show cardinality for each truth table:
            if args.interactive:
                mega_other_class = None
                for classification in classifications:
                    for classname in sorted(classification.keys()):
                        print(
                            "%r: %d members" %
                            (classname, int(classification[classname].sum())))
                    print("")
                    summary = {}
                    for classname, truth_table in classification.items():
                        (members, ) = np.where(truth_table)
                        summary[classname] = [
                            collection.get_name(int(member))
                            for member in members[:20]
                        ]
                    print(json.dumps(summary, indent=4))

                    other_class = get_other_class(classification)
                    if other_class.sum() > 0:
                        # there are missing items:
                        to_report = (classifier.class_report if hasattr(
                            classifier, "class_report") else [
                                wprop.SUBCLASS_OF, wprop.INSTANCE_OF,
                                wprop.OCCUPATION, wprop.CATEGORY_LINK
                            ])
                        collection.class_report(to_report,
                                                other_class,
                                                name="Other")
                        if mega_other_class is None:
                            mega_other_class = other_class
                        else:
                            mega_other_class = np.logical_and(
                                mega_other_class, other_class)
                if len(classifications) > 1:
                    if mega_other_class.sum() > 0:
                        # there are missing items:
                        to_report = [
                            wprop.SUBCLASS_OF, wprop.INSTANCE_OF,
                            wprop.OCCUPATION, wprop.CATEGORY_LINK
                        ]
                        collection.class_report(to_report,
                                                mega_other_class,
                                                name="Other-combined")
            if should_export:
                assert (len(classifications) == len(
                    args.export_classification)), (
                        "classification outputs missing for export.")
                for classification, savename in zip(
                        classifications, args.export_classification):
                    export_classification(classification, savename)
        except KeyboardInterrupt as e:
            pass

        if args.interactive:
            enter_or_quit()
        else:
            break