def output(self, node: Node, target: Any, name: str): # Replace the output by dictionaries, holding either node type-wise or # edge type-wise data. def _recurse(value: Any) -> Any: if isinstance(value, Node): return { key: self.find_by_name(f'{value.name}__{key2str(key)}') for key in self.metadata[int(self.is_edge_level(value))] } elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list): return [_recurse(v) for v in value] elif isinstance(value, tuple): return tuple(_recurse(v) for v in value) else: return value if node.type is not None and isinstance(node.args[0], Node): output = node.args[0] Type = EdgeType if self.is_edge_level(output) else NodeType node.type = Dict[Type, node.type] else: node.type = None node.args = (_recurse(node.args[0]), )
def output(self, node: Node, target: Any, name: str): # Split the output to dictionaries, holding either node type-wise or # edge type-wise data. def _recurse(value: Any) -> Any: if isinstance(value, Node) and self.is_edge_level(value): self.graph.inserting_before(node) return self.graph.create_node( 'call_function', target=split_output, args=(value, self.find_by_name('edge_offset_dict')), name=f'{value.name}__split') pass elif isinstance(value, Node): self.graph.inserting_before(node) return self.graph.create_node( 'call_function', target=split_output, args=(value, self.find_by_name('node_offset_dict')), name=f'{value.name}__split') elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list): return [_recurse(v) for v in value] elif isinstance(value, tuple): return tuple(_recurse(v) for v in value) else: return value if node.type is not None and isinstance(node.args[0], Node): output = node.args[0] Type = EdgeType if self.is_edge_level(output) else NodeType node.type = Dict[Type, node.type] else: node.type = None node.args = (_recurse(node.args[0]), )
def placeholder(self, node: Node, target: Any, name: str): # Adds a `get` call to the input dictionary for every node-type or # edge-type. if node.type is not None: Type = EdgeType if self.is_edge_level(node) else NodeType node.type = Dict[Type, node.type] self.graph.inserting_after(node) for key in self.metadata[int(self.is_edge_level(node))]: out = self.graph.create_node('call_method', target='get', args=(node, key), name=f'{name}__{key2str(key)}') self.graph.inserting_after(out)
def placeholder(self, node: Node, target: Any, name: str): if node.type is not None: Type = EdgeType if self.is_edge_level(node) else NodeType node.type = Dict[Type, node.type] out = node # Create `node_offset_dict` and `edge_offset_dict` dictionaries in case # they are not yet initialized. These dictionaries hold the cumulated # sizes used to create a unified graph representation and to split the # output data. if self.is_edge_level(node) and not self._edge_offset_dict_initialized: self.graph.inserting_after(out) out = self.graph.create_node('call_function', target=get_edge_offset_dict, args=(node, self.edge_type2id), name='edge_offset_dict') self._edge_offset_dict_initialized = True elif not self._node_offset_dict_initialized: self.graph.inserting_after(out) out = self.graph.create_node('call_function', target=get_node_offset_dict, args=(node, self.node_type2id), name='node_offset_dict') self._node_offset_dict_initialized = True # Create a `edge_type` tensor used as input to `HeteroBasisConv`: if self.is_edge_level(node) and not self._edge_type_initialized: self.graph.inserting_after(out) out = self.graph.create_node('call_function', target=get_edge_type, args=(node, self.edge_type2id), name='edge_type') self._edge_type_initialized = True # Add `Linear` operation to align features to the same dimensionality: if name in self.in_channels: self.graph.inserting_after(out) out = self.graph.create_node('call_module', target=f'align_lin__{name}', args=(node, ), name=f'{name}__aligned') self._state[out.name] = self._state[name] lin = LinearAlign(self.metadata[int(self.is_edge_level(node))], self.in_channels[name]) setattr(self.module, f'align_lin__{name}', lin) # Perform grouping of type-wise values into a single tensor: if self.is_edge_level(node): self.graph.inserting_after(out) out = self.graph.create_node( 'call_function', target=group_edge_placeholder, args=(out if name in self.in_channels else node, self.edge_type2id, self.find_by_name('node_offset_dict')), name=f'{name}__grouped') self._state[out.name] = 'edge' else: self.graph.inserting_after(out) out = self.graph.create_node( 'call_function', target=group_node_placeholder, args=(out if name in self.in_channels else node, self.node_type2id), name=f'{name}__grouped') self._state[out.name] = 'node' self.replace_all_uses_with(node, out)