Example #1
0
    def _prepare(self, model, qconfig_dict, inplace,
                 prepare_custom_config_dict, is_standalone_module):
        """ standalone_module means it a submodule that is not inlined in parent module,
        and will be quantized separately as one unit.

        When we are preparing a standalone module:
        input of the module is observed in parent module, output of the module
        is observed in the standalone module.
        Returns:
            model(GraphModule): prepared standalone module with following attributes:
                _standalone_module_observed_input_idxs(List[Int]): a list of indexs for the graph inputs that
                                         needs to be observed in parent module
                _output_is_observed(Bool): a boolean variable indicate whether the output of the
                                   custom module is observed or not
        """
        if prepare_custom_config_dict is None:
            prepare_custom_config_dict = {}
        if not inplace:
            model = copy.deepcopy(model)
        additional_quant_patterns = prepare_custom_config_dict.get(
            "additional_quant_pattern", {})
        self.patterns = get_default_quant_patterns().copy()
        for k, v in additional_quant_patterns.items():
            self.patterns[k] = v

        flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
        # TODO: support regex as well
        propagate_qconfig_(model, flattened_qconfig_dict)
        if model.training:
            additional_qat_module_mapping = prepare_custom_config_dict.get(
                "additioanl_qat_module_mapping", {})
            self._qat_swap_modules(model, additional_qat_module_mapping)

        self.modules = dict(model.named_modules())

        convert_dict_to_ordered_dict(qconfig_dict)
        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph, qconfig_dict)

        # match the patterns that will get quantized
        standalone_module_names = prepare_custom_config_dict.get(
            "standalone_module_name", None)
        custom_module_class_mapping = prepare_custom_config_dict.get(
            "float_to_observed_custom_module_class", None)
        matches = self._find_matches(model.graph, self.modules, self.patterns,
                                     standalone_module_names,
                                     custom_module_class_mapping)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()
        env = {}
        observed_graph = Graph()
        observed_node_names_set = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        # indexes for the inputs that needs to be observed
        standalone_module_observed_input_idxs = []
        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        get_new_observer_name = get_new_attr_name_with_prefix(
            'activation_post_process_')

        result_node: Optional[Node] = None
        for node in model.graph.nodes:
            if node.op == 'output':
                observed_graph.output(load_arg(node.args[0]))
                result_node = node
                continue
            if node.name in observed_node_names_set:
                continue

            prefix = node.name + '_activation_post_process_'
            root_node, _, obj, qconfig = matches.get(node.name,
                                                     (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                if qconfig is None:
                    continue

                def insert_observer(node, observer, device):
                    get_new_observer_name = get_new_attr_name_with_prefix(
                        prefix)
                    observer_name = get_new_observer_name(model)
                    setattr(model, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, (load_arg(node), ), {})
                    observed_node_names_set.add(node.name)
                    if device:
                        getattr(model, observer_name).to(device)

                if isinstance(obj, CustomModuleQuantizeHandler):
                    custom_module = self.modules[node.target]
                    observed_custom_module_class = \
                        custom_module_class_mapping[type(custom_module)]
                    observed_custom_module = \
                        observed_custom_module_class.from_float(custom_module)
                    parent_name, name = _parent_name(node.target)
                    setattr(self.modules[parent_name], name,
                            observed_custom_module)

                # index for input of custom module that needs to be observed in parent
                standalone_module_input_idxs = None
                if isinstance(obj, StandaloneModuleQuantizeHandler):
                    # observe standalone module
                    standalone_module = self.modules[node.target]
                    prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx
                    observed_standalone_module = prepare(
                        standalone_module, {'': qconfig})
                    observed_standalone_module.qconfig = qconfig
                    standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs
                    observed_standalone_module = mark_observed_standalone_module(
                        observed_standalone_module)
                    parent_name, name = _parent_name(node.target)
                    setattr(self.modules[parent_name], name,
                            observed_standalone_module)
                    self.modules[node.target] = observed_standalone_module

                # don't need to insert observer for output if activation does not
                # need to be statically quantized
                if not activation_is_statically_quantized(qconfig):
                    continue

                # inserting observers for output of observed module, or mark the output
                # as observed
                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed_node_names_set
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))

                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed_node_names_set.add(node.name)
                elif (isinstance(obj, Add)
                      or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed_node_names_set:
                        observed_node_names_set.add(node.name)
                elif isinstance(obj, StandaloneModuleQuantizeHandler):
                    assert node.op == 'call_module'
                    output_is_observed = self.modules[
                        node.target]._output_is_observed
                    if output_is_observed:
                        observed_node_names_set.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    new_observer = qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    insert_observer(node, new_observer, device)

                # insert observer for input of standalone module
                if standalone_module_input_idxs is not None:
                    for idx in standalone_module_input_idxs:
                        if node.args[idx].name not in observed_node_names_set:
                            new_observer = qconfig.activation()
                            device = assert_and_get_unique_device(model)
                            insert_observer(node.args[idx], new_observer,
                                            device)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed_node_names_set and node.name in quants:
                if is_standalone_module and node.name in graph_inputs:
                    # we'll insert observer for input of standalone module
                    # in parent graph
                    standalone_module_observed_input_idxs.append(
                        graph_inputs.index(node.name))
                    continue
                get_new_observer_name = get_new_attr_name_with_prefix(prefix)
                observer_name = get_new_observer_name(model)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    # TODO: use insert_observer
                    new_observer = \
                        qconfig.weight() if is_weight else qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    if device:
                        new_observer.to(device)
                    self.activation_post_process_map[node.name] = new_observer
                    setattr(model, observer_name,
                            self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, (load_arg(node), ), {})
                    observed_node_names_set.add(node.name)

        model = GraphModule(model, observed_graph)
        self.save_state(model)
        if is_standalone_module:
            assert result_node is not None
            assert isinstance(result_node.args[0], Node), \
                'standalone module returning dict is not yet supported'
            # indicator for whether output is observed or not.
            # This used for correctly quantize standalone modules
            output_is_observed = result_node.args[
                0].name in observed_node_names_set
            model._standalone_module_observed_input_idxs = standalone_module_observed_input_idxs
            model._output_is_observed = output_is_observed
        return model
Example #2
0
    def _prepare(self, model, qconfig_dict, prepare_custom_config_dict,
                 is_standalone_module):
        """ standalone_module means it a submodule that is not inlined in
        parent module, and will be quantized separately as one unit.

        When we are preparing a standalone module:
        input of the module is observed in parent module, output of the module
        is observed in the standalone module.
        Returns:
            model(GraphModule): prepared standalone module with following
            attributes:
                _standalone_module_observed_input_idxs(List[Int]): a list of
                    indexes for the graph inputs that needs to be observed in
                    parent module
                _output_is_observed(Bool): a boolean variable indicate whether
                    the output of the custom module is observed or not
        """
        if prepare_custom_config_dict is None:
            prepare_custom_config_dict = {}

        additional_quant_patterns = \
            prepare_custom_config_dict.get("additional_quant_pattern", {})
        self.patterns = get_combined_dict(get_default_quant_patterns(),
                                          additional_quant_patterns)

        flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
        # TODO: support regex as well
        propagate_qconfig_(model, flattened_qconfig_dict)
        if model.training:
            additional_qat_module_mapping = prepare_custom_config_dict.get(
                "additional_qat_module_mapping", {})
            self._qat_swap_modules(model, additional_qat_module_mapping)

        self.modules = dict(model.named_modules())

        convert_dict_to_ordered_dict(qconfig_dict)
        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph, qconfig_dict)

        # match the patterns that will get quantized
        standalone_module_names = prepare_custom_config_dict.get(
            "standalone_module_name", None)
        standalone_module_classes = prepare_custom_config_dict.get(
            "standalone_module_class", None)
        custom_module_classes = get_custom_module_class_keys(
            prepare_custom_config_dict,
            "float_to_observed_custom_module_class")
        matches = self._find_matches(model.graph, self.modules, self.patterns,
                                     standalone_module_names,
                                     standalone_module_classes,
                                     custom_module_classes)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuantizeHandler object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()
        env: Dict[Any, Any] = {}
        observed_graph = Graph()
        observed_node_names_set = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        # indexes for the inputs that needs to be observed
        standalone_module_observed_input_idxs = []
        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        get_new_observer_name = get_new_attr_name_with_prefix(
            'activation_post_process_')
        model_device = assert_and_get_unique_device(model)

        def insert_observer(node, observer):
            """Insert observer for node by modifying the observed_graph and
               attach observer module to the model
               Args:
                 node: Node
                 observer: observer/fake_quantize module instance
            """
            # respect device affinity when adding observers
            if model_device:
                observer.to(model_device)
            # add observer module as attribute
            prefix = node.name + '_activation_post_process_'
            get_new_observer_name = get_new_attr_name_with_prefix(prefix)
            observer_name = get_new_observer_name(model)
            setattr(model, observer_name, observer)
            # put observer instance activation_post_process map
            assert self.activation_post_process_map is not None
            self.activation_post_process_map[node.name] = observer
            # insert observer call
            env[node.name] = observed_graph.create_node(
                'call_module', observer_name, (load_arg(node), ), {})
            observed_node_names_set.add(node.name)

        def insert_observer_for_special_module(quantize_handler):
            """ Insert observer for custom module and standalone module
              Returns: standalone_module_input_idxs: the indexs for inputs that
              needs to be observed by parent module
            """
            standalone_module_input_idxs = None
            assert self.modules is not None
            if isinstance(quantize_handler, CustomModuleQuantizeHandler):
                custom_module = self.modules[node.target]
                custom_module_class_mapping = prepare_custom_config_dict.get(
                    "float_to_observed_custom_module_class", {})
                observed_custom_module_class = \
                    get_swapped_custom_module_class(
                        custom_module, custom_module_class_mapping, qconfig)
                observed_custom_module = \
                    observed_custom_module_class.from_float(custom_module)
                parent_name, name = _parent_name(node.target)
                setattr(self.modules[parent_name], name,
                        observed_custom_module)
            elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
                # observe standalone module
                standalone_module = self.modules[node.target]
                prepare = \
                    torch.quantization.quantize_fx._prepare_standalone_module_fx  # type: ignore
                observed_standalone_module = \
                    prepare(standalone_module, {"": qconfig})
                observed_standalone_module.qconfig = qconfig
                standalone_module_input_idxs = observed_standalone_module.\
                    _standalone_module_observed_input_idxs
                observed_standalone_module = mark_observed_standalone_module(
                    observed_standalone_module)
                parent_name, name = _parent_name(node.target)
                setattr(self.modules[parent_name], name,
                        observed_standalone_module)
                self.modules[node.target] = observed_standalone_module
            return standalone_module_input_idxs

        def insert_observer_for_output_of_the_node(
                node, quantize_handler, qconfig, standalone_module_input_idxs):
            """ Insert observer/fake_quantize module for output of the observed
            module if needed
            """
            # don't need to insert observer for output if activation does not
            # need to be statically quantized
            assert self.modules is not None
            if activation_is_statically_quantized(qconfig):
                if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) \
                        and model.training:
                    # we only insert fake quantize module in qat
                    assert pattern is not None
                    activation_post_process_ctr = \
                        get_default_output_activation_post_process_map().get(
                            pattern, None)
                    assert activation_post_process_ctr is not None, \
                        "activation_post_process constructor not provided " + \
                        "for pattern:" + str(pattern)
                    insert_observer(node, activation_post_process_ctr())
                elif (isinstance(quantize_handler,
                                 FixedQParamsOpQuantizeHandler) and
                      not model.training) or \
                        isinstance(quantize_handler, CopyNode):
                    # inserting observers for output of observed module, or
                    # mark the output as observed
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed_node_names_set
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))

                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed_node_names_set.add(node.name)
                elif ((isinstance(quantize_handler, Add)
                       or isinstance(quantize_handler, Mul))
                      and quantize_handler.num_node_args == 1):
                    assert matched_nodes is not None
                    input_node = matched_nodes[
                        -1]  # first node in the sequence

                    def input_is_observed(arg):
                        return (isinstance(arg, Node)
                                and arg.name in observed_node_names_set)

                    # This is checking if one of the argument of add/mul
                    # is an observed node
                    # If both of the inputs are number,
                    # we will not consider the output to be observed
                    if (input_is_observed(input_node.args[0])
                            or input_is_observed(input_node.args[1])):
                        observed_node_names_set.add(node.name)
                elif isinstance(quantize_handler,
                                StandaloneModuleQuantizeHandler):
                    assert node.op == 'call_module'
                    output_is_observed = \
                        self.modules[node.target]._output_is_observed
                    if output_is_observed:
                        observed_node_names_set.add(node.name)
                elif (quantize_handler.all_node_args
                      and input_output_observed(quantize_handler)):
                    # observer for outputs
                    new_observer = qconfig.activation()
                    insert_observer(node, new_observer)

            # insert observer for input of standalone module
            if standalone_module_input_idxs is not None:
                for idx in standalone_module_input_idxs:
                    if node.args[idx].name not in observed_node_names_set:
                        new_observer = qconfig.activation()
                        insert_observer(node.args[idx], new_observer)

        def insert_observer_for_input_arg_of_observed_node(arg):
            """
               Input:
                 arg: input arg node for another observed node, e.g.
                 input activaiton for functional linear node
            """
            if node.name not in observed_node_names_set and node.name in quants:
                if is_standalone_module and node.name in graph_inputs:
                    # we'll insert observer for input of standalone module
                    # in parent graph
                    standalone_module_observed_input_idxs.append(
                        graph_inputs.index(node.name))
                    return
                _, activation_post_process_ctr = quants[node.name]
                if activation_post_process_ctr is not None:
                    insert_observer(node, activation_post_process_ctr())

        result_node: Optional[Node] = None
        for node in model.graph.nodes:
            if node.op == 'output':
                observed_graph.output(load_arg(node.args[0]))
                result_node = node
                continue
            if node.name in observed_node_names_set:
                continue

            root_node, matched_nodes, pattern, obj, qconfig = matches.get(
                node.name, (None, None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                # index for input of custom module that needs to be observed in
                # parent
                if qconfig is not None:
                    standalone_module_input_idxs = \
                        insert_observer_for_special_module(obj)
                    insert_observer_for_output_of_the_node(
                        node, obj, qconfig, standalone_module_input_idxs)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            insert_observer_for_input_arg_of_observed_node(node)

        model = GraphModule(model, observed_graph)
        self.save_state(model)
        model = mark_observed_module(model)
        if is_standalone_module:
            assert result_node is not None
            assert isinstance(result_node.args[0], Node), \
                'standalone module returning dict is not yet supported'
            # indicator for whether output is observed or not.
            # This used for correctly quantize standalone modules
            output_is_observed = \
                result_node.args[0].name in observed_node_names_set
            model._standalone_module_observed_input_idxs = \
                standalone_module_observed_input_idxs
            model._output_is_observed = output_is_observed
        return model