示例#1
0
def _change_node_to_module(
    node: Node,
    name: str,
    base_module: Module,
    new_module: Module,
    args: tuple,
) -> None:
    """Helper function to change an existing node to a module.

    The new module is registered in the base_module as a submodule.
    The attribute name is based on name{int}.
    The attributes of the node are changed so they point onto the new module.

    Args:
        node: existing node
        name: proposed name, real name is name{int}
        base_module: the module that should get new_module as a child
        new_module: the new module to register on the node and base_module
        args: arguments of the new node
    """
    new_name = _get_free_name(base_module, name)
    node.op = "call_module"
    node.target = new_name
    node.args = args
    setattr(base_module, new_name, new_module)
示例#2
0
    def output(self, node: Node, target: Any, name: str):
        # Replace the output by dictionaries, holding either node type-wise or
        # edge type-wise data.
        def _recurse(value: Any) -> Any:
            if isinstance(value, Node):
                return {
                    key: self.find_by_name(f'{value.name}__{key2str(key)}')
                    for key in self.metadata[int(self.is_edge_level(value))]
                }
            elif isinstance(value, dict):
                return {k: _recurse(v) for k, v in value.items()}
            elif isinstance(value, list):
                return [_recurse(v) for v in value]
            elif isinstance(value, tuple):
                return tuple(_recurse(v) for v in value)
            else:
                return value

        if node.type is not None and isinstance(node.args[0], Node):
            output = node.args[0]
            Type = EdgeType if self.is_edge_level(output) else NodeType
            node.type = Dict[Type, node.type]
        else:
            node.type = None

        node.args = (_recurse(node.args[0]), )
    def output(self, node: Node, target: Any, name: str):
        # Split the output to dictionaries, holding either node type-wise or
        # edge type-wise data.
        def _recurse(value: Any) -> Any:
            if isinstance(value, Node) and self.is_edge_level(value):
                self.graph.inserting_before(node)
                return self.graph.create_node(
                    'call_function',
                    target=split_output,
                    args=(value, self.find_by_name('edge_offset_dict')),
                    name=f'{value.name}__split')

                pass
            elif isinstance(value, Node):
                self.graph.inserting_before(node)
                return self.graph.create_node(
                    'call_function',
                    target=split_output,
                    args=(value, self.find_by_name('node_offset_dict')),
                    name=f'{value.name}__split')

            elif isinstance(value, dict):
                return {k: _recurse(v) for k, v in value.items()}
            elif isinstance(value, list):
                return [_recurse(v) for v in value]
            elif isinstance(value, tuple):
                return tuple(_recurse(v) for v in value)
            else:
                return value

        if node.type is not None and isinstance(node.args[0], Node):
            output = node.args[0]
            Type = EdgeType if self.is_edge_level(output) else NodeType
            node.type = Dict[Type, node.type]
        else:
            node.type = None

        node.args = (_recurse(node.args[0]), )
 def call_message_passing_module(self, node: Node, target: Any, name: str):
     # Call the `HeteroBasisConv` wrapper instead instead of a single
     # message passing layer. We need to inject the `edge_type` as first
     # argument in order to do so.
     node.args = (self.find_by_name('edge_type'), ) + node.args