Exemplo n.º 1
0
 def test_graph_fns(self):
     g = Graph()
     a = g.placeholder('a')
     b = g.call_module('linear', (a, ))
     c = g.get_attr('bias')
     d = g.call_method('add', (b, c))
     e = g.call_function(torch.sin, (d, ))
     g.output(e)
     mod = torch.nn.Module()
     mod.linear = torch.nn.Linear(3, 4)
     mod.bias = torch.rand(4)
     gm = GraphModule(mod, g)
     gm.graph.lint(gm)
     input = torch.rand(3)
     r = gm(input)
     ref = torch.sin(mod.linear(input) + mod.bias)
     self.assertEqual(r, ref)
Exemplo n.º 2
0
 def test_graph_unique_names(self):
     class M(torch.nn.Module):
         def forward(self, a, b):
             return a + b
     m = M()
     g = symbolic_trace(m).graph
     new_g = torch.fx.Graph()
     val_map : Dict[Node, Node] = {}
     output_val = new_g.graph_copy(g, val_map)
     t = Proxy(output_val)
     # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
     new_g.output((t + t).node)
     gm = GraphModule(m, new_g)
     seen_names : Set[str] = set()
     for node in gm.graph.nodes:
         assert node.name not in seen_names
         seen_names.add(node.name)
Exemplo n.º 3
0
def _prepare_fx(model,
                qconfig_dict,
                inplace,
                prepare_custom_config_dict=None,
                is_standalone_module=False):
    r""" Internal helper function for prepare_fx
    Args:
      `model`, `qconfig_dict`, `inplace` `prepare_custom_config_dict`: see docs for :func:`~torch.quantization.prepare_fx`
      `is_standalone_module`: a boolean flag indicates whether we are
      quantizing a standalone module or not, a standalone module
      is a submodule of the parent module that is not inlined in the
forward graph of the parent module,
      the way we quantize standalone module is described in:
      :func:`~torch.quantization._prepare_standalone_module_fx`
    """
    if prepare_custom_config_dict is None:
        prepare_custom_config_dict = {}

    skipped_module_names = prepare_custom_config_dict.get(
        "non_traceable_module_name", [])
    skipped_module_classes = prepare_custom_config_dict.get(
        "non_traceable_module_class", [])

    # swap FloatFunctional with FXFloatFunctional
    _swap_ff_with_fxff(model)

    # symbolically trace the model
    if not is_standalone_module:
        # standalone module and custom module config are applied in top level module
        standalone_module_names = prepare_custom_config_dict.get(
            'standalone_module_name', [])
        skipped_module_names += standalone_module_names
        custom_module_config = prepare_custom_config_dict.get(
            'float_to_observed_custom_module_class', {})
        custom_module_classes = list(custom_module_config.keys())
        skipped_module_classes += custom_module_classes
    tracer = CustomTracer(skipped_module_names, skipped_module_classes)
    graph_module = GraphModule(model, tracer.trace(model))
    graph_module = _fuse_fx(graph_module, inplace)
    quantizer = Quantizer()
    return quantizer.prepare(
        graph_module,
        qconfig_dict,
        inplace=True,
        prepare_custom_config_dict=prepare_custom_config_dict,
        is_standalone_module=is_standalone_module)
Exemplo n.º 4
0
    def quantize(self):
        self.quantized_graph = Graph()

        env = {}
        quant_env = {}

        def load_arg(n, quantized):
            if not quantized:
                if n.name not in env and n.name in quant_env:
                    env[n.name] = Proxy(quant_env[n.name]).dequantize().node
                return env[n.name]
            else:
                if n.name not in quant_env and n.name in env:
                    quant_env[n.name] = self.quants[n.name].quantize(
                        env[n.name])
                return quant_env[n.name]

        def copy_recursive(node):
            def load_or_emit(n):
                if n.name in env or e.name in quant_env:
                    return load_arg(n, quantized=False)
                else:
                    return copy_recusive(n)

            r = env[node.name] = self.quantized_graph.node_copy(
                node, lambda n: load_arg(n, quantized=False))
            return r

        for node in self.graph.nodes:
            root_node, obj = self.matches.get(node.name, (None, None))
            if root_node is None:
                # not quantized just copy it
                env[node.name] = self.quantized_graph.node_copy(
                    node, lambda n: load_arg(n, quantized=False))

            elif root_node is node:
                r = obj.quantize(
                    self, node, lambda a: map_arg(
                        a, lambda n: load_arg(n, quantized=True)))
                if r is NotImplemented:
                    # quantizer choose to to quantize the node take the entire match, and just copy it over
                    env[node.name] = copy_recursive(node)
                else:
                    quant_env[node.name] = r

        return GraphModule(self.root, self.quantized_graph)
Exemplo n.º 5
0
    def _fold_weight(self, quantized):
        packed_weights = dict()
        # map from folded node name to the prepacked weight name
        folded_nodes = dict()
        # get packed weights
        for node in quantized.graph.nodes:
            if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
                nodes_to_fold = collect_producer_nodes(node)
                if nodes_to_fold is not None:
                    for node_to_fold in nodes_to_fold:
                        folded_nodes[node_to_fold.name] = node

                    prepacking_module = graph_module_from_producer_nodes(
                        quantized, nodes_to_fold)
                    packed_weight = prepacking_module()
                    packed_weights[node.name] = packed_weight

        # remove folded nodes and replace the prepacking node with getattr
        folded_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        get_new_packed_weight_name = get_new_attr_name_with_prefix(
            '_fx_pass_packed_weight_')
        quantized_root = quantized
        quantized_graph = quantized.graph
        for node in quantized_graph.nodes:
            prepack_node = folded_nodes.get(node.name, None)
            if prepack_node is node:
                packed_weight = packed_weights[node.name]
                # add a prepacked attribute to root
                packed_weight_name = get_new_packed_weight_name(quantized_root)
                setattr(quantized_root, packed_weight_name, packed_weight)
                # replace prepack node with a getattr node
                env[node.name] = folded_graph.create_node(
                    'get_attr', packed_weight_name, (), {})
            elif prepack_node is not None:
                # remove the foled node
                continue
            else:
                # copy other nodes
                env[node.name] = folded_graph.node_copy(node, load_arg)
        quantized = GraphModule(quantized_root, folded_graph)
        return quantized
Exemplo n.º 6
0
def _transform_remove_duplicates(module: GraphModule, debug: bool) -> GraphModule:
    """Removes duplicate modules by creating a copy of the module.

    This is necessary because BackPACK saves input/output which is overwritten
    if the module is called multiple times.

    Args:
        module: container module to transform
        debug: whether to print debug messages

    Returns:
        equivalent transformed module

    Raises:
        NotImplementedError: if a duplicate module has parameters
    """
    if debug:
        print("\tBegin transformation: remove duplicates")

    graph: Graph = BackpackTracer().trace(module)

    targets = [n.target for n in graph.nodes]
    duplicates = {t for t in targets if targets.count(t) > 1}
    nodes = [n for n in graph.nodes if n.target in duplicates]

    for node in nodes:
        target = node.target
        original_module = module.get_submodule(target)

        for _ in original_module.parameters():
            raise NotImplementedError(
                f"Cycle with parameters detected: module {original_module} with target"
                f" {target} has parameters and is used {targets.count(target)} times."
            )

        new_module = deepcopy(original_module)
        new_target = _get_free_name(module, target)
        module.add_submodule(new_target, new_module)
        node.target = new_target

    graph.lint()

    if debug:
        print(f"\tDuplicates removed: {len(nodes)}")

    return GraphModule(module, graph)
Exemplo n.º 7
0
    def test_type_check_conv2D_maxpool2d_flatten(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self):
                super(BasicBlock, self).__init__()

                self.conv1 = torch.nn.Conv2d(3, 6, 5)
                self.pool = torch.nn.MaxPool2d(2, 2)
                self.conv2 = torch.nn.Conv2d(6, 16, 5)
                self.fc1 = torch.nn.Linear(5, 120)
                self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7))

            def forward(self, x: TensorType((4, 3, 32, 32))):
                out = self.conv1(x)
                out = self.pool(out)
                out = self.conv2(out)
                out = self.pool(out)
                out = self.fc1(out)
                out = self.pool2(out)
                out = torch.flatten(out, 1)
                return out

        B = BasicBlock()
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)
        tc.type_check()

        expected_ph_types = [
            TensorType((4, 3, 32, 32)),
            TensorType((4, 6, 28, 28)),
            TensorType((4, 6, 14, 14)),
            TensorType((4, 16, 10, 10)),
            TensorType((4, 16, 5, 5)),
            TensorType((4, 16, 5, 120)),
            TensorType((4, 16, 6, 7)),
            TensorType((4, 672)),
            TensorType((4, 672))
        ]

        expected_iter = iter(expected_ph_types)
        traced.graph.eliminate_dead_code()

        for n in traced.graph.nodes:
            assert n.type == next(expected_iter)
Exemplo n.º 8
0
def remove_observers_add_loggers(
    gm: GraphModule,
    node_to_instrument_to_ref_node_name: Dict[Node, Optional[str]],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, removes all observers, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(node.args[0], load_arg))
            continue

        if node.op == 'call_module' and is_activation_post_process(
                modules[node.target]):
            # remove activation post process node
            env[node.name] = env[node.args[0].name]

        elif node in node_to_instrument_to_ref_node_name:
            other_node_name = node_to_instrument_to_ref_node_name[node]
            # ensure env is populated with base node
            env[node.name] = new_graph.node_copy(node, load_arg)
            # add the logger after the base node
            env[node.name] = _insert_logger_after_node(env[node.name], gm,
                                                       logger_cls,
                                                       '_ns_logger_',
                                                       model_name,
                                                       other_node_name)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm
