def generate_hierarchy_vis(args):
    path = get_graph_path_from_args(**vars(args))
    print('==> Reading from {}'.format(path))

    G = read_graph(path)

    roots = list(get_roots(G))
    num_roots = len(roots)
    root = args.vis_root or next(get_roots(G))

    assert root in G, f'Node {root} is not a valid node. Nodes: {G.nodes}'

    dataset = None
    if args.dataset and args.vis_leaf_images:
        cls = getattr(data, args.dataset)
        dataset = cls(root='./data', train=False, download=True)

    color_info = get_color_info(G,
                                args.color,
                                color_leaves=not args.vis_no_color_leaves,
                                color_path_to=args.vis_color_path_to,
                                color_nodes=args.vis_color_nodes or ())

    node_to_conf = generate_node_conf(args.vis_node_conf)

    tree = build_tree(G,
                      root,
                      color_info=color_info,
                      force_labels_left=args.vis_force_labels_left or [],
                      dataset=dataset,
                      include_leaf_images=args.vis_leaf_images,
                      image_resize_factor=args.vis_image_resize_factor,
                      include_fake_sublabels=args.vis_fake_sublabels,
                      node_to_conf=node_to_conf)
    graph = build_graph(G)

    if num_roots > 1:
        Colors.red(f'Found {num_roots} roots! Should be only 1: {roots}')
    else:
        print(f'Found just {num_roots} root.')

    fname = generate_vis_fname(**vars(args))
    parent = Path(fwd()).parent
    generate_vis(str(parent / 'nbdt/templates/tree-template.html'),
                 tree,
                 fname,
                 zoom=args.vis_zoom,
                 straight_lines=not args.vis_curved,
                 show_sublabels=args.vis_sublabels,
                 height=args.vis_height,
                 width=args.vis_width,
                 dark=args.vis_dark,
                 margin_top=args.vis_margin_top,
                 margin_left=args.vis_margin_left,
                 hide=args.vis_hide or [],
                 above_dy=args.vis_above_dy,
                 below_dy=args.vis_below_dy,
                 scale=args.vis_scale,
                 root_y=args.vis_root_y,
                 colormap=args.vis_colormap)
def print_stats(leaves_seen, wnid_set, tree_name, node_type):
    print(
        f"[{tree_name}] \t {node_type}: {len(leaves_seen)} \t WNIDs missing from {node_type}: {len(wnid_set)}"
    )
    if len(wnid_set):
        Colors.red(
            f"==> Warning: WNIDs in wnid.txt are missing from {tree_name} {node_type}"
        )
Example #3
0
    def f(**kwargs):
        try:
            net = init(**optional_kwargs, **kwargs)
        except TypeError as e:  # likely because `dataset` not allowed arg
            print(e)

            try:
                net = init(**kwargs)
            except Exception as e:
                Colors.red(f"Fatal error: {e}")
                exit()
        return net
def test_hierarchy(args):
    wnids = get_wnids_from_dataset(args.dataset)
    path = get_graph_path_from_args(**vars(args))
    print("==> Reading from {}".format(path))

    G = read_graph(path)

    G_name = Path(path).stem

    leaves_seen, wnid_set1 = match_wnid_leaves(wnids, G, G_name)
    print_stats(leaves_seen, wnid_set1, G_name, "leaves")

    leaves_seen, wnid_set2 = match_wnid_nodes(wnids, G, G_name)
    print_stats(leaves_seen, wnid_set2, G_name, "nodes")

    num_roots = len(list(get_roots(G)))
    if num_roots == 1:
        Colors.green("Found just 1 root.")
    else:
        Colors.red(f"Found {num_roots} roots. Should be only 1.")

    if len(wnid_set1) == len(wnid_set2) == 0 and num_roots == 1:
        Colors.green("==> All checks pass!")
    else:
        Colors.red("==> Test failed")
