コード例 #1
0
    def fuse(self, model, inplace=False):
        if not inplace:
            model = copy.deepcopy(model)
        input_root = model
        input_graph = model.graph
        self.modules = dict(input_root.named_modules())

        fusion_patterns = get_fusion_patterns()
        # find fusion
        fusion_pairs = self._find_matches(input_root, input_graph,
                                          fusion_patterns)
        self.fused_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in input_graph.nodes:
            root_node, obj = fusion_pairs.get(node.name, (None, None))
            if root_node is node:
                env[node.name] = obj.fuse(self, load_arg)
            elif root_node is None:
                env[node.name] = self.fused_graph.node_copy(node, load_arg)
            # node matched in patterns and is not root is removed here

        self.fused_graph.output(load_arg(input_graph.result))
        model = GraphModule(input_root, self.fused_graph)
        return model
コード例 #2
0
ファイル: test_fx.py プロジェクト: snuspl/nimble
 def transform(traced):
     new_graph = torch._fx.Graph()
     new_graph.graph_copy(traced.graph)
     relu_out = new_graph.create_node(op='call_method',
                                      target='neg',
                                      args=(new_graph.nodes[-1], ),
                                      kwargs={})
     new_graph.output(relu_out)
     return GraphModule(traced, new_graph)
コード例 #3
0
ファイル: test_fx.py プロジェクト: snuspl/nimble
    def test_graph_edit_with_proxy(self):
        class M(torch.nn.Module):
            def forward(self, a, b):
                return a + b

        m = M()
        g = symbolic_trace(m).graph
        new_g = torch._fx.Graph()
        new_g.graph_copy(g)
        t = Proxy(new_g.nodes[-1])
        # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
        new_g.output((t + t).node)
        gm = GraphModule(m, new_g)
        gm.graph.lint(gm)
        self.assertEqual(gm(3, 4), 14)
コード例 #4
0
ファイル: quantization.py プロジェクト: snuspl/nimble
    def quantize(self):
        self.quantized_graph = Graph()

        env = {}
        quant_env = {}

        def load_arg(n, quantized):
            if not quantized:
                if n.name not in env and n.name in quant_env:
                    env[n.name] = Proxy(quant_env[n.name]).dequantize().node
                return env[n.name]
            else:
                if n.name not in quant_env and n.name in env:
                    quant_env[n.name] = self.quants[n.name].quantize(
                        env[n.name])
                return quant_env[n.name]

        def copy_recursive(node):
            def load_or_emit(n):
                if n.name in env or e.name in quant_env:
                    return load_arg(n, quantized=False)
                else:
                    return copy_recusive(n)

            r = env[node.name] = self.quantized_graph.node_copy(
                node, lambda n: load_arg(n, quantized=False))
            return r

        for node in self.graph.nodes:
            root_node, obj = self.matches.get(node.name, (None, None))
            if root_node is None:
                # not quantized just copy it
                env[node.name] = self.quantized_graph.node_copy(
                    node, lambda n: load_arg(n, quantized=False))

            elif root_node is node:
                r = obj.quantize(
                    self, node, lambda a: map_arg(
                        a, lambda n: load_arg(n, quantized=True)))
                if r is NotImplemented:
                    # quantizer choose to to quantize the node take the entire match, and just copy it over
                    env[node.name] = copy_recursive(node)
                else:
                    quant_env[node.name] = r

        self.quantized_graph.output(
            load_arg(self.graph.result, quantized=False))
        return GraphModule(self.root, self.quantized_graph)
コード例 #5
0
ファイル: test_fx.py プロジェクト: snuspl/nimble
 def test_graph_fns(self):
     g = Graph()
     a = g.placeholder('a')
     b = g.call_module('linear', (a, ))
     c = g.get_attr('bias')
     d = g.call_method('add', (b, c))
     e = g.call_function(torch.sin, (d, ))
     g.output(e)
     mod = torch.nn.Module()
     mod.linear = torch.nn.Linear(3, 4)
     mod.bias = torch.rand(4)
     gm = GraphModule(mod, g)
     gm.graph.lint(gm)
     input = torch.rand(3)
     r = gm(input)
     ref = torch.sin(mod.linear(input) + mod.bias)
     self.assertEqual(r, ref)
コード例 #6
0
ファイル: test_fx.py プロジェクト: snuspl/nimble
    def test_graph_unique_names(self):
        class M(torch.nn.Module):
            def forward(self, a, b):
                return a + b

        m = M()
        g = symbolic_trace(m).graph
        new_g = torch._fx.Graph()
        new_g.graph_copy(g)
        t = Proxy(new_g.nodes[-1])
        # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
        new_g.output((t + t).node)
        gm = GraphModule(m, new_g)
        seen_names: Set[str] = set()
        for node in gm.graph.nodes:
            assert node.name not in seen_names
            seen_names.add(node.name)