Exemplo n.º 9
0
def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any,
                prepare_custom_config_dict: Dict[str, Any] = None,
                is_standalone_module: bool = False) -> GraphModule:
    r""" Internal helper function for prepare_fx
    Args:
      `model`, `qconfig_dict`, `prepare_custom_config_dict`: see docs for :func:`~torch.quantization.prepare_fx`
      `is_standalone_module`: a boolean flag indicates whether we are
      quantizing a standalone module or not, a standalone module
      is a submodule of the parent module that is not inlined in the
forward graph of the parent module,
      the way we quantize standalone module is described in:
      :func:`~torch.quantization._prepare_standalone_module_fx`
    """
    if prepare_custom_config_dict is None:
        prepare_custom_config_dict = {}

    skipped_module_names = prepare_custom_config_dict.get("non_traceable_module_name", [])
    skipped_module_classes = prepare_custom_config_dict.get("non_traceable_module_class", [])

    # swap FloatFunctional with FXFloatFunctional
    _swap_ff_with_fxff(model)

    # symbolically trace the model
    if not is_standalone_module:
        # standalone module and custom module config are applied in top level module
        standalone_module_name_configs = prepare_custom_config_dict.get("standalone_module_name", [])
        skipped_module_names += [config[0] for config in standalone_module_name_configs]

        standalone_module_class_configs = prepare_custom_config_dict.get("standalone_module_class", [])
        skipped_module_classes += [config[0] for config in standalone_module_class_configs]
        float_custom_module_classes = get_custom_module_class_keys(
            prepare_custom_config_dict, "float_to_observed_custom_module_class")
        skipped_module_classes += float_custom_module_classes
    tracer = CustomTracer(skipped_module_names, skipped_module_classes)
    graph_module = GraphModule(model, tracer.trace(model))
    graph_module = _fuse_fx(graph_module, prepare_custom_config_dict)
    quantizer = Quantizer()
    return quantizer.prepare(
        graph_module,
        qconfig_dict,
        prepare_custom_config_dict=prepare_custom_config_dict,
        is_standalone_module=is_standalone_module)
Exemplo n.º 10
0
    def test_type_typechecl_maxpool2d_3dinput(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self):
                super(BasicBlock, self).__init__()
                self.pool = torch.nn.MaxPool2d(5, 8)

            def forward(self, x: TensorType((64, 8, 8))):
                out = self.pool(x)
                return out

        B = BasicBlock()
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)
        tc.type_check()

        for n in traced.graph.nodes:
            if n.target == 'output':
                assert n.type == TensorType((64, 1, 1))
Exemplo n.º 11
0
    def test_type_check_batch_norm_2D_false(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self, inplanes, planes):
                super(BasicBlock, self).__init__()
                norm_layer = torch.nn.BatchNorm2d
                self.bn1 = norm_layer(planes)

            def forward(self, x: TensorType((2, 2, 5))):
                identity = x
                out: TensorType((2, 2, Dyn, 4)) = self.bn1(x)
                out += identity
                return out

        B = BasicBlock(2, 2)
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)
        with self.assertRaises(TypeError):
            tc.type_check()
Exemplo n.º 12
0
def _prepare_fx(model,
                qconfig_dict,
                inplace,
                prepare_custom_config_dict=None,
                is_standalone_module=False):
    r""" Internal helper function for prepare_fx
    Args:
      `model`, `qconfig_dict`, `inplace` `prepare_custom_config_dict`: see docs for :func:`~torch.quantization.prepare_fx`
      `is_standalone_module`: a boolean flag indicates whether we are
      quantizing a standalone module or not, a standalone module
      is a submodule of the parent module that is not inlined in the
forward graph of the parent module,
      the way we quantize standalone module is described in:
      :func:`~torch.quantization._prepare_standalone_module_fx`
    """
    if prepare_custom_config_dict is None:
        prepare_custom_config_dict = {}

    # symbolically trace the model
    if is_standalone_module:
        # standlone module is traced before quantizing standalone modules
        graph_module = symbolic_trace(model)
    else:
        standalone_modules = prepare_custom_config_dict.get(
            'standalone_module_name', [])
        custom_module_config = qconfig_dict.get('custom_module_class', [])
        custom_module_classes = [config[0] for config in custom_module_config]
        # TODO: currently we are registering classes globally,
        # we want to make custom module class mapping local
        _register_custom_module_class(custom_module_config)
        # skipping tracing standalone modules when tracing top level module
        tracer = CustomTracer(standalone_modules, custom_module_classes)
        graph_module = GraphModule(model, tracer.trace(model))
    graph_module = _fuse_fx(graph_module, inplace)
    quantizer = Quantizer()
    return quantizer.prepare(
        graph_module,
        qconfig_dict,
        inplace=True,
        prepare_custom_config_dict=prepare_custom_config_dict,
        is_standalone_module=is_standalone_module)
Exemplo n.º 13
0
def _transform_lstm_rnn(module: Module, debug: bool) -> GraphModule:
    """Transforms multi-layer RNN/LSTM to Sequential of single-layer RNN/LSTM.

    Converts multi-layer RNN/LSTM to Sequential with single-layer RNN/LSTM.
    If dropout probability is nonzero, creates intermediate dropout layers.
    Finally, copies training mode.

    Args:
        module: container module to transform
        debug: whether to print debug messages

    Returns:
        equivalent transformed module

    Raises:
        NotImplementedError: if initial hidden state is used in forward pass
    """
    if debug:
        print("\tBegin transformation: LSTM, RNN")
    graph: Graph = BackpackTracer().trace(module)

    nodes = [
        n
        for n in graph.nodes
        if n.op == "call_module"
        and isinstance(module.get_submodule(n.target), (RNN, LSTM))
        and module.get_submodule(n.target).num_layers > 1
    ]
    for node in nodes:
        if len(node.args) > 1:
            raise NotImplementedError(
                "For conversion, LSTM/RNN input must not have hidden states."
            )
        lstm_module_replace = _make_rnn_backpack(module.get_submodule(node.target))
        module.add_module(node.target, lstm_module_replace)

    graph.lint()
    if debug:
        print(f"\tRNNs, LSTMs transformed: {len(nodes)}")
    return GraphModule(module, graph)
Exemplo n.º 14
0
def graph_module_from_producer_nodes(root, producer_nodes):
    r''' Construct a graph module from extracted producer nodes
    from `collect_producer_nodes` function
    Args:
      root: the root module for the original graph
      producer_nodes: a list of nodes we use to construct the graph
    Return:
      A graph module constructed from the producer nodes
    '''
    assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
    # since we traced back from node to getattrr
    producer_nodes.reverse()
    graph = Graph()
    env = {}

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])
    for producer_node in producer_nodes:
        env[producer_node.name] = graph.node_copy(producer_node, load_arg)
    graph.output(load_arg(producer_nodes[-1].name))
    graph_module = GraphModule(root, graph)
    return graph_module
Exemplo n.º 15
0
def _transform_transpose_to_module(module: Module, debug: bool) -> GraphModule:
    """Transforms transpose function or method to Permute module.

    The Permute module is initialized with transpose parameters and computes
    the permutation on its first forward pass.

    Args:
        module: container module to transform
        debug: whether to print debug messages

    Returns:
        equivalent transformed module
    """
    target_function = "<built-in method transpose"
    target_method = "transpose"
    if debug:
        print(f"\tBegin transformation: {target_method} -> Permute")
    graph: Graph = BackpackTracer().trace(module)

    nodes = [
        n
        for n in graph.nodes
        if (n.op == "call_function" and target_function in str(n.target))
        or (n.op == "call_method" and target_method == str(n.target))
    ]

    for node in nodes:
        _change_node_to_module(
            node,
            "permute",
            module,
            Permute(*node.args[1:], init_transpose=True),
            (node.args[0],),
        )

    graph.lint()
    if debug:
        print(f"\tPermute transformed: {len(nodes)}")
    return GraphModule(module, graph)
Exemplo n.º 16
0
def _transform_flatten_to_module(module: Module, debug: bool) -> GraphModule:
    """Transforms PyTorch's flatten method to the nn.Flatten module.

    Args:
        module: container module to transform
        debug: whether to print debug messages

    Returns:
        equivalent transformed module
    """
    target_function = "<built-in method flatten"
    target_method = "flatten"
    if debug:
        print(f"\tBegin transformation: {target_function} -> Flatten")

    graph: Graph = BackpackTracer().trace(module)
    nodes = [
        n
        for n in graph.nodes
        if (n.op == "call_function" and target_function in str(n.target))
        or (n.op == "call_method" and target_method == str(n.target))
    ]

    for node in nodes:
        start_dim = node.args[1] if len(node.args) > 1 else 0
        end_dim = node.args[2] if len(node.args) > 2 else -1
        _change_node_to_module(
            node, "flatten", module, Flatten(start_dim, end_dim), (node.args[0],)
        )

    graph.lint()

    if debug:
        print(f"\tFlatten functions transformed: {len(nodes)}")

    return GraphModule(module, graph)
Exemplo n.º 17
0
def _transform_permute_to_module(module: Module, debug: bool) -> GraphModule:
    """Transforms permute function or method to Permute module.

    Args:
        module: container module to transform
        debug: whether to print debug messages

    Returns:
        equivalent transformed module
    """
    target1 = "permute"
    target2 = "<built-in method permute"
    if debug:
        print(f"\tBegin transformation: {target1}|{target2} -> Permute")
    graph: Graph = BackpackTracer().trace(module)

    nodes = [
        n
        for n in graph.nodes
        if (n.op == "call_function" and target2 in str(n.target))
        or (n.op == "call_method" and target1 == str(n.target))
    ]

    for node in nodes:
        _change_node_to_module(
            node,
            "permute",
            module,
            Permute(*node.args[1]) if len(node.args) == 2 else Permute(*node.args[1:]),
            (node.args[0],),
        )

    graph.lint()
    if debug:
        print(f"\tPermute transformed: {len(nodes)}")
    return GraphModule(module, graph)
