def _change_node_to_module( node: Node, name: str, base_module: Module, new_module: Module, args: tuple, ) -> None: """Helper function to change an existing node to a module. The new module is registered in the base_module as a submodule. The attribute name is based on name{int}. The attributes of the node are changed so they point onto the new module. Args: node: existing node name: proposed name, real name is name{int} base_module: the module that should get new_module as a child new_module: the new module to register on the node and base_module args: arguments of the new node """ new_name = _get_free_name(base_module, name) node.op = "call_module" node.target = new_name node.args = args setattr(base_module, new_name, new_module)
def output(self, node: Node, target: Any, name: str): # Replace the output by dictionaries, holding either node type-wise or # edge type-wise data. def _recurse(value: Any) -> Any: if isinstance(value, Node): return { key: self.find_by_name(f'{value.name}__{key2str(key)}') for key in self.metadata[int(self.is_edge_level(value))] } elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list): return [_recurse(v) for v in value] elif isinstance(value, tuple): return tuple(_recurse(v) for v in value) else: return value if node.type is not None and isinstance(node.args[0], Node): output = node.args[0] Type = EdgeType if self.is_edge_level(output) else NodeType node.type = Dict[Type, node.type] else: node.type = None node.args = (_recurse(node.args[0]), )
def run_node(self, node: Node): self.node_counter += 1 result = super().run_node(node) node.meta['fake_result'] = result node.meta['node_idx'] = self.node_counter # (1) Update metadata with the list of nodes that are used by this node # copy_() doesn't read from its first argument; it writes to it, overwriting previous data. # We don't want to treat it as "being used as an input". node_args = node.args if node.target is torch.ops.aten.copy_.default: node_args = node_args[1:] # (2) Update metadata to track aliasing information about view tensor nodes. if node.op == 'call_function': view_type = _get_view_type(node.target) if view_type == _ViewType.SingleOutputView: assert isinstance(node.args[0], Node) node.meta['view_of'] = node.args[0] elif view_type == _ViewType.MultiOutputView: self.multi_output_view_nodes[node] = node.args[0] # Check if we returned a multi-output view, # and we're now grabbing the individual views from the output. # # For multi-output views, we want to map each output view to the base, # but this mapping involves two separate nodes in FX IR. # e.g. "a, b = x_1.split(...)" becomes: # %split_tensor : [#users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {}) # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {}) # %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {}) # And we'd like to set: # getitem1.meta['view_of'] = x_1 elif node.target is _operator.getitem: list_arg = node.args[0] maybe_base_of_view = self.multi_output_view_nodes.get( list_arg, None) if maybe_base_of_view is not None: # Note: we could also track indexing info here for multi-output views. # I don't think this metadata is strictly needed for de-functionalization. assert isinstance(maybe_base_of_view, Node) node.meta['view_of'] = maybe_base_of_view if 'view_of' in node.meta: # We're linking the current node with its first argument as views. # Assert here that this is actually the case, and their storages are the same. assert isinstance(node.meta['fake_result'], FakeTensor) assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor) view_storage = StorageWeakRef(node.meta['fake_result'].storage()) base_storage = StorageWeakRef( node.meta['view_of'].meta['fake_result'].storage()) assert view_storage == base_storage return result
def output(self, node: Node, target: Any, name: str): # Split the output to dictionaries, holding either node type-wise or # edge type-wise data. def _recurse(value: Any) -> Any: if isinstance(value, Node) and self.is_edge_level(value): self.graph.inserting_before(node) return self.graph.create_node( 'call_function', target=split_output, args=(value, self.find_by_name('edge_offset_dict')), name=f'{value.name}__split') pass elif isinstance(value, Node): self.graph.inserting_before(node) return self.graph.create_node( 'call_function', target=split_output, args=(value, self.find_by_name('node_offset_dict')), name=f'{value.name}__split') elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list): return [_recurse(v) for v in value] elif isinstance(value, tuple): return tuple(_recurse(v) for v in value) else: return value if node.type is not None and isinstance(node.args[0], Node): output = node.args[0] Type = EdgeType if self.is_edge_level(output) else NodeType node.type = Dict[Type, node.type] else: node.type = None node.args = (_recurse(node.args[0]), )
def placeholder(self, node: Node, target: Any, name: str): # Adds a `get` call to the input dictionary for every node-type or # edge-type. if node.type is not None: Type = EdgeType if self.is_edge_level(node) else NodeType node.type = Dict[Type, node.type] self.graph.inserting_after(node) for key in self.metadata[int(self.is_edge_level(node))]: out = self.graph.create_node('call_method', target='get', args=(node, key), name=f'{name}__{key2str(key)}') self.graph.inserting_after(out)
def call_message_passing_module(self, node: Node, target: Any, name: str): # Call the `HeteroBasisConv` wrapper instead instead of a single # message passing layer. We need to inject the `edge_type` as first # argument in order to do so. node.args = (self.find_by_name('edge_type'), ) + node.args
def placeholder(self, node: Node, target: Any, name: str): if node.type is not None: Type = EdgeType if self.is_edge_level(node) else NodeType node.type = Dict[Type, node.type] out = node # Create `node_offset_dict` and `edge_offset_dict` dictionaries in case # they are not yet initialized. These dictionaries hold the cumulated # sizes used to create a unified graph representation and to split the # output data. if self.is_edge_level(node) and not self._edge_offset_dict_initialized: self.graph.inserting_after(out) out = self.graph.create_node('call_function', target=get_edge_offset_dict, args=(node, self.edge_type2id), name='edge_offset_dict') self._edge_offset_dict_initialized = True elif not self._node_offset_dict_initialized: self.graph.inserting_after(out) out = self.graph.create_node('call_function', target=get_node_offset_dict, args=(node, self.node_type2id), name='node_offset_dict') self._node_offset_dict_initialized = True # Create a `edge_type` tensor used as input to `HeteroBasisConv`: if self.is_edge_level(node) and not self._edge_type_initialized: self.graph.inserting_after(out) out = self.graph.create_node('call_function', target=get_edge_type, args=(node, self.edge_type2id), name='edge_type') self._edge_type_initialized = True # Add `Linear` operation to align features to the same dimensionality: if name in self.in_channels: self.graph.inserting_after(out) out = self.graph.create_node('call_module', target=f'align_lin__{name}', args=(node, ), name=f'{name}__aligned') self._state[out.name] = self._state[name] lin = LinearAlign(self.metadata[int(self.is_edge_level(node))], self.in_channels[name]) setattr(self.module, f'align_lin__{name}', lin) # Perform grouping of type-wise values into a single tensor: if self.is_edge_level(node): self.graph.inserting_after(out) out = self.graph.create_node( 'call_function', target=group_edge_placeholder, args=(out if name in self.in_channels else node, self.edge_type2id, self.find_by_name('node_offset_dict')), name=f'{name}__grouped') self._state[out.name] = 'edge' else: self.graph.inserting_after(out) out = self.graph.create_node( 'call_function', target=group_node_placeholder, args=(out if name in self.in_channels else node, self.node_type2id), name=f'{name}__grouped') self._state[out.name] = 'node' self.replace_all_uses_with(node, out)
def run_node(self, n: Node): result = super().run_node(n) n.meta['fake_result'] = result return result