예제 #1
0
def _insert_quantize_per_tensor_node(
    prev_node_c: Node,
    node_a: Node,
    gm_b: GraphModule,
    graph_c: Graph,
    scale: Union[torch.Tensor, float],
    zero_point: Union[torch.Tensor, int],
    dtype_cast_name: str,
) -> Node:
    # copy scale
    scale_node_name = \
        get_new_attr_name_with_prefix(
            node_a.name + '_input_scale_')(gm_b)
    setattr(gm_b, scale_node_name, scale)
    scale_node = graph_c.create_node('get_attr', scale_node_name, (), {},
                                     scale_node_name)
    # copy zero_point
    zero_point_node_name = \
        get_new_attr_name_with_prefix(
            node_a.name + '_input_zero_point_')(gm_b)
    setattr(gm_b, zero_point_node_name, zero_point)
    zero_point_node = graph_c.create_node('get_attr', zero_point_node_name, (),
                                          {}, zero_point_node_name)
    # create the quantize_per_tensor call
    return graph_c.create_node(
        'call_function', torch.quantize_per_tensor,
        (prev_node_c, scale_node, zero_point_node, torch.quint8), {},
        dtype_cast_name)
예제 #2
0
def _insert_dtype_cast_after_node(
    node_a: Node,
    node_c: Node,
    prev_node_c: Union[Node, List[Node]],
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
    node_name_prefix: str,
) -> Union[Node, List[Node]]:
    """
    Given a starting graph C (derived from graph B) of

    ... -> prev_node_c -> node_c -> ...

    And a corresponding related node_a, inserts the correct dtype
    cast node after prev_node_c to cast into the dtype expected
    by node_a, resulting in:

                          dtype_cast
                        /
    ... -> prev_node_c -> node_c -> ...

    For example, if node_c is an int8 op and node_a is an fp32 op, this function
    will insert a dequant.
    """
    dtype_cast_op = None
    node_input_type_a = get_node_input_type(node_a, gm_a)
    node_input_type_c = get_node_input_type(node_c, gm_b)

    if node_input_type_a == NodeInputType.FP32 and node_input_type_c == NodeInputType.INT8:
        dtype_cast_op = torch.dequantize
    else:
        raise AssertionError(
            f"dtype cast from {node_input_type_c} to {node_input_type_a} needs to be implemented"
        )

    if isinstance(prev_node_c, Node):
        new_dtype_cast_name = \
            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)

        return graph_c.create_node('call_function', dtype_cast_op,
                                   (prev_node_c, ), {}, new_dtype_cast_name)
    elif isinstance(prev_node_c, list):
        results = []
        for prev_node_c_inner in prev_node_c:
            new_dtype_cast_name = \
                get_new_attr_name_with_prefix(node_name_prefix)(gm_b)

            new_dtype_cast_node = graph_c.create_node('call_function',
                                                      dtype_cast_op,
                                                      (prev_node_c_inner, ),
                                                      {}, new_dtype_cast_name)
            results.append(new_dtype_cast_node)
        return results
    else:
        raise AssertionError(f"type f{type(prev_node_c)} is not handled")
예제 #3
0
def _insert_logger_after_node(
    node: Node,
    gm: GraphModule,
    logger_cls: Callable,
    logger_node_name_suffix: str,
    ref_node_name: str,
    model_name: str,
    ref_name: str,
    results_type: str,
    index_within_arg: int,
    index_of_arg: int,
) -> Node:
    """
    Given a starting graph of

    prev_node -> node -> next_node

    This function creates a new logger_cls obj and adds it
    after node, resulting in

    prev_node -> node -> logger_obj -> next_node
    """
    # create new name
    logger_node_name = \
        get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm)
    target_type = get_target_type_str(node, gm)
    # create the logger object
    logger_obj = logger_cls(ref_node_name, node.name, model_name, ref_name,
                            target_type, results_type, index_within_arg,
                            index_of_arg)
    # attach the logger object to the parent module
    setattr(gm, logger_node_name, logger_obj)
    logger_node = node.graph.create_node('call_module', logger_node_name,
                                         (node, ), {})
    return logger_node