コード例 #7
0
    def _fold_weight(self, quantized):
        packed_weights = dict()
        # map from folded node name to the prepacked weight name
        folded_nodes = dict()
        # get packed weights
        for node in quantized.graph.nodes:
            if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
                nodes_to_fold = collect_producer_nodes(node)
                if nodes_to_fold is not None:
                    for node_to_fold in nodes_to_fold:
                        folded_nodes[node_to_fold.name] = node

                    prepacking_module = graph_module_from_producer_nodes(
                        quantized, nodes_to_fold)
                    packed_weight = prepacking_module()
                    packed_weights[node.name] = packed_weight

        # remove folded nodes and replace the prepacking node with getattr
        folded_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])
        get_new_packed_weight_name = get_new_attr_name_with_prefix('_fx_pass_packed_weight_')
        quantized_root = quantized
        quantized_graph = quantized.graph
        for node in quantized_graph.nodes:
            prepack_node = folded_nodes.get(node.name, None)
            if prepack_node is node:
                packed_weight = packed_weights[node.name]
                # add a prepacked attribute to root
                packed_weight_name = get_new_packed_weight_name(quantized_root)
                setattr(quantized_root, packed_weight_name, packed_weight)
                # replace prepack node with a getattr node
                env[node.name] = folded_graph.create_node(
                    'get_attr', packed_weight_name, (), {})
            elif prepack_node is not None:
                # remove the foled node
                continue
            else:
                # copy other nodes
                env[node.name] = folded_graph.node_copy(node, load_arg)
        folded_graph.output(load_arg(quantized_graph.result))
        quantized = GraphModule(quantized_root, folded_graph)
        return quantized
コード例 #8
0
def graph_module_from_producer_nodes(root, producer_nodes):
    r''' Construct a graph module from extracted producer nodes
    from `collect_producer_nodes` function
    Args:
      root: the root module for the original graph
      producer_nodes: a list of nodes we use to construct the graph
    Return:
      A graph module constructed from the producer nodes
    '''
    assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
    # since we traced back from node to getattrr
    producer_nodes.reverse()
    graph = Graph()
    env = {}

    def load_arg(a):
        return map_arg(a, lambda node: env[node])
    for producer_node in producer_nodes:
        env[producer_node] = graph.node_copy(producer_node, load_arg)
    graph.output(load_arg(producer_nodes[-1]))
    graph_module = GraphModule(root, graph)
    return graph_module
コード例 #9
0
ファイル: test_fx.py プロジェクト: snuspl/nimble
        def lower_to_elementwise_interpreter(
                orig_mod: torch.nn.Module) -> torch.nn.Module:
            # ===== Stage 1: Symbolic trace the module =====
            mod = symbolic_trace(orig_mod)

            # ===== Stage 2: Lower GraphModule representation to the C++
            #       interpreter's instruction format ======
            instructions = []
            constant_idx = 0
            constants = {}
            fn_input_names = []

            target_to_name = {operator.add: "add", operator.mul: "mul"}

            # For each instruction, create a triple
            # (instruction_name : str, inputs : List[str], output : str)
            # to feed into the C++ interpreter
            for n in mod.graph.nodes:
                target, args, out_name = n.target, n.args, n.name
                assert len(n.kwargs) == 0, "kwargs currently not supported"

                if n.op == 'placeholder':
                    # Placeholders specify function argument names. Save these
                    # for later when we generate the wrapper GraphModule
                    fn_input_names.append(target)
                elif n.op == 'call_function':
                    assert target in target_to_name, "Unsupported call target " + target
                    arg_names = []
                    for arg in args:
                        if not isinstance(arg, Node):
                            # Pull out constants. These constants will later be
                            # fed to the interpreter C++ object via add_constant()
                            arg_name = f'constant_{constant_idx}'
                            constants[arg_name] = torch.Tensor(
                                [arg] if isinstance(arg, numbers.Number
                                                    ) else arg)
                            arg_names.append(arg_name)
                            constant_idx += 1
                        else:
                            arg_names.append(arg.name)
                    instructions.append(
                        (target_to_name[target], arg_names, out_name))

                else:
                    raise RuntimeError('Unsupported opcode' + n.op)

            interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter(
            )
            # Load constants
            for k, v in constants.items():
                interpreter.add_constant(k, v)
            # Specify names for positional input arguments
            interpreter.set_input_names(fn_input_names)
            # Load instructions
            interpreter.set_instructions(instructions)
            # Specify name for single output
            interpreter.set_output_name(mod.graph.result.name)

            # ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
            class WrapperModule(torch.nn.Module):
                def __init__(self, interpreter):
                    super().__init__()
                    self.interpreter = interpreter

            wrapper = WrapperModule(interpreter)

            # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
            # 3) Returns the speficied return value

            # FIXME: The following code could be greatly simplified by symbolic_trace'ing
            # the wrapper with a Tracer that considers the Wrapper instance a root
            # module, however, I can't get `__call__` exposed on TorchBind classes
            # without it messing up Python `hasattr` for some reason. More digging
            # into CPython's implementation of hasattr is probably in order...

            graph = torch._fx.Graph()
            # Add placeholders for fn inputs
            placeholder_nodes = []
            for name in fn_input_names:
                placeholder_nodes.append(graph.create_node(
                    'placeholder', name))

            # Get the interpreter object
            interpreter_node = graph.create_node('get_attr', 'interpreter')

            # Add a node to call the interpreter instance
            output_node = graph.create_node(op='call_method',
                                            target='__call__',
                                            args=(interpreter_node,
                                                  placeholder_nodes))

            # Register output
            graph.output(output_node)

            graph.lint(wrapper)

            # Return final GraphModule!!!
            return GraphModule(wrapper, graph)
