コード例 #1
0
    def graph(self, inputs):
        """
        Return model supernet graph.

        Parameters
        ----------
        inputs: tuple of tensor
            Inputs that will be feeded into the network.

        Returns
        -------
        dict
            Containing ``node``, in Tensorboard GraphDef format.
            Additional key ``mutable`` is a map from key to list of modules.
        """
        if not torch.__version__.startswith("1.4"):
            logger.warning(
                "Graph is only tested with PyTorch 1.4. Other versions might not work."
            )
        from nni._graph_utils import build_graph
        from google.protobuf import json_format
        # protobuf should be installed as long as tensorboard is installed
        try:
            self._connect_all = True
            graph_def, _ = build_graph(self.model, inputs, verbose=False)
            result = json_format.MessageToDict(graph_def)
        finally:
            self._connect_all = False

        # `mutable` is to map the keys to a list of corresponding modules.
        # A key can be linked to multiple modules, use `dedup=False` to find them all.
        result["mutable"] = defaultdict(list)
        for mutable in self.mutables.traverse(deduplicate=False):
            # A module will be represent in the format of
            # [{"type": "Net", "name": ""}, {"type": "Cell", "name": "cell1"}, {"type": "Conv2d": "name": "conv"}]
            # which will be concatenated into Net/Cell[cell1]/Conv2d[conv] in frontend.
            # This format is aligned with the scope name jit gives.
            modules = mutable.name.split(".")
            path = [{"type": self.model.__class__.__name__, "name": ""}]
            m = self.model
            for module in modules:
                m = getattr(m, module)
                path.append({"type": m.__class__.__name__, "name": module})
            result["mutable"][mutable.key].append(path)
        return result
コード例 #2
0
    def _test_graph(self, model, dummy_input, expected_file):
        actual_proto, _ = build_graph(model, dummy_input)

        assert os.path.exists(expected_file), expected_file
        with open(expected_file, "r") as f:
            expected_str = f.read()

        expected_proto = GraphDef()
        text_format.Parse(expected_str, expected_proto)

        self.assertEqual(len(expected_proto.node), len(actual_proto.node))
        for i in range(len(expected_proto.node)):
            expected_node = expected_proto.node[i]
            actual_node = actual_proto.node[i]
            self.assertEqual(expected_node.name, actual_node.name)
            self.assertEqual(expected_node.op, actual_node.op)
            self.assertEqual(expected_node.input, actual_node.input)
            self.assertEqual(expected_node.device, actual_node.device)
            self.assertEqual(sorted(expected_node.attr.keys()),
                             sorted(actual_node.attr.keys()))