Example #5
0
def generate_hierarchy_vis(args):
    path = get_graph_path_from_args(**vars(args))
    print('==> Reading from {}'.format(path))

    G = read_graph(path)

    roots = list(get_roots(G))
    num_roots = len(roots)
    root = next(get_roots(G))

    dataset = None
    if args.dataset:
        cls = getattr(data, args.dataset)
        dataset = cls(root='./data', train=False, download=True)

    color_info = get_color_info(G,
                                args.color,
                                color_leaves=not args.vis_no_color_leaves,
                                color_path_to=args.vis_color_path_to,
                                color_nodes=args.vis_color_nodes or ())

    tree = build_tree(G,
                      root,
                      color_info=color_info,
                      force_labels_left=args.vis_force_labels_left or [],
                      dataset=dataset,
                      include_leaf_images=args.vis_leaf_images,
                      image_resize_factor=args.vis_image_resize_factor)
    graph = build_graph(G)

    if num_roots > 1:
        Colors.red(f'Found {num_roots} roots! Should be only 1: {roots}')
    else:
        print(f'Found just {num_roots} root.')

    fname = generate_vis_fname(**vars(args))
    parent = Path(fwd()).parent
    generate_vis(str(parent / 'nbdt/templates/tree-template.html'),
                 tree,
                 'tree',
                 fname,
                 zoom=args.vis_zoom,
                 straight_lines=not args.vis_curved,
                 show_sublabels=args.vis_sublabels,
                 height=args.vis_height,
                 dark=args.vis_dark)
Example #6
0
def generate_vis(path_template,
                 data,
                 name,
                 fname,
                 zoom=2,
                 straight_lines=True,
                 show_sublabels=False,
                 height=750,
                 dark=False):
    with open(path_template) as f:
        html = f.read() \
        .replace(
            "CONFIG_TREE_DATA",
            json.dumps([data])) \
        .replace(
            "CONFIG_ZOOM",
            str(zoom)) \
        .replace(
            "CONFIG_STRAIGHT_LINES",
            str(straight_lines).lower()) \
        .replace(
            "CONFIG_SHOW_SUBLABELS",
            str(show_sublabels).lower()) \
        .replace(
            "CONFIG_TITLE",
            fname) \
        .replace(
            "CONFIG_VIS_HEIGHT",
            str(height)) \
        .replace(
            "CONFIG_BG_COLOR",
            "#111111" if dark else "#FFFFFF") \
        .replace(
            "CONFIG_TEXT_COLOR",
            '#FFFFFF' if dark else '#000000') \
        .replace(
            "CONFIG_TEXT_RECT_COLOR",
            "rgba(17,17,17,0.8)" if dark else "rgba(255,255,255,0.8)")

    os.makedirs('out', exist_ok=True)
    path_html = f'out/{fname}-{name}.html'
    with open(path_html, 'w') as f:
        f.write(html)

    Colors.green('==> Wrote HTML to {}'.format(path_html))
Example #7
0
    def test(epoch, checkpoint=True):
        nonlocal best_acc
        net.eval()
        test_loss = 0
        metric.clear()
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)

                if not args.disable_test_eval:
                    loss = criterion(outputs, targets)
                    test_loss += loss.item()
                    metric.forward(outputs, targets)
                transform = testset.transform_val_inverse().to(device)
                stat = analyzer.update_batch(outputs, targets,
                                             transform(inputs))

                progress_bar(
                    batch_idx,
                    len(testloader),
                    "Loss: %.3f | Acc: %.3f%% (%d/%d) %s" % (
                        test_loss / (batch_idx + 1),
                        100.0 * metric.report(),
                        metric.correct,
                        metric.total,
                        f"| {analyzer.name}: {stat}" if stat else "",
                    ),
                )

        # Save checkpoint.
        acc = 100.0 * metric.report()
        print("Accuracy: {}, {}/{} | Best Accurracy: {}".format(
            acc, metric.correct, metric.total, best_acc))
        if acc > best_acc and checkpoint:
            Colors.green(f"Saving to {checkpoint_fname} ({acc})..")
            state = {
                "net": net.state_dict(),
                "acc": acc,
                "epoch": epoch,
            }
            os.makedirs("checkpoint", exist_ok=True)
            torch.save(state, f"./checkpoint/{checkpoint_fname}.pth")
            best_acc = acc
