def _change_node_to_module( node: Node, name: str, base_module: Module, new_module: Module, args: tuple, ) -> None: """Helper function to change an existing node to a module. The new module is registered in the base_module as a submodule. The attribute name is based on name{int}. The attributes of the node are changed so they point onto the new module. Args: node: existing node name: proposed name, real name is name{int} base_module: the module that should get new_module as a child new_module: the new module to register on the node and base_module args: arguments of the new node """ new_name = _get_free_name(base_module, name) node.op = "call_module" node.target = new_name node.args = args setattr(base_module, new_name, new_module)
def output(self, node: Node, target: Any, name: str): # Replace the output by dictionaries, holding either node type-wise or # edge type-wise data. def _recurse(value: Any) -> Any: if isinstance(value, Node): return { key: self.find_by_name(f'{value.name}__{key2str(key)}') for key in self.metadata[int(self.is_edge_level(value))] } elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list): return [_recurse(v) for v in value] elif isinstance(value, tuple): return tuple(_recurse(v) for v in value) else: return value if node.type is not None and isinstance(node.args[0], Node): output = node.args[0] Type = EdgeType if self.is_edge_level(output) else NodeType node.type = Dict[Type, node.type] else: node.type = None node.args = (_recurse(node.args[0]), )
def output(self, node: Node, target: Any, name: str): # Split the output to dictionaries, holding either node type-wise or # edge type-wise data. def _recurse(value: Any) -> Any: if isinstance(value, Node) and self.is_edge_level(value): self.graph.inserting_before(node) return self.graph.create_node( 'call_function', target=split_output, args=(value, self.find_by_name('edge_offset_dict')), name=f'{value.name}__split') pass elif isinstance(value, Node): self.graph.inserting_before(node) return self.graph.create_node( 'call_function', target=split_output, args=(value, self.find_by_name('node_offset_dict')), name=f'{value.name}__split') elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list): return [_recurse(v) for v in value] elif isinstance(value, tuple): return tuple(_recurse(v) for v in value) else: return value if node.type is not None and isinstance(node.args[0], Node): output = node.args[0] Type = EdgeType if self.is_edge_level(output) else NodeType node.type = Dict[Type, node.type] else: node.type = None node.args = (_recurse(node.args[0]), )
def call_message_passing_module(self, node: Node, target: Any, name: str): # Call the `HeteroBasisConv` wrapper instead instead of a single # message passing layer. We need to inject the `edge_type` as first # argument in order to do so. node.args = (self.find_by_name('edge_type'), ) + node.args