def match_wnid_leaves(wnids, G, tree_name): wnid_set = set() for wnid in wnids: wnid_set.add(wnid.strip()) leaves_seen = get_seen_wnids(wnid_set, get_leaves(G)) return leaves_seen, wnid_set
def get_color_info(G, color, color_leaves, color_path_to=None, color_nodes=()): """Mapping from node to color information.""" nodes = {} leaves = list(get_leaves(G)) if color_leaves: for leaf in leaves: nodes[leaf] = {'color': color} for (id, node) in G.nodes.items(): if node.get('label', '') in color_nodes or id in color_nodes: nodes[id] = {'color': color} root = get_root(G) target = None for leaf in leaves: node = G.nodes[leaf] if node.get('label', '') == color_path_to or leaf == color_path_to: target = leaf break if target is not None: while target != root: nodes[target] = {'color': color, 'color_incident_edge': True} view = G.pred[target] target = list(view.keys())[0] nodes[root] = {'color': color} return nodes
def build_class_mappings(self): old_to_new = defaultdict(lambda: []) new_to_old = defaultdict(lambda: []) for new_index, child in enumerate(self.get_children()): for leaf in get_leaves(self.G, child): old_index = self.wnid_to_class_index(leaf) old_to_new[old_index].append(new_index) new_to_old[new_index].append(old_index) if not self.has_other: return old_to_new, new_to_old new_index = self.num_children for old in range(self.num_original_classes): if old not in old_to_new: old_to_new[old].append(new_index) new_to_old[new_index].append(old) return old_to_new, new_to_old
def get_leaves(self): return get_leaves(self.G, self.wnid)
def build_tree(G, root, parent='null', color_info=(), force_labels_left=(), include_leaf_images=False, dataset=None, image_resize_factor=1, include_fake_sublabels=False, include_fake_labels=False, node_to_conf={}): """ :param color_info dict[str, dict]: mapping from node labels or IDs to color information. This is by default just a key called 'color' """ children = [ build_tree(G, child, root, color_info=color_info, force_labels_left=force_labels_left, include_leaf_images=include_leaf_images, dataset=dataset, image_resize_factor=image_resize_factor, include_fake_sublabels=include_fake_sublabels, include_fake_labels=include_fake_labels, node_to_conf=node_to_conf) for child in G.succ[root] ] _node = G.nodes[root] label = _node.get('label', '') sublabel = root if root.startswith('f') and label.startswith( '(') and not include_fake_labels: label = '' if root.startswith( 'f' ) and not include_fake_sublabels: # WARNING: hacky, ignores fake wnids -- this will have to be changed lol sublabel = '' node = { 'sublabel': sublabel, 'label': label, 'parent': parent, 'children': children, 'alt': _node.get('alt', ', '.join(map(wnid_to_name, get_leaves(G, root=root)))), 'id': root } if label in color_info: node.update(color_info[label]) if root in color_info: node.update(color_info[root]) if label in force_labels_left: node['force_text_on_left'] = True is_leaf = len(children) == 0 if include_leaf_images and is_leaf: try: image = get_class_image_from_dataset(dataset, label) except UserWarning as e: print(e) return node base64_encode = image_to_base64_encode(image, format="jpeg") image_href = f"data:image/jpeg;base64,{base64_encode.decode('utf-8')}" image_height, image_width = image.size node['image'] = { 'href': image_href, 'width': image_width * image_resize_factor, 'height': image_height * image_resize_factor } for key, value in node_to_conf[root].items(): set_dot_notation(node, key, value) return node