Exemplo n.º 18
0
    def test_type_check_conv2D_2_fully_static(self):
        annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                           (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)]
        input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                      (10, 15, 13, 14), (1, 2, 2, 3)]
        intermediate_types = [(1, Dyn, Dyn, 7), (2, Dyn, 4, 6),
                              (10, 15, Dyn, 5), (10, 15, 7, 7),
                              (1, Dyn, Dyn, Dyn)]
        in_planes_list = [2, 5, 15, 15, 2]
        stride_list = [1, 2, 3, 2, 2]
        out_planes_list = [2, 5, 15, 15, 2]
        groups_list = [1, 5, 5, 5, 2]
        dilation_list = [1, 2, 3, 3, 3]
        padding_list = [1, 2, 3, 3, 3]
        kernel_size_list = [1, 2, 3, 3, 3]
        output_types = [(1, 2, Dyn, 7), (2, 5, 4, 6), (10, 15, Dyn, 5),
                        (10, 15, 7, 7), (1, 2, Dyn, Dyn)]

        for i in range(5):
            annotation = annotation_list[i]
            input = input_list[i]
            in_planes = in_planes_list[i]
            stride = stride_list[i]
            out_planes = out_planes_list[i]
            groups = groups_list[i]
            dilation = dilation_list[i]
            padding = padding_list[i]
            kernel_size = kernel_size_list[i]
            intermediate_type = intermediate_types[i]

            class BasicBlock(torch.nn.Module):
                def __init__(self, in_planes, out_planes, kernel_size, stride,
                             padding, groups, dilation):
                    super(BasicBlock, self).__init__()
                    self.conv1 = torch.nn.Conv2d(in_channels=in_planes,
                                                 out_channels=out_planes,
                                                 kernel_size=kernel_size,
                                                 stride=stride,
                                                 padding=padding,
                                                 groups=groups,
                                                 bias=False,
                                                 dilation=dilation)

                def forward(self, x):
                    out = self.conv1(x)
                    return out

            B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding,
                           groups, dilation)
            ast_rewriter = RewritingTracer()
            graph = ast_rewriter.trace(B)
            traced = GraphModule(ast_rewriter.root, graph, "gm")

            # annotate our argument
            for n in graph.nodes:
                if n.op == 'placeholder':
                    n.type = TensorType(annotation)

            b = B.forward(torch.rand(input))
            tc = GraphTypeChecker({}, traced)
            tc.type_check()

            for n in graph.nodes:
                if n.op == 'output':
                    assert is_consistent(n.type, TensorType(b.size()))

            # test with intermediate annotations
            class BasicBlock(torch.nn.Module):
                def __init__(self, in_planes, out_planes, kernel_size, stride,
                             padding, groups, dilation):
                    super(BasicBlock, self).__init__()
                    self.conv1 = torch.nn.Conv2d(in_channels=in_planes,
                                                 out_channels=out_planes,
                                                 kernel_size=kernel_size,
                                                 stride=stride,
                                                 padding=padding,
                                                 groups=groups,
                                                 bias=False,
                                                 dilation=dilation)

                def forward(self, x):
                    out = self.conv1(x)
                    return out

            B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding,
                           groups, dilation)
            ast_rewriter = RewritingTracer()
            graph = ast_rewriter.trace(B)
            traced = GraphModule(ast_rewriter.root, graph, "gm")

            # populate our intermediate notes
            for n in traced.graph.nodes:
                if n.op == 'call_module':
                    n.type = TensorType(intermediate_type)

            tc = GraphTypeChecker({}, traced)
            tc.type_check()

            for n in traced.graph.nodes:
                if n.op == 'output':
                    assert n.type == TensorType(output_types[i])
                    assert is_consistent(n.type, TensorType(b.size()))
Exemplo n.º 19
0
    def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False):
        self.restore_state(model)
        if not inplace:
            model = copy.deepcopy(model)
        self.is_dynamic_quant = is_dynamic_quant
        # run weight observers before inserting quant dequant nodes
        # for dynamic quantization
        if self.is_dynamic_quant:
            self._run_weight_observers(model)

        # move to cpu since we only have quantized cpu kernels
        model.eval().cpu()
        self.modules = dict(model.named_modules())

        matches = self._find_matches(model.graph, self.modules, self.patterns)
        quants = self._find_quants(model.graph, matches)
        self.quantized_graph = Graph()
        env = {}
        quant_env = {}

        def load_non_quantized(n):
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find node:' + n.name + \
                    ' in quantized or non quantized environment, env: ' + str(env) + \
                    ' quant_env:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n):
            if n.name not in quant_env:
                assert n.name in env, \
                    'trying to load quantized node but did not find node:' + n.name + \
                    ' in float environment:' + str(env)
                assert n.name in quants, 'did not find quant object for node:' + n.name
                quant = quants[n.name][0]
                quant_env[n.name] = quant.convert(self, env[n.name])
            return quant_env[n.name]

        def load_x(n):
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized):
            """
            Input: quantized, which can be None, list, boolean or tuple
              - if quantized is a list or tuple, then arg should be a list and the args with corresponding
                indexes will be quantized
              - if quantized is a boolean, then all args will be quantized/not quantized
              - if quantized is None, then we'll load the node as long as it exists

            Output: fn which takes arg_or_args, and loads them from the corresponding
              environment depending on the value of quantized.
            """
            assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg_or_args):
                if quantized is None:
                    return map_arg(arg_or_args, load_x)
                if isinstance(quantized, bool):
                    return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                    loaded_args = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg_or_args):
                        if i in quantized:
                            loaded_args.append(map_arg(a, load_quantized))
                        else:
                            loaded_args.append(map_arg(a, load_non_quantized))
                    return type(arg_or_args)(loaded_args)
            return load_arg_impl

        def is_quantized(node):
            if isinstance(node, Node):
                assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
                # there might be nodes appearing in both environemnts, but quant_env will take
                # precedence
                if node.name in quant_env:
                    return True
                elif node.name in env:
                    return False
            elif isinstance(node, list):
                quantized = map(is_quantized, node)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception("partially quantized inputs in list not handled yet")

        for node in model.graph.nodes:
            root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None))
            if root_node is node:
                if qconfig is None:
                    result = self.quantized_graph.node_copy(node, load_non_quantized)
                    quantized = False
                else:
                    result = obj.convert(self, node, load_arg)
                    # Need to get correct quantized/non-quantized state for the output of CopyNode
                    if isinstance(obj, CopyNode):
                        assert node.op in [
                            'call_module',
                            'call_function',
                            'call_method'], \
                            'CopyNode of type ' + node.op + ' is not handled'
                        quantized = is_quantized(node.args[0])
                    else:
                        quantized = True

                    # output of dynamic quantization is not quantized
                    if self.is_dynamic_quant:
                        quantized = False

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module':
                if node.target.split('.')[-1].startswith('activation_post_process_'):
                    observer_module = self.modules[node.target]
                    prev_node = node.args[0]
                    if observer_module.dtype == torch.float16:
                        # activations are not quantized for
                        # fp16 dynamic quantization
                        # copy the activaiton_post_process node here
                        # since we may need it when we insert prepack
                        # op for weight of linear, this will be removed
                        # later in a separate pass
                        env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
                        continue
                    if prev_node.name in quant_env:
                        # if previous node is already quantized, we'll just remove the activation_post_process
                        quant_env[node.name] = quant_env[prev_node.name]
                        continue
                    # replace activation post process with quantization ops
                    root_module = self.modules['']
                    quant_env[node.name] = quantize_node(
                        root_module, self.quantized_graph,
                        load_non_quantized(node.args[0]), observer_module)
                    continue
            # dequantize inputs for the node that are not quantized
            env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
        self.quantized_graph.output(map_arg(model.graph.result, load_non_quantized))

        # remove activation post process
        act_post_process_removed_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])
        for node in self.quantized_graph.nodes:
            if node.op == 'call_module' and \
               node.target.split('.')[-1].startswith('activation_post_process_'):
                # remove activation post process
                env[node.name] = env[node.args[0].name]
            else:
                env[node.name] = act_post_process_removed_graph.node_copy(node, load_arg)
        act_post_process_removed_graph.output(map_arg(self.quantized_graph.result, load_arg))

        to_be_removed = []
        for name, _ in model.named_modules():
            if name.split('.')[-1].startswith('activation_post_process_'):
                to_be_removed.append(name)
        for n in to_be_removed:
            delattr(model, n)
        _remove_qconfig(model)
        model = GraphModule(model, act_post_process_removed_graph)
        return model
