def match_wnid_leaves(wnids, G, tree_name):
    wnid_set = set()
    for wnid in wnids:
        wnid_set.add(wnid.strip())

    leaves_seen = get_seen_wnids(wnid_set, get_leaves(G))
    return leaves_seen, wnid_set
def get_color_info(G,
                   color,
                   color_leaves,
                   color_path_to=None,
                   color_nodes=(),
                   theme="regular"):
    """Mapping from node to color information."""
    nodes = {}

    theme_to_bg = {"minimal": "#EEEEEE", "dark": "#111111"}
    nodes["bg"] = theme_to_bg.get(theme, "#FFFFFF")

    theme_to_text_rect = {
        "minimal": "rgba(0,0,0,0)",
        "dark": "rgba(17,17,17,0.8)",
    }
    nodes["text_rect"] = theme_to_text_rect.get(theme, "rgba(255,255,255,0.8)")

    leaves = list(get_leaves(G))
    if color_leaves:
        for leaf in leaves:
            nodes[leaf] = {"color": color, "highlighted": True, "theme": theme}

    for (id, node) in G.nodes.items():
        if node.get("label", "") in color_nodes or id in color_nodes:
            nodes[id] = {"color": color, "highlighted": True, "theme": theme}
        else:
            nodes[id] = {"color": "gray", "theme": theme}

    root = get_root(G)
    target = None
    for leaf in leaves:
        node = G.nodes[leaf]
        if node.get("label", "") == color_path_to or leaf == color_path_to:
            target = leaf
            break

    if target is not None:
        for node in G.nodes:
            nodes[node] = {
                "color": "#cccccc",
                "color_incident_edge": True,
                "highlighted": False,
                "theme": theme,
            }

        while target != root:
            nodes[target] = {
                "color": color,
                "color_incident_edge": True,
                "highlighted": True,
                "theme": theme,
            }
            view = G.pred[target]
            target = list(view.keys())[0]
        nodes[root] = {"color": color, "highlighted": True, "theme": theme}
    return nodes
예제 #3
0
    def build_class_mappings(self):
        if self.is_leaf():
            return {}, {}

        old_to_new = defaultdict(lambda: [])
        new_to_old = defaultdict(lambda: [])
        for new_index, child in enumerate(self.succ):
            for leaf in get_leaves(self.tree.G, child):
                old_index = self.wnid_to_class_index(leaf)
                old_to_new[old_index].append(new_index)
                new_to_old[new_index].append(old_index)

        if not self.has_other:
            return old_to_new, new_to_old

        new_index = self.num_children
        for old in range(self.num_original_classes):
            if old not in old_to_new:
                old_to_new[old].append(new_index)
                new_to_old[new_index].append(old)
        return old_to_new, new_to_old
def build_tree(
    G,
    root,
    parent="null",
    color_info=(),
    force_labels_left=(),
    include_leaf_images=False,
    dataset=None,
    image_resize_factor=1,
    include_fake_sublabels=False,
    include_fake_labels=False,
    node_to_conf={},
):
    """
    :param color_info dict[str, dict]: mapping from node labels or IDs to color
                                       information. This is by default just a
                                       key called 'color'
    """
    children = [
        build_tree(
            G,
            child,
            root,
            color_info=color_info,
            force_labels_left=force_labels_left,
            include_leaf_images=include_leaf_images,
            dataset=dataset,
            image_resize_factor=image_resize_factor,
            include_fake_sublabels=include_fake_sublabels,
            include_fake_labels=include_fake_labels,
            node_to_conf=node_to_conf,
        ) for child in G.succ[root]
    ]
    _node = G.nodes[root]
    label = _node.get("label", "")
    sublabel = root

    if root.startswith("f") and label.startswith(
            "(") and not include_fake_labels:
        label = ""

    if (
            root.startswith("f") and not include_fake_sublabels
    ):  # WARNING: hacky, ignores fake wnids -- this will have to be changed lol
        sublabel = ""

    node = {
        "sublabel":
        sublabel,
        "label":
        label,
        "parent":
        parent,
        "children":
        children,
        "alt":
        _node.get("alt", ", ".join(map(wnid_to_name, get_leaves(G,
                                                                root=root)))),
        "id":
        root,
    }

    if label in color_info:
        node.update(color_info[label])

    if root in color_info:
        node.update(color_info[root])

    if label in force_labels_left:
        node["force_text_on_left"] = True

    is_leaf = len(children) == 0
    if include_leaf_images and is_leaf:
        try:
            image = get_class_image_from_dataset(dataset, label)
        except UserWarning as e:
            print(e)
            return node
        base64_encode = image_to_base64_encode(image, format="jpeg")
        image_href = f"data:image/jpeg;base64,{base64_encode.decode('utf-8')}"
        image_height, image_width = image.size
        node["image"] = {
            "href": image_href,
            "width": image_width * image_resize_factor,
            "height": image_height * image_resize_factor,
        }

    for key, value in node_to_conf[root].items():
        set_dot_notation(node, key, value)
    return node
예제 #5
0
 def get_leaves(self):
     return get_leaves(self.tree.G, self.wnid)