Example #1
0
def _insert_quantize_per_tensor_node(
    prev_node_c: Node,
    node_a: Node,
    gm_b: GraphModule,
    graph_c: Graph,
    scale: Union[torch.Tensor, float],
    zero_point: Union[torch.Tensor, int],
    dtype_cast_name: str,
) -> Node:
    # copy scale
    scale_node_name = \
        get_new_attr_name_with_prefix(
            node_a.name + '_input_scale_')(gm_b)
    setattr(gm_b, scale_node_name, scale)
    scale_node = graph_c.create_node('get_attr', scale_node_name, (), {},
                                     scale_node_name)
    # copy zero_point
    zero_point_node_name = \
        get_new_attr_name_with_prefix(
            node_a.name + '_input_zero_point_')(gm_b)
    setattr(gm_b, zero_point_node_name, zero_point)
    zero_point_node = graph_c.create_node('get_attr', zero_point_node_name, (),
                                          {}, zero_point_node_name)
    # create the quantize_per_tensor call
    return graph_c.create_node(
        'call_function', torch.quantize_per_tensor,
        (prev_node_c, scale_node, zero_point_node, torch.quint8), {},
        dtype_cast_name)
Example #2
0
def _insert_dtype_cast_after_node(
    node_a: Node,
    node_c: Node,
    prev_node_c: Union[Node, List[Node]],
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
    node_name_prefix: str,
) -> Union[Node, List[Node]]:
    """
    Given a starting graph C (derived from graph B) of

    ... -> prev_node_c -> node_c -> ...

    And a corresponding related node_a, inserts the correct dtype
    cast node after prev_node_c to cast into the dtype expected
    by node_a, resulting in:

                          dtype_cast
                        /
    ... -> prev_node_c -> node_c -> ...

    For example, if node_c is an int8 op and node_a is an fp32 op, this function
    will insert a dequant.
    """
    dtype_cast_op = None
    node_input_type_a = get_node_input_type(node_a, gm_a)
    node_input_type_c = get_node_input_type(node_c, gm_b)

    if node_input_type_a == NodeInputType.FP32 and node_input_type_c == NodeInputType.INT8:
        dtype_cast_op = torch.dequantize
    else:
        raise AssertionError(
            f"dtype cast from {node_input_type_c} to {node_input_type_a} needs to be implemented"
        )

    if isinstance(prev_node_c, Node):
        new_dtype_cast_name = \
            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)

        return graph_c.create_node('call_function', dtype_cast_op,
                                   (prev_node_c, ), {}, new_dtype_cast_name)
    elif isinstance(prev_node_c, list):
        results = []
        for prev_node_c_inner in prev_node_c:
            new_dtype_cast_name = \
                get_new_attr_name_with_prefix(node_name_prefix)(gm_b)

            new_dtype_cast_node = graph_c.create_node('call_function',
                                                      dtype_cast_op,
                                                      (prev_node_c_inner, ),
                                                      {}, new_dtype_cast_name)
            results.append(new_dtype_cast_node)
        return results
    else:
        raise AssertionError(f"type f{type(prev_node_c)} is not handled")
Example #3
0
def _copy_node_from_a_to_c(
    node_a: Node,
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
) -> Node:
    """
    Simple copy of node_a to graph_c.
    """
    if node_a.op == 'get_attr':
        node_a_copy_name = \
            get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
        node_a_obj = getattr_from_fqn(gm_a,
                                      node_a.target)  # type: ignore[arg-type]
        if torch.is_tensor(node_a_obj):
            node_a_obj = node_a_obj.detach()
        setattr(gm_b, node_a_copy_name, node_a_obj)
        node_a_copy = graph_c.create_node(node_a.op, node_a_copy_name, (), {},
                                          node_a_copy_name)
        return node_a_copy
    elif node_a.op == 'call_method':
        assert node_a.target in ('dequantize', 'to'), \
            f"target {node_a.target} is not implemented"
        if node_a.target == 'dequantize':
            arg_copy = _copy_node_from_a_to_c(
                get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b,
                graph_c)  # type: ignore[arg-type]
            node_a_copy_name = \
                get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
            node_a_copy = graph_c.create_node(node_a.op, node_a.target,
                                              (arg_copy, ), {},
                                              node_a_copy_name)
            return node_a_copy
        else:  # to
            arg_copy = _copy_node_from_a_to_c(
                get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b,
                graph_c)  # type: ignore[arg-type]
            node_a_copy_name = \
                get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
            node_a_copy = graph_c.create_node(
                node_a.op, node_a.target,
                (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)), {},
                node_a_copy_name)
            return node_a_copy

    else:
        raise AssertionError(
            f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented"
        )
Example #4
0
def insert_observer(
        node: Node, observer: torch.quantization.ObserverBase,
        model: torch.nn.Module,
        activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
        env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable,
        observed_node_names_set: Set[str]):
    """Insert observer for node by modifying the observed_graph and
       attach observer module to the model
       Args:
         node: Node
         observer: observer/fake_quantize module instance
    """
    # respect device affinity when adding observers
    model_device = assert_and_get_unique_device(model)
    if model_device:
        observer.to(model_device)
    # add observer module as attribute
    prefix = node.name + '_activation_post_process_'
    get_new_observer_name = get_new_attr_name_with_prefix(prefix)
    observer_name = get_new_observer_name(model)
    setattr(model, observer_name, observer)
    # put observer instance activation_post_process map
    assert activation_post_process_map is not None
    activation_post_process_map[node.name] = observer
    # insert observer call
    env[node.name] = observed_graph.create_node(
        'call_module', observer_name, (load_arg(node),), {})
    observed_node_names_set.add(node.name)
Example #5
0
def insert_observer(
    node: Node,
    observed_op: Node,
    observer: ObserverBase,
    model: torch.nn.Module,
    modules: Dict[str, torch.nn.Module],
    graph: Graph,
    node_name_to_scope: Dict[str, Tuple[str, type]],
    input_or_output: str,
) -> Node:
    """
    Attaches `observer` to `model`, and creates a node which calls
    `observer` on the output of `node`.
    """
    model_device = assert_and_get_unique_device(model)
    if model_device:
        observer.to(model_device)
    # add observer module as attribute
    # NOTE: We get the FQN of the module/op being observed here using the node_name_to_scope
    # Please don't change/update this behavior as it might impact how observer stats are transferred
    # from the train model to the inference model for some models.
    obs_name_prefix, _ = node_name_to_scope[observed_op.name]
    obs_name_prefix = node.name if obs_name_prefix == '' else obs_name_prefix
    if is_equalization_observer(observer):
        prefix = node.name + '_equalization_process_'
    else:
        prefix = obs_name_prefix + '_' + input_or_output + '_activation_post_process_'
    get_new_observer_name = get_new_attr_name_with_prefix(prefix)
    observer_name = get_new_observer_name(model)
    setattr(model, observer_name, observer)
    modules[observer_name] = observer
    with graph.inserting_after(node):
        new_obs = graph.create_node(
            'call_module', observer_name, (node,), {})
    return new_obs
Example #6
0
def insert_observer(
    node: Node,
    observer: torch.quantization.ObserverBase,
    model: torch.nn.Module,
    modules: Dict[str, torch.nn.Module],
    graph: Graph,
) -> Node:
    """
    Attaches `observer` to `model`, and creates a node which calls
    `observer` on the output of `node`.
    """
    model_device = assert_and_get_unique_device(model)
    if model_device:
        observer.to(model_device)
    # add observer module as attribute
    if is_equalization_observer(observer):
        prefix = node.name + '_equalization_process_'
    else:
        prefix = node.name + '_activation_post_process_'
    get_new_observer_name = get_new_attr_name_with_prefix(prefix)
    observer_name = get_new_observer_name(model)
    setattr(model, observer_name, observer)
    modules[observer_name] = observer
    with graph.inserting_after(node):
        new_obs = graph.create_node(
            'call_module', observer_name, (node,), {})
    return new_obs