Example #8
0
def load_state_dict_from_key(
    keys,
    model_urls,
    pretrained=False,
    progress=True,
    root=".cache/torch/checkpoints",
    device="cpu",
):
    valid_keys = [key for key in keys if key in model_urls]
    if not valid_keys:
        raise UserWarning(f"None of the keys {keys} correspond to a pretrained model.")
    key = valid_keys[-1]
    url = model_urls[key]
    Colors.green(f"Loading pretrained model {key} from {url}")
    return load_state_dict_from_url(
        url,
        Path.home() / root,
        progress=progress,
        check_hash=False,
        map_location=torch.device(device),
    )
    def __init__(
        self,
        *args,
        superclass_wnids,
        dataset_test=None,
        Rules=SoftRules,
        metric=None,
        **kwargs,
    ):
        """Pass wnids to classify.

        Assumes index of each wnid is the index of the wnid in the rules.wnids
        list. This agrees with Node.wnid_to_class_index as of writing, since
        rules.wnids = get_wnids(...).
        """
        # TODO: for now, ignores metric
        super().__init__(*args, **kwargs)

        kwargs["dataset"] = dataset_test
        kwargs.pop("path_graph", "")
        kwargs.pop("path_wnids", "")
        self.rules_test = Rules(*args, **kwargs)
        self.superclass_wnids = superclass_wnids
        self.total = self.correct = 0

        self.mapping_target, self.new_to_old_classes_target = Superclass.build_mapping(
            self.rules_test.tree.wnids_leaves, superclass_wnids)
        self.mapping_pred, self.new_to_old_classes_pred = Superclass.build_mapping(
            self.rules.tree.wnids_leaves, superclass_wnids)

        mapped_classes = [
            self.classes[i] for i in (self.mapping_target >= 0).nonzero()
        ]
        Colors.cyan(
            f"==> Mapped {len(mapped_classes)} classes to your superclasses: "
            f"{mapped_classes}")
def generate_hierarchy(
    dataset,
    method,
    seed=0,
    branching_factor=2,
    extra=0,
    no_prune=False,
    fname="",
    path="",
    single_path=False,
    induced_linkage="ward",
    induced_affinity="euclidean",
    checkpoint=None,
    arch=None,
    model=None,
    **kwargs,
):
    wnids = get_wnids_from_dataset(dataset)

    if method == "wordnet":
        G = build_minimal_wordnet_graph(wnids, single_path)
    elif method == "random":
        G = build_random_graph(wnids,
                               seed=seed,
                               branching_factor=branching_factor)
    elif method == "induced":
        G = build_induced_graph(
            wnids,
            dataset=dataset,
            checkpoint=checkpoint,
            model=arch,
            linkage=induced_linkage,
            affinity=induced_affinity,
            branching_factor=branching_factor,
            state_dict=model.state_dict() if model is not None else None,
        )
    else:
        raise NotImplementedError(f'Method "{method}" not yet handled.')
    print_graph_stats(G, "matched")
    assert_all_wnids_in_graph(G, wnids)

    if not no_prune:
        G = prune_single_successor_nodes(G)
        print_graph_stats(G, "pruned")
        assert_all_wnids_in_graph(G, wnids)

    if extra > 0:
        G, n_extra, n_imaginary = augment_graph(G, extra, True)
        print(f"[extra] \t Extras: {n_extra} \t Imaginary: {n_imaginary}")
        print_graph_stats(G, "extra")
        assert_all_wnids_in_graph(G, wnids)

    path = get_graph_path_from_args(
        dataset=dataset,
        method=method,
        seed=seed,
        branching_factor=branching_factor,
        extra=extra,
        no_prune=no_prune,
        fname=fname,
        path=path,
        single_path=single_path,
        induced_linkage=induced_linkage,
        induced_affinity=induced_affinity,
        checkpoint=checkpoint,
        arch=arch,
    )
    write_graph(G, path)

    Colors.green("==> Wrote tree to {}".format(path))
    return path
