Beispiel #1
0
 def move_args_to_kwargs(self, graph_model):
     for node in graph_model.graph.nodes:
         if node.op == 'call_function' and node.target is torch.cat:
             if len(node.args) > 0:
                 if isinstance(node.args[-1], int):
                     kwargs = dict(node.kwargs)
                     kwargs['dim'] = node.args[-1]
                     node.kwargs = immutable_dict(kwargs)
                     node.args = node.args[:-1]
                 if isinstance(node.args[0], (tuple, list)):
                     kwargs = dict(node.kwargs)
                     kwargs['tensors'] = node.args[0]
                     node.kwargs = immutable_dict(kwargs)
                     node.args = node.args[1:]
Beispiel #2
0
 def rewrite_node(self, node: Node, graph_model: GraphModule):
     module_name = node.name
     assert module_name not in dict(graph_model.named_modules()).keys()
     self.move_node_args_to_kwargs(node)
     node_kwargs, module_kwargs = self.split_kwargs(node)
     module = self.new_module_class(**module_kwargs,
                                    **self.new_module_kwargs)
     node.target = module_name
     node.op = 'call_module'
     node.kwargs = immutable_dict(node_kwargs)
     set_module(graph_model, module, module_name)
Beispiel #3
0
 def move_node_args_to_kwargs(self, node: Node):
     """Move non Node args to kwargs, as long as we can resolve the fn signature somehow"""
     fn = node.target
     if fn in _TORCH_TESTING_DICT:
         fn = _TORCH_TESTING_DICT[fn]
     try:
         fn_kwargs = getcallargs(fn, *node.args, **node.kwargs)
         fn_args = []
         for k, a in list(fn_kwargs.items()):
             if isinstance(a, Node):
                 fn_args.append(fn_kwargs.pop(k))
             else:
                 break
         node.args = tuple(fn_args)
         node.kwargs = immutable_dict(fn_kwargs)
     except TypeError:
         pass
Beispiel #4
0
 def move_node_args_to_kwargs(self, node: Node):
     if 'self' in node.kwargs:
         node_kwargs = dict(node.kwargs)
         slf = node_kwargs.pop('self')
         node.kwargs = immutable_dict(node_kwargs)
         node.args = tuple([slf] + list(node.args))
Beispiel #5
0
 def move_node_args_to_kwargs(self, node: Node):
     if 'dim' in node.kwargs:
         node.kwargs = immutable_dict(dict(node.kwargs).pop('dim'))
     elif (2, 3) in node.args or [2, 3] in node.args:
         node.args = tuple(
             [a for a in node.args if a != (2, 3) and a != [2, 3]])