예제 #4
0
def _copy_node_from_a_to_c(
    node_a: Node,
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
) -> Node:
    """
    Simple copy of node_a to graph_c.
    """
    if node_a.op == 'get_attr':
        node_a_copy_name = \
            get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
        node_a_obj = getattr_from_fqn(gm_a,
                                      node_a.target)  # type: ignore[arg-type]
        if torch.is_tensor(node_a_obj):
            node_a_obj = node_a_obj.detach()
        setattr(gm_b, node_a_copy_name, node_a_obj)
        node_a_copy = graph_c.create_node(node_a.op, node_a_copy_name, (), {},
                                          node_a_copy_name)
        return node_a_copy
    elif node_a.op == 'call_method':
        assert node_a.target in ('dequantize', 'to'), \
            f"target {node_a.target} is not implemented"
        if node_a.target == 'dequantize':
            arg_copy = _copy_node_from_a_to_c(
                node_a.args[0], gm_a, gm_b, graph_c)  # type: ignore[arg-type]
            node_a_copy_name = \
                get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
            node_a_copy = graph_c.create_node(node_a.op, node_a.target,
                                              (arg_copy, ), {},
                                              node_a_copy_name)
            return node_a_copy
        else:  # to
            arg_copy = _copy_node_from_a_to_c(
                node_a.args[0], gm_a, gm_b, graph_c)  # type: ignore[arg-type]
            node_a_copy_name = \
                get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
            node_a_copy = graph_c.create_node(node_a.op, node_a.target,
                                              (arg_copy, node_a.args[1]), {},
                                              node_a_copy_name)
            return node_a_copy

    else:
        raise AssertionError(
            f"handling of node with op {node_a.op} is not implemented")
예제 #5
0
def _insert_logger_after_node(
    node: Node,
    gm: GraphModule,
    logger_cls: Callable,
    logger_node_name_suffix: str,
    ref_node_name: str,
    model_name: str,
    ref_name: str,
    results_type: str,
    index_within_arg: int,
) -> Node:
    """
    Given a starting graph of

    prev_node -> node -> next_node

    This function creates a new logger_cls obj and adds it
    after node, resulting in

    prev_node -> node -> logger_obj -> next_node
    """
    # create new name
    logger_node_name = \
        get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm)
    # print('node.name', node.name, 'suffix', logger_node_name_suffix, 'new name', logger_node_name)
    # create a string representation of the node's target type
    target_type = ''
    if node.op == 'call_function':
        target_type = str(node.target)
    elif node.op == 'call_module':
        assert isinstance(node.target, str)
        target_mod = getattr_from_fqn(gm, node.target)
        target_type = str(type(target_mod))
    # create the logger object
    logger_obj = logger_cls(
        ref_node_name, node.name, model_name, ref_name, target_type,
        results_type, index_within_arg)
    # attach the logger object to the parent module
    setattr(gm, logger_node_name, logger_obj)
    logger_node = node.graph.create_node(
        'call_module', logger_node_name, (node,), {})
    return logger_node
예제 #6
0
def _insert_dtype_cast_after_node(
    node_a: Node,
    node_c: Node,
    prev_node_c: Node,
    gm_a: GraphModule,
    gm_b: GraphModule,
    node_name_prefix: str,
) -> Node:
    """
    Given a starting graph C (derived from graph B) of

    ... -> prev_node_c -> node_c -> ...

    And a corresponding related node_a, inserts the correct dtype
    cast node after prev_node_c to cast into the dtype expected
    by node_a, resulting in:

                          dtype_cast
                        /
    ... -> prev_node_c -> node_c -> ...

    For example, if node_c is an int8 op and node_a is an fp32 op, this function
    will insert a dequant.
    """
    dtype_cast_op = None
    node_io_type_a = get_node_io_type(node_a, gm_a)
    node_io_type_c = get_node_io_type(node_c, gm_b)

    if node_io_type_a == NodeIOType.FP32 and node_io_type_c == NodeIOType.INT8:
        dtype_cast_op = torch.dequantize
    else:
        raise AssertionError(
            f"dtype cast from {node_io_type_c} to {node_io_type_a} needs to be implemented"
        )

    new_dtype_cast_name = \
        get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
    return prev_node_c.graph.create_node('call_function', dtype_cast_op,
                                         (prev_node_c, ), {},
                                         new_dtype_cast_name)
