Ejemplo n.º 1
0
def test_graph(summary):
    graph = Graph()
    node_a = Node('a', 'Node A', size=(4, ))
    node_b = Node('b', 'Node B', size=(16, ))
    graph.add_node(node_a)
    graph.add_node(node_b)
    graph.add_edge(node_a, node_b)
    summary["graph"] = graph
    graph = disk_summary(summary)["graph"]
    path = graph["path"]
    data = open(os.path.join(summary._run.dir, path)).read()
    graph_data = json.loads(data)
    assert graph_data == {
        'edges': [['a', 'b']],
        'format':
        'keras',
        'nodes': [{
            'id': 'a',
            'name': 'Node A',
            'size': [4]
        }, {
            'id': 'b',
            'name': 'Node B',
            'size': [16]
        }]
    }
Ejemplo n.º 2
0
        def after_forward_hook(module, input, output):
            if id(module) in modules:
                return
            modules.add(id(module))
            if not isinstance(output, tuple):
                output = (output, )
            parameters = [(pname, list(param.size()))
                          for pname, param in module.named_parameters()]

            node = Node(id=id(module),
                        name=name,
                        class_name=str(module),
                        output_shape=nested_shape(output),
                        parameters=parameters,
                        num_parameters=[
                            reduce(mul, size, 1)
                            for (pname, size) in parameters
                        ])
            graph.nodes_by_id[id(module)] = node
            for param in module.parameters():
                graph.nodes_by_id[id(param)] = node
            graph.add_node(node)
            if not graph.criterion_passed:
                if hasattr(output[0], 'grad_fn'):
                    graph.criterion = output[0].grad_fn
                elif isinstance(output[0], list) and hasattr(
                        output[0][0], 'grad_fn'):
                    graph.criterion = output[0][0].grad_fn
Ejemplo n.º 3
0
        def after_forward_hook(module, input, output):
            if id(module) not in self._graph_hooks:
                # hook already processed -> noop
                return
            if not isinstance(output, tuple):
                output = (output,)
            parameters = [
                (pname, list(param.size()))
                for pname, param in module.named_parameters()
            ]

            node = Node(
                id=id(module),
                name=name,
                class_name=str(module),
                output_shape=nested_shape(output),
                parameters=parameters,
                num_parameters=[reduce(mul, size, 1) for (pname, size) in parameters],
            )
            graph.nodes_by_id[id(module)] = node
            for param in module.parameters():
                graph.nodes_by_id[id(param)] = node
            graph.add_node(node)
            if not graph.criterion_passed:
                if hasattr(output[0], "grad_fn"):
                    graph.criterion = output[0].grad_fn
                elif (
                    isinstance(output[0], list)
                    and output[0]
                    and hasattr(output[0][0], "grad_fn")
                ):
                    graph.criterion = output[0][0].grad_fn

            # hook has been processed
            self._graph_hooks -= {id(module)}

            if not self._graph_hooks:
                # we went through the entire graph
                wandb.run.summary["graph_%i" % graph_idx] = self
Ejemplo n.º 4
0
        def after_forward_hook(module, input, output):
            if id(module) not in self._graph_hooks.keys():
                # shound not happen
                return
            if not isinstance(output, tuple):
                output = (output, )
            parameters = [(pname, list(param.size()))
                          for pname, param in module.named_parameters()]

            node = Node(
                id=id(module),
                name=name,
                class_name=str(module),
                output_shape=nested_shape(output),
                parameters=parameters,
                num_parameters=[
                    reduce(mul, size, 1) for (pname, size) in parameters
                ],
            )
            graph.nodes_by_id[id(module)] = node
            for param in module.parameters():
                graph.nodes_by_id[id(param)] = node
            graph.add_node(node)
            if not graph.criterion_passed:
                if hasattr(output[0], "grad_fn"):
                    graph.criterion = output[0].grad_fn
                elif isinstance(output[0], list) and hasattr(
                        output[0][0], "grad_fn"):
                    graph.criterion = output[0][0].grad_fn

            # log graph and remove hook
            hook = self._graph_hooks.pop(id(module), None)
            if hook is not None:
                hook.remove()

            if not self._graph_hooks:
                # we went through the entire graph
                wandb.run.summary["graph_%i" % graph_idx] = self