Exemplo n.º 20
0
    def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant):
        if not inplace:
            model = copy.deepcopy(model)
        self.is_dynamic_quant = is_dynamic_quant
        # TODO: allow user specified patterns
        if self.is_dynamic_quant:
            self.patterns = get_dynamic_quant_patterns()
        else:
            self.patterns = get_quant_patterns()

        propagate_qconfig_(model, qconfig_dict)
        if model.training:
            self._qat_swap_modules(model)

        self.modules = dict(model.named_modules())

        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph)

        # match the patterns that will get quantized
        matches = self._find_matches(model.graph, self.modules, self.patterns)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()

        env = {}
        observed_graph = Graph()
        observed_node_names_set = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in model.graph.nodes:
            if node.name in observed_node_names_set:
                continue

            get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_')
            root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)

                def insert_observer(node, observer, device):
                    observer_name = get_new_observer_name(model)
                    setattr(model, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
                    observed_node_names_set.add(node.name)
                    if device:
                        getattr(model, observer_name).to(device)

                # don't need to insert observer for output in dynamic quantization
                if self.is_dynamic_quant:
                    continue

                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed_node_names_set
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))
                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed_node_names_set.add(node.name)
                elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed_node_names_set:
                        observed_node_names_set.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    new_observer = qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    insert_observer(node, new_observer, device)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed_node_names_set and node.name in quants:
                observer_name = get_new_observer_name(model)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    new_observer = \
                        qconfig.weight() if is_weight else qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    if device:
                        new_observer.to(device)
                    self.activation_post_process_map[node.name] = new_observer
                    setattr(model, observer_name, self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
                    observed_node_names_set.add(node.name)
        observed_graph.output(load_arg(model.graph.result))

        model = GraphModule(model, observed_graph)
        self.save_state(model)
        return model
Exemplo n.º 21
0
    def _prepare(self, model, qconfig_dict, inplace, quant_type):
        input_root = model.root
        if not inplace:
            input_root = copy.deepcopy(input_root)

        input_graph = model.graph
        self.quant_type = quant_type
        # TODO: allow user specified patterns
        if self.quant_type == QuantType.DYNAMIC:
            self.patterns = get_dynamic_quant_patterns()
        else:
            self.patterns = get_quant_patterns()

        propagate_qconfig_(input_root, qconfig_dict)
        if input_root.training:
            self._qat_swap_modules(input_root)

        self.modules = dict(input_root.named_modules())

        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(input_root, input_graph)

        # match the patterns that will get quantized
        matches = self._find_matches(input_graph, self.modules, self.patterns)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(input_graph, matches)

        self.activation_post_process_map = dict()

        env = {}
        observed_graph = Graph()
        observed = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in input_graph.nodes:
            if node.name in observed:
                continue

            def get_new_observer_name(parent_module):
                i = 0

                def get_observer_name(i):
                    return 'activation_post_process_' + str(i)

                observer_name = get_observer_name(i)
                while hasattr(parent_module, observer_name):
                    i += 1
                    observer_name = get_observer_name(i)
                return observer_name

            root_node, _, obj, qconfig = matches.get(node.name,
                                                     (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)

                def insert_observer(node, observer):
                    observer_name = get_new_observer_name(input_root)
                    setattr(input_root, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, [load_arg(node)], {})
                    observed.add(node.name)

                # don't need to insert observer for output in dynamic quantization
                if self.quant_type == QuantType.DYNAMIC:
                    continue

                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))

                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed.add(node.name)
                elif (isinstance(obj, Add)
                      or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed:
                        observed.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    insert_observer(node, qconfig.activation())
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed and node.name in quants:
                observer_name = get_new_observer_name(input_root)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    self.activation_post_process_map[
                        node.name] = qconfig.weight(
                        ) if is_weight else qconfig.activation()
                    setattr(input_root, observer_name,
                            self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, [load_arg(node)], {})
                    observed.add(node.name)
        observed_graph.output(load_arg(input_graph.result))

        return GraphModule(input_root, observed_graph)
Exemplo n.º 22
0
    def _prepare(self, model: GraphModule, qconfig_dict: Any,
                 prepare_custom_config_dict: Optional[Dict[str, Any]],
                 is_standalone_module: bool) -> GraphModule:
        """ standalone_module means it a submodule that is not inlined in
        parent module, and will be quantized separately as one unit.

        When we are preparing a standalone module:
        both input and output are observed in prepared standalone module
        Returns:
            model(GraphModule): prepared standalone module
        """
        if prepare_custom_config_dict is None:
            prepare_custom_config_dict = {}
        self.prepare_custom_config_dict = prepare_custom_config_dict

        additional_quant_patterns = \
            prepare_custom_config_dict.get("additional_quant_pattern", {})
        self.patterns = get_combined_dict(
            get_default_quant_patterns(), additional_quant_patterns)

        flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
        # TODO: support regex as well
        propagate_qconfig_(model, flattened_qconfig_dict)
        if model.training:
            additional_qat_module_mapping = prepare_custom_config_dict.get(
                "additional_qat_module_mapping", {})
            self._qat_swap_modules(model, additional_qat_module_mapping)

        self.modules = dict(model.named_modules())

        convert_dict_to_ordered_dict(qconfig_dict)
        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph, qconfig_dict)

        # match the patterns that will get quantized
        standalone_module_names = prepare_custom_config_dict.get(
            "standalone_module_name", None)
        standalone_module_classes = prepare_custom_config_dict.get(
            "standalone_module_class", None)
        custom_module_classes = get_custom_module_class_keys(
            prepare_custom_config_dict, "float_to_observed_custom_module_class")
        assert self.patterns is not None
        matches = self._find_matches(
            model.graph, self.modules, self.patterns, standalone_module_names,
            standalone_module_classes, custom_module_classes)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuantizeHandler object for each
        quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \
            self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()
        env: Dict[Any, Any] = {}
        observed_graph = Graph()
        observed_node_names_set: Set[str] = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        # indexes for the inputs that needs to be observed
        standalone_module_observed_input_idxs: List[int] = []
        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        get_new_observer_name = get_new_attr_name_with_prefix(
            'activation_post_process_')

        placeholder_node_seen_cnt = 0
        output_node_seen_cnt = 0
        input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
            "input_quantized_idxs", [])
        output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
            "output_quantized_idxs", [])

        result_node : Optional[Node] = None
        for node in model.graph.nodes:
            if node.op == 'output':
                # If this output is hardcoded to be quantized, insert an
                # observer on the previous node if it does not already
                # exist.
                cur_output_node_idx = output_node_seen_cnt
                output_node_seen_cnt += 1
                if cur_output_node_idx in output_quantized_idxs:
                    prev_node = node.args[0]
                    assert isinstance(prev_node, Node), \
                        ('hardcoding list/dict outputs to be quantized is ' +
                         'not supported')
                    if prev_node.name not in observed_node_names_set:
                        assert self.qconfig_map is not None
                        local_qconfig = self.qconfig_map[prev_node.name]
                        assert local_qconfig is not None, \
                            'qconfig of a node before a quantized output must exist'
                        insert_observer(
                            prev_node, local_qconfig.activation(),
                            model, self.activation_post_process_map,
                            env, observed_graph, load_arg, observed_node_names_set)

                observed_graph.output(load_arg(node.args[0]))
                result_node = node
                continue

            if node.name in observed_node_names_set:
                continue

            root_node, matched_nodes, pattern, obj, qconfig = matches.get(
                node.name, (None, None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                # index for input of custom module that needs to be observed in
                # parent
                if qconfig is not None:
                    assert obj is not None
                    insert_observer_for_special_module(
                        obj, self.modules, prepare_custom_config_dict, qconfig,
                        node)
                    insert_observer_for_output_of_the_node(
                        node, obj, qconfig, self.modules, model, pattern,
                        self.activation_post_process_map, env,
                        observed_graph, load_arg, observed_node_names_set,
                        matched_nodes)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.op == 'placeholder':
                # skip adding observers at the graph input if the input is
                # overriden to be quantized
                cur_placeholder_node_idx = placeholder_node_seen_cnt
                placeholder_node_seen_cnt += 1
                if cur_placeholder_node_idx in input_quantized_idxs:
                    observed_node_names_set.add(node.name)
                    continue

            insert_observer_for_input_arg_of_observed_node(
                node, observed_node_names_set, quants,
                model, self.activation_post_process_map, env,
                observed_graph, load_arg)


        model = GraphModule(model, observed_graph)
        self.save_state(model)
        model = mark_observed_module(model)
        return model
Exemplo n.º 23
0
    def _convert(self,
                 model,
                 inplace=False,
                 debug=False,
                 convert_custom_config_dict=None,
                 is_standalone_module=False):
        """ standalone_module means it a submodule that is not inlined in parent module,
        and will be quantized separately as one unit.
        For standalone module: the inputs will be quantized by parent module,
        checks `_standalone_module_observed_input_idxs` of
        input observed model and will treat these inputs as quantized
        also will not dequantize the final output.
        Returns a quantized standalone module which accepts quantized input(if needed)
        and produces quantized output (if needed).
        """
        if convert_custom_config_dict is None:
            convert_custom_config_dict = {}
        self.restore_state(model)
        if not inplace:
            model = copy.deepcopy(model)
        # always run weight observers in the top level forward method
        # for dynamic quant ops or weight only quant ops
        self._run_weight_observers(model)

        # move to cpu since we only have quantized cpu kernels
        model.eval().cpu()
        self.modules = dict(model.named_modules())

        custom_module_class_mapping = convert_custom_config_dict.get(
            "observed_to_quantized_custom_module_class", None)
        matches = self._find_matches(
            model.graph,
            self.modules,
            self.patterns,
            custom_module_class_mapping=custom_module_class_mapping)

        quants = self._find_quants(model.graph, matches)

        self.quantized_graph = Graph()
        env = {}
        quant_env = {}

        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        def load_non_quantized(n):
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find node:' + n.name + \
                    ' in quantized or non quantized environment, env: ' + str(env) + \
                    ' quant_env:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n):
            if n.name not in quant_env:
                assert n.name in env, \
                    'trying to load quantized node but did not find node:' + n.name + \
                    ' in float environment:' + str(env)
                assert n.name in quants, 'did not find quant object for node:' + n.name
                quant = quants[n.name][0]
                quant_env[n.name] = quant.convert(self, env[n.name])
            return quant_env[n.name]

        def load_x(n):
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized):
            """
            Input: quantized, which can be None, list, boolean or tuple
              - if quantized is a list or tuple, then arg should be a list and the args with corresponding
                indexes will be quantized
              - if quantized is a boolean, then all args will be quantized/not quantized
              - if quantized is None, then we'll load the node as long as it exists

            Output: fn which takes arg_or_args, and loads them from the corresponding
              environment depending on the value of quantized.
            """
            assert quantized is None or isinstance(
                quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg_or_args):
                if quantized is None:
                    return map_arg(arg_or_args, load_x)
                if isinstance(quantized, bool):
                    return map_arg(
                        arg_or_args,
                        load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                    loaded_args = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg_or_args):
                        if i in quantized:
                            loaded_args.append(map_arg(a, load_quantized))
                        else:
                            loaded_args.append(map_arg(a, load_non_quantized))
                    return type(arg_or_args)(loaded_args)

            return load_arg_impl

        def is_quantized(node):
            if isinstance(node, Node):
                assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
                # there might be nodes appearing in both environemnts, but quant_env will take
                # precedence
                if node.name in quant_env:
                    return True
                elif node.name in env:
                    return False
            elif isinstance(node, list):
                quantized = map(is_quantized, node)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception(
                        "partially quantized inputs in list not handled yet")

        for node in model.graph.nodes:
            if node.op == 'output':
                if is_standalone_module:
                    # result are kept quantized in the quantized standalone module
                    graph_output = map_arg(node.args[0], load_x)
                else:
                    graph_output = map_arg(node.args[0], load_non_quantized)
                self.quantized_graph.output(graph_output)
                continue
            root_node, matched, obj, qconfig = matches.get(
                node.name, (None, None, None, None))
            if root_node is node:
                if qconfig is None:
                    result = self.quantized_graph.node_copy(
                        node, load_non_quantized)
                    quantized = False
                else:
                    result = obj.convert(
                        self,
                        node,
                        load_arg,
                        debug=debug,
                        convert_custom_config_dict=convert_custom_config_dict)
                    if node.op == 'call_module' and is_observed_standalone_module(
                            self.modules[node.target]):
                        quantized = self.modules[
                            node.target]._output_is_observed
                    else:
                        quantized = True

                    # Need to get correct quantized/non-quantized state for the output of CopyNode
                    if isinstance(obj, CopyNode):
                        assert node.op in [
                            'call_module',
                            'call_function',
                            'call_method'], \
                            'CopyNode of type ' + node.op + ' is not handled'
                        quantized = is_quantized(node.args[0])

                    if not activation_is_statically_quantized(qconfig):
                        quantized = False

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module':
                if is_activation_post_process(self.modules[node.target]):
                    observer_module = self.modules[node.target]
                    prev_node = node.args[0]
                    if observer_module.dtype == torch.float16:
                        # activations are not quantized for
                        # fp16 dynamic quantization
                        # copy the activaiton_post_process node here
                        # since we may need it when we insert prepack
                        # op for weight of linear, this will be removed
                        # later in a separate pass
                        env[node.name] = self.quantized_graph.node_copy(
                            node, load_non_quantized)
                        continue
                    if prev_node.name in quant_env:
                        # if previous node is already quantized, we'll just remove the activation_post_process
                        quant_env[node.name] = quant_env[prev_node.name]
                        continue
                    # replace activation post process with quantization ops
                    root_module = self.modules['']
                    quant_env[node.name] = quantize_node(
                        root_module, self.quantized_graph,
                        load_non_quantized(node.args[0]), observer_module)
                    continue

            if is_standalone_module and node.op == 'placeholder' and \
               graph_inputs.index(node.name) in model._standalone_module_observed_input_idxs:
                # the node is quantized in parent module
                quant_env[node.name] = self.quantized_graph.node_copy(
                    node, load_non_quantized)
            else:
                # dequantize inputs for the node that are not quantized
                env[node.name] = self.quantized_graph.node_copy(
                    node, load_non_quantized)

        # remove activation post process
        act_post_process_removed_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in self.quantized_graph.nodes:
            if node.op == 'output':
                act_post_process_removed_graph.output(
                    map_arg(node.args[0], load_arg))
                continue
            if node.op == 'call_module' and \
               is_activation_post_process(self.modules[node.target]):
                # remove activation post process node
                env[node.name] = env[node.args[0].name]
            else:
                env[node.name] = act_post_process_removed_graph.node_copy(
                    node, load_arg)

        module_dict = dict(model.named_modules())
        to_be_removed = []
        for name, module in model.named_modules():
            if is_activation_post_process(
                    module) and not is_submodule_of_fake_quant(
                        name, module, module_dict):
                to_be_removed.append(name)
        for n in to_be_removed:
            delattr(model, n)
        _remove_qconfig(model)
        model = GraphModule(model, act_post_process_removed_graph)
        return model
