def print_stats(leaves_seen, wnid_set, tree_name, node_type):
    print(
        f"[{tree_name}] \t {node_type}: {len(leaves_seen)} \t WNIDs missing from {node_type}: {len(wnid_set)}"
    )
    if len(wnid_set):
        Colors.red(
            f"==> Warning: WNIDs in wnid.txt are missing from {tree_name} {node_type}"
        )
def test_hierarchy(args):
    wnids = get_wnids_from_dataset(args.dataset)
    path = get_graph_path_from_args(**vars(args))
    print('==> Reading from {}'.format(path))

    G = read_graph(path)

    G_name = Path(path).stem

    leaves_seen, wnid_set1 = match_wnid_leaves(wnids, G, G_name)
    print_stats(leaves_seen, wnid_set1, G_name, 'leaves')

    leaves_seen, wnid_set2 = match_wnid_nodes(wnids, G, G_name)
    print_stats(leaves_seen, wnid_set2, G_name, 'nodes')

    num_roots = len(list(get_roots(G)))
    if num_roots == 1:
        Colors.green('Found just 1 root.')
    else:
        Colors.red(f'Found {num_roots} roots. Should be only 1.')

    if len(wnid_set1) == len(wnid_set2) == 0 and num_roots == 1:
        Colors.green("==> All checks pass!")
    else:
        Colors.red('==> Test failed')
def generate_hierarchy_vis(args):
    path = get_graph_path_from_args(**vars(args))
    print('==> Reading from {}'.format(path))

    G = read_graph(path)

    roots = list(get_roots(G))
    num_roots = len(roots)
    root = next(get_roots(G))

    dataset = None
    if args.dataset:
        cls = getattr(data, args.dataset)
        dataset = cls(root='./data', train=False, download=True)

    color_info = get_color_info(G,
                                args.color,
                                color_leaves=not args.vis_no_color_leaves,
                                color_path_to=args.vis_color_path_to,
                                color_nodes=args.vis_color_nodes or ())

    tree = build_tree(G,
                      root,
                      color_info=color_info,
                      force_labels_left=args.vis_force_labels_left or [],
                      dataset=dataset,
                      include_leaf_images=args.vis_leaf_images,
                      image_resize_factor=args.vis_image_resize_factor)
    graph = build_graph(G)

    if num_roots > 1:
        Colors.red(f'Found {num_roots} roots! Should be only 1: {roots}')
    else:
        print(f'Found just {num_roots} root.')

    fname = generate_vis_fname(**vars(args))
    parent = Path(fwd()).parent
    generate_vis(str(parent / 'nbdt/templates/tree-template.html'),
                 tree,
                 'tree',
                 fname,
                 zoom=args.vis_zoom,
                 straight_lines=not args.vis_curved,
                 show_sublabels=args.vis_sublabels,
                 height=args.vis_height,
                 dark=args.vis_dark)
def generate_vis(path_template,
                 data,
                 name,
                 fname,
                 zoom=2,
                 straight_lines=True,
                 show_sublabels=False,
                 height=750,
                 dark=False):
    with open(path_template) as f:
        html = f.read() \
        .replace(
            "CONFIG_TREE_DATA",
            json.dumps([data])) \
        .replace(
            "CONFIG_ZOOM",
            str(zoom)) \
        .replace(
            "CONFIG_STRAIGHT_LINES",
            str(straight_lines).lower()) \
        .replace(
            "CONFIG_SHOW_SUBLABELS",
            str(show_sublabels).lower()) \
        .replace(
            "CONFIG_TITLE",
            fname) \
        .replace(
            "CONFIG_VIS_HEIGHT",
            str(height)) \
        .replace(
            "CONFIG_BG_COLOR",
            "#111111" if dark else "#FFFFFF") \
        .replace(
            "CONFIG_TEXT_COLOR",
            '#FFFFFF' if dark else '#000000') \
        .replace(
            "CONFIG_TEXT_RECT_COLOR",
            "rgba(17,17,17,0.8)" if dark else "rgba(255,255,255,0.8)")

    os.makedirs('out', exist_ok=True)
    path_html = f'out/{fname}-{name}.html'
    with open(path_html, 'w') as f:
        f.write(html)

    Colors.green('==> Wrote HTML to {}'.format(path_html))