Example #7
0
def fold_weight(
        quantized: QuantizedGraphModule,
        node_name_to_scope: Dict[str, Tuple[str,
                                            type]]) -> QuantizedGraphModule:
    """
    Trace back from the weight node util we hit getattr, reconstruct the
    graph module with the traced nodes and run the graph module to pack the
    weight. then replace the original chain of ops with the packed weight.
    """
    packed_weights = dict()
    # map from folded node name to the prepacked weight name
    folded_nodes = dict()
    # get packed weights
    for node in quantized.graph.nodes:
        if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
            nodes_to_fold = collect_producer_nodes(node)
            if nodes_to_fold is not None:
                for node_to_fold in nodes_to_fold:
                    folded_nodes[node_to_fold.name] = node

                prepacking_module = graph_module_from_producer_nodes(
                    quantized, nodes_to_fold)
                packed_weight = prepacking_module()
                packed_weights[node.name] = packed_weight

    # remove folded nodes and replace the prepacking node with getattr
    folded_graph = Graph()
    env: Dict[Any, Any] = {}

    def load_arg(a):
        return map_arg(a, lambda node: env[node.name])

    quantized_root = quantized
    quantized_graph = quantized.graph

    for node in quantized_graph.nodes:
        prepack_node = folded_nodes.get(node.name, None)
        if prepack_node is node:
            packed_weight = packed_weights[node.name]
            # add a prepacked attribute to root
            op_node = list(prepack_node.users)[0]
            module_path, _ = node_name_to_scope[op_node.name]
            get_new_packed_weight_name = \
                get_new_attr_name_with_prefix(module_path + '_packed_weight_')
            packed_weight_name = get_new_packed_weight_name(quantized_root)
            setattr(quantized_root, packed_weight_name, packed_weight)
            # replace prepack node with a getattr node
            env[node.name] = folded_graph.create_node('get_attr',
                                                      packed_weight_name, (),
                                                      {})
        elif prepack_node is not None:
            # remove the foled node
            continue
        else:
            # copy other nodes
            env[node.name] = folded_graph.node_copy(node, load_arg)
    quantized = QuantizedGraphModule(quantized_root, folded_graph,
                                     quantized_root.preserved_attr_names)
    return quantized
Example #8
0
    def replace_observer_with_quantize_dequantize_node(
            model: torch.nn.Module, graph: Graph, node: Node,
            modules: Dict[str, torch.nn.Module],
            node_name_to_scope: Dict[str, Tuple[str, type]],
            qconfig_map: Dict[str, QConfigAny]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        module_path, prefix = get_module_path_and_prefix(
            node, node_name_to_scope, qconfig_map)
        observer_module = modules[node.target]
        maybe_quantize_node_info = get_quantize_node_info(observer_module)
        # Skip replacing observers to quant/dequant nodes if the qconfigs of all
        # consumers and producers of this observer are None
        skip_replacement = all([
            has_none_qconfig(n, qconfig_map)
            for n in list(node.args) + list(node.users.keys())
        ])
        if skip_replacement or maybe_quantize_node_info is None:
            # didn't find correponding quantize op and info for the observer_module
            # so we just remove the observer
            with graph.inserting_before(node):
                node.replace_all_uses_with(node.args[0])
                graph.erase_node(node)
        else:
            # otherwise, we can convert the observer moduel call to quantize/dequantize node
            node_type, quantize_op, qparams = maybe_quantize_node_info
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    # TODO: we can add the information of whether a value needs to
                    # be registered as an attribute in qparams dict itself
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(
                            model, graph, module_path + prefix + key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)

                quantized_node = graph.create_node(node_type, quantize_op,
                                                   tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize",
                                                     args=(quantized_node, ))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)
Example #9
0
def create_node_from_old_node_preserve_meta(
    quantized_graph: Graph,
    create_node_args: Tuple[Any, ...],
    old_node: Node,
) -> Node:
    """
    Creates `new_node` and copies the necessary metadata to it from `old_node`.
    """
    new_node = quantized_graph.create_node(*create_node_args)
    new_node.stack_trace = old_node.stack_trace
    return new_node
Example #10
0
def create_getattr_from_value(module: GraphModule, graph: Graph, prefix: str, value: Any) -> Node:
    """
    Given a value of any type, creates a getattr node corresponding to the value and
    registers the value as a buffer to the module.
    """
    get_new_attr_name = get_new_attr_name_with_prefix(prefix)
    attr_name = get_new_attr_name(module)
    module.register_buffer(attr_name, torch.tensor(value))
    # Create get_attr with value
    attr_node = graph.create_node("get_attr", attr_name)
    return attr_node
Example #11
0
def create_getattr_from_value(module: torch.nn.Module, graph: Graph,
                              prefix: str, value: Any) -> Node:
    """
    Given a value of any type, creates a getattr node corresponding to the value and
    registers the value as a buffer to the module.
    """
    get_new_attr_name = get_new_attr_name_with_prefix(prefix)
    attr_name = get_new_attr_name(module)
    device = assert_and_get_unique_device(module)
    new_value = value.clone().detach() if isinstance(value, torch.Tensor) \
        else torch.tensor(value, device=device)
    module.register_buffer(attr_name, new_value)
    # Create get_attr with value
    attr_node = graph.create_node("get_attr", attr_name)
    return attr_node
Example #12
0
    def replace_observer_with_quantize_dequantize_node(
            graph: Graph, node: Node, modules: Dict[str,
                                                    torch.nn.Module]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        observer_module = modules[node.target]
        root_module = modules[""]
        if observer_module.dtype == torch.float32:
            # remove the node for now
            # TODO: support dynamic quant
            with graph.inserting_before(node):
                node.replace_all_uses_with(node.args[0])
                graph.erase_node(node)
        elif observer_module.dtype in [
                torch.quint8, torch.qint8, torch.float16
        ]:
            node_type, quantize_op, qparams = get_quantize_node_info(
                observer_module)
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(
                            root_module, graph, key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)

                quantized_node = graph.create_node(node_type, quantize_op,
                                                   tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize",
                                                     args=(quantized_node, ))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)
Example #13
0
    def _fold_weight(self, quantized):
        packed_weights = dict()
        # map from folded node name to the prepacked weight name
        folded_nodes = dict()
        # get packed weights
        for node in quantized.graph.nodes:
            if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
                nodes_to_fold = collect_producer_nodes(node)
                if nodes_to_fold is not None:
                    for node_to_fold in nodes_to_fold:
                        folded_nodes[node_to_fold.name] = node

                    prepacking_module = graph_module_from_producer_nodes(
                        quantized, nodes_to_fold)
                    packed_weight = prepacking_module()
                    packed_weights[node.name] = packed_weight

        # remove folded nodes and replace the prepacking node with getattr
        folded_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        get_new_packed_weight_name = get_new_attr_name_with_prefix(
            '_fx_pass_packed_weight_')
        quantized_root = quantized
        quantized_graph = quantized.graph
        for node in quantized_graph.nodes:
            prepack_node = folded_nodes.get(node.name, None)
            if prepack_node is node:
                packed_weight = packed_weights[node.name]
                # add a prepacked attribute to root
                packed_weight_name = get_new_packed_weight_name(quantized_root)
                setattr(quantized_root, packed_weight_name, packed_weight)
                # replace prepack node with a getattr node
                env[node.name] = folded_graph.create_node(
                    'get_attr', packed_weight_name, (), {})
            elif prepack_node is not None:
                # remove the foled node
                continue
            else:
                # copy other nodes
                env[node.name] = folded_graph.node_copy(node, load_arg)
        folded_graph.output(load_arg(quantized_graph.result))
        quantized = GraphModule(quantized_root, folded_graph)
        return quantized
Example #14
0
def replace_target_nodes_with(
    fx_module: GraphModule,
    old_op: str,
    old_target: Target,
    new_op: str,
    new_target: Target,
):
    """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
    and updates them to match the new op code and target"""
    new_graph = Graph()
    val_map : Dict[Node, Node] = {}
    for node in fx_module.graph.nodes:
        if node.op == old_op and node.target == old_target:
            args = map_arg(node.args, lambda n: val_map[n])
            kwargs = map_arg(node.kwargs, lambda n: val_map[n])
            assert isinstance(args, tuple)
            assert isinstance(kwargs, dict)
            val_map[node] = new_graph.create_node(new_op, new_target, args, kwargs, node.name)
        else:
            val_map[node] = new_graph.node_copy(node, lambda n : val_map[n])
    fx_module.graph = new_graph
