Ejemplo n.º 1
0
def build_minimal_wordnet_graph(wnids, multi_path=False):
    G = nx.DiGraph()

    for wnid in wnids:
        G.add_node(wnid)
        synset = wnid_to_synset(wnid)
        set_node_label(G, synset)

        if wnid == "n10129825":  # hardcode 'girl' to not be child of 'woman'
            if not multi_path:
                G.add_edge(
                    "n09624168",
                    "n10129825")  # child of 'male' (sibling to 'male_child')
            else:
                G.add_edge("n09619168", "n10129825")  # child of 'female'
            G.add_edge("n09619168", "n10129825")  # child of 'female'
            continue

        hypernyms = [synset]
        while hypernyms:
            current = hypernyms.pop(0)
            set_node_label(G, current)
            for hypernym in current.hypernyms():
                G.add_edge(synset_to_wnid(hypernym), synset_to_wnid(current))
                hypernyms.append(hypernym)

                if not multi_path:
                    break

        children = [(key, wnid_to_synset(key).name()) for key in G.succ[wnid]]
        assert (
            len(children) == 0
        ), f"Node {wnid} ({synset.name()}) is not a leaf. Children: {children}"
    return G
Ejemplo n.º 2
0
def pick_unseen_hypernym(G, common_hypernyms):
    assert len(common_hypernyms) > 0

    candidate = deepest_synset(common_hypernyms)
    wnid = synset_to_wnid(candidate)

    while common_hypernyms and wnid in G.nodes:
        common_hypernyms -= {candidate}
        if not common_hypernyms:
            return None

        candidate = deepest_synset(common_hypernyms)
        wnid = synset_to_wnid(candidate)
    return candidate
Ejemplo n.º 3
0
def add_node_to_graph(G, candidate, children):
    root = get_root(G)

    wnid = synset_to_wnid(candidate)
    G.add_node(wnid)
    set_node_label(G, candidate)

    for child in children:
        G.add_edge(wnid, child)
    G.add_edge(root, wnid)
Ejemplo n.º 4
0
def set_node_label(G, synset):
    nx.set_node_attributes(G, {synset_to_wnid(synset): synset_to_name(synset)},
                           "label")
Ejemplo n.º 5
0
def build_induced_graph(
    wnids,
    checkpoint,
    model=None,
    linkage="ward",
    affinity="euclidean",
    branching_factor=2,
    dataset="CIFAR10",
    state_dict=None,
):
    num_classes = len(wnids)
    assert (
        checkpoint or model or state_dict
    ), "Need to specify either `checkpoint` or `method` or `state_dict`."
    if state_dict:
        centers = get_centers_from_state_dict(state_dict)
    elif checkpoint:
        centers = get_centers_from_checkpoint(checkpoint)
    else:
        centers = get_centers_from_model(model, num_classes, dataset)
    assert num_classes == centers.size(0), (
        f"The model FC supports {centers.size(0)} classes. However, the dataset"
        f" {dataset} features {num_classes} classes. Try passing the "
        "`--dataset` with the right number of classes.")

    if centers.is_cuda:
        centers = centers.cpu()

    G = nx.DiGraph()

    # add leaves
    for wnid in wnids:
        G.add_node(wnid)
        set_node_label(G, wnid_to_synset(wnid))

    # add rest of tree
    clustering = AgglomerativeClustering(
        linkage=linkage,
        n_clusters=branching_factor,
        affinity=affinity,
    ).fit(centers)
    children = clustering.children_
    index_to_wnid = {}

    for index, pair in enumerate(map(tuple, children)):
        child_wnids = []
        child_synsets = []
        for child in pair:
            if child < num_classes:
                child_wnid = wnids[child]
            else:
                child_wnid = index_to_wnid[child - num_classes]
            child_wnids.append(child_wnid)
            child_synsets.append(wnid_to_synset(child_wnid))

        parent = get_wordnet_meaning(G, child_synsets)
        parent_wnid = synset_to_wnid(parent)
        G.add_node(parent_wnid)
        set_node_label(G, parent)
        index_to_wnid[index] = parent_wnid

        for child_wnid in child_wnids:
            G.add_edge(parent_wnid, child_wnid)

    assert len(list(get_roots(G))) == 1, list(get_roots(G))
    return G