コード例 #10
0
    def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False):
        self.restore_state(model)
        if not inplace:
            model = copy.deepcopy(model)
        self.is_dynamic_quant = is_dynamic_quant
        # run weight observers before inserting quant dequant nodes
        # for dynamic quantization
        if self.is_dynamic_quant:
            self._run_weight_observers(model)

        # move to cpu since we only have quantized cpu kernels
        model.eval().cpu()
        self.modules = dict(model.named_modules())

        matches = self._find_matches(model.graph, self.modules, self.patterns)

        quants = self._find_quants(model.graph, matches)
        self.quantized_graph = Graph()
        env = {}
        quant_env = {}

        def load_non_quantized(n):
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find node:' + n.name + \
                    ' in quantized or non quantized environment, env: ' + str(env) + \
                    ' quant_env:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n):
            if n.name not in quant_env:
                assert n.name in env, \
                    'trying to load quantized node but did not find node:' + n.name + \
                    ' in float environment:' + str(env)
                assert n.name in quants, 'did not find quant object for node:' + n.name
                quant = quants[n.name][0]
                quant_env[n.name] = quant.convert(self, env[n.name])
            return quant_env[n.name]

        def load_x(n):
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized):
            """
            Input: quantized, which can be None, list, boolean or tuple
              - if quantized is a list or tuple, then arg should be a list and the args with corresponding
                indexes will be quantized
              - if quantized is a boolean, then all args will be quantized/not quantized
              - if quantized is None, then we'll load the node as long as it exists

            Output: fn which takes arg_or_args, and loads them from the corresponding
              environment depending on the value of quantized.
            """
            assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg_or_args):
                if quantized is None:
                    return map_arg(arg_or_args, load_x)
                if isinstance(quantized, bool):
                    return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                    loaded_args = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg_or_args):
                        if i in quantized:
                            loaded_args.append(map_arg(a, load_quantized))
                        else:
                            loaded_args.append(map_arg(a, load_non_quantized))
                    return type(arg_or_args)(loaded_args)
            return load_arg_impl

        def is_quantized(node):
            if isinstance(node, Node):
                assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
                # there might be nodes appearing in both environemnts, but quant_env will take
                # precedence
                if node.name in quant_env:
                    return True
                elif node.name in env:
                    return False
            elif isinstance(node, list):
                quantized = map(is_quantized, node)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception("partially quantized inputs in list not handled yet")

        for node in model.graph.nodes:
            root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None))
            if root_node is node:
                if qconfig is None:
                    result = self.quantized_graph.node_copy(node, load_non_quantized)
                    quantized = False
                else:
                    result = obj.convert(self, node, load_arg)
                    # Need to get correct quantized/non-quantized state for the output of CopyNode
                    if isinstance(obj, CopyNode):
                        assert node.op in [
                            'call_module',
                            'call_function',
                            'call_method'], \
                            'CopyNode of type ' + node.op + ' is not handled'
                        quantized = is_quantized(node.args[0])
                    else:
                        quantized = True

                    # output of dynamic quantization is not quantized
                    if self.is_dynamic_quant:
                        quantized = False

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module':
                if is_activation_post_process(self.modules[node.target]):
                    observer_module = self.modules[node.target]
                    prev_node = node.args[0]
                    if observer_module.dtype == torch.float16:
                        # activations are not quantized for
                        # fp16 dynamic quantization
                        # copy the activaiton_post_process node here
                        # since we may need it when we insert prepack
                        # op for weight of linear, this will be removed
                        # later in a separate pass
                        env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
                        continue
                    if prev_node.name in quant_env:
                        # if previous node is already quantized, we'll just remove the activation_post_process
                        quant_env[node.name] = quant_env[prev_node.name]
                        continue
                    # replace activation post process with quantization ops
                    root_module = self.modules['']
                    quant_env[node.name] = quantize_node(
                        root_module, self.quantized_graph,
                        load_non_quantized(node.args[0]), observer_module)
                    continue
            # dequantize inputs for the node that are not quantized
            env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
        self.quantized_graph.output(map_arg(model.graph.result, load_non_quantized))

        # remove activation post process
        act_post_process_removed_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])
        for node in self.quantized_graph.nodes:
            if node.op == 'call_module' and \
               is_activation_post_process(self.modules[node.target]):
                # remove activation post process
                env[node.name] = env[node.args[0].name]
            else:
                env[node.name] = act_post_process_removed_graph.node_copy(node, load_arg)
        act_post_process_removed_graph.output(map_arg(self.quantized_graph.result, load_arg))

        module_dict = dict(model.named_modules())
        to_be_removed = []
        for name, module in model.named_modules():
            if is_activation_post_process(module) and not is_submodule_of_fake_quant(name, module, module_dict):
                to_be_removed.append(name)
        for n in to_be_removed:
            delattr(model, n)
        _remove_qconfig(model)
        model = GraphModule(model, act_post_process_removed_graph)
        return model