Ejemplo n.º 5
0
    def from_torch_layers(cls, module_graph, variable):
        """Recover something like neural net layers from PyTorch Module's and the
        compute graph from a Variable.

        Example output for a multi-layer RNN. We confusingly assign shared embedding values
        to the encoder, but ordered next to the decoder.

        rnns.0.linear.module.weight_raw rnns.0
        rnns.0.linear.module.bias rnns.0
        rnns.1.linear.module.weight_raw rnns.1
        rnns.1.linear.module.bias rnns.1
        rnns.2.linear.module.weight_raw rnns.2
        rnns.2.linear.module.bias rnns.2
        rnns.3.linear.module.weight_raw rnns.3
        rnns.3.linear.module.bias rnns.3
        decoder.weight encoder
        decoder.bias decoder
        """
        # TODO: We're currently not using this, but I left it here incase we want to resurrect! - CVP
        torch = util.get_module("torch", "Could not import torch")

        module_nodes_by_hash = {id(n): n for n in module_graph.nodes}
        module_parameter_nodes = [
            n for n in module_graph.nodes if isinstance(n.obj, torch.nn.Parameter)
        ]

        names_by_pid = {id(n.obj): n.name for n in module_parameter_nodes}

        reachable_param_nodes = module_graph[0].reachable_descendents()
        reachable_params = {}
        module_reachable_params = {}
        names = {}
        for pid, reachable_nodes in reachable_param_nodes.items():
            node = module_nodes_by_hash[pid]
            if not isinstance(node.obj, torch.nn.Module):
                continue
            module = node.obj
            reachable_params = {}  # by object id
            module_reachable_params[id(module)] = reachable_params
            names[node.name] = set()
            for reachable_hash in reachable_nodes:
                reachable = module_nodes_by_hash[reachable_hash]
                if isinstance(reachable.obj, torch.nn.Parameter):
                    param = reachable.obj
                    reachable_params[id(param)] = param
                    names[node.name].add(names_by_pid[id(param)])

        # we look for correspondences between sets of parameters used in subtrees of the
        # computation graph and sets of parameters contained in subtrees of the module
        # graph
        node_depths = {id(n): d for n, d in module_graph[0].descendent_bfs()}
        parameter_module_names = {}
        parameter_modules = {}
        for param_node in (
            n for n in module_graph.nodes if isinstance(n.obj, torch.nn.Parameter)
        ):
            pid = id(param_node.obj)
            best_node = None
            best_depth = None
            best_reachable_params = None
            for node in module_graph.nodes:
                if not isinstance(node.obj, torch.nn.Module):
                    continue
                module = node.obj
                reachable_params = module_reachable_params[id(module)]
                if pid in reachable_params:
                    depth = node_depths[id(node)]
                    if best_node is None or (len(reachable_params), depth) <= (
                        len(best_reachable_params),
                        best_depth,
                    ):
                        best_node = node
                        best_depth = depth
                        best_reachable_params = reachable_params

            parameter_modules[pid] = best_node
            parameter_module_names[param_node.name] = best_node.name

        # contains all parameters but only a minimal set of modules necessary
        # to contain them (and which ideally correspond to conceptual layers)
        reduced_module_graph = cls()
        rmg_ids = itertools.count()
        rmg_root = Node(id=next(rmg_ids), node=module_graph[0])
        reduced_module_graph.add_node(rmg_root)
        reduced_module_graph.root = rmg_root
        rmg_nodes_by_pid = {}

        module_nodes_by_pid = {id(n.obj): n for n in module_graph.nodes}

        compute_graph, compute_node_vars = cls.from_torch_compute_graph(variable)
        for node, _ in reversed(list(compute_graph[0].ancestor_bfs())):
            param = compute_node_vars.get(node.id)
            pid = id(param)
            if not isinstance(param, torch.nn.Parameter):
                continue
            if pid not in module_nodes_by_pid:
                # not all Parameters that occur in the compute graph come from the Module graph
                continue

            # add the nodes in the order we want to display them on the frontend
            mid = id(parameter_modules[pid].obj)
            if mid in rmg_nodes_by_pid:
                rmg_module = rmg_nodes_by_pid[mid]
            else:
                rmg_module = rmg_nodes_by_pid[mid] = Node(
                    id=next(rmg_ids), node=module_nodes_by_pid[mid]
                )
                reduced_module_graph.add_node(rmg_module)
                reduced_module_graph.add_edge(rmg_root, rmg_module)

            rmg_param = Node(id=next(rmg_ids), node=module_nodes_by_pid[pid])
            rmg_nodes_by_pid[pid] = rmg_param
            reduced_module_graph.add_node(rmg_param)

            reduced_module_graph.add_edge(rmg_module, rmg_param)
        return reduced_module_graph