Exemplo n.º 24
0
    def test_typecheck_basicblock(self):
        class BasicBlock(torch.nn.Module):
            expansion = 1

            def __init__(self,
                         inplanes,
                         planes,
                         stride=1,
                         downsample=None,
                         groups=1,
                         base_width=64,
                         dilation=1,
                         norm_layer=None):
                super(BasicBlock, self).__init__()
                if norm_layer is None:
                    norm_layer = torch.nn.BatchNorm2d
                if groups != 1 or base_width != 64:
                    raise ValueError(
                        'BasicBlock only supports groups=1 and base_width=64')
                if dilation > 1:
                    raise NotImplementedError(
                        "Dilation > 1 not supported in BasicBlock")
                # Both self.conv1 and self.downsample layers downsample the input when stride != 1
                self.conv1 = conv3x3(inplanes, planes, stride)
                self.bn1 = norm_layer(planes)
                self.relu = torch.nn.ReLU(inplace=True)
                self.conv2 = conv3x3(planes, planes)
                self.bn2 = norm_layer(planes)
                self.downsample = downsample
                self.stride = stride

            def forward(self, x: TensorType((2, 2, 4, 5))):
                identity = x

                out = self.conv1(x)
                out = self.bn1(out)
                out = self.relu(out)

                out = self.conv2(out)
                out = self.bn2(out)

                if self.downsample is not None:
                    identity = self.downsample(x)

                out += identity
                out = self.relu(out)

                return out

        B = BasicBlock(2, 2)

        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")

        tc = GraphTypeChecker({}, traced)
        tc.type_check()

        for n in traced.graph.nodes:
            if n.target == 'output':
                assert isinstance(n.type, TensorType)
                assert torch.Size(n.type.__args__) == B.forward(
                    torch.rand(2, 2, 4, 5)).size()
Exemplo n.º 25
0
        def lower_to_elementwise_interpreter(
                orig_mod: torch.nn.Module) -> torch.nn.Module:
            # ===== Stage 1: Symbolic trace the module =====
            mod = symbolic_trace(orig_mod)

            # ===== Stage 2: Lower GraphModule representation to the C++
            #       interpreter's instruction format ======
            instructions = []
            constant_idx = 0
            constants = {}
            fn_input_names = []

            target_to_name = {operator.add: "add", operator.mul: "mul"}

            # For each instruction, create a triple
            # (instruction_name : str, inputs : List[str], output : str)
            # to feed into the C++ interpreter
            for n in mod.graph.nodes:
                target, args, out_name = n.target, n.args, n.name
                assert len(n.kwargs) == 0, "kwargs currently not supported"

                if n.op == 'placeholder':
                    # Placeholders specify function argument names. Save these
                    # for later when we generate the wrapper GraphModule
                    fn_input_names.append(target)
                elif n.op == 'call_function':
                    assert target in target_to_name, "Unsupported call target " + target
                    arg_names = []
                    for arg in args:
                        if not isinstance(arg, Node):
                            # Pull out constants. These constants will later be
                            # fed to the interpreter C++ object via add_constant()
                            arg_name = f'constant_{constant_idx}'
                            constants[arg_name] = torch.Tensor(
                                [arg] if isinstance(arg, numbers.Number
                                                    ) else arg)
                            arg_names.append(arg_name)
                            constant_idx += 1
                        else:
                            arg_names.append(arg.name)
                    instructions.append(
                        (target_to_name[target], arg_names, out_name))

                else:
                    raise RuntimeError('Unsupported opcode' + n.op)

            interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter(
            )
            # Load constants
            for k, v in constants.items():
                interpreter.add_constant(k, v)
            # Specify names for positional input arguments
            interpreter.set_input_names(fn_input_names)
            # Load instructions
            interpreter.set_instructions(instructions)
            # Specify name for single output
            interpreter.set_output_name(mod.graph.result.name)

            # ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
            class WrapperModule(torch.nn.Module):
                def __init__(self, interpreter):
                    super().__init__()
                    self.interpreter = interpreter

            wrapper = WrapperModule(interpreter)

            # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
            # 3) Returns the speficied return value

            # FIXME: The following code could be greatly simplified by symbolic_trace'ing
            # the wrapper with a Tracer that considers the Wrapper instance a root
            # module, however, I can't get `__call__` exposed on TorchBind classes
            # without it messing up Python `hasattr` for some reason. More digging
            # into CPython's implementation of hasattr is probably in order...

            graph = torch.fx.Graph()
            # Add placeholders for fn inputs
            placeholder_nodes = []
            for name in fn_input_names:
                placeholder_nodes.append(graph.create_node(
                    'placeholder', name))

            # Get the interpreter object
            interpreter_node = graph.create_node('get_attr', 'interpreter')

            # Add a node to call the interpreter instance
            output_node = graph.create_node(op='call_method',
                                            target='__call__',
                                            args=(interpreter_node,
                                                  placeholder_nodes))

            # Register output
            graph.output(output_node)

            graph.lint(wrapper)

            # Return final GraphModule!!!
            return GraphModule(wrapper, graph)
