def __init__(self, dataset, path_graph=None, path_wnids=None, classes=()):

        if not path_graph:
            path_graph = dataset_to_default_path_graph(dataset)
        if not path_wnids:
            path_wnids = dataset_to_default_path_wnids(dataset)
        if not classes:
            classes = dataset_to_dummy_classes(dataset)
        super().__init__()
        assert all([dataset, path_graph, path_wnids, classes])

        self.classes = classes

        self.nodes = Node.get_nodes(path_graph, path_wnids, classes)
        self.G = self.nodes[0].G
        self.wnid_to_node = {node.wnid: node for node in self.nodes}

        self.wnids = get_wnids(path_wnids)
        self.wnid_to_class = {
            wnid: cls
            for wnid, cls in zip(self.wnids, self.classes)
        }

        self.correct = 0
        self.total = 0

        self.I = torch.eye(len(classes))
示例#2
0
    def __init__(self,
            dataset,
            model,
            arch=None,
            path_graph=None,
            path_wnids=None,
            classes=None,
            hierarchy=None,
            pretrained=None,
            **kwargs):
        super().__init__()

        if dataset and not hierarchy and not path_graph:
            assert arch, 'Must specify `arch` if no `hierarchy` or `path_graph`'
            hierarchy = f'induced-{arch}'
        if dataset and hierarchy and not path_graph:
            path_graph = hierarchy_to_path_graph(dataset, hierarchy)
        if dataset and not path_graph:
            path_graph = dataset_to_default_path_graph(dataset)
        if dataset and not path_wnids:
            path_wnids = dataset_to_default_path_wnids(dataset)
        if dataset and not classes:
            classes = dataset_to_dummy_classes(dataset)
        if pretrained and not arch:
            raise UserWarning(
                'To load a pretrained NBDT, you need to specify the `arch`. '
                '`arch` is the name of the architecture. e.g., ResNet18')
        if isinstance(model, str):
            raise NotImplementedError('Model must be nn.Module')

        self.init(dataset, model, path_graph, path_wnids, classes,
            arch=arch, pretrained=pretrained, hierarchy=hierarchy, **kwargs)
示例#3
0
    def __init__(self,
                 dataset,
                 criterion,
                 path_graph=None,
                 path_wnids=None,
                 classes=None,
                 hierarchy=None,
                 Rules=HardEmbeddedDecisionRules,
                 **kwargs):
        super().__init__()

        if dataset and hierarchy and not path_graph:
            path_graph = hierarchy_to_path_graph(dataset, hierarchy)
        if dataset and not path_graph:
            path_graph = dataset_to_default_path_graph(dataset)
        if dataset and not path_wnids:
            path_wnids = dataset_to_default_path_wnids(dataset)
        if dataset and not classes:
            classes = dataset_to_dummy_classes(dataset)

        self.init(dataset,
                  criterion,
                  path_graph,
                  path_wnids,
                  classes,
                  Rules=Rules,
                  **kwargs)
def set_default_values(args):
    assert not (
        args.hierarchy and args.path_graph
    ), "Only one, between --hierarchy and --path-graph can be provided."
    if args.hierarchy and not args.path_graph:
        args.path_graph = hierarchy_to_path_graph(args.dataset, args.hierarchy)
    if not args.path_graph:
        args.path_graph = dataset_to_default_path_graph(args.dataset)
    if not args.path_wnids:
        args.path_wnids = dataset_to_default_path_wnids(args.dataset)
def set_default_values(args):
    assert not (args.hierarchy and args.path_graph), \
        'Only one, between --hierarchy and --path-graph can be provided.'
    if 'TreeSupLoss' not in args.loss:
        return
    if args.hierarchy and not args.path_graph:
        args.path_graph = hierarchy_to_path_graph(args.dataset, args.hierarchy)
    if not args.path_graph:
        args.path_graph = dataset_to_default_path_graph(args.dataset)
    if not args.path_wnids:
        args.path_wnids = dataset_to_default_path_wnids(args.dataset)
示例#6
0
    def __init__(self,
                 dataset,
                 path_graph=None,
                 path_wnids=None,
                 classes=None,
                 hierarchy=None):
        if dataset and hierarchy and not path_graph:
            path_graph = hierarchy_to_path_graph(dataset, hierarchy)
        if dataset and not path_graph:
            path_graph = dataset_to_default_path_graph(dataset)
        if dataset and not path_wnids:
            path_wnids = dataset_to_default_path_wnids(dataset)
        if dataset and not classes:
            classes = dataset_to_dummy_classes(dataset)

        self.load_hierarchy(dataset, path_graph, path_wnids, classes)