def test_hierarchy(args): wnids = get_wnids_from_dataset(args.dataset) path = get_graph_path_from_args(**vars(args)) print('==> Reading from {}'.format(path)) G = read_graph(path) G_name = Path(path).stem leaves_seen, wnid_set1 = match_wnid_leaves(wnids, G, G_name) print_stats(leaves_seen, wnid_set1, G_name, 'leaves') leaves_seen, wnid_set2 = match_wnid_nodes(wnids, G, G_name) print_stats(leaves_seen, wnid_set2, G_name, 'nodes') num_roots = len(list(get_roots(G))) if num_roots == 1: Colors.green('Found just 1 root.') else: Colors.red(f'Found {num_roots} roots. Should be only 1.') if len(wnid_set1) == len(wnid_set2) == 0 and num_roots == 1: Colors.green("==> All checks pass!") else: Colors.red('==> Test failed')
def generate_hierarchy_vis(args): path = get_graph_path_from_args(**vars(args)) print('==> Reading from {}'.format(path)) G = read_graph(path) roots = list(get_roots(G)) num_roots = len(roots) root = args.vis_root or next(get_roots(G)) assert root in G, f'Node {root} is not a valid node. Nodes: {G.nodes}' dataset = None if args.dataset and args.vis_leaf_images: cls = getattr(data, args.dataset) dataset = cls(root='./data', train=False, download=True) color_info = get_color_info(G, args.color, color_leaves=not args.vis_no_color_leaves, color_path_to=args.vis_color_path_to, color_nodes=args.vis_color_nodes or ()) node_to_conf = generate_node_conf(args.vis_node_conf) tree = build_tree(G, root, color_info=color_info, force_labels_left=args.vis_force_labels_left or [], dataset=dataset, include_leaf_images=args.vis_leaf_images, image_resize_factor=args.vis_image_resize_factor, include_fake_sublabels=args.vis_fake_sublabels, node_to_conf=node_to_conf) graph = build_graph(G) if num_roots > 1: Colors.red(f'Found {num_roots} roots! Should be only 1: {roots}') else: print(f'Found just {num_roots} root.') fname = generate_vis_fname(**vars(args)) parent = Path(fwd()).parent generate_vis(str(parent / 'nbdt/templates/tree-template.html'), tree, fname, zoom=args.vis_zoom, straight_lines=not args.vis_curved, show_sublabels=args.vis_sublabels, height=args.vis_height, width=args.vis_width, dark=args.vis_dark, margin_top=args.vis_margin_top, margin_left=args.vis_margin_left, hide=args.vis_hide or [], above_dy=args.vis_above_dy, below_dy=args.vis_below_dy, scale=args.vis_scale, root_y=args.vis_root_y, colormap=args.vis_colormap)
def get_wnid_to_node(path_graph, path_wnids, classes): wnid_to_node = {} G = read_graph(path_graph) for wnid in get_non_leaves(G): wnid_to_node[wnid] = Node(wnid, classes, path_graph=path_graph, path_wnids=path_wnids) return wnid_to_node
def generate_hierarchy_vis(args): path = get_graph_path_from_args(**vars(args)) print('==> Reading from {}'.format(path)) G = read_graph(path) roots = list(get_roots(G)) num_roots = len(roots) root = next(get_roots(G)) dataset = None if args.dataset: cls = getattr(data, args.dataset) dataset = cls(root='./data', train=False, download=True) color_info = get_color_info(G, args.color, color_leaves=not args.vis_no_color_leaves, color_path_to=args.vis_color_path_to, color_nodes=args.vis_color_nodes or ()) tree = build_tree(G, root, color_info=color_info, force_labels_left=args.vis_force_labels_left or [], dataset=dataset, include_leaf_images=args.vis_leaf_images, image_resize_factor=args.vis_image_resize_factor) graph = build_graph(G) if num_roots > 1: Colors.red(f'Found {num_roots} roots! Should be only 1: {roots}') else: print(f'Found just {num_roots} root.') fname = generate_vis_fname(**vars(args)) parent = Path(fwd()).parent generate_vis(str(parent / 'nbdt/templates/tree-template.html'), tree, 'tree', fname, zoom=args.vis_zoom, straight_lines=not args.vis_curved, show_sublabels=args.vis_sublabels, height=args.vis_height, dark=args.vis_dark)
def __init__(self, wnid, classes, path_graph, path_wnids, other_class=False): self.path_graph = path_graph self.path_wnids = path_wnids self.wnid = wnid self.wnids = get_wnids(path_wnids) self.G = read_graph(path_graph) self.synset = wnid_to_synset(wnid) self.original_classes = classes self.num_original_classes = len(self.wnids) assert not self.is_leaf(), 'Cannot build dataset for leaf' self.has_other = other_class and not (self.is_root() or self.is_leaf()) self.num_children = len(self.get_children()) self.num_classes = self.num_children + int(self.has_other) self.old_to_new_classes, self.new_to_old_classes = \ self.build_class_mappings() self.classes = self.build_classes() assert len(self.classes) == self.num_classes, ( f'Number of classes {self.num_classes} does not equal number of ' f'class names found ({len(self.classes)}): {self.classes}') self.children = list(self.get_children()) self.leaves = list(self.get_leaves()) self.num_leaves = len(self.leaves) self._probabilities = None self._class_weights = None