コード例 #1
0
 def apply(self, graph_model: GraphModule) -> GraphModule:
     for node in graph_model.graph.nodes:
         if self.match_node(node):
             self.rewrite_node(node, graph_model)
     graph_model.recompile()
     graph_model.graph.lint()
     return graph_model
コード例 #2
0
 def apply(self, graph_model: GraphModule) -> GraphModule:
     with graph_model.graph.inserting_after(self.node):
         quant_identity_node = graph_model.graph.call_module(
             self.module_name, args=(self.node, ))
     replace_all_uses_except(self.node, quant_identity_node,
                             [quant_identity_node])
     graph_model.recompile()
     graph_model.graph.lint()
     return graph_model
コード例 #3
0
 def is_converged(self, graph_model: GraphModule) -> bool:
     for cat_node in graph_model.graph.nodes:
         if (cat_node.target is torch.cat and len(cat_node.users) == 1
                 and cat_node.kwargs['dim'] == 1
                 and list(cat_node.users)[0].op == 'call_module'):
             bn_node = list(cat_node.users)[0]
             module = get_module(graph_model, bn_node.target)
             if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                 inp_nodes = cat_node.all_input_nodes
                 if all([
                         inp_node.op == 'call_module'
                         and len(inp_node.users) == 1
                         for inp_node in inp_nodes
                 ]):
                     before_mods = [
                         get_module(graph_model, inp_node.target)
                         for inp_node in inp_nodes
                     ]
                     if all(
                             isinstance(mod, self.before_modules_types)
                             for mod in before_mods):
                         assert inp_nodes == cat_node.kwargs['tensors']
                         num_features_list = [
                             get_output_channels(mod) for mod in before_mods
                         ]
                         chunk_bn_list = [
                             type(module)(n) for n in num_features_list
                         ]
                         for i, chunk_bn in enumerate(chunk_bn_list):
                             chunk_bn_name = f'{bn_node.name}_{i}'
                             graph_model.add_module(chunk_bn_name, chunk_bn)
                             start = sum(num_features_list[:i])
                             end = sum(num_features_list[:i + 1])
                             chunk_bn.weight.data = module.weight.data[
                                 start:end]
                             chunk_bn.bias.data = module.bias.data[
                                 start:end]
                             chunk_bn.running_mean = module.running_mean.data[
                                 start:end]
                             chunk_bn.running_var = module.running_var.data[
                                 start:end]
                             inp_node = cat_node.kwargs['tensors'][i]
                             with graph_model.graph.inserting_after(
                                     inp_node):
                                 chunk_bn_node = graph_model.graph.call_module(
                                     chunk_bn_name, args=(inp_node, ))
                             replace_all_uses_except(
                                 inp_node, chunk_bn_node, [chunk_bn_node])
                         bn_node.replace_all_uses_with(cat_node)
                         graph_model.graph.erase_node(bn_node)
                         del_module(graph_model, bn_node.target)
                         graph_model.graph.lint()
                         graph_model.recompile()
                         return False
     return True
コード例 #4
0
 def is_converged(self, graph_model: GraphModule):
     named_modules = dict(graph_model.named_modules())
     for node in graph_model.graph.nodes:
         for pattern in self.patterns:
             if matches_module_pattern(pattern, node, named_modules):
                 if len(node.args[0].users) > 1:
                     continue
                 layer = named_modules[node.args[0].target]
                 bn = named_modules[node.target]
                 merge_bn(layer, bn, get_output_channel_dim(layer))
                 node.replace_all_uses_with(node.args[0])
                 graph_model.graph.erase_node(node)
                 del_module(graph_model, node.target)
     graph_model.recompile()
     graph_model.graph.lint()
     return graph_model
コード例 #5
0
 def apply(self, model: GraphModule) -> GraphModule:
     for old_module in model.modules():
         if old_module is self.old_module_instance:
             # init the new module based on the old one
             replace_module(model, old_module, self.new_module_instance)
             break
     return model
コード例 #6
0
 def apply(self, model: GraphModule) -> GraphModule:
     for name, old_module in model.named_modules():
         if name == self.old_module_name:
             # init the new module based on the old one
             new_module = self._init_new_module(old_module)
             self._replace_old_module(model, old_module, new_module)
             break
     return model
コード例 #7
0
 def rewrite_node(self, node: Node, graph_model: GraphModule):
     module_name = node.name
     assert module_name not in dict(graph_model.named_modules()).keys()
     self.move_node_args_to_kwargs(node)
     node_kwargs, module_kwargs = self.split_kwargs(node)
     module = self.new_module_class(**module_kwargs,
                                    **self.new_module_kwargs)
     node.target = module_name
     node.op = 'call_module'
     node.kwargs = immutable_dict(node_kwargs)
     set_module(graph_model, module, module_name)
コード例 #8
0
ファイル: standardize.py プロジェクト: fpjentzsch/brevitas
 def apply(self, graph_model: GraphModule):
     named_mods = graph_model.named_modules(
     )  # duplicates are returned only once
     dup_mod_dict: Dict[str, int] = {}
     for name, mod in dict(named_mods).items():
         is_stateful = list(mod.parameters(recurse=True)) or list(
             mod.buffers(recurse=True))
         if not is_stateful:
             for node in list(graph_model.graph.nodes):
                 # duplicates are collapsed under the same target str during tracing
                 if isinstance(node.target, str) and node.target == name:
                     if name in dup_mod_dict.keys():
                         dup_mod_dict[name] += 1
                         dup_name = f'{name}_{dup_mod_dict[name]}'
                         set_module(graph_model, deepcopy(mod), dup_name)
                         node.target = dup_name
                     else:
                         dup_mod_dict[name] = 0
     graph_model.recompile()
     graph_model.graph.lint()
     return graph_model
コード例 #9
0
 def apply(self, model: GraphModule) -> GraphModule:
     old_new_module_dict = {}
     for old_module in model.modules():
         # check for equality, not inheritance
         if type(old_module) == self.old_module_class:
             # init the new module based on the old one
             new_module = self._init_new_module(old_module)
             # register modules pair to be replaced
             old_new_module_dict[old_module] = new_module
     # replace all pairs registered
     for old_module, new_module in old_new_module_dict.items():
         self._replace_old_module(model, old_module, new_module)
     return model