Exemplo n.º 1
0
    def output(self, node: Node, target: Any, name: str):
        # Replace the output by dictionaries, holding either node type-wise or
        # edge type-wise data.
        def _recurse(value: Any) -> Any:
            if isinstance(value, Node):
                return {
                    key: self.find_by_name(f'{value.name}__{key2str(key)}')
                    for key in self.metadata[int(self.is_edge_level(value))]
                }
            elif isinstance(value, dict):
                return {k: _recurse(v) for k, v in value.items()}
            elif isinstance(value, list):
                return [_recurse(v) for v in value]
            elif isinstance(value, tuple):
                return tuple(_recurse(v) for v in value)
            else:
                return value

        if node.type is not None and isinstance(node.args[0], Node):
            output = node.args[0]
            Type = EdgeType if self.is_edge_level(output) else NodeType
            node.type = Dict[Type, node.type]
        else:
            node.type = None

        node.args = (_recurse(node.args[0]), )
    def output(self, node: Node, target: Any, name: str):
        # Split the output to dictionaries, holding either node type-wise or
        # edge type-wise data.
        def _recurse(value: Any) -> Any:
            if isinstance(value, Node) and self.is_edge_level(value):
                self.graph.inserting_before(node)
                return self.graph.create_node(
                    'call_function',
                    target=split_output,
                    args=(value, self.find_by_name('edge_offset_dict')),
                    name=f'{value.name}__split')

                pass
            elif isinstance(value, Node):
                self.graph.inserting_before(node)
                return self.graph.create_node(
                    'call_function',
                    target=split_output,
                    args=(value, self.find_by_name('node_offset_dict')),
                    name=f'{value.name}__split')

            elif isinstance(value, dict):
                return {k: _recurse(v) for k, v in value.items()}
            elif isinstance(value, list):
                return [_recurse(v) for v in value]
            elif isinstance(value, tuple):
                return tuple(_recurse(v) for v in value)
            else:
                return value

        if node.type is not None and isinstance(node.args[0], Node):
            output = node.args[0]
            Type = EdgeType if self.is_edge_level(output) else NodeType
            node.type = Dict[Type, node.type]
        else:
            node.type = None

        node.args = (_recurse(node.args[0]), )
    def placeholder(self, node: Node, target: Any, name: str):
        # Adds a `get` call to the input dictionary for every node-type or
        # edge-type.
        if node.type is not None:
            Type = EdgeType if self.is_edge_level(node) else NodeType
            node.type = Dict[Type, node.type]

        self.graph.inserting_after(node)
        for key in self.metadata[int(self.is_edge_level(node))]:
            out = self.graph.create_node('call_method',
                                         target='get',
                                         args=(node, key),
                                         name=f'{name}__{key2str(key)}')
            self.graph.inserting_after(out)
    def placeholder(self, node: Node, target: Any, name: str):
        if node.type is not None:
            Type = EdgeType if self.is_edge_level(node) else NodeType
            node.type = Dict[Type, node.type]

        out = node

        # Create `node_offset_dict` and `edge_offset_dict` dictionaries in case
        # they are not yet initialized. These dictionaries hold the cumulated
        # sizes used to create a unified graph representation and to split the
        # output data.
        if self.is_edge_level(node) and not self._edge_offset_dict_initialized:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_function',
                                         target=get_edge_offset_dict,
                                         args=(node, self.edge_type2id),
                                         name='edge_offset_dict')
            self._edge_offset_dict_initialized = True

        elif not self._node_offset_dict_initialized:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_function',
                                         target=get_node_offset_dict,
                                         args=(node, self.node_type2id),
                                         name='node_offset_dict')
            self._node_offset_dict_initialized = True

        # Create a `edge_type` tensor used as input to `HeteroBasisConv`:
        if self.is_edge_level(node) and not self._edge_type_initialized:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_function',
                                         target=get_edge_type,
                                         args=(node, self.edge_type2id),
                                         name='edge_type')
            self._edge_type_initialized = True

        # Add `Linear` operation to align features to the same dimensionality:
        if name in self.in_channels:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_module',
                                         target=f'align_lin__{name}',
                                         args=(node, ),
                                         name=f'{name}__aligned')
            self._state[out.name] = self._state[name]

            lin = LinearAlign(self.metadata[int(self.is_edge_level(node))],
                              self.in_channels[name])
            setattr(self.module, f'align_lin__{name}', lin)

        # Perform grouping of type-wise values into a single tensor:
        if self.is_edge_level(node):
            self.graph.inserting_after(out)
            out = self.graph.create_node(
                'call_function',
                target=group_edge_placeholder,
                args=(out if name in self.in_channels else node,
                      self.edge_type2id,
                      self.find_by_name('node_offset_dict')),
                name=f'{name}__grouped')
            self._state[out.name] = 'edge'

        else:
            self.graph.inserting_after(out)
            out = self.graph.create_node(
                'call_function',
                target=group_node_placeholder,
                args=(out if name in self.in_channels else node,
                      self.node_type2id),
                name=f'{name}__grouped')
            self._state[out.name] = 'node'

        self.replace_all_uses_with(node, out)