def generate_hierarchy_vis_from(G,
                                dataset,
                                path_html,
                                color="blue",
                                vis_root=None,
                                vis_no_color_leaves=False,
                                vis_color_path_to=None,
                                vis_color_nodes=(),
                                vis_theme="regular",
                                vis_force_labels_left=(),
                                vis_leaf_images=False,
                                vis_image_resize_factor=1,
                                vis_fake_sublabels=False,
                                vis_zoom=2,
                                vis_curved=False,
                                vis_sublabels=False,
                                vis_height=750,
                                vis_width=1000,
                                vis_margin_top=20,
                                vis_margin_left=250,
                                vis_hide=(),
                                vis_above_dy=325,
                                vis_below_dy=475,
                                vis_scale=1,
                                vis_root_y="null",
                                vis_colormap="colormap_annotated.png",
                                vis_node_conf=(),
                                verbose=False,
                                **kwargs):
    """
    :param path_html: Where to write final hierarchy
    """

    roots = list(get_roots(G))
    num_roots = len(roots)
    root = vis_root or next(get_roots(G))

    assert root in G, f"Node {root} is not a valid node. Nodes: {G.nodes}"

    color_info = get_color_info(
        G,
        color,
        color_leaves=not vis_no_color_leaves,
        color_path_to=vis_color_path_to,
        color_nodes=vis_color_nodes or (),
        theme=vis_theme,
    )

    node_to_conf = generate_node_conf(vis_node_conf)

    tree = build_tree(
        G,
        root,
        color_info=color_info,
        force_labels_left=vis_force_labels_left or [],
        dataset=dataset,
        include_leaf_images=vis_leaf_images,
        image_resize_factor=vis_image_resize_factor,
        include_fake_sublabels=vis_fake_sublabels,
        node_to_conf=node_to_conf,
    )
    graph = build_graph(G)

    if num_roots > 1:
        Colors.red(f"Found {num_roots} roots! Should be only 1: {roots}")
    elif verbose:
        print(f"Found just {num_roots} root.")

    parent = Path(fwd()).parent
    generate_vis(
        str(parent / "nbdt/templates/tree-template.html"),
        tree,
        path_html,
        zoom=vis_zoom,
        straight_lines=not vis_curved,
        show_sublabels=vis_sublabels,
        height=vis_height,
        bg=color_info["bg"],
        text_rect=color_info["text_rect"],
        width=vis_width,
        margin_top=vis_margin_top,
        margin_left=vis_margin_left,
        hide=vis_hide or [],
        above_dy=vis_above_dy,
        below_dy=vis_below_dy,
        scale=vis_scale,
        root_y=vis_root_y,
        colormap=vis_colormap,
        verbose=verbose,
    )