예제 #7
0
def _insert_copy_of_node_a_after_input_node_c(
    input_node_c: Union[Node, List[Node]],
    input_node_c_2: Optional[Union[Node, List[Node]]],
    node_a: Node,
    gm_a: GraphModule,
    gm_b: GraphModule,
    node_name_prefix: str,
) -> Node:
    """
    Assume that node_a from graph_a has
      args (input, (input2)?, arg1, ...), and
      kwargs {kw0: kwarg0, ...}

    Note: input2 is optional. If it equals to None, we assume that the op
    has a single non-param input.  If it is specified, we assume that the op
    has two non-param inputs.

    Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
    and creates the corresponding nodes in graph_c. Note: observers are ignored,
    so if an arg is an observer we navigate up until we find a non-observer parent.

    If node_a is a call_module, points the module pointed to by node_a to gm_b.

    Creates the copy of node_a in graph_c, with input as the first arg,
    and all other args and kwargs pointing to the copies of the objects
    in gm_b created above.

    An example in pictures:

    graph A:
    ========

    input -------------> node_a
                         / / /
    (input_2)?----------/ / /
                         / /
    weight -> weight_obs  /
                         /
    bias ----------------

    graph C (derived from B):
    =========================

    input_node_c --> node_a_copy
                     / / /
    (input_node_c_2)? / /
                     / /
    weight_copy ----/ /
                     /
    bias_copy ------/
    """
    if isinstance(input_node_c, Node):
        graph_c = input_node_c.graph
    else:
        assert isinstance(input_node_c, list)
        graph_c = input_node_c[0].graph

    # generically handle all args and kwargs except for the input
    # Note: this hasn't been tested with many ops, logic may change.
    new_args: List[Any] = []
    # assumes that the first arg is the input
    num_non_param_args = 1 if input_node_c_2 is None else 2
    for node_a_arg in node_a.args[num_non_param_args:]:
        if isinstance(node_a_arg, Node):
            arg_a = return_first_non_observer_node(node_a_arg, gm_a)
            node_a_arg_copy = _copy_node_from_a_to_c(arg_a, gm_a, gm_b,
                                                     graph_c)
            new_args.append(node_a_arg_copy)
        elif isinstance(node_a_arg, (int, float)):
            new_args.append(node_a_arg)
        elif isinstance(node_a_arg, (list, tuple)):
            for el in node_a_arg:
                assert not isinstance(el, Node), \
                    "handling of Node inside list is not implemented"
            new_args.append(node_a_arg)
        else:
            raise AssertionError(
                f"handling for arg of type {type(node_a_arg)} is not implemented"
            )

    new_kwargs: Dict[str, Any] = {}
    for node_a_k, node_a_kwarg in node_a.kwargs.items():
        if isinstance(node_a_kwarg, Node):
            kwarg_a = return_first_non_observer_node(node_a_kwarg, gm_a)
            node_a_kwarg_copy = _copy_node_from_a_to_c(kwarg_a, gm_a, gm_b,
                                                       graph_c)
            new_kwargs[node_a_k] = node_a_kwarg_copy
        else:
            new_kwargs[node_a_k] = node_a_kwarg

    node_a_shadows_c_name = \
        get_new_attr_name_with_prefix(node_name_prefix)(gm_b)

    if input_node_c_2:
        input_node_c_args = [input_node_c, input_node_c_2]
    else:
        input_node_c_args = [input_node_c]

    if node_a.op == 'call_module':
        # if target is a module, we point to the module from gm_b
        new_mod_copy_name = \
            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
        # fetch the corresponding module from gm_a
        assert isinstance(node_a.target, str)
        mod_a = getattr_from_fqn(gm_a, node_a.target)
        setattr(gm_b, new_mod_copy_name, mod_a)
        node_a_shadows_c = graph_c.create_node(node_a.op, new_mod_copy_name,
                                               (*input_node_c_args, *new_args),
                                               new_kwargs,
                                               node_a_shadows_c_name)
        return node_a_shadows_c
    else:
        assert node_a.op in ('call_function', 'call_method')
        node_a_shadows_c = graph_c.create_node(node_a.op, node_a.target,
                                               (*input_node_c_args, *new_args),
                                               new_kwargs,
                                               node_a_shadows_c_name)
        return node_a_shadows_c