def generate_hierarchy(dataset,
                       method,
                       seed=0,
                       branching_factor=2,
                       extra=0,
                       no_prune=False,
                       fname='',
                       single_path=False,
                       induced_linkage='ward',
                       induced_affinity='euclidean',
                       checkpoint=None,
                       arch=None,
                       model=None,
                       **kwargs):
    wnids = get_wnids_from_dataset(dataset)

    if method == 'wordnet':
        print("the mnist dataset doesn't support wordnet")
        # G = build_minimal_wordnet_graph(wnids, single_path)
    # elif method == 'random':
    #     G = build_random_graph(wnids, seed=seed, branching_factor=branching_factor)
    elif method == 'induced':
        G = build_induced_graph(
            wnids,
            dataset=dataset,
            checkpoint=checkpoint,
            model=arch,
            linkage=induced_linkage,
            affinity=induced_affinity,
            branching_factor=branching_factor,
            state_dict=model.state_dict() if model is not None else None)
    else:
        raise NotImplementedError(f'Method "{method}" not yet handled.')
    print_graph_stats(G, 'matched')
    assert_all_wnids_in_graph(G, wnids)

    if not no_prune:
        print("mnist graph does not support prune")
        # G = prune_single_successor_nodes(G)
        # print_graph_stats(G, 'pruned')
        # assert_all_wnids_in_graph(G, wnids)

    if extra > 0:
        print("mnist graph does not support augment")
        # G, n_extra, n_imaginary = augment_graph(G, extra, True)
        # print(f'[extra] \t Extras: {n_extra} \t Imaginary: {n_imaginary}')
        # print_graph_stats(G, 'extra')
        # assert_all_wnids_in_graph(G, wnids)

    path = get_graph_path_from_args(dataset=dataset,
                                    method=method,
                                    seed=seed,
                                    branching_factor=branching_factor,
                                    extra=extra,
                                    no_prune=no_prune,
                                    fname=fname,
                                    single_path=single_path,
                                    induced_linkage=induced_linkage,
                                    induced_affinity=induced_affinity,
                                    checkpoint=checkpoint,
                                    arch=arch)
    write_graph(G, path)

    Colors.green('==> Wrote tree to {}'.format(path))
Exemple #6
0
    name=f'Dataset {args.dataset}',
    keys=data.custom.keys,
    globals=globals())

trainset = dataset(**dataset_kwargs, root='./data', train=True, download=True, transform=transform_train)
testset = dataset(**dataset_kwargs, root='./data', train=False, download=True, transform=transform_test)

assert trainset.classes == testset.classes, (trainset.classes, testset.classes)

trainloader = torch.utils.data.DataLoader(trainset,batch_size=args.batch_size, shuffle=True, num_workers=2)
# trainloader = torch.utils.data.DataLoader(trainset)
# testloader = torch.utils.data.DataLoader(testset)
testloader = torch.utils.data.DataLoader(testset,batch_size=100, shuffle=True, num_workers=2)


Colors.cyan(f'Training with dataset {args.dataset} and {len(trainset.classes)} classes')


# Model
print('==> Building model..')
# model = getattr(models, args.arch)
# model_kwargs = {'num_classes': len(trainset.classes) }
#
# if args.pretrained:
#     print('==> Loading pretrained model..')
#     try:
#         net = model(pretrained=True, dataset=args.dataset, **model_kwargs)
#     except TypeError as e:  # likely because `dataset` not allowed arg
#         print(e)
#
#         try:
                   transform=transform_)
testset = dataset(**dataset_kwargs,
                  root='./data',
                  train=False,
                  download=True,
                  transform=transform_)

assert trainset.classes == testset.classes, (trainset.classes, testset.classes)

trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=1)

Colors.cyan(
    f'Testing with dataset {args.dataset} and {len(testset.classes)} classes')


# Model
# TODO(alvin): fix checkpoint structure so that this isn't neededd
def load_state_dict(net, state_dict):
    try:
        net.load_state_dict(state_dict)
    except RuntimeError as e:
        if 'Missing key(s) in state_dict:' in str(e):
            net.load_state_dict({
                key.replace('module.', '', 1): value
                for key, value in state_dict.items()
            })