def generate_vis(
    path_template,
    data,
    path_html,
    zoom=2,
    straight_lines=True,
    show_sublabels=False,
    height=750,
    margin_top=20,
    above_dy=325,
    y_node_sep=170,
    hide=[],
    _print=False,
    out_dir=".",
    scale=1,
    colormap="colormap_annotated.png",
    below_dy=475,
    root_y="null",
    width=1000,
    margin_left=250,
    bg="#FFFFFF",
    text_rect="rgba(255,255,255,0.8)",
    stroke_width=0.45,
    verbose=False,
):
    fname = Path(path_html).stem
    out_dir = Path(path_html).parent
    with open(path_template) as f:
        html = (
            f.read().replace("CONFIG_MARGIN_LEFT", str(margin_left)).replace(
                "CONFIG_VIS_WIDTH",
                str(width)).replace("CONFIG_SCALE", str(scale)).replace(
                    "CONFIG_PRINT",
                    str(_print).lower()).replace(
                        "CONFIG_HIDE", str(hide)).replace(
                            "CONFIG_Y_NODE_SEP", str(y_node_sep)).replace(
                                "CONFIG_ABOVE_DY", str(above_dy)).replace(
                                    "CONFIG_BELOW_DY", str(below_dy)).replace(
                                        "CONFIG_TREE_DATA",
                                        json.dumps([data])).replace(
                                            "CONFIG_ZOOM", str(zoom)).replace(
                                                "CONFIG_STRAIGHT_LINES",
                                                str(straight_lines).lower()).
            replace(
                "CONFIG_SHOW_SUBLABELS",
                str(show_sublabels).lower()).replace(
                    "CONFIG_TITLE",
                    fname).replace("CONFIG_VIS_HEIGHT", str(height)).replace(
                        "CONFIG_BG_COLOR",
                        bg).replace(
                            "CONFIG_TEXT_RECT_COLOR", text_rect).replace(
                                "CONFIG_STROKE_WIDTH",
                                str(stroke_width)).replace(
                                    "CONFIG_MARGIN_TOP",
                                    str(margin_top)).replace(
                                        "CONFIG_ROOT_Y",
                                        str(root_y)).replace(
                                            "CONFIG_COLORMAP",
                                            f"""<img src="{colormap}" style="
        position: absolute;
        top: 40px;
        left: 80px;
        height: 250px;
        border: 4px solid #ccc;">""" if isinstance(colormap, str)
                                            and os.path.exists(colormap) else
                                            "",
                                        ))

    os.makedirs(out_dir, exist_ok=True)
    with open(path_html, "w") as f:
        f.write(html)

    if verbose:
        Colors.green("==> Wrote HTML to {}".format(path_html))
                  download=True,
                  transform=transform_test)

assert trainset.classes == testset.classes, (trainset.classes, testset.classes)

trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=2)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=100,
                                         shuffle=False,
                                         num_workers=2)

Colors.cyan(
    f'Training with dataset {args.dataset} and {len(trainset.classes)} classes'
)

# Model
print('==> Building model..')
model = getattr(models, args.arch)
model_kwargs = {'num_classes': len(trainset.classes)}

if args.pretrained:
    print('==> Loading pretrained model..')
    try:
        net = model(pretrained=True, dataset=args.dataset, **model_kwargs)
    except TypeError as e:  # likely because `dataset` not allowed arg
        print(e)

        try:
Example #14
0
dataset = getattr(data, args.dataset)

dataset_kwargs = generate_kwargs(args, dataset,
    name=f'Dataset {args.dataset}',
    keys=data.custom.keys,
    globals=globals())

trainset = dataset(**dataset_kwargs, root='./data', train=True, download=True, transform=transform_train)
testset = dataset(**dataset_kwargs, root='./data', train=False, download=True, transform=transform_test)

assert trainset.classes == testset.classes, (trainset.classes, testset.classes)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

Colors.cyan(f'Training with dataset {args.dataset} and {len(trainset.classes)} classes')

# Model
print('==> Building model..')
model = getattr(models, args.arch)
model_kwargs = {'num_classes': len(trainset.classes) }

if args.pretrained:
    print('==> Loading pretrained model..')
    net = model(pretrained=True, dataset=args.dataset, **model_kwargs)
else:
    print('==> Loading NBDT model..')
    net = ResNet18()
    net = SoftNBDT(pretrained=True, dataset='CIFAR10', arch='ResNet18', model=net)

