Esempio n. 1
0
def test_hierarchy(args):
    wnids = get_wnids_from_dataset(args.dataset)
    path = get_graph_path_from_args(**vars(args))
    print('==> Reading from {}'.format(path))

    G = read_graph(path)

    G_name = Path(path).stem

    leaves_seen, wnid_set1 = match_wnid_leaves(wnids, G, G_name)
    print_stats(leaves_seen, wnid_set1, G_name, 'leaves')

    leaves_seen, wnid_set2 = match_wnid_nodes(wnids, G, G_name)
    print_stats(leaves_seen, wnid_set2, G_name, 'nodes')

    num_roots = len(list(get_roots(G)))
    if num_roots == 1:
        Colors.green('Found just 1 root.')
    else:
        Colors.red(f'Found {num_roots} roots. Should be only 1.')

    if len(wnid_set1) == len(wnid_set2) == 0 and num_roots == 1:
        Colors.green("==> All checks pass!")
    else:
        Colors.red('==> Test failed')
def generate_hierarchy_vis(args):
    path = get_graph_path_from_args(**vars(args))
    print('==> Reading from {}'.format(path))

    G = read_graph(path)

    roots = list(get_roots(G))
    num_roots = len(roots)
    root = args.vis_root or next(get_roots(G))

    assert root in G, f'Node {root} is not a valid node. Nodes: {G.nodes}'

    dataset = None
    if args.dataset and args.vis_leaf_images:
        cls = getattr(data, args.dataset)
        dataset = cls(root='./data', train=False, download=True)

    color_info = get_color_info(G,
                                args.color,
                                color_leaves=not args.vis_no_color_leaves,
                                color_path_to=args.vis_color_path_to,
                                color_nodes=args.vis_color_nodes or ())

    node_to_conf = generate_node_conf(args.vis_node_conf)

    tree = build_tree(G,
                      root,
                      color_info=color_info,
                      force_labels_left=args.vis_force_labels_left or [],
                      dataset=dataset,
                      include_leaf_images=args.vis_leaf_images,
                      image_resize_factor=args.vis_image_resize_factor,
                      include_fake_sublabels=args.vis_fake_sublabels,
                      node_to_conf=node_to_conf)
    graph = build_graph(G)

    if num_roots > 1:
        Colors.red(f'Found {num_roots} roots! Should be only 1: {roots}')
    else:
        print(f'Found just {num_roots} root.')

    fname = generate_vis_fname(**vars(args))
    parent = Path(fwd()).parent
    generate_vis(str(parent / 'nbdt/templates/tree-template.html'),
                 tree,
                 fname,
                 zoom=args.vis_zoom,
                 straight_lines=not args.vis_curved,
                 show_sublabels=args.vis_sublabels,
                 height=args.vis_height,
                 width=args.vis_width,
                 dark=args.vis_dark,
                 margin_top=args.vis_margin_top,
                 margin_left=args.vis_margin_left,
                 hide=args.vis_hide or [],
                 above_dy=args.vis_above_dy,
                 below_dy=args.vis_below_dy,
                 scale=args.vis_scale,
                 root_y=args.vis_root_y,
                 colormap=args.vis_colormap)
 def get_wnid_to_node(path_graph, path_wnids, classes):
     wnid_to_node = {}
     G = read_graph(path_graph)
     for wnid in get_non_leaves(G):
         wnid_to_node[wnid] = Node(wnid,
                                   classes,
                                   path_graph=path_graph,
                                   path_wnids=path_wnids)
     return wnid_to_node
Esempio n. 4
0
def generate_hierarchy_vis(args):
    path = get_graph_path_from_args(**vars(args))
    print('==> Reading from {}'.format(path))

    G = read_graph(path)

    roots = list(get_roots(G))
    num_roots = len(roots)
    root = next(get_roots(G))

    dataset = None
    if args.dataset:
        cls = getattr(data, args.dataset)
        dataset = cls(root='./data', train=False, download=True)

    color_info = get_color_info(G,
                                args.color,
                                color_leaves=not args.vis_no_color_leaves,
                                color_path_to=args.vis_color_path_to,
                                color_nodes=args.vis_color_nodes or ())

    tree = build_tree(G,
                      root,
                      color_info=color_info,
                      force_labels_left=args.vis_force_labels_left or [],
                      dataset=dataset,
                      include_leaf_images=args.vis_leaf_images,
                      image_resize_factor=args.vis_image_resize_factor)
    graph = build_graph(G)

    if num_roots > 1:
        Colors.red(f'Found {num_roots} roots! Should be only 1: {roots}')
    else:
        print(f'Found just {num_roots} root.')

    fname = generate_vis_fname(**vars(args))
    parent = Path(fwd()).parent
    generate_vis(str(parent / 'nbdt/templates/tree-template.html'),
                 tree,
                 'tree',
                 fname,
                 zoom=args.vis_zoom,
                 straight_lines=not args.vis_curved,
                 show_sublabels=args.vis_sublabels,
                 height=args.vis_height,
                 dark=args.vis_dark)
    def __init__(self,
                 wnid,
                 classes,
                 path_graph,
                 path_wnids,
                 other_class=False):
        self.path_graph = path_graph
        self.path_wnids = path_wnids

        self.wnid = wnid
        self.wnids = get_wnids(path_wnids)
        self.G = read_graph(path_graph)
        self.synset = wnid_to_synset(wnid)

        self.original_classes = classes
        self.num_original_classes = len(self.wnids)

        assert not self.is_leaf(), 'Cannot build dataset for leaf'
        self.has_other = other_class and not (self.is_root() or self.is_leaf())
        self.num_children = len(self.get_children())
        self.num_classes = self.num_children + int(self.has_other)

        self.old_to_new_classes, self.new_to_old_classes = \
            self.build_class_mappings()
        self.classes = self.build_classes()

        assert len(self.classes) == self.num_classes, (
            f'Number of classes {self.num_classes} does not equal number of '
            f'class names found ({len(self.classes)}): {self.classes}')

        self.children = list(self.get_children())
        self.leaves = list(self.get_leaves())
        self.num_leaves = len(self.leaves)

        self._probabilities = None
        self._class_weights = None