Exemplo n.º 26
0
def create_a_shadows_b(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
    matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
    logger_cls: Callable,
    should_log_inputs: bool,
    node_type_to_io_type_map: Optional[Dict[str,
                                            Set[NSNodeTargetType]]] = None,
) -> GraphModule:
    """
    Creates a new GraphModule consisting of the graph of C, with the meaningful
    nodes of A shadowing the corresponding nodes of B.  For example,

    Graph A:
    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2

    Graph B:
    b0 -> op0_int8 -> b1 -> op1_int8 -> b2

    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}

    Graph C (A shadows B):

        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
       /                                     /
    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1

    In a nutshell, this function does the following for each node pair:
    * copies the necessary attributes and modules from gm_a to gm_b,
      keeping names unique
    * adds a dtype cast op (dequant, quant, etc)
    * adds a copy of node_a in gm_b's graph
    * adds loggers to the outputs of node_a and node_b
    """

    if node_type_to_io_type_map is None:
        node_type_to_io_type_map = get_node_type_to_io_type_map()

    # graph_c is the graph created from copying the nodes of graph_b and inserting
    # the shadows with the nodes copied from graph_a
    graph_c = Graph()
    env_c: Dict[str, Any] = {}
    modules = dict(gm_b.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env_c[node.name])

    start_node_b_to_matched_subgraph_a_and_name = {}
    end_node_b_to_matched_subgraph_a_and_name = {}
    for match_name, match in matched_subgraph_pairs.items():
        subgraph_a, subgraph_b = match
        ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
        ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
        start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \
            (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
        end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \
            (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)

    for node_b in gm_b.graph.nodes:
        if node_b.op == 'output':
            graph_c.output(map_arg(node_b.args[0], load_arg))
            continue

        # calculate the flags to determine what to do with this node
        node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
        node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name

        if (node_b_is_start_node or node_b_is_end_node):

            if node_b_is_start_node:
                subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
                    start_node_b_to_matched_subgraph_a_and_name[node_b]
            else:
                assert node_b_is_end_node
                subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
                    end_node_b_to_matched_subgraph_a_and_name[node_b]

            # For both start_node and end_node verify that we know how to do
            # the dtype cast. If we do not, skip.
            node_input_type_a, node_output_type_a = \
                get_node_first_input_and_output_type(
                    subgraph_a.start_node, gm_a, logger_cls,
                    node_type_to_io_type_map)
            node_input_type_b, node_output_type_b = \
                get_node_first_input_and_output_type(
                    node_b, gm_b, logger_cls,
                    node_type_to_io_type_map)
            node_io_types_known_a_and_b = (
                node_input_type_a != NodeInputOrOutputType.UNKNOWN
                and node_output_type_a != NodeInputOrOutputType.UNKNOWN
                and node_input_type_b != NodeInputOrOutputType.UNKNOWN
                and node_output_type_b != NodeInputOrOutputType.UNKNOWN)
            if not node_io_types_known_a_and_b:
                print(
                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}'
                    +
                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}'
                    + ', unknown dtype cast')
                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                continue

            # If we are shadowing from fp32 to int8, we need to insert
            # quantize_per_tensor call with qparams from the previous node.
            # Only do this if we are able to infer these qparams from the graph.
            if (node_input_type_a == NodeInputOrOutputType.INT8
                    and node_input_type_b == NodeInputOrOutputType.FP32):
                node_a_input_qparams = get_node_input_qparams(
                    subgraph_a.start_node, gm_a, node_type_to_io_type_map)
                if not node_a_input_qparams:
                    print(
                        f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}'
                        +
                        f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}'
                        + ', unknown input qparams')
                    env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                    continue

            fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
            fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)

            if node_b_is_start_node:

                # if necessary, log the input of node_c
                if should_log_inputs:
                    if isinstance(node_b.args[0], Node):
                        prev_node_c = env_c[node_b.args[0].name]
                        env_c[prev_node_c.name] = _insert_logger_after_node(
                            prev_node_c,
                            gm_b,
                            logger_cls,
                            '_ns_logger_b_inp_',
                            node_b.name,
                            name_b,
                            ref_name,
                            ref_node_type_b,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0,
                            index_of_arg=0,
                            fqn=fqn_base_b)
                    elif isinstance(node_b.args[0], list):
                        # first, save the prev_node instances, because they
                        # will be overwritten in the env after the first logger
                        # is added
                        prev_node_c_list = [
                            env_c[arg.name] for arg in node_b.args[0]
                        ]

                        for arg_idx, arg in enumerate(node_b.args[0]):
                            prev_node_c = prev_node_c_list[arg_idx]
                            env_c[
                                prev_node_c.name] = _insert_logger_after_node(
                                    prev_node_c,
                                    gm_b,
                                    logger_cls,
                                    '_ns_logger_b_inp_',
                                    node_b.name,
                                    name_b,
                                    ref_name,
                                    ref_node_type_b,
                                    NSSingleResultValuesType.NODE_INPUT.value,
                                    index_within_arg=arg_idx,
                                    index_of_arg=0,
                                    fqn=fqn_base_b)
                    else:
                        # logging of inputs which are not lists is not supported yet
                        raise AssertionError(
                            f"type {type(node_b.args[0])} is not handled yet")
                # subgraph so far:
                #
                # (prev_node_c)+ -> (logger_c_input)?

            # Note: this if statement is always True, spelling it out to clarify code
            # intent.
            if node_b_is_start_node or node_b_is_end_node:
                # ensure env_c is populated with base node
                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                node_c = env_c[node_b.name]

                # after this point,
                #
                # node_a is the original node from graph_a, with parent module gm_a
                # node_b is the original node from graph_b, with parent module gm_b
                # node_c is the copy of node_b in graph_c
                #
                # subgraph so far:
                #
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

            if node_b_is_start_node:

                # cast dtype from the dtype of node_c's input to the dtype of
                # node_a's input (dequant, etc)
                prev_node_c = node_c.args[0]
                if should_log_inputs:
                    # skip the input logger when inserting a dtype cast
                    if isinstance(prev_node_c, Node):
                        prev_node_c = prev_node_c.args[0]
                    elif isinstance(prev_node_c, list):
                        prev_node_c = [arg.args[0] for arg in prev_node_c]
                dtype_cast_node = _insert_dtype_cast_after_node(
                    subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b,
                    graph_c, node_b.name + '_dtype_cast_', logger_cls,
                    node_type_to_io_type_map)
                # note: not inserting to env_c because all nodes which use the dtype
                #   casts are copied from graph_a
                #
                # subgraph so far:
                #
                #           (dtype_cast_node)+
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

                # if input logging is enabled, log the input to the subgraph
                if should_log_inputs:
                    # TODO: explain this
                    ref_node_name = ''
                    if isinstance(dtype_cast_node, Node):
                        dtype_cast_node = _insert_logger_after_node(
                            dtype_cast_node,
                            gm_b,
                            logger_cls,
                            '_ns_logger_a_inp_',
                            ref_node_name,
                            name_a,
                            ref_name,
                            ref_node_type_a,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0,
                            index_of_arg=0,
                            fqn=fqn_base_a)
                        input_logger: Union[Node, List[Node]] = dtype_cast_node
                    else:
                        assert isinstance(dtype_cast_node, list)
                        new_loggers = []
                        for dtype_cast_idx, dtype_cast_node_inner in enumerate(
                                dtype_cast_node):
                            dtype_cast_logger = _insert_logger_after_node(
                                dtype_cast_node_inner,
                                gm_b,
                                logger_cls,
                                '_ns_logger_a_inp_',
                                ref_node_name,
                                name_a,
                                ref_name,
                                ref_node_type_a,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=dtype_cast_idx,
                                index_of_arg=0,
                                fqn=fqn_base_a)
                            new_loggers.append(dtype_cast_logger)
                        dtype_cast_node = new_loggers
                        input_logger = dtype_cast_node
                    # subgraph so far:
                    #
                    #       (dtype_cast_node)+ -> (logger_a_input)?
                    #                  /
                    # prev_node_c -> (logger_c_input)? -> node_start_c

                # hook up the new mod_a copy to be in the graph, receiving the
                # same inputs as mod_b does, with dtype cast to match a
                # Some ops, such as LSTMs, have two non-param inputs. If we have
                # such an op, pass the second param as well. Note: dtype casting
                # for the second param is not implemented yet, it can be added
                # later if there is a use case.
                node_c_second_non_param_arg = None
                num_non_param_args_node_a = get_number_of_non_param_args(
                    subgraph_a.start_node, gm_a)
                if num_non_param_args_node_a == 2:
                    node_c_second_non_param_arg = node_c.args[1]
                node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
                    dtype_cast_node, node_c_second_non_param_arg, subgraph_a,
                    gm_a, gm_b, node_c.name + '_shadow_copy_')
                env_c[node_a_shadows_c.name] = node_a_shadows_c
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

                if should_log_inputs:
                    # When we created the input logger, we left the ref_node_name
                    # as an empty string, because the subgraph copy did not exist
                    # yet. Now that the subgraph copy exists, we modify this name
                    # to its true value.
                    # Note: the alternative to this is to create the input logger
                    # after creating the subgraph, which is slightly more
                    # complicated. This is the lesser of two evils.
                    # input_logger = env_c[dtype_cast_node.name]
                    # Find the first node in the subgraph
                    cur_node = node_a_shadows_c
                    while cur_node.args[0] != input_logger:
                        cur_node = cur_node.args[0]  # type: ignore[assignment]
                    if isinstance(input_logger, Node):
                        input_logger_mod = getattr(gm_b, input_logger.name)
                        input_logger_mod.ref_node_name = cur_node.name
                    else:
                        assert isinstance(input_logger, list)
                        for input_logger_inner in input_logger:
                            input_logger_mod = getattr(gm_b,
                                                       input_logger_inner.name)
                            input_logger_mod.ref_node_name = cur_node.name

                # hook up a logger to the mod_a copy
                env_c[node_a_shadows_c.name] = _insert_logger_after_node(
                    env_c[node_a_shadows_c.name],
                    gm_b,
                    logger_cls,
                    '_ns_logger_a_',
                    node_a_shadows_c.name,
                    name_a,
                    ref_name,
                    ref_node_type_a,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0,
                    index_of_arg=0,
                    fqn=fqn_base_a)
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
                #                  /
                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c

            if node_b_is_end_node:

                # hook up a logger to the mod_b copy
                env_c[node_b.name] = _insert_logger_after_node(
                    env_c[node_b.name],
                    gm_b,
                    logger_cls,
                    '_ns_logger_b_',
                    node_b.name,
                    name_b,
                    ref_name,
                    ref_node_type_b,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0,
                    index_of_arg=0,
                    fqn=fqn_base_b)
                # subgraph so far:
                #
                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
                #                  /
                # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
                #
                # Note: node_start_c may be the same node as node_end_c, or they
                # may have nodes inbetween.

        else:
            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)

    gm_c = GraphModule(gm_b, graph_c)
    return gm_c
