Example #1
0
def get_color_info(G, color, color_leaves, color_path_to=None, color_nodes=()):
    """Mapping from node to color information."""
    nodes = {}
    leaves = list(get_leaves(G))
    if color_leaves:
        for leaf in leaves:
            nodes[leaf] = {'color': color}

    for (id, node) in G.nodes.items():
        if node.get('label', '') in color_nodes or id in color_nodes:
            nodes[id] = {'color': color}

    root = get_root(G)
    target = None
    for leaf in leaves:
        node = G.nodes[leaf]
        if node.get('label', '') == color_path_to or leaf == color_path_to:
            target = leaf
            break

    if target is not None:
        while target != root:
            nodes[target] = {'color': color, 'color_incident_edge': True}
            view = G.pred[target]
            target = list(view.keys())[0]
        nodes[root] = {'color': color}
    return nodes
    def traverse_tree(cls, wnid_to_outputs, nodes, wnid_to_class, classes):
        """Convert node outputs to final prediction.

        Note that the prediction output for this function can NOT be trained
        on. The outputs have been detached from the computation graph.
        """
        # move all to CPU, detach from computation graph
        example = wnid_to_outputs[nodes[0].wnid]
        n_samples = int(example['logits'].size(0))
        device = example['logits'].device

        for wnid in tuple(wnid_to_outputs.keys()):
            outputs = wnid_to_outputs[wnid]
            outputs['preds'] = list(map(int, outputs['preds'].cpu()))
            outputs['probs'] = outputs['probs'].detach().cpu()

        wnid_to_node = {node.wnid: node for node in nodes}
        wnid_root = get_root(nodes[0].G)
        node_root = wnid_to_node[wnid_root]

        decisions = []
        preds = []
        for index in range(n_samples):
            decision = [{'node': node_root, 'name': 'root', 'prob': 1}]
            wnid, node = wnid_root, node_root
            while node is not None:
                if node.wnid not in wnid_to_outputs:
                    wnid = node = None
                    break
                outputs = wnid_to_outputs[node.wnid]
                index_child = outputs['preds'][index]
                prob_child = float(outputs['probs'][index][index_child])
                wnid = node.children[index_child]
                node = wnid_to_node.get(wnid, None)
                decision.append({
                    'node': node,
                    'name': wnid_to_name(wnid),
                    'prob': prob_child
                })
            cls = wnid_to_class.get(wnid, None)
            pred = -1 if cls is None else classes.index(cls)
            preds.append(pred)
            decisions.append(decision)
        return torch.Tensor(preds).long().to(device), decisions
Example #3
0
    def traverse_tree(cls, wnid_to_outputs, nodes, wnid_to_class, classes):
        """Convert node outputs to final prediction.

        Note that the prediction output for this function can NOT be trained
        on. The outputs have been detached from the computation graph.
        """
        # move all to CPU, detach from computation graph
        example = wnid_to_outputs[nodes[0].wnid]
        n_samples = int(example['logits'].size(0))

        for wnid in tuple(wnid_to_outputs.keys()):
            outputs = wnid_to_outputs[wnid]
            outputs['preds'] = list(map(int, outputs['preds'].cpu()))
            outputs['probs'] = outputs['probs'].detach().cpu()

        wnid_to_node = {node.wnid: node for node in nodes}
        wnid_root = get_root(nodes[0].G)
        node_root = wnid_to_node[wnid_root]

        decisions = []
        preds = []

        names = []
        tree = []
        # retprob(names, tree, wnid_to_outputs, wnid_to_node, wnid_root)

        path_inds = []

        for index in range(n_samples):
            decision = [{'node': node_root, 'name': 'root', 'prob': 1}]
            wnid, node = wnid_root, node_root
            tr = []
            nm = []
            tr_wn = []
            path_wnids = []
            retprob(nm, tr, tr_wn, wnid_to_outputs, wnid_to_node, wnid, index)
            names.append(nm)
            # print('---')
            # print(tr)
            # print('---')
            tree.append(tr)
            while node is not None:
                if node.wnid not in wnid_to_outputs:
                    wnid = node = None
                    break
                outputs = wnid_to_outputs[node.wnid]
                index_child = outputs['preds'][index]
                prob_child = float(outputs['probs'][index][index_child])
                wnid = node.children[index_child]
                node = wnid_to_node.get(wnid, None)
                decision.append({
                    'node': node,
                    'name': wnid_to_name(wnid),
                    'prob': prob_child
                })
                path_wnids.append(wnid)
            cls = wnid_to_class.get(wnid, None)
            pred = -1 if cls is None else classes.index(cls)
            preds.append(pred)
            decisions.append(decision)

            pth = []
            for element in path_wnids:
                pth.append(tr_wn.index(element))

            path_inds.append(pth)

        # print("Tree: ",tree)
        # print("Names: ",names)
        # print("Path Indices: ",path_inds)

        return torch.Tensor(preds).long(), decisions, tree, names, path_inds