def test_pytorch_graph(self): dummy_input = (torch.zeros(1, 3),) class myLinear(torch.nn.Module): def __init__(self): super(myLinear, self).__init__() self.l = torch.nn.Linear(3, 5) def forward(self, x): return self.l(x) with self.createSummaryWriter() as w: w.add_graph(myLinear(), dummy_input) actual_proto, _ = graph(myLinear(), dummy_input) expected_str = read_expected_content(self) expected_proto = GraphDef() text_format.Parse(expected_str, expected_proto) self.assertEquals(len(expected_proto.node), len(actual_proto.node)) for i in range(len(expected_proto.node)): expected_node = expected_proto.node[i] actual_node = actual_proto.node[i] self.assertEquals(expected_node.name, actual_node.name) self.assertEquals(expected_node.op, actual_node.op) self.assertEquals(expected_node.input, actual_node.input) self.assertEquals(expected_node.device, actual_node.device) self.assertEquals( sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
def add_graph(self, model, input_to_model=None, verbose=False): # prohibit second call? # no, let tensorboard handle it and show its warning message. torch._C._log_api_usage_once("tensorboard.logging.add_graph") if hasattr(model, 'forward'): # A valid PyTorch model should have a 'forward' method self._get_file_writer().add_graph( graph(model, input_to_model, verbose)) else: # Caffe2 models do not have the 'forward' method from caffe2.proto import caffe2_pb2 from caffe2.python import core from torch.utils.tensorboard._caffe2_graph import ( model_to_graph_def, nets_to_graph_def, protos_to_graph_def) if isinstance(model, list): if isinstance(model[0], core.Net): current_graph = nets_to_graph_def(model) elif isinstance(model[0], caffe2_pb2.NetDef): current_graph = protos_to_graph_def(model) else: # Handles cnn.CNNModelHelper, model_helper.ModelHelper current_graph = model_to_graph_def(model) event = event_pb2.Event( graph_def=current_graph.SerializeToString()) self._get_file_writer().add_event(event)
def bonsai_parser(model, model_in): """ Full parsing function, handling the route layers Args: model: pytorch model to be processed model_in: model input Returns: A complete BonsaiParsedModel """ # Simple parsing, without the routing layers bonsai_parsed_model = parse_simple_model(model, model_in.size()) # Getting the graph that represents the underlying network connectivity gd = pg.graph(model, args=(model_in,)) name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(gd[0]) # Convert node numbers to their short weight name graph_layers_to_weights = {get_node_name(k):v for k,v in name_to_seq_num.items()} # Route layers route_layers = {k: v.op.split('::')[1] for k,v in name_to_node.items() if v.op in ['onnx::Concat', 'onnx::Add']} route_weight_names = [get_node_name(x) for x in route_layers] # matching node (full name, node shortened name, and previous nodes connected by the graph) raw_predecessors = {(k, get_node_name(k)): [get_node_name(x) for x in in_names_list] for k, in_names_list in name_to_input_name.items()} # removing duplicates and empty strings predecessors = {weight: list(set([x for x in weight_list if len(x) > 0])) for (name, weight), weight_list in raw_predecessors.items()} # removing nodes that are intermediate values, they dont correspond to graph layers real_nodes = list(set(bonsai_parsed_model.get_weight_names())) + route_weight_names real_predecessors = get_real_predecessors(predecessors.copy(), real_nodes) # getting the relevant layers for the routing computation route_real_predecessors = {k:v for k,v in real_predecessors.items() if k in route_weight_names} route_predecessors_layers = {k:bonsai_parsed_model.get_layers_by_weights(v) for k,v in route_real_predecessors.items()} # computing the layer number of the generated route layer # here we set it to be 1 after the node that is previous to it in the GraphDef graph graph_connection = {graph_layers_to_weights[k]: [graph_layers_to_weights[x] for x in v] for k, v in route_real_predecessors.items()} prev_index = {k: v.index(int(k)-1) for k, v in graph_connection.items()} res_layers = {v[prev_index[graph_layers_to_weights[k]]] + 1: v for k, v in route_predecessors_layers.items()} # keeping in mind that layer indices shift when we add new layers shifted_values = {k: [val + len([x for x in res_layers if int(x) < int(val)]) for val in v] for k, v in res_layers.items()} final_layers = {int(k) + len([x for x in shifted_values if int(x) < int(k)]): v for k, v in shifted_values.items()} # adding the layers to the model for (k, v), operation in zip(final_layers.items(), route_layers.values()): if operation == 'Concat': bonsai_parsed_model.insert_module(int(k), 'route') bonsai_parsed_model.insert_param(int(k), 'layers', str(v)) elif operation == 'Add': bonsai_parsed_model.insert_module(int(k), 'residual_add') bonsai_parsed_model.insert_param(int(k), 'layers', str(v)) return bonsai_parsed_model
def test_pytorch_graph(self): dummy_input = (torch.zeros(1, 3), ) class myLinear(torch.nn.Module): def __init__(self): super(myLinear, self).__init__() self.l = torch.nn.Linear(3, 5) def forward(self, x): return self.l(x) with self.createSummaryWriter() as w: w.add_graph(myLinear(), dummy_input) graphdef, _ = graph(myLinear(), dummy_input) self.assertTrue(compare_proto(graphdef, self))
def test_nested_nn_squential(self): dummy_input = torch.randn(2, 3) class InnerNNSquential(torch.nn.Module): def __init__(self, dim1, dim2): super().__init__() self.inner_nn_squential = torch.nn.Sequential( torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, dim1), ) def forward(self, x): x = self.inner_nn_squential(x) return x class OuterNNSquential(torch.nn.Module): def __init__(self, dim1=3, dim2=4, depth=2): super().__init__() layers = [] for _ in range(depth): layers.append(InnerNNSquential(dim1, dim2)) self.outer_nn_squential = torch.nn.Sequential(*layers) def forward(self, x): x = self.outer_nn_squential(x) return x with self.createSummaryWriter() as w: w.add_graph(OuterNNSquential(), dummy_input) actual_proto, _ = graph(OuterNNSquential(), dummy_input) expected_str = read_expected_content(self) expected_proto = GraphDef() text_format.Parse(expected_str, expected_proto) self.assertEqual(len(expected_proto.node), len(actual_proto.node)) for i in range(len(expected_proto.node)): expected_node = expected_proto.node[i] actual_node = actual_proto.node[i] self.assertEqual(expected_node.name, actual_node.name) self.assertEqual(expected_node.op, actual_node.op) self.assertEqual(expected_node.input, actual_node.input) self.assertEqual(expected_node.device, actual_node.device) self.assertEqual(sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))