def get_color_info(G, color, color_leaves, color_path_to=None, color_nodes=()): """Mapping from node to color information.""" nodes = {} leaves = list(get_leaves(G)) if color_leaves: for leaf in leaves: nodes[leaf] = {'color': color} for (id, node) in G.nodes.items(): if node.get('label', '') in color_nodes or id in color_nodes: nodes[id] = {'color': color} root = get_root(G) target = None for leaf in leaves: node = G.nodes[leaf] if node.get('label', '') == color_path_to or leaf == color_path_to: target = leaf break if target is not None: while target != root: nodes[target] = {'color': color, 'color_incident_edge': True} view = G.pred[target] target = list(view.keys())[0] nodes[root] = {'color': color} return nodes
def traverse_tree(cls, wnid_to_outputs, nodes, wnid_to_class, classes): """Convert node outputs to final prediction. Note that the prediction output for this function can NOT be trained on. The outputs have been detached from the computation graph. """ # move all to CPU, detach from computation graph example = wnid_to_outputs[nodes[0].wnid] n_samples = int(example['logits'].size(0)) device = example['logits'].device for wnid in tuple(wnid_to_outputs.keys()): outputs = wnid_to_outputs[wnid] outputs['preds'] = list(map(int, outputs['preds'].cpu())) outputs['probs'] = outputs['probs'].detach().cpu() wnid_to_node = {node.wnid: node for node in nodes} wnid_root = get_root(nodes[0].G) node_root = wnid_to_node[wnid_root] decisions = [] preds = [] for index in range(n_samples): decision = [{'node': node_root, 'name': 'root', 'prob': 1}] wnid, node = wnid_root, node_root while node is not None: if node.wnid not in wnid_to_outputs: wnid = node = None break outputs = wnid_to_outputs[node.wnid] index_child = outputs['preds'][index] prob_child = float(outputs['probs'][index][index_child]) wnid = node.children[index_child] node = wnid_to_node.get(wnid, None) decision.append({ 'node': node, 'name': wnid_to_name(wnid), 'prob': prob_child }) cls = wnid_to_class.get(wnid, None) pred = -1 if cls is None else classes.index(cls) preds.append(pred) decisions.append(decision) return torch.Tensor(preds).long().to(device), decisions
def traverse_tree(cls, wnid_to_outputs, nodes, wnid_to_class, classes): """Convert node outputs to final prediction. Note that the prediction output for this function can NOT be trained on. The outputs have been detached from the computation graph. """ # move all to CPU, detach from computation graph example = wnid_to_outputs[nodes[0].wnid] n_samples = int(example['logits'].size(0)) for wnid in tuple(wnid_to_outputs.keys()): outputs = wnid_to_outputs[wnid] outputs['preds'] = list(map(int, outputs['preds'].cpu())) outputs['probs'] = outputs['probs'].detach().cpu() wnid_to_node = {node.wnid: node for node in nodes} wnid_root = get_root(nodes[0].G) node_root = wnid_to_node[wnid_root] decisions = [] preds = [] names = [] tree = [] # retprob(names, tree, wnid_to_outputs, wnid_to_node, wnid_root) path_inds = [] for index in range(n_samples): decision = [{'node': node_root, 'name': 'root', 'prob': 1}] wnid, node = wnid_root, node_root tr = [] nm = [] tr_wn = [] path_wnids = [] retprob(nm, tr, tr_wn, wnid_to_outputs, wnid_to_node, wnid, index) names.append(nm) # print('---') # print(tr) # print('---') tree.append(tr) while node is not None: if node.wnid not in wnid_to_outputs: wnid = node = None break outputs = wnid_to_outputs[node.wnid] index_child = outputs['preds'][index] prob_child = float(outputs['probs'][index][index_child]) wnid = node.children[index_child] node = wnid_to_node.get(wnid, None) decision.append({ 'node': node, 'name': wnid_to_name(wnid), 'prob': prob_child }) path_wnids.append(wnid) cls = wnid_to_class.get(wnid, None) pred = -1 if cls is None else classes.index(cls) preds.append(pred) decisions.append(decision) pth = [] for element in path_wnids: pth.append(tr_wn.index(element)) path_inds.append(pth) # print("Tree: ",tree) # print("Names: ",names) # print("Path Indices: ",path_inds) return torch.Tensor(preds).long(), decisions, tree, names, path_inds