Example #15
0
    def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant):
        if not inplace:
            model = copy.deepcopy(model)
        self.is_dynamic_quant = is_dynamic_quant
        # TODO: allow user specified patterns
        if self.is_dynamic_quant:
            self.patterns = get_dynamic_quant_patterns()
        else:
            self.patterns = get_quant_patterns()

        propagate_qconfig_(model, qconfig_dict)
        if model.training:
            self._qat_swap_modules(model)

        self.modules = dict(model.named_modules())

        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph)

        # match the patterns that will get quantized
        matches = self._find_matches(model.graph, self.modules, self.patterns)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()

        env = {}
        observed_graph = Graph()
        observed_node_names_set = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in model.graph.nodes:
            if node.name in observed_node_names_set:
                continue

            get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_')
            root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)

                def insert_observer(node, observer, device):
                    observer_name = get_new_observer_name(model)
                    setattr(model, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
                    observed_node_names_set.add(node.name)
                    if device:
                        getattr(model, observer_name).to(device)

                # don't need to insert observer for output in dynamic quantization
                if self.is_dynamic_quant:
                    continue

                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed_node_names_set
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))
                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed_node_names_set.add(node.name)
                elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed_node_names_set:
                        observed_node_names_set.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    new_observer = qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    insert_observer(node, new_observer, device)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed_node_names_set and node.name in quants:
                observer_name = get_new_observer_name(model)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    new_observer = \
                        qconfig.weight() if is_weight else qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    if device:
                        new_observer.to(device)
                    self.activation_post_process_map[node.name] = new_observer
                    setattr(model, observer_name, self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
                    observed_node_names_set.add(node.name)
        observed_graph.output(load_arg(model.graph.result))

        model = GraphModule(model, observed_graph)
        self.save_state(model)
        return model
Example #16
0
def _insert_dtype_cast_after_node(
    node_a: Node,
    node_c: Node,
    prev_node_c: Union[Node, List[Node]],
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
    node_name_prefix: str,
    logger_cls: Callable,
    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Union[Node, List[Node]]:
    """
    Given a starting graph C (derived from graph B) of

    ... -> prev_node_c -> node_c -> ...

    And a corresponding related node_a, inserts the correct dtype
    cast node after prev_node_c to cast into the dtype expected
    by node_a, resulting in:

                          dtype_cast
                        /
    ... -> prev_node_c -> node_c -> ...

    For example, if node_c is an int8 op and node_a is an fp32 op, this function
    will insert a dequant.
    """
    dtype_cast_op = None
    dtype_cast_mod_cls = None
    dtype_cast_scale = None
    dtype_cast_zero_point = None
    node_input_type_a, _node_output_type_a = \
        get_node_first_input_and_output_type(
            node_a, gm_a, logger_cls, node_type_to_io_type_map)
    node_input_type_c, _node_output_type_c = \
        get_node_first_input_and_output_type(
            node_c, gm_b, logger_cls, node_type_to_io_type_map)

    if ((node_input_type_a == NodeInputOrOutputType.FP32
         and node_input_type_c == NodeInputOrOutputType.INT8)
            or (node_input_type_a == NodeInputOrOutputType.FP32
                and node_input_type_c == NodeInputOrOutputType.FP16) or
            # TODO(future PR): determine the actual dtype of node_c,
            # the current code only works because dequantize works with
            # multiple input dtypes.
        (node_input_type_a == NodeInputOrOutputType.FP32
         and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)):
        dtype_cast_op = torch.dequantize
    elif (node_input_type_a == node_input_type_c
          and node_input_type_a != NodeInputOrOutputType.UNKNOWN):
        dtype_cast_mod_cls = torch.nn.Identity
    elif (node_input_type_a == NodeInputOrOutputType.INT8
          and node_input_type_c == NodeInputOrOutputType.FP32):
        # int8 shadows fp32, the dtype cast needs to quantize to int8
        # with the right qparams.
        node_a_input_qparams = get_node_input_qparams(
            node_a, gm_a, node_type_to_io_type_map)
        if node_a_input_qparams is not None:
            dtype_cast_op = torch.quantize_per_tensor  # type: ignore[assignment]
            dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
    else:
        raise AssertionError(
            f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
            f"{node_input_type_a} {node_a.format_node()} needs to be implemented"
        )

    if isinstance(prev_node_c, Node):
        new_dtype_cast_name = \
            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
        if dtype_cast_op:
            if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
                return _insert_quantize_per_tensor_node(
                    prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale,
                    dtype_cast_zero_point, new_dtype_cast_name)
            else:
                return graph_c.create_node('call_function', dtype_cast_op,
                                           (prev_node_c, ), {},
                                           new_dtype_cast_name)
        else:
            assert dtype_cast_mod_cls
            dtype_cast_mod = dtype_cast_mod_cls()
            setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
            return graph_c.create_node('call_module', new_dtype_cast_name,
                                       (prev_node_c, ), {},
                                       new_dtype_cast_name)
    elif isinstance(prev_node_c, list):
        results = []
        for prev_node_c_inner in prev_node_c:
            new_dtype_cast_name = \
                get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
            if dtype_cast_op:
                # TODO(future PR): add handling for quantize_per_tensor
                new_dtype_cast_node = graph_c.create_node(
                    'call_function', dtype_cast_op, (prev_node_c_inner, ), {},
                    new_dtype_cast_name)
                results.append(new_dtype_cast_node)
            else:
                assert dtype_cast_mod_cls
                dtype_cast_mod = dtype_cast_mod_cls()
                setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
                new_dtype_cast_node = graph_c.create_node(
                    'call_module', new_dtype_cast_name, (prev_node_c_inner, ),
                    {}, new_dtype_cast_name)
                results.append(new_dtype_cast_node)
        return results
    else:
        raise AssertionError(f"type f{type(prev_node_c)} is not handled")
Example #17
0
    def _prepare(self, model, qconfig_dict, inplace,
                 prepare_custom_config_dict, is_standalone_module):
        """ standalone_module means it a submodule that is not inlined in parent module,
        and will be quantized separately as one unit.

        When we are preparing a standalone module:
        input of the module is observed in parent module, output of the module
        is observed in the standalone module.
        Returns:
            model(GraphModule): prepared standalone module with following attributes:
                _standalone_module_observed_input_idxs(List[Int]): a list of indexs for the graph inputs that
                                         needs to be observed in parent module
                _output_is_observed(Bool): a boolean variable indicate whether the output of the
                                   custom module is observed or not
        """
        if prepare_custom_config_dict is None:
            prepare_custom_config_dict = {}
        if not inplace:
            model = copy.deepcopy(model)
        additional_quant_patterns = prepare_custom_config_dict.get(
            "additional_quant_pattern", {})
        self.patterns = get_default_quant_patterns().copy()
        for k, v in additional_quant_patterns.items():
            self.patterns[k] = v

        flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
        # TODO: support regex as well
        propagate_qconfig_(model, flattened_qconfig_dict)
        if model.training:
            additional_qat_module_mapping = prepare_custom_config_dict.get(
                "additioanl_qat_module_mapping", {})
            self._qat_swap_modules(model, additional_qat_module_mapping)

        self.modules = dict(model.named_modules())

        convert_dict_to_ordered_dict(qconfig_dict)
        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph, qconfig_dict)

        # match the patterns that will get quantized
        standalone_module_names = prepare_custom_config_dict.get(
            "standalone_module_name", None)
        custom_module_class_mapping = prepare_custom_config_dict.get(
            "float_to_observed_custom_module_class", None)
        matches = self._find_matches(model.graph, self.modules, self.patterns,
                                     standalone_module_names,
                                     custom_module_class_mapping)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()
        env = {}
        observed_graph = Graph()
        observed_node_names_set = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        # indexes for the inputs that needs to be observed
        standalone_module_observed_input_idxs = []
        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        get_new_observer_name = get_new_attr_name_with_prefix(
            'activation_post_process_')

        result_node: Optional[Node] = None
        for node in model.graph.nodes:
            if node.op == 'output':
                observed_graph.output(load_arg(node.args[0]))
                result_node = node
                continue
            if node.name in observed_node_names_set:
                continue

            prefix = node.name + '_activation_post_process_'
            root_node, _, obj, qconfig = matches.get(node.name,
                                                     (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                if qconfig is None:
                    continue

                def insert_observer(node, observer, device):
                    get_new_observer_name = get_new_attr_name_with_prefix(
                        prefix)
                    observer_name = get_new_observer_name(model)
                    setattr(model, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, (load_arg(node), ), {})
                    observed_node_names_set.add(node.name)
                    if device:
                        getattr(model, observer_name).to(device)

                if isinstance(obj, CustomModuleQuantizeHandler):
                    custom_module = self.modules[node.target]
                    observed_custom_module_class = \
                        custom_module_class_mapping[type(custom_module)]
                    observed_custom_module = \
                        observed_custom_module_class.from_float(custom_module)
                    parent_name, name = _parent_name(node.target)
                    setattr(self.modules[parent_name], name,
                            observed_custom_module)

                # index for input of custom module that needs to be observed in parent
                standalone_module_input_idxs = None
                if isinstance(obj, StandaloneModuleQuantizeHandler):
                    # observe standalone module
                    standalone_module = self.modules[node.target]
                    prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx
                    observed_standalone_module = prepare(
                        standalone_module, {'': qconfig})
                    observed_standalone_module.qconfig = qconfig
                    standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs
                    observed_standalone_module = mark_observed_standalone_module(
                        observed_standalone_module)
                    parent_name, name = _parent_name(node.target)
                    setattr(self.modules[parent_name], name,
                            observed_standalone_module)
                    self.modules[node.target] = observed_standalone_module

                # don't need to insert observer for output if activation does not
                # need to be statically quantized
                if not activation_is_statically_quantized(qconfig):
                    continue

                # inserting observers for output of observed module, or mark the output
                # as observed
                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed_node_names_set
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))

                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed_node_names_set.add(node.name)
                elif (isinstance(obj, Add)
                      or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed_node_names_set:
                        observed_node_names_set.add(node.name)
                elif isinstance(obj, StandaloneModuleQuantizeHandler):
                    assert node.op == 'call_module'
                    output_is_observed = self.modules[
                        node.target]._output_is_observed
                    if output_is_observed:
                        observed_node_names_set.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    new_observer = qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    insert_observer(node, new_observer, device)

                # insert observer for input of standalone module
                if standalone_module_input_idxs is not None:
                    for idx in standalone_module_input_idxs:
                        if node.args[idx].name not in observed_node_names_set:
                            new_observer = qconfig.activation()
                            device = assert_and_get_unique_device(model)
                            insert_observer(node.args[idx], new_observer,
                                            device)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed_node_names_set and node.name in quants:
                if is_standalone_module and node.name in graph_inputs:
                    # we'll insert observer for input of standalone module
                    # in parent graph
                    standalone_module_observed_input_idxs.append(
                        graph_inputs.index(node.name))
                    continue
                get_new_observer_name = get_new_attr_name_with_prefix(prefix)
                observer_name = get_new_observer_name(model)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    # TODO: use insert_observer
                    new_observer = \
                        qconfig.weight() if is_weight else qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    if device:
                        new_observer.to(device)
                    self.activation_post_process_map[node.name] = new_observer
                    setattr(model, observer_name,
                            self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, (load_arg(node), ), {})
                    observed_node_names_set.add(node.name)

        model = GraphModule(model, observed_graph)
        self.save_state(model)
        if is_standalone_module:
            assert result_node is not None
            assert isinstance(result_node.args[0], Node), \
                'standalone module returning dict is not yet supported'
            # indicator for whether output is observed or not.
            # This used for correctly quantize standalone modules
            output_is_observed = result_node.args[
                0].name in observed_node_names_set
            model._standalone_module_observed_input_idxs = standalone_module_observed_input_idxs
            model._output_is_observed = output_is_observed
        return model
Example #18
0
class Quantizer:
    def __init__(self):
        # mapping from matched node to activation_post_process
        # must be filled before convert
        self.activation_post_process_map = None

    def _qat_swap_modules(self, root):
        convert(root,
                mapping=DEFAULT_QAT_MODULE_MAPPING,
                inplace=True,
                remove_qconfig=False)

    def _generate_qconfig_map(self, root, input_graph):
        def get_qconfig(module):
            return module.qconfig if hasattr(module, 'qconfig') else None

        self.qconfig_map = dict()
        for node in input_graph.nodes:
            if node.op == 'get_param':
                parent, _ = _parent_name(node.target)
                self.qconfig_map[node.name] = get_qconfig(self.modules[parent])
            elif node.op == 'call_function':
                self.qconfig_map[node.name] = get_qconfig(root)
            elif node.op == 'call_method':
                self_obj = node.args[0]
                # qconfig for call_method should be the same as the `self` object for the call
                self.qconfig_map[node.name] = self.qconfig_map[self_obj.name]
            elif node.op == 'call_module':
                self.qconfig_map[node.name] = get_qconfig(
                    self.modules[node.target])

    def _prepare(self, model, qconfig_dict, inplace, quant_type):
        input_root = model.root
        if not inplace:
            input_root = copy.deepcopy(input_root)

        input_graph = model.graph
        self.quant_type = quant_type
        # TODO: allow user specified patterns
        if self.quant_type == QuantType.DYNAMIC:
            self.patterns = get_dynamic_quant_patterns()
        else:
            self.patterns = get_quant_patterns()

        propagate_qconfig_(input_root, qconfig_dict)
        if input_root.training:
            self._qat_swap_modules(input_root)

        self.modules = dict(input_root.named_modules())

        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(input_root, input_graph)

        # match the patterns that will get quantized
        matches = self._find_matches(input_graph, self.modules, self.patterns)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(input_graph, matches)

        self.activation_post_process_map = dict()

        env = {}
        observed_graph = Graph()
        observed = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in input_graph.nodes:
            if node.name in observed:
                continue

            def get_new_observer_name(parent_module):
                i = 0

                def get_observer_name(i):
                    return 'activation_post_process_' + str(i)

                observer_name = get_observer_name(i)
                while hasattr(parent_module, observer_name):
                    i += 1
                    observer_name = get_observer_name(i)
                return observer_name

            root_node, _, obj, qconfig = matches.get(node.name,
                                                     (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)

                def insert_observer(node, observer):
                    observer_name = get_new_observer_name(input_root)
                    setattr(input_root, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, [load_arg(node)], {})
                    observed.add(node.name)

                # don't need to insert observer for output in dynamic quantization
                if self.quant_type == QuantType.DYNAMIC:
                    continue

                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))

                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed.add(node.name)
                elif (isinstance(obj, Add)
                      or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed:
                        observed.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    insert_observer(node, qconfig.activation())
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed and node.name in quants:
                observer_name = get_new_observer_name(input_root)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    self.activation_post_process_map[
                        node.name] = qconfig.weight(
                        ) if is_weight else qconfig.activation()
                    setattr(input_root, observer_name,
                            self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, [load_arg(node)], {})
                    observed.add(node.name)
        observed_graph.output(load_arg(input_graph.result))

        return GraphModule(input_root, observed_graph)

    def prepare(self, model, qconfig_dict, inplace=False):
        return self._prepare(model,
                             qconfig_dict,
                             inplace,
                             quant_type=QuantType.STATIC)

    def prepare_dynamic(self, model, qconfig_dict, inplace=False):
        return self._prepare(model,
                             qconfig_dict,
                             inplace,
                             quant_type=QuantType.DYNAMIC)

    def convert(self, observed, inplace=False, debug=False):
        assert self.activation_post_process_map is not None
        # move to cpu since we only have quantized cpu kernels
        observed.eval().cpu()
        observed_root = observed.root
        observed_graph = observed.graph
        if not inplace:
            observed_root = copy.deepcopy(observed_root)
        self.modules = dict(observed_root.named_modules())

        matches = self._find_matches(observed.graph, self.modules,
                                     self.patterns)
        quants = self._find_quants(observed.graph, matches)
        self.quantized_graph = Graph()
        env = {}
        quant_env = {}

        def load_non_quantized(n):
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find node:' + n.name + \
                    ' in quantized environment:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n):
            if n.name not in quant_env:
                assert n.name in env, \
                    'trying to load quantized node but did not find node:' + n.name + \
                    ' in float environment:' + str(env)
                assert n.name in quants, 'did not find quant object for node:' + n.name
                quant = quants[n.name][0]
                quant_env[n.name] = quant.convert(self, env[n.name])
            return quant_env[n.name]

        def load_x(n):
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either of the environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized):
            """
            if quantized is a list, then arg should be a list and the args with corresponding
            indexes will be quantized
            if quantized is a boolean, then all args will be quantized/not quantized
            if quantized is None, then we'll load the node as long as it exists
            """
            assert quantized is None or isinstance(
                quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg):
                if quantized is None:
                    return map_arg(arg, load_x)
                if isinstance(quantized, bool):
                    return map_arg(
                        arg,
                        load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg, (tuple, list)), arg
                    loaded_arg = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg):
                        if i in quantized:
                            loaded_arg.append(map_arg(a, load_quantized))
                        else:
                            loaded_arg.append(map_arg(a, load_non_quantized))
                    return type(arg)(loaded_arg)

            return load_arg_impl

        def is_quantized(node):
            if isinstance(node, Node):
                assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
                # there might be nodes appearing in both environemnts, but quant_env will take
                # precedence
                if node.name in quant_env:
                    return True
                elif node.name in env:
                    return False
            elif isinstance(node, list):
                quantized = map(is_quantized, node)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception(
                        "partially quantized inputs in list not handled yet")

        for node in observed_graph.nodes:
            root_node, matched, obj, qconfig = matches.get(
                node.name, (None, None, None, None))
            if root_node is node:
                result = obj.convert(self, node, load_arg)
                quantized = True
                # Need to get correct quantized/non-quantized state for the output of CopyNode
                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'
                    quantized = is_quantized(node.args[0])

                if self.quant_type == QuantType.DYNAMIC:
                    quantized = False

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module':
                if node.target.split('.')[-1].startswith(
                        'activation_post_process_'):
                    observer_module = self.modules[node.target]
                    prev_node = node.args[0]
                    if prev_node.name in quant_env:
                        # if previous node is already quantized, we'll just remove the activation_post_process
                        quant_env[node.name] = quant_env[prev_node.name]
                        continue
                    # replace activation post process with quantization ops
                    parent_name = ''

                    scale, zero_point = observer_module.calculate_qparams()
                    # TODO: per channel
                    scale = float(scale)
                    zero_point = int(zero_point)
                    dtype = observer_module.dtype
                    qparams = {
                        '_scale_': scale,
                        '_zero_point_': zero_point,
                        '_dtype_': dtype
                    }
                    i = 0

                    def noattr(module, qparams, i):
                        for name in qparams.keys():
                            if hasattr(module, name + str(i)):
                                return False
                        return True

                    def get_next_i(module, qparams):
                        i = 0
                        while not noattr(module, qparams, i):
                            i += 1
                        return i

                    parent_module = self.modules[parent_name]
                    i = get_next_i(parent_module, qparams)
                    inputs = [load_non_quantized(node.args[0])]
                    for key, value in qparams.items():
                        setattr(parent_module, key + str(i), value)
                        qparam_full_path = key + str(i)
                        if parent_name:
                            qparam_full_path = parent_name + '.' + qparam_full_path
                        inputs.append(
                            self.quantized_graph.create_node(
                                'get_param', qparam_full_path))
                    quant_env[node.name] = self.quantized_graph.create_node(
                        'call_function', torch.quantize_per_tensor, inputs, {})
                    continue
            # dequantize inputs for the node that are not quantized
            env[node.name] = self.quantized_graph.node_copy(
                node, load_non_quantized)

        self.quantized_graph.output(load_non_quantized(observed_graph.result))

        to_be_removed = []
        for name, _ in observed_root.named_modules():
            if name.split('.')[-1].startswith('activation_post_process_'):
                to_be_removed.append(name)
        for n in to_be_removed:
            delattr(observed_root, n)
        return GraphModule(observed_root, self.quantized_graph)

    def _find_matches(self, graph, modules, patterns):
        match_map = {}  # node name -> (root_node, match_value?)
        all_matched = set()

        def record_match(pattern, node, matched):
            if isinstance(pattern, tuple):
                s, *args = pattern
                record_match(s, node, matched)
                if pattern[0] is not getattr:
                    for subpattern, arg in zip(args, node.args):
                        record_match(subpattern, arg, matched)
            else:
                matched.append(node)

        for node in reversed(graph.nodes):
            if node.name not in match_map and node.name not in all_matched:
                for pattern, value in patterns.items():
                    if matches(modules, node, pattern):
                        matched = []
                        record_match(pattern, node, matched)
                        for n in matched:
                            match_map[n.name] = (node, matched,
                                                 value(self, node),
                                                 self.qconfig_map[n.name])
                            all_matched.add(n.name)
                        # break after finding the first match
                        break
        return match_map

    def _find_quants(self, graph, matches):
        quants = {}

        def visit(node, qconfig):
            def visit_arg(arg):
                # note: we have to measure quantization information
                # even for nodes where we might not use it because it is already
                # quantized. This is because each match has the option to
                # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
                is_weight = False
                if isinstance(
                        node, Node
                ) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
                    for i, node_arg in enumerate(node.args):
                        if arg is node_arg and i in WEIGHT_INDEX_DICT[
                                node.target]:
                            is_weight = True
                if self.quant_type != QuantType.DYNAMIC or is_weight:
                    # overwrite previous quant config
                    quants[arg.name] = (DefaultQuant(self,
                                                     arg), qconfig, is_weight)

            return visit_arg

        for node in graph.nodes:
            if node.name in matches:
                root_node, matched, obj, qconfig = matches[node.name]
                # don't attach observer/fake_quant for CopyNode
                if isinstance(obj, CopyNode):
                    qconfig = None
                if root_node is node:
                    # matched[-1] is the first op in the sequence and
                    # matched[0] is the last op in the sequence
                    # inputs
                    map_arg(matched[-1].args, visit(matched[-1], qconfig))
                    map_arg(matched[-1].kwargs, visit(matched[-1], qconfig))
                    # output
                    map_arg(matched[0], visit(None, qconfig))
        return quants
