コード例 #1
0
def test_hierarchy(args):
    wnids = get_wnids_from_dataset(args.dataset)
    path = get_graph_path_from_args(**vars(args))
    print("==> Reading from {}".format(path))

    G = read_graph(path)

    G_name = Path(path).stem

    leaves_seen, wnid_set1 = match_wnid_leaves(wnids, G, G_name)
    print_stats(leaves_seen, wnid_set1, G_name, "leaves")

    leaves_seen, wnid_set2 = match_wnid_nodes(wnids, G, G_name)
    print_stats(leaves_seen, wnid_set2, G_name, "nodes")

    num_roots = len(list(get_roots(G)))
    if num_roots == 1:
        Colors.green("Found just 1 root.")
    else:
        Colors.red(f"Found {num_roots} roots. Should be only 1.")

    if len(wnid_set1) == len(wnid_set2) == 0 and num_roots == 1:
        Colors.green("==> All checks pass!")
    else:
        Colors.red("==> Test failed")
コード例 #2
0
def generate_hierarchy_vis(args):
    path = get_graph_path_from_args(**vars(args))
    print('==> Reading from {}'.format(path))

    G = read_graph(path)

    roots = list(get_roots(G))
    num_roots = len(roots)
    root = args.vis_root or next(get_roots(G))

    assert root in G, f'Node {root} is not a valid node. Nodes: {G.nodes}'

    dataset = None
    if args.dataset and args.vis_leaf_images:
        cls = getattr(data, args.dataset)
        dataset = cls(root='./data', train=False, download=True)

    color_info = get_color_info(G,
                                args.color,
                                color_leaves=not args.vis_no_color_leaves,
                                color_path_to=args.vis_color_path_to,
                                color_nodes=args.vis_color_nodes or ())

    node_to_conf = generate_node_conf(args.vis_node_conf)

    tree = build_tree(G,
                      root,
                      color_info=color_info,
                      force_labels_left=args.vis_force_labels_left or [],
                      dataset=dataset,
                      include_leaf_images=args.vis_leaf_images,
                      image_resize_factor=args.vis_image_resize_factor,
                      include_fake_sublabels=args.vis_fake_sublabels,
                      node_to_conf=node_to_conf)
    graph = build_graph(G)

    if num_roots > 1:
        Colors.red(f'Found {num_roots} roots! Should be only 1: {roots}')
    else:
        print(f'Found just {num_roots} root.')

    fname = generate_vis_fname(**vars(args))
    parent = Path(fwd()).parent
    generate_vis(str(parent / 'nbdt/templates/tree-template.html'),
                 tree,
                 fname,
                 zoom=args.vis_zoom,
                 straight_lines=not args.vis_curved,
                 show_sublabels=args.vis_sublabels,
                 height=args.vis_height,
                 width=args.vis_width,
                 dark=args.vis_dark,
                 margin_top=args.vis_margin_top,
                 margin_left=args.vis_margin_left,
                 hide=args.vis_hide or [],
                 above_dy=args.vis_above_dy,
                 below_dy=args.vis_below_dy,
                 scale=args.vis_scale,
                 root_y=args.vis_root_y,
                 colormap=args.vis_colormap)
コード例 #3
0
def generate_hierarchy_vis(args):
    path = get_graph_path_from_args(**vars(args))
    print('==> Reading from {}'.format(path))

    G = read_graph(path)

    roots = list(get_roots(G))
    num_roots = len(roots)
    root = next(get_roots(G))

    dataset = None
    if args.dataset:
        cls = getattr(data, args.dataset)
        dataset = cls(root='./data', train=False, download=True)

    color_info = get_color_info(G,
                                args.color,
                                color_leaves=not args.vis_no_color_leaves,
                                color_path_to=args.vis_color_path_to,
                                color_nodes=args.vis_color_nodes or ())

    tree = build_tree(G,
                      root,
                      color_info=color_info,
                      force_labels_left=args.vis_force_labels_left or [],
                      dataset=dataset,
                      include_leaf_images=args.vis_leaf_images,
                      image_resize_factor=args.vis_image_resize_factor)
    graph = build_graph(G)

    if num_roots > 1:
        Colors.red(f'Found {num_roots} roots! Should be only 1: {roots}')
    else:
        print(f'Found just {num_roots} root.')

    fname = generate_vis_fname(**vars(args))
    parent = Path(fwd()).parent
    generate_vis(str(parent / 'nbdt/templates/tree-template.html'),
                 tree,
                 'tree',
                 fname,
                 zoom=args.vis_zoom,
                 straight_lines=not args.vis_curved,
                 show_sublabels=args.vis_sublabels,
                 height=args.vis_height,
                 dark=args.vis_dark)
コード例 #4
0
def generate_hierarchy_vis(args):
    path_hie = get_graph_path_from_args(**vars(args))
    print("==> Reading from {}".format(path_hie))
    G = read_graph(path_hie)

    path_html = f"./{generate_vis_fname(**vars(args))}.html"
    kwargs = vars(args)

    dataset = None
    if args.dataset and args.vis_leaf_images:
        cls = getattr(data, kwargs.pop('dataset'))
        dataset = cls(root="./data", train=False, download=True)

    kwargs.pop('dataset', '')
    kwargs.pop('fname', '')
    return generate_hierarchy_vis_from(G,
                                       dataset,
                                       path_html,
                                       verbose=True,
                                       **kwargs)
コード例 #5
0
def generate_hierarchy(
    dataset,
    method,
    seed=0,
    branching_factor=2,
    extra=0,
    no_prune=False,
    fname="",
    path="",
    single_path=False,
    induced_linkage="ward",
    induced_affinity="euclidean",
    checkpoint=None,
    arch=None,
    model=None,
    **kwargs,
):
    wnids = get_wnids_from_dataset(dataset)

    if method == "wordnet":
        G = build_minimal_wordnet_graph(wnids, single_path)
    elif method == "random":
        G = build_random_graph(wnids,
                               seed=seed,
                               branching_factor=branching_factor)
    elif method == "induced":
        G = build_induced_graph(
            wnids,
            dataset=dataset,
            checkpoint=checkpoint,
            model=arch,
            linkage=induced_linkage,
            affinity=induced_affinity,
            branching_factor=branching_factor,
            state_dict=model.state_dict() if model is not None else None,
        )
    else:
        raise NotImplementedError(f'Method "{method}" not yet handled.')
    print_graph_stats(G, "matched")
    assert_all_wnids_in_graph(G, wnids)

    if not no_prune:
        G = prune_single_successor_nodes(G)
        print_graph_stats(G, "pruned")
        assert_all_wnids_in_graph(G, wnids)

    if extra > 0:
        G, n_extra, n_imaginary = augment_graph(G, extra, True)
        print(f"[extra] \t Extras: {n_extra} \t Imaginary: {n_imaginary}")
        print_graph_stats(G, "extra")
        assert_all_wnids_in_graph(G, wnids)

    path = get_graph_path_from_args(
        dataset=dataset,
        method=method,
        seed=seed,
        branching_factor=branching_factor,
        extra=extra,
        no_prune=no_prune,
        fname=fname,
        path=path,
        single_path=single_path,
        induced_linkage=induced_linkage,
        induced_affinity=induced_affinity,
        checkpoint=checkpoint,
        arch=arch,
    )
    write_graph(G, path)

    Colors.green("==> Wrote tree to {}".format(path))
    return path