def test_hierarchy(args): wnids = get_wnids_from_dataset(args.dataset) path = get_graph_path_from_args(**vars(args)) print("==> Reading from {}".format(path)) G = read_graph(path) G_name = Path(path).stem leaves_seen, wnid_set1 = match_wnid_leaves(wnids, G, G_name) print_stats(leaves_seen, wnid_set1, G_name, "leaves") leaves_seen, wnid_set2 = match_wnid_nodes(wnids, G, G_name) print_stats(leaves_seen, wnid_set2, G_name, "nodes") num_roots = len(list(get_roots(G))) if num_roots == 1: Colors.green("Found just 1 root.") else: Colors.red(f"Found {num_roots} roots. Should be only 1.") if len(wnid_set1) == len(wnid_set2) == 0 and num_roots == 1: Colors.green("==> All checks pass!") else: Colors.red("==> Test failed")
def generate_hierarchy_vis(args): path = get_graph_path_from_args(**vars(args)) print('==> Reading from {}'.format(path)) G = read_graph(path) roots = list(get_roots(G)) num_roots = len(roots) root = args.vis_root or next(get_roots(G)) assert root in G, f'Node {root} is not a valid node. Nodes: {G.nodes}' dataset = None if args.dataset and args.vis_leaf_images: cls = getattr(data, args.dataset) dataset = cls(root='./data', train=False, download=True) color_info = get_color_info(G, args.color, color_leaves=not args.vis_no_color_leaves, color_path_to=args.vis_color_path_to, color_nodes=args.vis_color_nodes or ()) node_to_conf = generate_node_conf(args.vis_node_conf) tree = build_tree(G, root, color_info=color_info, force_labels_left=args.vis_force_labels_left or [], dataset=dataset, include_leaf_images=args.vis_leaf_images, image_resize_factor=args.vis_image_resize_factor, include_fake_sublabels=args.vis_fake_sublabels, node_to_conf=node_to_conf) graph = build_graph(G) if num_roots > 1: Colors.red(f'Found {num_roots} roots! Should be only 1: {roots}') else: print(f'Found just {num_roots} root.') fname = generate_vis_fname(**vars(args)) parent = Path(fwd()).parent generate_vis(str(parent / 'nbdt/templates/tree-template.html'), tree, fname, zoom=args.vis_zoom, straight_lines=not args.vis_curved, show_sublabels=args.vis_sublabels, height=args.vis_height, width=args.vis_width, dark=args.vis_dark, margin_top=args.vis_margin_top, margin_left=args.vis_margin_left, hide=args.vis_hide or [], above_dy=args.vis_above_dy, below_dy=args.vis_below_dy, scale=args.vis_scale, root_y=args.vis_root_y, colormap=args.vis_colormap)
def generate_hierarchy_vis(args): path = get_graph_path_from_args(**vars(args)) print('==> Reading from {}'.format(path)) G = read_graph(path) roots = list(get_roots(G)) num_roots = len(roots) root = next(get_roots(G)) dataset = None if args.dataset: cls = getattr(data, args.dataset) dataset = cls(root='./data', train=False, download=True) color_info = get_color_info(G, args.color, color_leaves=not args.vis_no_color_leaves, color_path_to=args.vis_color_path_to, color_nodes=args.vis_color_nodes or ()) tree = build_tree(G, root, color_info=color_info, force_labels_left=args.vis_force_labels_left or [], dataset=dataset, include_leaf_images=args.vis_leaf_images, image_resize_factor=args.vis_image_resize_factor) graph = build_graph(G) if num_roots > 1: Colors.red(f'Found {num_roots} roots! Should be only 1: {roots}') else: print(f'Found just {num_roots} root.') fname = generate_vis_fname(**vars(args)) parent = Path(fwd()).parent generate_vis(str(parent / 'nbdt/templates/tree-template.html'), tree, 'tree', fname, zoom=args.vis_zoom, straight_lines=not args.vis_curved, show_sublabels=args.vis_sublabels, height=args.vis_height, dark=args.vis_dark)
def generate_hierarchy_vis(args): path_hie = get_graph_path_from_args(**vars(args)) print("==> Reading from {}".format(path_hie)) G = read_graph(path_hie) path_html = f"./{generate_vis_fname(**vars(args))}.html" kwargs = vars(args) dataset = None if args.dataset and args.vis_leaf_images: cls = getattr(data, kwargs.pop('dataset')) dataset = cls(root="./data", train=False, download=True) kwargs.pop('dataset', '') kwargs.pop('fname', '') return generate_hierarchy_vis_from(G, dataset, path_html, verbose=True, **kwargs)
def generate_hierarchy( dataset, method, seed=0, branching_factor=2, extra=0, no_prune=False, fname="", path="", single_path=False, induced_linkage="ward", induced_affinity="euclidean", checkpoint=None, arch=None, model=None, **kwargs, ): wnids = get_wnids_from_dataset(dataset) if method == "wordnet": G = build_minimal_wordnet_graph(wnids, single_path) elif method == "random": G = build_random_graph(wnids, seed=seed, branching_factor=branching_factor) elif method == "induced": G = build_induced_graph( wnids, dataset=dataset, checkpoint=checkpoint, model=arch, linkage=induced_linkage, affinity=induced_affinity, branching_factor=branching_factor, state_dict=model.state_dict() if model is not None else None, ) else: raise NotImplementedError(f'Method "{method}" not yet handled.') print_graph_stats(G, "matched") assert_all_wnids_in_graph(G, wnids) if not no_prune: G = prune_single_successor_nodes(G) print_graph_stats(G, "pruned") assert_all_wnids_in_graph(G, wnids) if extra > 0: G, n_extra, n_imaginary = augment_graph(G, extra, True) print(f"[extra] \t Extras: {n_extra} \t Imaginary: {n_imaginary}") print_graph_stats(G, "extra") assert_all_wnids_in_graph(G, wnids) path = get_graph_path_from_args( dataset=dataset, method=method, seed=seed, branching_factor=branching_factor, extra=extra, no_prune=no_prune, fname=fname, path=path, single_path=single_path, induced_linkage=induced_linkage, induced_affinity=induced_affinity, checkpoint=checkpoint, arch=arch, ) write_graph(G, path) Colors.green("==> Wrote tree to {}".format(path)) return path