コード例 #11
0
    def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant):
        if not inplace:
            model = copy.deepcopy(model)
        self.is_dynamic_quant = is_dynamic_quant
        if self.is_dynamic_quant:
            self.patterns = get_dynamic_quant_patterns()
        else:
            self.patterns = get_quant_patterns()

        flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
        # TODO: support regex as well
        propagate_qconfig_(model, flattened_qconfig_dict)
        if model.training:
            self._qat_swap_modules(model)

        self.modules = dict(model.named_modules())

        convert_dict_to_ordered_dict(qconfig_dict)
        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph, qconfig_dict)

        # match the patterns that will get quantized
        matches = self._find_matches(model.graph, self.modules, self.patterns)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()

        env = {}
        observed_graph = Graph()
        observed_node_names_set = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in model.graph.nodes:
            if node.name in observed_node_names_set:
                continue

            prefix = node.name + '_activation_post_process_'
            root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                if qconfig is None:
                    continue

                def insert_observer(node, observer, device):
                    get_new_observer_name = get_new_attr_name_with_prefix(prefix)
                    observer_name = get_new_observer_name(model)
                    setattr(model, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
                    observed_node_names_set.add(node.name)
                    if device:
                        getattr(model, observer_name).to(device)

                if isinstance(obj, CustomModuleQuantizeHandler):
                    custom_module = self.modules[node.target]
                    observed_custom_module_class = \
                        get_observed_custom_module_class(type(custom_module))
                    observed_custom_module = \
                        observed_custom_module_class.from_float(custom_module)
                    mark_observed_custom_module(observed_custom_module, type(custom_module))
                    parent_name, name = _parent_name(node.target)
                    setattr(self.modules[parent_name], name, observed_custom_module)

                # don't need to insert observer for output in dynamic quantization
                if self.is_dynamic_quant:
                    continue

                # inserting observers for output of observed module, or mark the output
                # as observed
                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed_node_names_set
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))
                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed_node_names_set.add(node.name)
                elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed_node_names_set:
                        observed_node_names_set.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    new_observer = qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    insert_observer(node, new_observer, device)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed_node_names_set and node.name in quants:
                get_new_observer_name = get_new_attr_name_with_prefix(prefix)
                observer_name = get_new_observer_name(model)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    new_observer = \
                        qconfig.weight() if is_weight else qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    if device:
                        new_observer.to(device)
                    self.activation_post_process_map[node.name] = new_observer
                    setattr(model, observer_name, self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
                    observed_node_names_set.add(node.name)
        observed_graph.output(load_arg(model.graph.result))

        model = GraphModule(model, observed_graph)
        self.save_state(model)
        return model