def match_wnid_leaves(wnids, G, tree_name): wnid_set = set() for wnid in wnids: wnid_set.add(wnid.strip()) leaves_seen = get_seen_wnids(wnid_set, get_leaves(G)) return leaves_seen, wnid_set
def get_color_info(G, color, color_leaves, color_path_to=None, color_nodes=(), theme="regular"): """Mapping from node to color information.""" nodes = {} theme_to_bg = {"minimal": "#EEEEEE", "dark": "#111111"} nodes["bg"] = theme_to_bg.get(theme, "#FFFFFF") theme_to_text_rect = { "minimal": "rgba(0,0,0,0)", "dark": "rgba(17,17,17,0.8)", } nodes["text_rect"] = theme_to_text_rect.get(theme, "rgba(255,255,255,0.8)") leaves = list(get_leaves(G)) if color_leaves: for leaf in leaves: nodes[leaf] = {"color": color, "highlighted": True, "theme": theme} for (id, node) in G.nodes.items(): if node.get("label", "") in color_nodes or id in color_nodes: nodes[id] = {"color": color, "highlighted": True, "theme": theme} else: nodes[id] = {"color": "gray", "theme": theme} root = get_root(G) target = None for leaf in leaves: node = G.nodes[leaf] if node.get("label", "") == color_path_to or leaf == color_path_to: target = leaf break if target is not None: for node in G.nodes: nodes[node] = { "color": "#cccccc", "color_incident_edge": True, "highlighted": False, "theme": theme, } while target != root: nodes[target] = { "color": color, "color_incident_edge": True, "highlighted": True, "theme": theme, } view = G.pred[target] target = list(view.keys())[0] nodes[root] = {"color": color, "highlighted": True, "theme": theme} return nodes
def build_class_mappings(self): if self.is_leaf(): return {}, {} old_to_new = defaultdict(lambda: []) new_to_old = defaultdict(lambda: []) for new_index, child in enumerate(self.succ): for leaf in get_leaves(self.tree.G, child): old_index = self.wnid_to_class_index(leaf) old_to_new[old_index].append(new_index) new_to_old[new_index].append(old_index) if not self.has_other: return old_to_new, new_to_old new_index = self.num_children for old in range(self.num_original_classes): if old not in old_to_new: old_to_new[old].append(new_index) new_to_old[new_index].append(old) return old_to_new, new_to_old
def build_tree( G, root, parent="null", color_info=(), force_labels_left=(), include_leaf_images=False, dataset=None, image_resize_factor=1, include_fake_sublabels=False, include_fake_labels=False, node_to_conf={}, ): """ :param color_info dict[str, dict]: mapping from node labels or IDs to color information. This is by default just a key called 'color' """ children = [ build_tree( G, child, root, color_info=color_info, force_labels_left=force_labels_left, include_leaf_images=include_leaf_images, dataset=dataset, image_resize_factor=image_resize_factor, include_fake_sublabels=include_fake_sublabels, include_fake_labels=include_fake_labels, node_to_conf=node_to_conf, ) for child in G.succ[root] ] _node = G.nodes[root] label = _node.get("label", "") sublabel = root if root.startswith("f") and label.startswith( "(") and not include_fake_labels: label = "" if ( root.startswith("f") and not include_fake_sublabels ): # WARNING: hacky, ignores fake wnids -- this will have to be changed lol sublabel = "" node = { "sublabel": sublabel, "label": label, "parent": parent, "children": children, "alt": _node.get("alt", ", ".join(map(wnid_to_name, get_leaves(G, root=root)))), "id": root, } if label in color_info: node.update(color_info[label]) if root in color_info: node.update(color_info[root]) if label in force_labels_left: node["force_text_on_left"] = True is_leaf = len(children) == 0 if include_leaf_images and is_leaf: try: image = get_class_image_from_dataset(dataset, label) except UserWarning as e: print(e) return node base64_encode = image_to_base64_encode(image, format="jpeg") image_href = f"data:image/jpeg;base64,{base64_encode.decode('utf-8')}" image_height, image_width = image.size node["image"] = { "href": image_href, "width": image_width * image_resize_factor, "height": image_height * image_resize_factor, } for key, value in node_to_conf[root].items(): set_dot_notation(node, key, value) return node
def get_leaves(self): return get_leaves(self.tree.G, self.wnid)