Esempio n. 1
0
def _change_node_to_module(
    node: Node,
    name: str,
    base_module: Module,
    new_module: Module,
    args: tuple,
) -> None:
    """Helper function to change an existing node to a module.

    The new module is registered in the base_module as a submodule.
    The attribute name is based on name{int}.
    The attributes of the node are changed so they point onto the new module.

    Args:
        node: existing node
        name: proposed name, real name is name{int}
        base_module: the module that should get new_module as a child
        new_module: the new module to register on the node and base_module
        args: arguments of the new node
    """
    new_name = _get_free_name(base_module, name)
    node.op = "call_module"
    node.target = new_name
    node.args = args
    setattr(base_module, new_name, new_module)
Esempio n. 2
0
    def output(self, node: Node, target: Any, name: str):
        # Replace the output by dictionaries, holding either node type-wise or
        # edge type-wise data.
        def _recurse(value: Any) -> Any:
            if isinstance(value, Node):
                return {
                    key: self.find_by_name(f'{value.name}__{key2str(key)}')
                    for key in self.metadata[int(self.is_edge_level(value))]
                }
            elif isinstance(value, dict):
                return {k: _recurse(v) for k, v in value.items()}
            elif isinstance(value, list):
                return [_recurse(v) for v in value]
            elif isinstance(value, tuple):
                return tuple(_recurse(v) for v in value)
            else:
                return value

        if node.type is not None and isinstance(node.args[0], Node):
            output = node.args[0]
            Type = EdgeType if self.is_edge_level(output) else NodeType
            node.type = Dict[Type, node.type]
        else:
            node.type = None

        node.args = (_recurse(node.args[0]), )