#Analyzer
def generate_vis(path_template,
                 data,
                 fname,
                 zoom=2,
                 straight_lines=True,
                 show_sublabels=False,
                 height=750,
                 dark=False,
                 margin_top=20,
                 above_dy=325,
                 y_node_sep=170,
                 hide=[],
                 _print=False,
                 out_dir='.',
                 scale=1,
                 colormap='colormap_annotated.png',
                 below_dy=475,
                 root_y='null',
                 width=1000,
                 margin_left=250):
    with open(path_template) as f:
        html = f.read() \
        .replace(
            "CONFIG_MARGIN_LEFT",
            str(margin_left)) \
        .replace(
            "CONFIG_VIS_WIDTH",
            str(width)) \
        .replace(
            "CONFIG_SCALE",
            str(scale)) \
        .replace(
            "CONFIG_PRINT",
            str(_print).lower()) \
        .replace(
            "CONFIG_HIDE",
            str(hide)) \
        .replace(
            "CONFIG_Y_NODE_SEP",
            str(y_node_sep)) \
        .replace(
            "CONFIG_ABOVE_DY",
            str(above_dy)) \
        .replace(
            "CONFIG_BELOW_DY",
            str(below_dy)) \
        .replace(
            "CONFIG_TREE_DATA",
            json.dumps([data])) \
        .replace(
            "CONFIG_ZOOM",
            str(zoom)) \
        .replace(
            "CONFIG_STRAIGHT_LINES",
            str(straight_lines).lower()) \
        .replace(
            "CONFIG_SHOW_SUBLABELS",
            str(show_sublabels).lower()) \
        .replace(
            "CONFIG_TITLE",
            fname) \
        .replace(
            "CONFIG_VIS_HEIGHT",
            str(height)) \
        .replace(
            "CONFIG_BG_COLOR",
            "#111111" if dark else "#FFFFFF") \
        .replace(
            "CONFIG_TEXT_COLOR",
            '#FFFFFF' if dark else '#000000') \
        .replace(
            "CONFIG_TEXT_RECT_COLOR",
            "rgba(17,17,17,0.8)" if dark else "rgba(255,255,255,1)") \
        .replace(
            "CONFIG_MARGIN_TOP",
            str(margin_top)) \
        .replace(
            "CONFIG_ROOT_Y",
            str(root_y)) \
        .replace(
            "CONFIG_COLORMAP",
            f'''<img src="{colormap}" style="
        position: absolute;
        top: 40px;
        left: 80px;
        height: 250px;
        border: 4px solid #ccc;">''' if isinstance(colormap, str) and os.path.exists(colormap) else ''
        )

    os.makedirs(out_dir, exist_ok=True)
    path_html = f'{out_dir}/{fname}.html'
    with open(path_html, 'w') as f:
        f.write(html)

    Colors.green('==> Wrote HTML to {}'.format(path_html))
