def print_stats(leaves_seen, wnid_set, tree_name, node_type): print( f"[{tree_name}] \t {node_type}: {len(leaves_seen)} \t WNIDs missing from {node_type}: {len(wnid_set)}" ) if len(wnid_set): Colors.red( f"==> Warning: WNIDs in wnid.txt are missing from {tree_name} {node_type}" )
def test_hierarchy(args): wnids = get_wnids_from_dataset(args.dataset) path = get_graph_path_from_args(**vars(args)) print('==> Reading from {}'.format(path)) G = read_graph(path) G_name = Path(path).stem leaves_seen, wnid_set1 = match_wnid_leaves(wnids, G, G_name) print_stats(leaves_seen, wnid_set1, G_name, 'leaves') leaves_seen, wnid_set2 = match_wnid_nodes(wnids, G, G_name) print_stats(leaves_seen, wnid_set2, G_name, 'nodes') num_roots = len(list(get_roots(G))) if num_roots == 1: Colors.green('Found just 1 root.') else: Colors.red(f'Found {num_roots} roots. Should be only 1.') if len(wnid_set1) == len(wnid_set2) == 0 and num_roots == 1: Colors.green("==> All checks pass!") else: Colors.red('==> Test failed')
def generate_hierarchy_vis(args): path = get_graph_path_from_args(**vars(args)) print('==> Reading from {}'.format(path)) G = read_graph(path) roots = list(get_roots(G)) num_roots = len(roots) root = next(get_roots(G)) dataset = None if args.dataset: cls = getattr(data, args.dataset) dataset = cls(root='./data', train=False, download=True) color_info = get_color_info(G, args.color, color_leaves=not args.vis_no_color_leaves, color_path_to=args.vis_color_path_to, color_nodes=args.vis_color_nodes or ()) tree = build_tree(G, root, color_info=color_info, force_labels_left=args.vis_force_labels_left or [], dataset=dataset, include_leaf_images=args.vis_leaf_images, image_resize_factor=args.vis_image_resize_factor) graph = build_graph(G) if num_roots > 1: Colors.red(f'Found {num_roots} roots! Should be only 1: {roots}') else: print(f'Found just {num_roots} root.') fname = generate_vis_fname(**vars(args)) parent = Path(fwd()).parent generate_vis(str(parent / 'nbdt/templates/tree-template.html'), tree, 'tree', fname, zoom=args.vis_zoom, straight_lines=not args.vis_curved, show_sublabels=args.vis_sublabels, height=args.vis_height, dark=args.vis_dark)
def generate_vis(path_template, data, name, fname, zoom=2, straight_lines=True, show_sublabels=False, height=750, dark=False): with open(path_template) as f: html = f.read() \ .replace( "CONFIG_TREE_DATA", json.dumps([data])) \ .replace( "CONFIG_ZOOM", str(zoom)) \ .replace( "CONFIG_STRAIGHT_LINES", str(straight_lines).lower()) \ .replace( "CONFIG_SHOW_SUBLABELS", str(show_sublabels).lower()) \ .replace( "CONFIG_TITLE", fname) \ .replace( "CONFIG_VIS_HEIGHT", str(height)) \ .replace( "CONFIG_BG_COLOR", "#111111" if dark else "#FFFFFF") \ .replace( "CONFIG_TEXT_COLOR", '#FFFFFF' if dark else '#000000') \ .replace( "CONFIG_TEXT_RECT_COLOR", "rgba(17,17,17,0.8)" if dark else "rgba(255,255,255,0.8)") os.makedirs('out', exist_ok=True) path_html = f'out/{fname}-{name}.html' with open(path_html, 'w') as f: f.write(html) Colors.green('==> Wrote HTML to {}'.format(path_html))
def generate_hierarchy(dataset, method, seed=0, branching_factor=2, extra=0, no_prune=False, fname='', single_path=False, induced_linkage='ward', induced_affinity='euclidean', checkpoint=None, arch=None, model=None, **kwargs): wnids = get_wnids_from_dataset(dataset) if method == 'wordnet': print("the mnist dataset doesn't support wordnet") # G = build_minimal_wordnet_graph(wnids, single_path) # elif method == 'random': # G = build_random_graph(wnids, seed=seed, branching_factor=branching_factor) elif method == 'induced': G = build_induced_graph( wnids, dataset=dataset, checkpoint=checkpoint, model=arch, linkage=induced_linkage, affinity=induced_affinity, branching_factor=branching_factor, state_dict=model.state_dict() if model is not None else None) else: raise NotImplementedError(f'Method "{method}" not yet handled.') print_graph_stats(G, 'matched') assert_all_wnids_in_graph(G, wnids) if not no_prune: print("mnist graph does not support prune") # G = prune_single_successor_nodes(G) # print_graph_stats(G, 'pruned') # assert_all_wnids_in_graph(G, wnids) if extra > 0: print("mnist graph does not support augment") # G, n_extra, n_imaginary = augment_graph(G, extra, True) # print(f'[extra] \t Extras: {n_extra} \t Imaginary: {n_imaginary}') # print_graph_stats(G, 'extra') # assert_all_wnids_in_graph(G, wnids) path = get_graph_path_from_args(dataset=dataset, method=method, seed=seed, branching_factor=branching_factor, extra=extra, no_prune=no_prune, fname=fname, single_path=single_path, induced_linkage=induced_linkage, induced_affinity=induced_affinity, checkpoint=checkpoint, arch=arch) write_graph(G, path) Colors.green('==> Wrote tree to {}'.format(path))
name=f'Dataset {args.dataset}', keys=data.custom.keys, globals=globals()) trainset = dataset(**dataset_kwargs, root='./data', train=True, download=True, transform=transform_train) testset = dataset(**dataset_kwargs, root='./data', train=False, download=True, transform=transform_test) assert trainset.classes == testset.classes, (trainset.classes, testset.classes) trainloader = torch.utils.data.DataLoader(trainset,batch_size=args.batch_size, shuffle=True, num_workers=2) # trainloader = torch.utils.data.DataLoader(trainset) # testloader = torch.utils.data.DataLoader(testset) testloader = torch.utils.data.DataLoader(testset,batch_size=100, shuffle=True, num_workers=2) Colors.cyan(f'Training with dataset {args.dataset} and {len(trainset.classes)} classes') # Model print('==> Building model..') # model = getattr(models, args.arch) # model_kwargs = {'num_classes': len(trainset.classes) } # # if args.pretrained: # print('==> Loading pretrained model..') # try: # net = model(pretrained=True, dataset=args.dataset, **model_kwargs) # except TypeError as e: # likely because `dataset` not allowed arg # print(e) # # try:
transform=transform_) testset = dataset(**dataset_kwargs, root='./data', train=False, download=True, transform=transform_) assert trainset.classes == testset.classes, (trainset.classes, testset.classes) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(testset, batch_size=1) Colors.cyan( f'Testing with dataset {args.dataset} and {len(testset.classes)} classes') # Model # TODO(alvin): fix checkpoint structure so that this isn't neededd def load_state_dict(net, state_dict): try: net.load_state_dict(state_dict) except RuntimeError as e: if 'Missing key(s) in state_dict:' in str(e): net.load_state_dict({ key.replace('module.', '', 1): value for key, value in state_dict.items() })