Exemplo n.º 27
0
    def _convert(self, model: GraphModule, debug: bool = False,
                 convert_custom_config_dict: Dict[str, Any] = None,
                 is_standalone_module: bool = False) -> GraphModule:
        """ standalone_module means it a submodule that is not inlined in
        parent module, and will be quantized separately as one unit.

        Returns a quantized standalone module which accepts float input
        and produces float output.
        """
        if convert_custom_config_dict is None:
            convert_custom_config_dict = {}
        self.restore_state(model)
        # always run weight observers in the top level forward method
        # for dynamic quant ops or weight only quant ops
        self._run_weight_observers(model)

        # move to cpu since we only have quantized cpu kernels
        model.eval().cpu()
        self.modules = dict(model.named_modules())

        custom_module_classes = get_custom_module_class_keys(
            convert_custom_config_dict,
            "observed_to_quantized_custom_module_class")
        assert self.patterns is not None
        matches = self._find_matches(
            model.graph, self.modules, self.patterns,
            custom_module_classes=custom_module_classes)

        quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]] = \
            self._find_quants(model.graph, matches)

        self.quantized_graph = Graph()
        env: Dict[str, Node] = {}
        quant_env: Dict[str, Node] = {}

        graph_inputs: List[str] = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        def load_non_quantized(n: Node) -> Node:
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find ' + \
                    'node:' + n.name + \
                    ' in quantized or non quantized environment, env: ' + \
                    str(env) + ' quant_env:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n: Node) -> Node:
            assert n.name in quant_env, \
                'trying to load quantized node but did not find node:' + \
                n.name + ' in quant environment:' + str(quant_env)
            return quant_env[n.name]

        def load_x(n: Node) -> Node:
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized: Optional[Union[List[Any], bool, Tuple[Any, ...]]]
                     ) -> Callable[[Node], Argument]:
            """
            Input: quantized, which can be None, list, boolean or tuple
              - if quantized is a list or tuple, then arg should be a list and
                the args with corresponding indexes will be quantized
              - if quantized is a boolean, then all args will be
                quantized/not quantized
              - if quantized is None, then we'll load the node as long as it
                exists

            Output: fn which takes arg_or_args, and loads them from the
                corresponding environment depending on the value of quantized.
            """
            assert quantized is None or \
                isinstance(quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg_or_args):
                if quantized is None:
                    return map_arg(arg_or_args, load_x)
                if isinstance(quantized, bool):
                    return map_arg(
                        arg_or_args,
                        load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                    loaded_args = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg_or_args):
                        if i in quantized:
                            loaded_args.append(map_arg(a, load_quantized))
                        else:
                            loaded_args.append(map_arg(a, load_non_quantized))
                    return type(arg_or_args)(loaded_args)
            return load_arg_impl

        def node_arg_is_quantized(node_arg: Any) -> bool:
            if isinstance(node_arg, Node):
                assert node_arg.name in env or node_arg.name in quant_env, \
                    'Expecting node_arg to be in the environment'
                # there might be nodes appearing in both environemnts, but
                # quant_env will take precedence
                if node_arg.name in quant_env:
                    return True
                elif node_arg.name in env:
                    return False
                else:
                    return False
            elif isinstance(node_arg, list):
                quantized = map(node_arg_is_quantized, node_arg)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception(
                        "partially quantized inputs in list not handled yet")
            else:
                return False

        def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool:
            """ Check if output node is quantized or not """
            assert self.modules is not None
            # by default the output is expected to be quantized
            quantized = True

            # Need to get correct quantized/non-quantized state for the output
            # of CopyNode
            if type(obj) in [
                    CopyNode,
                    FixedQParamsOpQuantizeHandler
            ]:
                assert node.op in [
                    'call_module',
                    'call_function',
                    'call_method'], \
                    'CopyNode of type ' + node.op + ' is not handled'
                quantized = node_arg_is_quantized(node.args[0])

            if not activation_is_statically_quantized(qconfig) or \
               not input_output_observed(obj):
                quantized = False

            return quantized

        def insert_quantize_node(node: Node) -> None:
            """ Given a activation_post_process module call node, insert a
            quantize node"""
            assert self.modules is not None
            assert isinstance(node.target, str)
            observer_module = self.modules[node.target]
            prev_node = node.args[0]
            if observer_module.dtype == torch.float16:
                # activations are not quantized for
                # fp16 dynamic quantization
                # copy the activaiton_post_process node here
                # since we may need it when we insert prepack
                # op for weight of linear, this will be removed
                # later in a separate pass
                env[node.name] = self.quantized_graph.node_copy(
                    node, load_non_quantized)
            elif isinstance(prev_node, Node) and prev_node.name in quant_env:
                # if previous node is already quantized, we'll just remove the
                # activation_post_process
                quant_env[node.name] = quant_env[prev_node.name]
            else:
                # replace activation post process with quantization ops
                root_module = self.modules[""]
                assert isinstance(node.args[0], Node)
                quant_env[node.name] = quantize_node(
                    root_module, self.quantized_graph,
                    load_non_quantized(node.args[0]), observer_module)

        # additional state to override inputs to be quantized, if specified
        # by the user
        placeholder_node_seen_cnt = 0
        output_node_seen_cnt = 0
        input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
            "input_quantized_idxs", [])
        output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
            "output_quantized_idxs", [])

        for node in model.graph.nodes:
            if node.op == 'output':
                cur_output_node_idx = output_node_seen_cnt
                output_node_seen_cnt += 1
                if cur_output_node_idx in output_quantized_idxs:
                    # Result are kept quantized if the user specified the
                    # output_quantized_idxs override.
                    graph_output = map_arg(node.args[0], load_x)
                else:
                    graph_output = map_arg(node.args[0], load_non_quantized)
                self.quantized_graph.output(graph_output)
                continue
            root_node, matched, matched_pattern, obj, qconfig = \
                matches.get(node.name, (None, None, None, None, None))
            if root_node is node:
                if qconfig is None:
                    result = self.quantized_graph.node_copy(
                        node, load_non_quantized)
                    quantized = False
                else:
                    assert obj is not None
                    is_standalone_module_node = (
                        node.op == 'call_module' and
                        is_observed_standalone_module(
                            self.modules[node.target])  # type: ignore
                    )
                    result = obj.convert(
                        self, node, load_arg, debug=debug,
                        convert_custom_config_dict=convert_custom_config_dict)
                    if is_standalone_module_node:
                        quantized = False
                    else:
                        quantized = is_output_quantized(node, obj)

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module' and \
                    is_activation_post_process(self.modules[node.target]):
                insert_quantize_node(node)
            elif node.op == 'placeholder':
                cur_placeholder_node_idx = placeholder_node_seen_cnt
                placeholder_node_seen_cnt += 1
                if cur_placeholder_node_idx in input_quantized_idxs:
                    quant_env[node.name] = \
                        self.quantized_graph.node_copy(node, load_non_quantized)
                else:
                    env[node.name] = \
                        self.quantized_graph.node_copy(node, load_non_quantized)
            else:
                # copy quantized or non-quantized node
                env[node.name] = \
                    self.quantized_graph.node_copy(node, load_non_quantized)

        # remove activation post process
        act_post_process_removed_graph = Graph()
        env = {}

        def load_arg_simple(a: Argument) -> Argument:
            return map_arg(a, lambda node: env[node.name])
        for node in self.quantized_graph.nodes:
            if node.op == 'output':
                act_post_process_removed_graph.output(
                    map_arg(node.args[0], load_arg_simple))
                continue
            if node.op == 'call_module' and \
               is_activation_post_process(self.modules[node.target]):
                # remove activation post process node
                env[node.name] = env[node.args[0].name]
            else:
                env[node.name] = act_post_process_removed_graph.node_copy(
                    node, load_arg_simple)

        # removes qconfig and activation_post_process modules
        _remove_qconfig(model)
        model = GraphModule(model, act_post_process_removed_graph)
        return model