예제 #8
0
def _insert_dtype_cast_after_node(
    node_a: Node,
    node_c: Node,
    prev_node_c: Union[Node, List[Node]],
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
    node_name_prefix: str,
    logger_cls: Callable,
    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Union[Node, List[Node]]:
    """
    Given a starting graph C (derived from graph B) of

    ... -> prev_node_c -> node_c -> ...

    And a corresponding related node_a, inserts the correct dtype
    cast node after prev_node_c to cast into the dtype expected
    by node_a, resulting in:

                          dtype_cast
                        /
    ... -> prev_node_c -> node_c -> ...

    For example, if node_c is an int8 op and node_a is an fp32 op, this function
    will insert a dequant.
    """
    dtype_cast_op = None
    dtype_cast_mod_cls = None
    dtype_cast_scale = None
    dtype_cast_zero_point = None
    node_input_type_a, _node_output_type_a = \
        get_node_first_input_and_output_type(
            node_a, gm_a, logger_cls, node_type_to_io_type_map)
    node_input_type_c, _node_output_type_c = \
        get_node_first_input_and_output_type(
            node_c, gm_b, logger_cls, node_type_to_io_type_map)

    if ((node_input_type_a == NodeInputOrOutputType.FP32
         and node_input_type_c == NodeInputOrOutputType.INT8)
            or (node_input_type_a == NodeInputOrOutputType.FP32
                and node_input_type_c == NodeInputOrOutputType.FP16) or
            # TODO(future PR): determine the actual dtype of node_c,
            # the current code only works because dequantize works with
            # multiple input dtypes.
        (node_input_type_a == NodeInputOrOutputType.FP32
         and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)):
        dtype_cast_op = torch.dequantize
    elif (node_input_type_a == node_input_type_c
          and node_input_type_a != NodeInputOrOutputType.UNKNOWN):
        dtype_cast_mod_cls = torch.nn.Identity
    elif (node_input_type_a == NodeInputOrOutputType.INT8
          and node_input_type_c == NodeInputOrOutputType.FP32):
        # int8 shadows fp32, the dtype cast needs to quantize to int8
        # with the right qparams.
        node_a_input_qparams = get_node_input_qparams(
            node_a, gm_a, node_type_to_io_type_map)
        if node_a_input_qparams is not None:
            dtype_cast_op = torch.quantize_per_tensor  # type: ignore[assignment]
            dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
    else:
        raise AssertionError(
            f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
            f"{node_input_type_a} {node_a.format_node()} needs to be implemented"
        )

    if isinstance(prev_node_c, Node):
        new_dtype_cast_name = \
            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
        if dtype_cast_op:
            if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
                return _insert_quantize_per_tensor_node(
                    prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale,
                    dtype_cast_zero_point, new_dtype_cast_name)
            else:
                return graph_c.create_node('call_function', dtype_cast_op,
                                           (prev_node_c, ), {},
                                           new_dtype_cast_name)
        else:
            assert dtype_cast_mod_cls
            dtype_cast_mod = dtype_cast_mod_cls()
            setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
            return graph_c.create_node('call_module', new_dtype_cast_name,
                                       (prev_node_c, ), {},
                                       new_dtype_cast_name)
    elif isinstance(prev_node_c, list):
        results = []
        for prev_node_c_inner in prev_node_c:
            new_dtype_cast_name = \
                get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
            if dtype_cast_op:
                # TODO(future PR): add handling for quantize_per_tensor
                new_dtype_cast_node = graph_c.create_node(
                    'call_function', dtype_cast_op, (prev_node_c_inner, ), {},
                    new_dtype_cast_name)
                results.append(new_dtype_cast_node)
            else:
                assert dtype_cast_mod_cls
                dtype_cast_mod = dtype_cast_mod_cls()
                setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
                new_dtype_cast_node = graph_c.create_node(
                    'call_module', new_dtype_cast_name, (prev_node_c_inner, ),
                    {}, new_dtype_cast_name)
                results.append(new_dtype_cast_node)
        return results
    else:
        raise AssertionError(f"type f{type(prev_node_c)} is not handled")
예제 #9
0
def _insert_dtype_cast_after_node(
    node_a: Node,
    node_c: Node,
    prev_node_c: Union[Node, List[Node]],
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
    node_name_prefix: str,
    logger_cls: Callable,
) -> Union[Node, List[Node]]:
    """
    Given a starting graph C (derived from graph B) of

    ... -> prev_node_c -> node_c -> ...

    And a corresponding related node_a, inserts the correct dtype
    cast node after prev_node_c to cast into the dtype expected
    by node_a, resulting in:

                          dtype_cast
                        /
    ... -> prev_node_c -> node_c -> ...

    For example, if node_c is an int8 op and node_a is an fp32 op, this function
    will insert a dequant.
    """
    dtype_cast_op = None
    dtype_cast_mod_cls = None
    node_input_type_a, _node_output_type_a = \
        get_node_first_input_and_output_type(node_a, gm_a, logger_cls)
    node_input_type_c, _node_output_type_c = \
        get_node_first_input_and_output_type(node_c, gm_b, logger_cls)

    if ((node_input_type_a == NodeInputOrOutputType.FP32
         and node_input_type_c == NodeInputOrOutputType.INT8)
            or (node_input_type_a == NodeInputOrOutputType.FP32
                and node_input_type_c == NodeInputOrOutputType.FP16)):
        dtype_cast_op = torch.dequantize
    elif (node_input_type_a == NodeInputOrOutputType.FP32
          and node_input_type_c == NodeInputOrOutputType.FP32):
        dtype_cast_mod_cls = torch.nn.Identity
    elif (node_input_type_a == NodeInputOrOutputType.INT8
          and node_input_type_c == NodeInputOrOutputType.INT8):
        dtype_cast_mod_cls = torch.nn.Identity
    elif (node_input_type_a == NodeInputOrOutputType.FP32_OR_INT8
          and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8):
        dtype_cast_mod_cls = torch.nn.Identity
    else:
        raise AssertionError(
            f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
            f"{node_input_type_a} {node_a.format_node()} needs to be implemented"
        )

    if isinstance(prev_node_c, Node):
        new_dtype_cast_name = \
            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
        if dtype_cast_op:
            return graph_c.create_node('call_function', dtype_cast_op,
                                       (prev_node_c, ), {},
                                       new_dtype_cast_name)
        else:
            assert dtype_cast_mod_cls
            dtype_cast_mod = dtype_cast_mod_cls()
            setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
            return graph_c.create_node('call_module', new_dtype_cast_name,
                                       (prev_node_c, ), {},
                                       new_dtype_cast_name)
    elif isinstance(prev_node_c, list):
        results = []
        for prev_node_c_inner in prev_node_c:
            new_dtype_cast_name = \
                get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
            if dtype_cast_op:
                new_dtype_cast_node = graph_c.create_node(
                    'call_function', dtype_cast_op, (prev_node_c_inner, ), {},
                    new_dtype_cast_name)
                results.append(new_dtype_cast_node)
            else:
                assert dtype_cast_mod_cls
                dtype_cast_mod = dtype_cast_mod_cls()
                setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
                new_dtype_cast_node = graph_c.create_node(
                    'call_module', new_dtype_cast_name, (prev_node_c, ), {},
                    new_dtype_cast_name)
                results.append(new_dtype_cast_node)
        return results
    else:
        raise AssertionError(f"type f{type(prev_node_c)} is not handled")
예제 #10
0
def _insert_copy_of_node_a_after_input_node_c(
    input_node_c: Node,
    node_a: Node,
    gm_a: GraphModule,
    gm_b: GraphModule,
    node_name_prefix: str,
) -> Node:
    """
    Assume that node_a from graph_a has
      args (input, arg1, ...), and
      kwargs {kw0: kwarg0, ...}

    Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
    and creates the corresponding nodes in graph_c. Note: observers are ignored,
    so if an arg is an observer we navigate up until we find a non-observer parent.

    If node_a is a call_module, points the module pointed to by node_a to gm_b.

    Creates the copy of node_a in graph_c, with input as the first arg,
    and all other args and kwargs pointing to the copies of the objects
    in gm_b created above.

    An example in pictures:

    graph A:
    ========

    input -------------> node_a
                         / /
    weight -> weight_obs  /
                         /
    bias ----------------

    graph C (derived from B):
    =========================

    input_node_c --> node_a_copy
                     / /
    weight_copy ----/ /
                     /
    bias_copy ------/
    """
    graph_c = input_node_c.graph

    # generically handle all args and kwargs except for the input
    # Note: this hasn't been tested with many ops, logic may change.
    new_args = []
    # assumes that the first arg is the input
    for node_a_arg in node_a.args[1:]:
        if isinstance(node_a_arg, Node):
            arg_a = return_first_non_observer_node(node_a_arg, gm_a)
            arg_a_copy_name = \
                get_new_attr_name_with_prefix(arg_a.name + '_shadow_copy_')(gm_b)  # type: ignore
            arg_a_obj = getattr_from_fqn(gm_a, arg_a.target)  # type: ignore
            setattr(gm_b, arg_a_copy_name, arg_a_obj.detach())
            node_a_arg_copy = graph_c.create_node('get_attr', arg_a_copy_name,
                                                  (), {}, arg_a_copy_name)
            new_args.append(node_a_arg_copy)
        else:
            raise AssertionError(
                f"handling for arg of type {type(node_a_arg)} is not implemented"
            )

    new_kwargs: Dict[str, Any] = {}
    for node_a_k, node_a_kwarg in node_a.kwargs.items():
        if isinstance(node_a_kwarg, Node):
            kwarg_a_copy_name = \
                get_new_attr_name_with_prefix(node_a_kwarg.name + '_shadow_copy_')(gm_b)  # type: ignore
            kwarg_a_obj = getattr_from_fqn(gm_a,
                                           node_a_kwarg.target)  # type: ignore
            setattr(gm_b, kwarg_a_copy_name, kwarg_a_obj.detach())
            node_a_kwarg_copy = graph_c.create_node('get_attr',
                                                    kwarg_a_copy_name, (), {},
                                                    kwarg_a_copy_name)
            new_kwargs[node_a_k] = node_a_kwarg_copy
        else:
            new_kwargs[node_a_k] = node_a_kwarg

    node_a_shadows_c_name = \
        get_new_attr_name_with_prefix(node_name_prefix)(gm_b)

    if node_a.op == 'call_module':
        # if target is a module, we point to the module from gm_b
        new_mod_copy_name = \
            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
        # fetch the corresponding module from gm_a
        assert isinstance(node_a.target, str)
        mod_a = getattr_from_fqn(gm_a, node_a.target)
        setattr(gm_b, new_mod_copy_name, mod_a)
        node_a_shadows_c = graph_c.create_node(
            node_a.op, new_mod_copy_name, (input_node_c, *new_args),
            new_kwargs, node_a_shadows_c_name)  # type: ignore
        return node_a_shadows_c
    else:
        assert node_a.op == 'call_function'
        node_a_shadows_c = graph_c.create_node(
            node_a.op, node_a.target, (input_node_c, *new_args), new_kwargs,
            node_a_shadows_c_name)  # type: ignore
        return node_a_shadows_c