def graph(self, inputs): """ Return model supernet graph. Parameters ---------- inputs: tuple of tensor Inputs that will be feeded into the network. Returns ------- dict Containing ``node``, in Tensorboard GraphDef format. Additional key ``mutable`` is a map from key to list of modules. """ if not torch.__version__.startswith("1.4"): logger.warning( "Graph is only tested with PyTorch 1.4. Other versions might not work." ) from nni._graph_utils import build_graph from google.protobuf import json_format # protobuf should be installed as long as tensorboard is installed try: self._connect_all = True graph_def, _ = build_graph(self.model, inputs, verbose=False) result = json_format.MessageToDict(graph_def) finally: self._connect_all = False # `mutable` is to map the keys to a list of corresponding modules. # A key can be linked to multiple modules, use `dedup=False` to find them all. result["mutable"] = defaultdict(list) for mutable in self.mutables.traverse(deduplicate=False): # A module will be represent in the format of # [{"type": "Net", "name": ""}, {"type": "Cell", "name": "cell1"}, {"type": "Conv2d": "name": "conv"}] # which will be concatenated into Net/Cell[cell1]/Conv2d[conv] in frontend. # This format is aligned with the scope name jit gives. modules = mutable.name.split(".") path = [{"type": self.model.__class__.__name__, "name": ""}] m = self.model for module in modules: m = getattr(m, module) path.append({"type": m.__class__.__name__, "name": module}) result["mutable"][mutable.key].append(path) return result
def _test_graph(self, model, dummy_input, expected_file): actual_proto, _ = build_graph(model, dummy_input) assert os.path.exists(expected_file), expected_file with open(expected_file, "r") as f: expected_str = f.read() expected_proto = GraphDef() text_format.Parse(expected_str, expected_proto) self.assertEqual(len(expected_proto.node), len(actual_proto.node)) for i in range(len(expected_proto.node)): expected_node = expected_proto.node[i] actual_node = actual_proto.node[i] self.assertEqual(expected_node.name, actual_node.name) self.assertEqual(expected_node.op, actual_node.op) self.assertEqual(expected_node.input, actual_node.input) self.assertEqual(expected_node.device, actual_node.device) self.assertEqual(sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))