Esempio n. 3
0
    def run_node(self, node: Node):
        self.node_counter += 1
        result = super().run_node(node)
        node.meta['fake_result'] = result
        node.meta['node_idx'] = self.node_counter

        # (1) Update metadata with the list of nodes that are used by this node
        # copy_() doesn't read from its first argument; it writes to it, overwriting previous data.
        # We don't want to treat it as "being used as an input".
        node_args = node.args
        if node.target is torch.ops.aten.copy_.default:
            node_args = node_args[1:]

        # (2) Update metadata to track aliasing information about view tensor nodes.
        if node.op == 'call_function':
            view_type = _get_view_type(node.target)
            if view_type == _ViewType.SingleOutputView:
                assert isinstance(node.args[0], Node)
                node.meta['view_of'] = node.args[0]
            elif view_type == _ViewType.MultiOutputView:
                self.multi_output_view_nodes[node] = node.args[0]

            # Check if we returned a multi-output view,
            # and we're now grabbing the individual views from the output.
            #
            # For multi-output views, we want to map each output view to the base,
            # but this mapping involves two separate nodes in FX IR.
            # e.g. "a, b = x_1.split(...)" becomes:
            #    %split_tensor : [#users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {})
            #    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {})
            #    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {})
            # And we'd like to set:
            #    getitem1.meta['view_of'] = x_1
            elif node.target is _operator.getitem:
                list_arg = node.args[0]
                maybe_base_of_view = self.multi_output_view_nodes.get(
                    list_arg, None)
                if maybe_base_of_view is not None:
                    # Note: we could also track indexing info here for multi-output views.
                    # I don't think this metadata is strictly needed for de-functionalization.
                    assert isinstance(maybe_base_of_view, Node)
                    node.meta['view_of'] = maybe_base_of_view

        if 'view_of' in node.meta:
            # We're linking the current node with its first argument as views.
            # Assert here that this is actually the case, and their storages are the same.
            assert isinstance(node.meta['fake_result'], FakeTensor)
            assert isinstance(node.meta['view_of'].meta['fake_result'],
                              FakeTensor)
            view_storage = StorageWeakRef(node.meta['fake_result'].storage())
            base_storage = StorageWeakRef(
                node.meta['view_of'].meta['fake_result'].storage())
            assert view_storage == base_storage
        return result
    def output(self, node: Node, target: Any, name: str):
        # Split the output to dictionaries, holding either node type-wise or
        # edge type-wise data.
        def _recurse(value: Any) -> Any:
            if isinstance(value, Node) and self.is_edge_level(value):
                self.graph.inserting_before(node)
                return self.graph.create_node(
                    'call_function',
                    target=split_output,
                    args=(value, self.find_by_name('edge_offset_dict')),
                    name=f'{value.name}__split')

                pass
            elif isinstance(value, Node):
                self.graph.inserting_before(node)
                return self.graph.create_node(
                    'call_function',
                    target=split_output,
                    args=(value, self.find_by_name('node_offset_dict')),
                    name=f'{value.name}__split')

            elif isinstance(value, dict):
                return {k: _recurse(v) for k, v in value.items()}
            elif isinstance(value, list):
                return [_recurse(v) for v in value]
            elif isinstance(value, tuple):
                return tuple(_recurse(v) for v in value)
            else:
                return value

        if node.type is not None and isinstance(node.args[0], Node):
            output = node.args[0]
            Type = EdgeType if self.is_edge_level(output) else NodeType
            node.type = Dict[Type, node.type]
        else:
            node.type = None

        node.args = (_recurse(node.args[0]), )
    def placeholder(self, node: Node, target: Any, name: str):
        # Adds a `get` call to the input dictionary for every node-type or
        # edge-type.
        if node.type is not None:
            Type = EdgeType if self.is_edge_level(node) else NodeType
            node.type = Dict[Type, node.type]

        self.graph.inserting_after(node)
        for key in self.metadata[int(self.is_edge_level(node))]:
            out = self.graph.create_node('call_method',
                                         target='get',
                                         args=(node, key),
                                         name=f'{name}__{key2str(key)}')
            self.graph.inserting_after(out)
 def call_message_passing_module(self, node: Node, target: Any, name: str):
     # Call the `HeteroBasisConv` wrapper instead instead of a single
     # message passing layer. We need to inject the `edge_type` as first
     # argument in order to do so.
     node.args = (self.find_by_name('edge_type'), ) + node.args
    def placeholder(self, node: Node, target: Any, name: str):
        if node.type is not None:
            Type = EdgeType if self.is_edge_level(node) else NodeType
            node.type = Dict[Type, node.type]

        out = node

        # Create `node_offset_dict` and `edge_offset_dict` dictionaries in case
        # they are not yet initialized. These dictionaries hold the cumulated
        # sizes used to create a unified graph representation and to split the
        # output data.
        if self.is_edge_level(node) and not self._edge_offset_dict_initialized:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_function',
                                         target=get_edge_offset_dict,
                                         args=(node, self.edge_type2id),
                                         name='edge_offset_dict')
            self._edge_offset_dict_initialized = True

        elif not self._node_offset_dict_initialized:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_function',
                                         target=get_node_offset_dict,
                                         args=(node, self.node_type2id),
                                         name='node_offset_dict')
            self._node_offset_dict_initialized = True

        # Create a `edge_type` tensor used as input to `HeteroBasisConv`:
        if self.is_edge_level(node) and not self._edge_type_initialized:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_function',
                                         target=get_edge_type,
                                         args=(node, self.edge_type2id),
                                         name='edge_type')
            self._edge_type_initialized = True

        # Add `Linear` operation to align features to the same dimensionality:
        if name in self.in_channels:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_module',
                                         target=f'align_lin__{name}',
                                         args=(node, ),
                                         name=f'{name}__aligned')
            self._state[out.name] = self._state[name]

            lin = LinearAlign(self.metadata[int(self.is_edge_level(node))],
                              self.in_channels[name])
            setattr(self.module, f'align_lin__{name}', lin)

        # Perform grouping of type-wise values into a single tensor:
        if self.is_edge_level(node):
            self.graph.inserting_after(out)
            out = self.graph.create_node(
                'call_function',
                target=group_edge_placeholder,
                args=(out if name in self.in_channels else node,
                      self.edge_type2id,
                      self.find_by_name('node_offset_dict')),
                name=f'{name}__grouped')
            self._state[out.name] = 'edge'

        else:
            self.graph.inserting_after(out)
            out = self.graph.create_node(
                'call_function',
                target=group_node_placeholder,
                args=(out if name in self.in_channels else node,
                      self.node_type2id),
                name=f'{name}__grouped')
            self._state[out.name] = 'node'

        self.replace_all_uses_with(node, out)
Esempio n. 8
0
 def run_node(self, n: Node):
     result = super().run_node(n)
     n.meta['fake_result'] = result
     return result