Example #16
0
def main():
    maybe_install_wordnet()
    datasets = data.cifar.names + data.imagenet.names + data.custom.names
    parser = argparse.ArgumentParser(description="PyTorch CIFAR Training")
    parser.add_argument("--batch-size",
                        default=512,
                        type=int,
                        help="Batch size used for training")
    parser.add_argument(
        "--epochs",
        "-e",
        default=200,
        type=int,
        help="By default, lr schedule is scaled accordingly",
    )
    parser.add_argument("--dataset", default="CIFAR10", choices=datasets)
    parser.add_argument("--arch",
                        default="ResNet18",
                        choices=list(models.get_model_choices()))
    parser.add_argument("--lr", default=0.1, type=float, help="learning rate")
    parser.add_argument("--resume",
                        "-r",
                        action="store_true",
                        help="resume from checkpoint")

    # extra general options for main script
    parser.add_argument("--path-resume",
                        default="",
                        help="Overrides checkpoint path generation")
    parser.add_argument(
        "--name",
        default="",
        help="Name of experiment. Used for checkpoint filename")
    parser.add_argument(
        "--pretrained",
        action="store_true",
        help="Download pretrained model. Not all models support this.",
    )
    parser.add_argument("--eval", help="eval only", action="store_true")
    parser.add_argument(
        "--dataset-test",
        choices=datasets,
        help="If not set, automatically set to train dataset",
    )
    parser.add_argument(
        "--disable-test-eval",
        help="Allows you to run model inference on a test dataset "
        " different from train dataset. Use an anlayzer to define "
        "a metric.",
        action="store_true",
    )

    # options specific to this project and its dataloaders
    parser.add_argument("--loss",
                        choices=loss.names,
                        default=["CrossEntropyLoss"],
                        nargs="+")
    parser.add_argument("--metric", choices=metrics.names, default="top1")
    parser.add_argument("--analysis",
                        choices=analysis.names,
                        help="Run analysis after each epoch")

    # other dataset, loss or analysis specific options
    data.custom.add_arguments(parser)
    T.add_arguments(parser)
    loss.add_arguments(parser)
    analysis.add_arguments(parser)

    args = parser.parse_args()
    loss.set_default_values(args)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # Data
    print("==> Preparing data..")
    dataset_train = getattr(data, args.dataset)
    dataset_test = getattr(data, args.dataset_test or args.dataset)

    transform_train = dataset_train.transform_train()
    transform_test = dataset_test.transform_val()

    dataset_train_kwargs = generate_kwargs(
        args,
        dataset_train,
        name=f"Dataset {dataset_train.__class__.__name__}",
        globals=locals(),
    )
    dataset_test_kwargs = generate_kwargs(
        args,
        dataset_test,
        name=f"Dataset {dataset_test.__class__.__name__}",
        globals=locals(),
    )
    trainset = dataset_train(
        **dataset_train_kwargs,
        root="./data",
        train=True,
        download=True,
        transform=transform_train,
    )
    testset = dataset_test(
        **dataset_test_kwargs,
        root="./data",
        train=False,
        download=True,
        transform=transform_test,
    )

    assert trainset.classes == testset.classes or args.disable_test_eval, (
        trainset.classes,
        testset.classes,
    )

    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=2)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=2)

    Colors.cyan(
        f"Training with dataset {args.dataset} and {len(trainset.classes)} classes"
    )
    Colors.cyan(
        f"Testing with dataset {args.dataset_test or args.dataset} and {len(testset.classes)} classes"
    )

    # Model
    print("==> Building model..")
    model = getattr(models, args.arch)

    if args.pretrained:
        print("==> Loading pretrained model..")
        model = make_kwarg_optional(model, dataset=args.dataset)
        net = model(pretrained=True, num_classes=len(trainset.classes))
    else:
        net = model(num_classes=len(trainset.classes))

    net = net.to(device)
    if device == "cuda":
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    checkpoint_fname = generate_checkpoint_fname(**vars(args))
    checkpoint_path = "./checkpoint/{}.pth".format(checkpoint_fname)
    print(f"==> Checkpoints will be saved to: {checkpoint_path}")

    resume_path = args.path_resume or checkpoint_path
    if args.resume:
        # Load checkpoint.
        print("==> Resuming from checkpoint..")
        assert os.path.isdir(
            "checkpoint"), "Error: no checkpoint directory found!"
        if not os.path.exists(resume_path):
            print("==> No checkpoint found. Skipping...")
        else:
            checkpoint = torch.load(resume_path,
                                    map_location=torch.device(device))

            if "net" in checkpoint:
                load_state_dict(net, checkpoint["net"])
                best_acc = checkpoint["acc"]
                start_epoch = checkpoint["epoch"]
                Colors.cyan(
                    f"==> Checkpoint found for epoch {start_epoch} with accuracy "
                    f"{best_acc} at {resume_path}")
            else:
                load_state_dict(net, checkpoint)
                Colors.cyan(f"==> Checkpoint found at {resume_path}")

    # hierarchy
    tree = Tree.create_from_args(args, classes=trainset.classes)

    # loss
    criterion = None
    for _loss in args.loss:
        if criterion is None and not hasattr(nn, _loss):
            criterion = nn.CrossEntropyLoss()
        class_criterion = getattr(loss, _loss)
        loss_kwargs = generate_kwargs(
            args,
            class_criterion,
            name=f"Loss {args.loss}",
            globals=locals(),
        )
        criterion = class_criterion(**loss_kwargs)

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[int(3 / 7.0 * args.epochs),
                    int(5 / 7.0 * args.epochs)])

    class_analysis = getattr(analysis, args.analysis or "Noop")
    analyzer_kwargs = generate_kwargs(
        args,
        class_analysis,
        name=f"Analyzer {args.analysis}",
        globals=locals(),
    )
    analyzer = class_analysis(**analyzer_kwargs)

    metric = getattr(metrics, args.metric)()

    # Training
    @analyzer.train_function
    def train(epoch):
        if hasattr(criterion, "set_epoch"):
            criterion.set_epoch(epoch, args.epochs)

        print("\nEpoch: %d / LR: %.04f" % (epoch, scheduler.get_last_lr()[0]))
        net.train()
        train_loss = 0
        metric.clear()
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            metric.forward(outputs, targets)
            transform = trainset.transform_val_inverse().to(device)
            stat = analyzer.update_batch(outputs, targets, transform(inputs))

            progress_bar(
                batch_idx,
                len(trainloader),
                "Loss: %.3f | Acc: %.3f%% (%d/%d) %s" % (
                    train_loss / (batch_idx + 1),
                    100.0 * metric.report(),
                    metric.correct,
                    metric.total,
                    f"| {analyzer.name}: {stat}" if stat else "",
                ),
            )
        scheduler.step()

    @analyzer.test_function
    def test(epoch, checkpoint=True):
        nonlocal best_acc
        net.eval()
        test_loss = 0
        metric.clear()
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)

                if not args.disable_test_eval:
                    loss = criterion(outputs, targets)
                    test_loss += loss.item()
                    metric.forward(outputs, targets)
                transform = testset.transform_val_inverse().to(device)
                stat = analyzer.update_batch(outputs, targets,
                                             transform(inputs))

                progress_bar(
                    batch_idx,
                    len(testloader),
                    "Loss: %.3f | Acc: %.3f%% (%d/%d) %s" % (
                        test_loss / (batch_idx + 1),
                        100.0 * metric.report(),
                        metric.correct,
                        metric.total,
                        f"| {analyzer.name}: {stat}" if stat else "",
                    ),
                )

        # Save checkpoint.
        acc = 100.0 * metric.report()
        print("Accuracy: {}, {}/{} | Best Accurracy: {}".format(
            acc, metric.correct, metric.total, best_acc))
        if acc > best_acc and checkpoint:
            Colors.green(f"Saving to {checkpoint_fname} ({acc})..")
            state = {
                "net": net.state_dict(),
                "acc": acc,
                "epoch": epoch,
            }
            os.makedirs("checkpoint", exist_ok=True)
            torch.save(state, f"./checkpoint/{checkpoint_fname}.pth")
            best_acc = acc

    if args.disable_test_eval and (not args.analysis
                                   or args.analysis == "Noop"):
        Colors.red(
            " * Warning: `disable_test_eval` is used but no custom metric "
            "`--analysis` is supplied. I suggest supplying an analysis to perform "
            " custom loss and accuracy calculation.")

    if args.eval:
        if not args.resume and not args.pretrained:
            Colors.red(" * Warning: Model is not loaded from checkpoint. "
                       "Use --resume or --pretrained (if supported)")
        with analyzer.epoch_context(0):
            test(0, checkpoint=False)
    else:
        for epoch in range(start_epoch, args.epochs):
            with analyzer.epoch_context(epoch):
                train(epoch)
                test(epoch)

    print(f"Best accuracy: {best_acc} // Checkpoint name: {checkpoint_fname}")