def build_minimal_wordnet_graph(wnids, multi_path=False): G = nx.DiGraph() for wnid in wnids: G.add_node(wnid) synset = wnid_to_synset(wnid) set_node_label(G, synset) if wnid == "n10129825": # hardcode 'girl' to not be child of 'woman' if not multi_path: G.add_edge( "n09624168", "n10129825") # child of 'male' (sibling to 'male_child') else: G.add_edge("n09619168", "n10129825") # child of 'female' G.add_edge("n09619168", "n10129825") # child of 'female' continue hypernyms = [synset] while hypernyms: current = hypernyms.pop(0) set_node_label(G, current) for hypernym in current.hypernyms(): G.add_edge(synset_to_wnid(hypernym), synset_to_wnid(current)) hypernyms.append(hypernym) if not multi_path: break children = [(key, wnid_to_synset(key).name()) for key in G.succ[wnid]] assert ( len(children) == 0 ), f"Node {wnid} ({synset.name()}) is not a leaf. Children: {children}" return G
def pick_unseen_hypernym(G, common_hypernyms): assert len(common_hypernyms) > 0 candidate = deepest_synset(common_hypernyms) wnid = synset_to_wnid(candidate) while common_hypernyms and wnid in G.nodes: common_hypernyms -= {candidate} if not common_hypernyms: return None candidate = deepest_synset(common_hypernyms) wnid = synset_to_wnid(candidate) return candidate
def add_node_to_graph(G, candidate, children): root = get_root(G) wnid = synset_to_wnid(candidate) G.add_node(wnid) set_node_label(G, candidate) for child in children: G.add_edge(wnid, child) G.add_edge(root, wnid)
def set_node_label(G, synset): nx.set_node_attributes(G, {synset_to_wnid(synset): synset_to_name(synset)}, "label")
def build_induced_graph( wnids, checkpoint, model=None, linkage="ward", affinity="euclidean", branching_factor=2, dataset="CIFAR10", state_dict=None, ): num_classes = len(wnids) assert ( checkpoint or model or state_dict ), "Need to specify either `checkpoint` or `method` or `state_dict`." if state_dict: centers = get_centers_from_state_dict(state_dict) elif checkpoint: centers = get_centers_from_checkpoint(checkpoint) else: centers = get_centers_from_model(model, num_classes, dataset) assert num_classes == centers.size(0), ( f"The model FC supports {centers.size(0)} classes. However, the dataset" f" {dataset} features {num_classes} classes. Try passing the " "`--dataset` with the right number of classes.") if centers.is_cuda: centers = centers.cpu() G = nx.DiGraph() # add leaves for wnid in wnids: G.add_node(wnid) set_node_label(G, wnid_to_synset(wnid)) # add rest of tree clustering = AgglomerativeClustering( linkage=linkage, n_clusters=branching_factor, affinity=affinity, ).fit(centers) children = clustering.children_ index_to_wnid = {} for index, pair in enumerate(map(tuple, children)): child_wnids = [] child_synsets = [] for child in pair: if child < num_classes: child_wnid = wnids[child] else: child_wnid = index_to_wnid[child - num_classes] child_wnids.append(child_wnid) child_synsets.append(wnid_to_synset(child_wnid)) parent = get_wordnet_meaning(G, child_synsets) parent_wnid = synset_to_wnid(parent) G.add_node(parent_wnid) set_node_label(G, parent) index_to_wnid[index] = parent_wnid for child_wnid in child_wnids: G.add_edge(parent_wnid, child_wnid) assert len(list(get_roots(G))) == 1, list(get_roots(G)) return G