Exemplo n.º 28
0
def add_loggers_to_model(
    gm: GraphModule,
    node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
    node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
    logger_cls: Callable,
    model_name: str,
) -> GraphModule:
    """
    Takes the graph of gm, adds loggers to the output
    of each node in nodes_to_instrument. Returns a GraphModule with the new
    graph.
    """

    new_graph = Graph()
    env: Dict[str, Any] = {}
    modules = dict(gm.named_modules())

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    for node in gm.graph.nodes:
        if node.op == 'output':
            new_graph.output(map_arg(node.args[0], load_arg))
            continue

        if ((node in node_to_instrument_inputs_to_ref_node_name)
                or (node in node_to_instrument_outputs_to_ref_node_name)):
            fqn = _maybe_get_fqn(node, gm)

            if node in node_to_instrument_inputs_to_ref_node_name:
                ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[
                    node]
                # Ops such add and mul are special because either
                # one or two of the first two arguments can be tensors,
                # and if one argument is a tensor it can be first or
                # second (x + 1 versus 1 + x).
                arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
                for node_arg_idx in arg_indices_to_log:
                    node_arg = node.args[node_arg_idx]
                    if type(node_arg) == Node:
                        # create a single input logger
                        prev_node = env[node_arg.name]
                        env[node_arg.name] = _insert_logger_after_node(
                            prev_node,
                            gm,
                            logger_cls,
                            '_ns_logger_',
                            node.name,
                            model_name,
                            ref_name,
                            ref_node_type,
                            NSSingleResultValuesType.NODE_INPUT.value,
                            index_within_arg=0,
                            index_of_arg=node_arg_idx,
                            fqn=fqn)
                    elif type(
                            node_arg
                    ) == torch.fx.immutable_collections.immutable_list:
                        # create N input loggers, one for each node
                        for arg_idx, arg in enumerate(node_arg):
                            prev_node = env[arg.name]
                            env[prev_node.name] = _insert_logger_after_node(
                                prev_node,
                                gm,
                                logger_cls,
                                '_ns_logger_',
                                node.name,
                                model_name,
                                ref_name,
                                ref_node_type,
                                NSSingleResultValuesType.NODE_INPUT.value,
                                index_within_arg=arg_idx,
                                index_of_arg=node_arg_idx,
                                fqn=fqn)
                    else:
                        pass

            # ensure env is populated with base node
            # Note: runs for both inputs and outputs
            env[node.name] = new_graph.node_copy(node, load_arg)

            if node in node_to_instrument_outputs_to_ref_node_name:
                ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[
                    node]
                # add the logger after the base node
                env[node.name] = _insert_logger_after_node(
                    env[node.name],
                    gm,
                    logger_cls,
                    '_ns_logger_',
                    node.name,
                    model_name,
                    ref_name,
                    ref_node_type,
                    NSSingleResultValuesType.NODE_OUTPUT.value,
                    index_within_arg=0,
                    index_of_arg=0,
                    fqn=fqn)

        else:
            env[node.name] = new_graph.node_copy(node, load_arg)

    new_gm = GraphModule(gm, new_graph)
    return new_gm
Exemplo n.º 29
0
    def convert(self, observed, inplace=False, debug=False):
        assert self.activation_post_process_map is not None
        # move to cpu since we only have quantized cpu kernels
        observed.eval().cpu()
        observed_root = observed.root
        observed_graph = observed.graph
        if not inplace:
            observed_root = copy.deepcopy(observed_root)
        self.modules = dict(observed_root.named_modules())

        matches = self._find_matches(observed.graph, self.modules,
                                     self.patterns)
        quants = self._find_quants(observed.graph, matches)
        self.quantized_graph = Graph()
        env = {}
        quant_env = {}

        def load_non_quantized(n):
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find node:' + n.name + \
                    ' in quantized environment:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n):
            if n.name not in quant_env:
                assert n.name in env, \
                    'trying to load quantized node but did not find node:' + n.name + \
                    ' in float environment:' + str(env)
                assert n.name in quants, 'did not find quant object for node:' + n.name
                quant = quants[n.name][0]
                quant_env[n.name] = quant.convert(self, env[n.name])
            return quant_env[n.name]

        def load_x(n):
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either of the environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized):
            """
            if quantized is a list, then arg should be a list and the args with corresponding
            indexes will be quantized
            if quantized is a boolean, then all args will be quantized/not quantized
            if quantized is None, then we'll load the node as long as it exists
            """
            assert quantized is None or isinstance(
                quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg):
                if quantized is None:
                    return map_arg(arg, load_x)
                if isinstance(quantized, bool):
                    return map_arg(
                        arg,
                        load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg, (tuple, list)), arg
                    loaded_arg = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg):
                        if i in quantized:
                            loaded_arg.append(map_arg(a, load_quantized))
                        else:
                            loaded_arg.append(map_arg(a, load_non_quantized))
                    return type(arg)(loaded_arg)

            return load_arg_impl

        def is_quantized(node):
            if isinstance(node, Node):
                assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
                # there might be nodes appearing in both environemnts, but quant_env will take
                # precedence
                if node.name in quant_env:
                    return True
                elif node.name in env:
                    return False
            elif isinstance(node, list):
                quantized = map(is_quantized, node)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception(
                        "partially quantized inputs in list not handled yet")

        for node in observed_graph.nodes:
            root_node, matched, obj, qconfig = matches.get(
                node.name, (None, None, None, None))
            if root_node is node:
                result = obj.convert(self, node, load_arg)
                quantized = True
                # Need to get correct quantized/non-quantized state for the output of CopyNode
                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'
                    quantized = is_quantized(node.args[0])

                if self.quant_type == QuantType.DYNAMIC:
                    quantized = False

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module':
                if node.target.split('.')[-1].startswith(
                        'activation_post_process_'):
                    observer_module = self.modules[node.target]
                    prev_node = node.args[0]
                    if prev_node.name in quant_env:
                        # if previous node is already quantized, we'll just remove the activation_post_process
                        quant_env[node.name] = quant_env[prev_node.name]
                        continue
                    # replace activation post process with quantization ops
                    parent_name = ''

                    scale, zero_point = observer_module.calculate_qparams()
                    # TODO: per channel
                    scale = float(scale)
                    zero_point = int(zero_point)
                    dtype = observer_module.dtype
                    qparams = {
                        '_scale_': scale,
                        '_zero_point_': zero_point,
                        '_dtype_': dtype
                    }
                    i = 0

                    def noattr(module, qparams, i):
                        for name in qparams.keys():
                            if hasattr(module, name + str(i)):
                                return False
                        return True

                    def get_next_i(module, qparams):
                        i = 0
                        while not noattr(module, qparams, i):
                            i += 1
                        return i

                    parent_module = self.modules[parent_name]
                    i = get_next_i(parent_module, qparams)
                    inputs = [load_non_quantized(node.args[0])]
                    for key, value in qparams.items():
                        setattr(parent_module, key + str(i), value)
                        qparam_full_path = key + str(i)
                        if parent_name:
                            qparam_full_path = parent_name + '.' + qparam_full_path
                        inputs.append(
                            self.quantized_graph.create_node(
                                'get_param', qparam_full_path))
                    quant_env[node.name] = self.quantized_graph.create_node(
                        'call_function', torch.quantize_per_tensor, inputs, {})
                    continue
            # dequantize inputs for the node that are not quantized
            env[node.name] = self.quantized_graph.node_copy(
                node, load_non_quantized)

        self.quantized_graph.output(load_non_quantized(observed_graph.result))

        to_be_removed = []
        for name, _ in observed_root.named_modules():
            if name.split('.')[-1].startswith('activation_post_process_'):
                to_be_removed.append(name)
        for n in to_be_removed:
            delattr(observed_root, n)
        return GraphModule(observed_root, self.quantized_graph)
Exemplo n.º 30
0
def _prepare_fx(model: torch.nn.Module,
                qconfig_dict: Any,
                prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
                equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
                backend_config_dict: Optional[Dict[str, Any]] = None,
                is_standalone_module: bool = False) -> ObservedGraphModule:
    r""" Internal helper function for prepare_fx
    Args:
      `model`, `qconfig_dict`, `prepare_custom_config_dict`, `equalization_qonfig_dict`:
      see docs for :func:`~torch.quantization.prepare_fx`
      `is_standalone_module`: a boolean flag indicates whether we are
      quantizing a standalone module or not, a standalone module
      is a submodule of the parent module that is not inlined in the
forward graph of the parent module,
      the way we quantize standalone module is described in:
      :func:`~torch.quantization._prepare_standalone_module_fx`
    """
    if prepare_custom_config_dict is None:
        prepare_custom_config_dict = {}
    if equalization_qconfig_dict is None:
        equalization_qconfig_dict = {}

    check_is_valid_qconfig_dict(qconfig_dict)
    check_is_valid_prepare_custom_config_dict(prepare_custom_config_dict)
    check_is_valid_qconfig_dict(equalization_qconfig_dict)

    skipped_module_names = prepare_custom_config_dict.get(
        "non_traceable_module_name", [])
    skipped_module_classes = prepare_custom_config_dict.get(
        "non_traceable_module_class", [])

    # swap FloatFunctional with FXFloatFunctional
    _swap_ff_with_fxff(model)

    # symbolically trace the model
    if not is_standalone_module:
        # standalone module and custom module config are applied in top level module
        standalone_module_name_configs = prepare_custom_config_dict.get(
            "standalone_module_name", [])
        skipped_module_names += [
            config[0] for config in standalone_module_name_configs
        ]

        standalone_module_class_configs = prepare_custom_config_dict.get(
            "standalone_module_class", [])
        skipped_module_classes += [
            config[0] for config in standalone_module_class_configs
        ]
        float_custom_module_classes = get_custom_module_class_keys(
            prepare_custom_config_dict,
            "float_to_observed_custom_module_class")
        skipped_module_classes += float_custom_module_classes

    preserved_attributes = prepare_custom_config_dict.get(
        "preserved_attributes", [])
    tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)
    graph_module = GraphModule(model, tracer.trace(model))
    for attr_name in preserved_attributes:
        setattr(graph_module, attr_name, getattr(model, attr_name))
    graph_module = _fuse_fx(graph_module, prepare_custom_config_dict)
    prepared = prepare(graph_module,
                       qconfig_dict,
                       tracer.node_name_to_scope,
                       prepare_custom_config_dict=prepare_custom_config_dict,
                       equalization_qconfig_dict=equalization_qconfig_dict,
                       is_standalone_module=is_standalone_module)

    for attr_name in preserved_attributes:
        setattr(prepared, attr_name, getattr(model, attr_name))
    return prepared