Example #19
0
class Quantizer:
    def __init__(self):
        # mapping from matched node to activation_post_process
        # must be filled before convert
        self.activation_post_process_map = None
        # mapping from node name to qconfig that should be used for that node
        # filled out for a model during _generate_qconfig_map
        self.qconfig_map = None
        # mapping from fully qualified module name to module instance
        # for example,
        # {
        #   '': Model(...),
        #   'linear': Linear(...),
        #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
        # }
        self.modules = None
        # mapping from a tuple of nodes in reverse order to uninitialized
        #   QuantizeHandler subclass. For example,
        # {
        #   # match a single node
        #   (<class 'torch.nn.modules.conv.Conv3d'>:
        #     <class 'torch.quantization.fx.quantize.ConvRelu'>),
        #   # match multiple nodes in reverse order
        #   ((<function relu at 0x7f766a7360d0>, <built-in function add>):
        #     <class 'torch.quantization.fx.quantize.Add'>),
        # }
        self.patterns = None

    def _qat_swap_modules(self, root):
        convert(root,
                mapping=DEFAULT_QAT_MODULE_MAPPING,
                inplace=True,
                remove_qconfig=False)

    def _generate_qconfig_map(self, root, input_graph):
        def get_qconfig(module):
            return module.qconfig if hasattr(module, 'qconfig') else None

        self.qconfig_map = dict()
        for node in input_graph.nodes:
            if node.op == 'get_param':
                parent, _ = _parent_name(node.target)
                self.qconfig_map[node.name] = get_qconfig(self.modules[parent])
            elif node.op == 'call_function':
                self.qconfig_map[node.name] = get_qconfig(root)
            elif node.op == 'call_method':
                self_obj = node.args[0]
                # qconfig for call_method should be the same as the `self` object for the call
                self.qconfig_map[node.name] = self.qconfig_map[self_obj.name]
            elif node.op == 'call_module':
                self.qconfig_map[node.name] = get_qconfig(
                    self.modules[node.target])

    def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant):
        assert not inplace, 'inplace prepare is not supported yet'
        input_root = model
        if not inplace:
            input_root = copy.deepcopy(input_root)

        input_graph = model.graph
        self.is_dynamic_quant = is_dynamic_quant
        # TODO: allow user specified patterns
        if self.is_dynamic_quant:
            self.patterns = get_dynamic_quant_patterns()
        else:
            self.patterns = get_quant_patterns()

        propagate_qconfig_(input_root, qconfig_dict)
        if input_root.training:
            self._qat_swap_modules(input_root)

        self.modules = dict(input_root.named_modules())

        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(input_root, input_graph)

        # match the patterns that will get quantized
        matches = self._find_matches(input_graph, self.modules, self.patterns)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(input_graph, matches)

        self.activation_post_process_map = dict()

        env = {}
        observed_graph = Graph()
        observed = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in input_graph.nodes:
            if node.name in observed:
                continue

            get_new_observer_name = get_new_attr_name_with_prefix(
                'activation_post_process_')
            root_node, _, obj, qconfig = matches.get(node.name,
                                                     (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)

                def insert_observer(node, observer, device):
                    observer_name = get_new_observer_name(input_root)
                    setattr(input_root, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, (load_arg(node), ), {})
                    observed.add(node.name)
                    if device:
                        getattr(input_root, observer_name).to(device)

                # don't need to insert observer for output in dynamic quantization
                if self.is_dynamic_quant:
                    continue

                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))

                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed.add(node.name)
                elif (isinstance(obj, Add)
                      or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed:
                        observed.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    new_observer = qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(input_root)
                    insert_observer(node, new_observer, device)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed and node.name in quants:
                observer_name = get_new_observer_name(input_root)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    new_observer = \
                        qconfig.weight() if is_weight else qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(input_root)
                    if device:
                        new_observer.to(device)
                    self.activation_post_process_map[node.name] = new_observer
                    setattr(input_root, observer_name,
                            self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, (load_arg(node), ), {})
                    observed.add(node.name)
        observed_graph.output(load_arg(input_graph.result))

        observed = GraphModule(input_root, observed_graph)
        self.save_state(observed)
        return observed

    def save_state(self, observed):
        observed._activation_post_process_map = self.activation_post_process_map
        observed._patterns = self.patterns
        observed._qconfig_map = self.qconfig_map

    def restore_state(self, observed):
        err_msg = 'please make sure the model is produced by prepare'
        assert hasattr(observed, '_activation_post_process_map'), 'did not found ' + \
            '_activation_post_process attribute ' + err_msg
        assert hasattr(observed, '_patterns'), 'did not found ' + \
            '_patterns attribute ' + err_msg
        assert hasattr(observed, '_qconfig_map'), 'did not found ' + \
            '_qconfig_map attribute ' + err_msg
        self.activation_post_process_map = observed._activation_post_process_map
        self.patterns = observed._patterns
        self.qconfig_map = observed._qconfig_map

    def prepare(self, model, qconfig_dict, inplace=False):
        return self._prepare(model,
                             qconfig_dict,
                             inplace,
                             is_dynamic_quant=False)

    def prepare_dynamic(self, model, qconfig_dict, inplace=False):
        return self._prepare(model,
                             qconfig_dict,
                             inplace,
                             is_dynamic_quant=True)

    def _run_weight_observers(self, observed):
        r''' Extract the subgraph that produces the weight for dynamically quantized
        node and run the subgraph to observe the weight.
        Note that the observers of dynamically quantized modules are run during
        the conversion step.
        '''
        for node in observed.graph.nodes:
            if node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
                for i, node_arg in enumerate(node.args):
                    if i in WEIGHT_INDEX_DICT[node.target]:
                        # node_arg is weight
                        weight_observer_nodes = collect_producer_nodes(
                            node_arg)
                        if weight_observer_nodes is not None:
                            weight_observer_module = graph_module_from_producer_nodes(
                                observed, weight_observer_nodes)
                            # run the weight observer
                            weight_observer_module()
        return

    def _convert(self,
                 observed,
                 inplace=False,
                 debug=False,
                 is_dynamic_quant=False):
        assert not inplace, 'inplace convert is not supported yet'
        self.restore_state(observed)
        self.is_dynamic_quant = is_dynamic_quant
        # run weight observers before inserting quant dequant nodes
        # for dynamic quantization
        if self.is_dynamic_quant:
            self._run_weight_observers(observed)

        # move to cpu since we only have quantized cpu kernels
        observed.eval().cpu()
        observed_root = observed
        observed_graph = observed.graph
        if not inplace:
            observed_root = copy.deepcopy(observed_root)

        self.modules = dict(observed_root.named_modules())

        matches = self._find_matches(observed.graph, self.modules,
                                     self.patterns)
        quants = self._find_quants(observed.graph, matches)
        self.quantized_graph = Graph()
        env = {}
        quant_env = {}

        def load_non_quantized(n):
            if n.name not in env:
                assert n.name in quant_env, \
                    'trying to load float node but did not find node:' + n.name + \
                    ' in quantized environment:' + str(quant_env)
                env[n.name] = Proxy(quant_env[n.name]).dequantize().node
            return env[n.name]

        def load_quantized(n):
            if n.name not in quant_env:
                assert n.name in env, \
                    'trying to load quantized node but did not find node:' + n.name + \
                    ' in float environment:' + str(env)
                assert n.name in quants, 'did not find quant object for node:' + n.name
                quant = quants[n.name][0]
                quant_env[n.name] = quant.convert(self, env[n.name])
            return quant_env[n.name]

        def load_x(n):
            assert n.name in env or n.name in quant_env, \
                'node ' + n.name + ' does not exist in either of the environment'
            if n.name in quant_env:
                return quant_env[n.name]
            else:
                return env[n.name]

        def load_arg(quantized):
            """
            Input: quantized, which can be None, list, boolean or tuple
              - if quantized is a list or tuple, then arg should be a list and the args with corresponding
                indexes will be quantized
              - if quantized is a boolean, then all args will be quantized/not quantized
              - if quantized is None, then we'll load the node as long as it exists

            Output: fn which takes arg_or_args, and loads them from the corresponding
              environment depending on the value of quantized.
            """
            assert quantized is None or isinstance(
                quantized, (tuple, list, bool)), type(quantized)

            def load_arg_impl(arg_or_args):
                if quantized is None:
                    return map_arg(arg_or_args, load_x)
                if isinstance(quantized, bool):
                    return map_arg(
                        arg_or_args,
                        load_quantized if quantized else load_non_quantized)
                elif isinstance(quantized, (tuple, list)):
                    assert isinstance(arg_or_args, (tuple, list)), arg_or_args
                    loaded_args = []
                    # for now, we only support quantizing positional arguments
                    for i, a in enumerate(arg_or_args):
                        if i in quantized:
                            loaded_args.append(map_arg(a, load_quantized))
                        else:
                            loaded_args.append(map_arg(a, load_non_quantized))
                    return type(arg_or_args)(loaded_args)

            return load_arg_impl

        def is_quantized(node):
            if isinstance(node, Node):
                assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
                # there might be nodes appearing in both environemnts, but quant_env will take
                # precedence
                if node.name in quant_env:
                    return True
                elif node.name in env:
                    return False
            elif isinstance(node, list):
                quantized = map(is_quantized, node)
                if all(quantized):
                    return True
                elif not any(quantized):
                    return False
                else:
                    raise Exception(
                        "partially quantized inputs in list not handled yet")

        for node in observed_graph.nodes:
            root_node, matched, obj, qconfig = matches.get(
                node.name, (None, None, None, None))
            if root_node is node:
                result = obj.convert(self, node, load_arg)
                quantized = True
                # Need to get correct quantized/non-quantized state for the output of CopyNode
                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'
                    quantized = is_quantized(node.args[0])

                # output of dynamic quantization is not quantized
                if self.is_dynamic_quant:
                    quantized = False

                if quantized:
                    quant_env[node.name] = result
                else:
                    env[node.name] = result
                continue
            elif root_node is not None:
                continue

            # handle activation post process calls
            if node.op == 'call_module':
                if node.target.split('.')[-1].startswith(
                        'activation_post_process_'):
                    observer_module = self.modules[node.target]
                    prev_node = node.args[0]
                    if prev_node.name in quant_env:
                        # if previous node is already quantized, we'll just remove the activation_post_process
                        quant_env[node.name] = quant_env[prev_node.name]
                        continue
                    # replace activation post process with quantization ops
                    parent_name = ''

                    scale, zero_point = observer_module.calculate_qparams()
                    dtype = observer_module.dtype

                    def is_per_channel(qscheme):
                        return qscheme == torch.per_channel_affine or \
                            qscheme == torch.per_channel_symmetric

                    if is_per_channel(observer_module.qscheme):
                        ch_axis = int(observer_module.ch_axis)
                        qparams = {
                            '_scale_': scale,
                            '_zero_point_': zero_point,
                            '_axis': ch_axis,
                            '_dtype_': dtype
                        }
                        quantize_op = torch.quantize_per_channel
                    else:
                        scale = float(scale)
                        zero_point = int(zero_point)
                        qparams = {
                            '_scale_': scale,
                            '_zero_point_': zero_point,
                            '_dtype_': dtype
                        }
                        quantize_op = torch.quantize_per_tensor
                    i = 0

                    def noattr(module, qparams, i):
                        for name in qparams.keys():
                            if hasattr(module, name + str(i)):
                                return False
                        return True

                    def get_next_i(module, qparams):
                        i = 0
                        while not noattr(module, qparams, i):
                            i += 1
                        return i

                    parent_module = self.modules[parent_name]
                    i = get_next_i(parent_module, qparams)
                    inputs = [load_non_quantized(node.args[0])]
                    for key, value in qparams.items():
                        setattr(parent_module, key + str(i), value)
                        qparam_full_path = key + str(i)
                        if parent_name:
                            qparam_full_path = parent_name + '.' + qparam_full_path
                        inputs.append(
                            self.quantized_graph.create_node(
                                'get_param', qparam_full_path))
                    quant_env[node.name] = self.quantized_graph.create_node(
                        'call_function', quantize_op, tuple(inputs), {})
                    continue
            # dequantize inputs for the node that are not quantized
            env[node.name] = self.quantized_graph.node_copy(
                node, load_non_quantized)

        self.quantized_graph.output(load_non_quantized(observed_graph.result))

        to_be_removed = []
        for name, _ in observed_root.named_modules():
            if name.split('.')[-1].startswith('activation_post_process_'):
                to_be_removed.append(name)
        for n in to_be_removed:
            delattr(observed_root, n)
        return GraphModule(observed_root, self.quantized_graph)

    # Trace back from the weight node util we hit getattr, reconstruct the graph module
    # with the traced nodes and run the graph module to pack the weight. then replace
    # the original chain of ops with the packed weight.
    def _fold_weight(self, quantized):
        packed_weights = dict()
        # map from folded node name to the prepacked weight name
        folded_nodes = dict()
        # get packed weights
        for node in quantized.graph.nodes:
            if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
                nodes_to_fold = collect_producer_nodes(node)
                if nodes_to_fold is not None:
                    for node_to_fold in nodes_to_fold:
                        folded_nodes[node_to_fold.name] = node

                    prepacking_module = graph_module_from_producer_nodes(
                        quantized, nodes_to_fold)
                    packed_weight = prepacking_module()
                    packed_weights[node.name] = packed_weight

        # remove folded nodes and replace the prepacking node with getattr
        folded_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        get_new_packed_weight_name = get_new_attr_name_with_prefix(
            '_fx_pass_packed_weight_')
        quantized_root = quantized
        quantized_graph = quantized.graph
        for node in quantized_graph.nodes:
            prepack_node = folded_nodes.get(node.name, None)
            if prepack_node is node:
                packed_weight = packed_weights[node.name]
                # add a prepacked attribute to root
                packed_weight_name = get_new_packed_weight_name(quantized_root)
                setattr(quantized_root, packed_weight_name, packed_weight)
                # replace prepack node with a getattr node
                env[node.name] = folded_graph.create_node(
                    'get_param', packed_weight_name, (), {})
            elif prepack_node is not None:
                # remove the foled node
                continue
            else:
                # copy other nodes
                env[node.name] = folded_graph.node_copy(node, load_arg)
        folded_graph.output(load_arg(quantized_graph.result))
        return GraphModule(quantized_root, folded_graph)

    def convert(self, observed, inplace=False, debug=False, is_dynamic=False):
        quantized = self._convert(observed, inplace, debug, is_dynamic)
        if not debug:
            quantized = self._fold_weight(quantized)
        return quantized

    def _find_matches(self, graph, modules, patterns):
        match_map = {}  # node name -> (root_node, match_value?)
        all_matched = set()

        def record_match(pattern, node, matched):
            if isinstance(pattern, tuple):
                s, *args = pattern
                record_match(s, node, matched)
                if pattern[0] is not getattr:
                    for subpattern, arg in zip(args, node.args):
                        record_match(subpattern, arg, matched)
            else:
                matched.append(node)

        for node in reversed(graph.nodes):
            if node.name not in match_map and node.name not in all_matched:
                for pattern, value in patterns.items():
                    if is_match(modules, node, pattern):
                        matched = []
                        record_match(pattern, node, matched)
                        for n in matched:
                            match_map[n.name] = (node, matched,
                                                 value(self, node),
                                                 self.qconfig_map[n.name])
                            all_matched.add(n.name)
                        # break after finding the first match
                        break
        return match_map

    def _find_quants(self, graph, matches):
        quants = {}

        def visit(node, qconfig):
            def visit_arg(arg):
                # note: we have to measure quantization information
                # even for nodes where we might not use it because it is already
                # quantized. This is because each match has the option to
                # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
                is_weight = False
                if isinstance(
                        node, Node
                ) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
                    for i, node_arg in enumerate(node.args):
                        if arg is node_arg and i in WEIGHT_INDEX_DICT[
                                node.target]:
                            is_weight = True
                if (not self.is_dynamic_quant) or is_weight:
                    # overwrite previous quant config
                    quants[arg.name] = (DefaultQuant(self,
                                                     arg), qconfig, is_weight)

            return visit_arg

        for node in graph.nodes:
            if node.name in matches:
                root_node, matched, obj, qconfig = matches[node.name]
                # don't attach observer/fake_quant for CopyNode
                if isinstance(obj, CopyNode):
                    qconfig = None
                if root_node is node:
                    # matched[-1] is the first op in the sequence and
                    # matched[0] is the last op in the sequence
                    # inputs
                    map_arg(matched[-1].args, visit(matched[-1], qconfig))
                    map_arg(matched[-1].kwargs, visit(matched[-1], qconfig))
                    # output
                    map_arg(matched[0], visit(None, qconfig))
        return quants
Example #20
0
def _insert_dtype_cast_after_node(
    node_a: Node,
    node_c: Node,
    prev_node_c: Union[Node, List[Node]],
    gm_a: GraphModule,
    gm_b: GraphModule,
    graph_c: Graph,
    node_name_prefix: str,
    logger_cls: Callable,
) -> Union[Node, List[Node]]:
    """
    Given a starting graph C (derived from graph B) of

    ... -> prev_node_c -> node_c -> ...

    And a corresponding related node_a, inserts the correct dtype
    cast node after prev_node_c to cast into the dtype expected
    by node_a, resulting in:

                          dtype_cast
                        /
    ... -> prev_node_c -> node_c -> ...

    For example, if node_c is an int8 op and node_a is an fp32 op, this function
    will insert a dequant.
    """
    dtype_cast_op = None
    dtype_cast_mod_cls = None
    node_input_type_a, _node_output_type_a = \
        get_node_first_input_and_output_type(node_a, gm_a, logger_cls)
    node_input_type_c, _node_output_type_c = \
        get_node_first_input_and_output_type(node_c, gm_b, logger_cls)

    if ((node_input_type_a == NodeInputOrOutputType.FP32
         and node_input_type_c == NodeInputOrOutputType.INT8)
            or (node_input_type_a == NodeInputOrOutputType.FP32
                and node_input_type_c == NodeInputOrOutputType.FP16)):
        dtype_cast_op = torch.dequantize
    elif (node_input_type_a == NodeInputOrOutputType.FP32
          and node_input_type_c == NodeInputOrOutputType.FP32):
        dtype_cast_mod_cls = torch.nn.Identity
    elif (node_input_type_a == NodeInputOrOutputType.INT8
          and node_input_type_c == NodeInputOrOutputType.INT8):
        dtype_cast_mod_cls = torch.nn.Identity
    elif (node_input_type_a == NodeInputOrOutputType.FP32_OR_INT8
          and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8):
        dtype_cast_mod_cls = torch.nn.Identity
    else:
        raise AssertionError(
            f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
            f"{node_input_type_a} {node_a.format_node()} needs to be implemented"
        )

    if isinstance(prev_node_c, Node):
        new_dtype_cast_name = \
            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
        if dtype_cast_op:
            return graph_c.create_node('call_function', dtype_cast_op,
                                       (prev_node_c, ), {},
                                       new_dtype_cast_name)
        else:
            assert dtype_cast_mod_cls
            dtype_cast_mod = dtype_cast_mod_cls()
            setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
            return graph_c.create_node('call_module', new_dtype_cast_name,
                                       (prev_node_c, ), {},
                                       new_dtype_cast_name)
    elif isinstance(prev_node_c, list):
        results = []
        for prev_node_c_inner in prev_node_c:
            new_dtype_cast_name = \
                get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
            if dtype_cast_op:
                new_dtype_cast_node = graph_c.create_node(
                    'call_function', dtype_cast_op, (prev_node_c_inner, ), {},
                    new_dtype_cast_name)
                results.append(new_dtype_cast_node)
            else:
                assert dtype_cast_mod_cls
                dtype_cast_mod = dtype_cast_mod_cls()
                setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
                new_dtype_cast_node = graph_c.create_node(
                    'call_module', new_dtype_cast_name, (prev_node_c, ), {},
                    new_dtype_cast_name)
                results.append(new_dtype_cast_node)
        return results
    else:
        raise AssertionError(f"type f{type(prev_node_c)} is not handled")
Example #21
0
    def _prepare(self, model, qconfig_dict, inplace, quant_type):
        input_root = model.root
        if not inplace:
            input_root = copy.deepcopy(input_root)

        input_graph = model.graph
        self.quant_type = quant_type
        # TODO: allow user specified patterns
        if self.quant_type == QuantType.DYNAMIC:
            self.patterns = get_dynamic_quant_patterns()
        else:
            self.patterns = get_quant_patterns()

        propagate_qconfig_(input_root, qconfig_dict)
        if input_root.training:
            self._qat_swap_modules(input_root)

        self.modules = dict(input_root.named_modules())

        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(input_root, input_graph)

        # match the patterns that will get quantized
        matches = self._find_matches(input_graph, self.modules, self.patterns)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(input_graph, matches)

        self.activation_post_process_map = dict()

        env = {}
        observed_graph = Graph()
        observed = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in input_graph.nodes:
            if node.name in observed:
                continue

            def get_new_observer_name(parent_module):
                i = 0

                def get_observer_name(i):
                    return 'activation_post_process_' + str(i)

                observer_name = get_observer_name(i)
                while hasattr(parent_module, observer_name):
                    i += 1
                    observer_name = get_observer_name(i)
                return observer_name

            root_node, _, obj, qconfig = matches.get(node.name,
                                                     (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)

                def insert_observer(node, observer):
                    observer_name = get_new_observer_name(input_root)
                    setattr(input_root, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, [load_arg(node)], {})
                    observed.add(node.name)

                # don't need to insert observer for output in dynamic quantization
                if self.quant_type == QuantType.DYNAMIC:
                    continue

                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))

                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed.add(node.name)
                elif (isinstance(obj, Add)
                      or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed:
                        observed.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    insert_observer(node, qconfig.activation())
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed and node.name in quants:
                observer_name = get_new_observer_name(input_root)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    self.activation_post_process_map[
                        node.name] = qconfig.weight(
                        ) if is_weight else qconfig.activation()
                    setattr(input_root, observer_name,
                            self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, [load_arg(node)], {})
                    observed.add(node.name)
        observed_graph.output(load_arg(input_graph.result))

        return GraphModule(input_root, observed_graph)
Example #22
0
class PrimContext(torch.overrides.TorchFunctionMode):
    """
    The prototype prim tracing context.

    Example usage:

    import torch._prims.utils as utils
    from torch._prims.context import PrimContext
    from torch._prims.executor import execute
    from torch.overrides import push_torch_function_mode

    a = torch.randn((2, 2))
    b = torch.randn((2, 2))

    with push_torch_function_mode(PrimContext):
      meta_a = ctx.placeholder(utils.TensorMeta(a))
      meta_b = ctx.placeholder(utils.TensorMeta(b))
      result = torch.add(meta_a, meta_b)
      ctx.output(result)

    exc_result = execute(ctx, a, b)

    Currently this only acquires a trace of prims, and
    it does not account for control flow. As such,
    execute must be called with tensors that have the
    same metadata (dtype, device, shape...) as
    the tensors used to trace the operations.

    The tracing context's FX graph can be acquired
    using its graph attribute.
    """
    def __init__(self):
        self.graph = Graph()

        # Private attributes for generating names
        self._tensor_name_counter = 0
        self._dim_name_counter = 0
        self._shape_name_counter = 0
        self._lowercase = tuple(string.ascii_lowercase)
        self._uppercase = tuple(string.ascii_uppercase)

    @staticmethod
    def _create_name(idx, chars):
        name = ""
        while idx >= len(chars):
            name = chars[idx % len(chars)] + name
            idx = idx - len(chars)
        name = chars[idx] + name

        return name

    def _tensor_name(self):
        idx = self._tensor_name_counter
        self._tensor_name_counter = self._tensor_name_counter + 1

        return self._create_name(idx, self._lowercase)

    def _add_user(self, tm: TensorMeta, node: Node) -> None:
        assert tm.node is not None
        tm.node.users[node] = None

    def placeholder(self, a: Any):
        name = self._tensor_name()
        node = self.graph.placeholder(name)

        if isinstance(a, TensorMeta):
            if a.node is not None:
                raise ValueError(
                    "Attempting to reuse a TensorMeta in a new trace!")
            a.tname = name
            a.node = node

        return a

    def output(self, tm: TensorMeta):
        # TODO: allow other output types
        assert isinstance(tm, TensorMeta)

        node = self.graph.output(tm)
        self._add_user(tm, node)

    def __torch_function__(
            self,
            func: Callable,
            types: Sequence,
            args: Sequence[Any] = (),
            kwargs: Dict = None,
    ):
        """
        Determines which function to call. The order of which
        function is called is determined by:

        - func's "meta" attribute, if it exists
        - if func is a torch operation, its corresponding reference
        - func
        """

        if kwargs is None:
            kwargs = {}

        if hasattr(func, "meta"):
            # TODO: add check that all args/kwargs are 'registered' properly
            # to this trace

            output = func.meta(*args, **kwargs)  # type: ignore[attr-defined]

            # Updates graph
            # TODO: handle outputs with multiple tensors
            # TODO: handle non-tensor outputs
            assert isinstance(output, TensorMeta)
            output_name = self._tensor_name()
            node = self.graph.create_node("call_function",
                                          func,
                                          name=output_name,
                                          args=args,
                                          kwargs=kwargs)
            output.tname = output_name
            output.node = node

            # Marks uses
            for x in (x for x in chain(args, kwargs.values())
                      if isinstance(x, TensorMeta)):
                self._add_user(x, node)

            return output

        # Remaps torch operations to their references
        if func in _torch_to_reference_map:
            fn = _torch_to_reference_map[func]
            with torch.overrides.enable_torch_function_mode(
                    self, replace=self.inner):
                return fn(*args, **kwargs)  # type: ignore[operator]

        return func(*args, **kwargs)