Esempio n. 1
0
def flexml_act_handler(model):
    rewriters = []
    for node in model.graph.nodes:
        if node.op == 'call_module':
            module = get_module(model, node.target)
            if isinstance(module, nn.ReLU):
                rewriter = ModuleToModuleByInstance(
                    module,
                    qnn.QuantReLU,
                    act_quant=Uint8ActPerTensorFixedPoint,
                    return_quant_tensor=True)
                rewriters.append(rewriter)
            elif isinstance(module, nn.ReLU6):
                rewriter = ModuleToModuleByInstance(
                    module,
                    qnn.QuantReLU,
                    act_quant=Uint8ActPerTensorFixedPointMaxInit,
                    max_val=6.,
                    return_quant_tensor=True)
                rewriters.append(rewriter)
            elif isinstance(module, nn.LeakyReLU):
                rewriter = ModuleToModuleByInstance(module,
                                                    FlexMLQuantLeakyReLU)
                rewriters.append(rewriter)
    for rewriter in rewriters:
        model = rewriter.apply(model)
    return model
Esempio n. 2
0
def are_inputs_unsigned(model, node, is_unsigned_list):
    for inp_node in node.all_input_nodes:
        if inp_node.op == 'call_module':
            inp_module = get_module(model, inp_node.target)
            if isinstance(inp_module, (nn.ReLU, nn.ReLU6)):
                is_unsigned_list.append(True)
            elif isinstance(inp_module, tuple(SIGN_PRESERVING_MODULES)):
                are_inputs_unsigned(model, inp_node, is_unsigned_list)
            elif isinstance(
                    inp_module,
                (qnn.QuantReLU, qnn.QuantIdentity, qnn.QuantHardTanh)):
                is_unsigned_list.append(not inp_module.is_quant_act_signed)
            else:
                is_unsigned_list.append(False)
        elif inp_node.op == 'call_function':
            if inp_node.target in [
                    torch.reshape, torch.flatten, torch.transpose, torch.cat
            ] + ADD_FNS:
                are_inputs_unsigned(model, inp_node, is_unsigned_list)
            else:
                is_unsigned_list.append(False)
        elif inp_node.op == 'call_method':
            if inp_node.target in [
                    'view', 'reshape', 'flatten', 't', 'permute'
            ] + ADD_METHODS:
                are_inputs_unsigned(model, inp_node, is_unsigned_list)
            else:
                is_unsigned_list.append(False)
    return all(is_unsigned_list)
Esempio n. 3
0
def add_input_handler(model, node, quant_identity_name, quant_identity,
                      rewriters):
    for inp_node in node.all_input_nodes:
        if inp_node.op == 'call_module':
            module = get_module(model, inp_node.target)
            if isinstance(module, tuple(SIGN_PRESERVING_MODULES)):
                add_input_handler(model, inp_node, quant_identity_name,
                                  quant_identity, rewriters)
            elif isinstance(module, qnn.QuantReLU):
                rewriter = ModuleToModuleByInstance(
                    module,
                    qnn.QuantReLU,
                    act_quant=Uint8ActPerTensorFixedPoint,
                    scaling_impl=quant_identity.act_quant.
                    fused_activation_quant_proxy.tensor_quant.scaling_impl,
                    int_scaling_impl=quant_identity.act_quant.
                    fused_activation_quant_proxy.tensor_quant.int_scaling_impl,
                    return_quant_tensor=True)
                rewriters.append(rewriter)
            elif isinstance(module, qnn.QuantIdentity):
                if module.is_quant_act_signed == quant_identity.is_quant_act_signed:
                    rewriters.append(
                        ModuleInstanceToModuleInstance(module, quant_identity))
                else:
                    assert not module.is_quant_act_signed and quant_identity.is_quant_act_signed
                    rewriter = ModuleToModuleByInstance(
                        module,
                        qnn.QuantIdentity,
                        act_quant=Uint8ActPerTensorFixedPoint,
                        scaling_impl=quant_identity.act_quant.
                        fused_activation_quant_proxy.tensor_quant.scaling_impl,
                        int_scaling_impl=quant_identity.act_quant.
                        fused_activation_quant_proxy.tensor_quant.
                        int_scaling_impl,
                        return_quant_tensor=True)
                    rewriters.append(rewriter)
            elif isinstance(module, FlexMLQuantLeakyReLU):
                rewriter = ModuleToModuleByInstance(
                    module, FlexMLQuantLeakyReLU, output_quant=quant_identity)
                rewriters.append(rewriter)
            else:
                rewriters.append(
                    InsertModuleCallAfter(quant_identity_name, inp_node))
        elif inp_node.op == 'call_function' and inp_node.target in [
                torch.flatten, torch.reshape, torch.transpose
        ]:
            add_input_handler(model, inp_node, quant_identity_name,
                              quant_identity, rewriters)
        elif inp_node.op == 'call_function' and inp_node.target is torch.cat:
            cat_input_handler(model, inp_node, quant_identity_name,
                              quant_identity, rewriters)
        elif inp_node.op == 'call_method' and inp_node.target in [
                'view', 'reshape', 'flatten', 'transpose'
        ]:
            add_input_handler(model, inp_node, quant_identity_name,
                              quant_identity, rewriters)
        else:
            rewriters.append(
                InsertModuleCallAfter(quant_identity_name, inp_node))
Esempio n. 4
0
 def extract_groups(self, graph_model: GraphModule):
     groups = []
     for node in graph_model.graph.nodes:
         if (_is_supported_module(graph_model, node)
                 and node.next.op == 'call_module'
                 and isinstance(get_module(graph_model, node.next.target),
                                _batch_norm)):
             node_next = node.next.next
             while _is_scale_invariant_module(
                     graph_model, node_next) or _is_reshaping_op(node_next):
                 node_next = node_next.next
             if _is_supported_module(graph_model, node_next):
                 group = (get_module(graph_model, node.target),
                          get_module(graph_model, node.next.target),
                          (node_next.target,
                           get_module(graph_model, node_next.target)))
                 groups.append(group)
     return groups
Esempio n. 5
0
def are_inputs_quantized(model, node, quantized_modules_list, same_sign):
    for inp_node in node.all_input_nodes:
        if inp_node.op == 'call_module':
            inp_module = get_module(model, inp_node.target)
            if isinstance(inp_module, (nn.ReLU, nn.ReLU6)):
                quantized_modules_list.append(None)
            elif isinstance(inp_module, tuple(SIGN_PRESERVING_MODULES)):
                are_inputs_quantized(model, inp_node, quantized_modules_list,
                                     same_sign)
            elif isinstance(
                    inp_module,
                (qnn.QuantReLU, qnn.QuantIdentity, qnn.QuantHardTanh)):
                tq = inp_module.act_quant.fused_activation_quant_proxy.tensor_quant
                if _tensor_quant_in_list(tq, quantized_modules_list,
                                         same_sign):
                    continue
                quantized_modules_list.append(tq)
            elif isinstance(inp_module, (FlexMLQuantLeakyReLU)):
                tq = inp_module.output_quant.act_quant.fused_activation_quant_proxy.tensor_quant
                if _tensor_quant_in_list(tq, quantized_modules_list,
                                         same_sign):
                    continue
                quantized_modules_list.append(tq)
            else:
                quantized_modules_list.append(None)
        elif inp_node.op == 'call_function':
            if inp_node.target in [
                    torch.reshape, torch.flatten, torch.transpose
            ]:
                are_inputs_quantized(model, inp_node, quantized_modules_list,
                                     same_sign)
            elif inp_node.target is torch.cat:
                are_inputs_quantized(model, inp_node, quantized_modules_list,
                                     True)
            elif inp_node.target in ADD_FNS:
                are_inputs_quantized(model, inp_node, quantized_modules_list,
                                     False)
            else:
                quantized_modules_list.append(None)
        elif inp_node.op == 'call_method':
            if inp_node.target in [
                    'view', 'reshape', 'flatten', 't', 'permute'
            ]:
                are_inputs_quantized(model, inp_node, quantized_modules_list,
                                     same_sign)
            elif inp_node.target in ADD_METHODS:
                are_inputs_quantized(model, inp_node, quantized_modules_list,
                                     False)
            else:
                quantized_modules_list.append(None)
    if None in quantized_modules_list:
        return False
    elif len(quantized_modules_list) > 1:
        return False
    else:
        return True
Esempio n. 6
0
def cat_input_handler(model, node, quant_identity_name, quant_identity,
                      rewriters):
    for inp_node in node.all_input_nodes:
        if inp_node.op == 'call_module':
            module = get_module(model, inp_node.target)
            if isinstance(module, tuple(SIGN_PRESERVING_MODULES)):
                cat_input_handler(model, inp_node, quant_identity_name,
                                  quant_identity, rewriters)
            elif isinstance(module, qnn.QuantReLU):
                rewriter = ModuleToModuleByInstance(
                    module,
                    qnn.QuantReLU,
                    # WORKAROUND
                    # TODO act_quant=quant_identity.act_quant is currently broken
                    # because it overrides act_impl even though it shouldn't
                    signed=quant_identity.act_quant.is_signed,
                    narrow_range=quant_identity.act_quant.is_narrow_range,
                    tensor_quant=quant_identity.act_quant.
                    fused_activation_quant_proxy.tensor_quant)
                rewriters.append(rewriter)
            elif isinstance(module, qnn.QuantIdentity):
                rewriter = ModuleInstanceToModuleInstance(
                    module, quant_identity)
                rewriters.append(rewriter)
            elif isinstance(module, FlexMLQuantLeakyReLU):
                rewriter = ModuleToModuleByInstance(
                    module, FlexMLQuantLeakyReLU, output_quant=quant_identity)
                rewriters.append(rewriter)
            else:
                rewriters.append(
                    InsertModuleCallAfter(quant_identity_name, inp_node))
        elif inp_node.op == 'call_function' and inp_node.target in [
                torch.flatten, torch.reshape, torch.transpose
        ]:
            cat_input_handler(model, inp_node, quant_identity_name,
                              quant_identity, rewriters)
        elif inp_node.op == 'call_function' and inp_node.target is torch.cat:
            cat_input_handler(model, inp_node, quant_identity_name,
                              quant_identity, rewriters)
        elif inp_node.op == 'call_method' and inp_node.target in [
                'view', 'reshape', 'flatten', 'transpose'
        ]:
            cat_input_handler(model, inp_node, quant_identity_name,
                              quant_identity, rewriters)
        else:
            rewriters.append(
                InsertModuleCallAfter(quant_identity_name, inp_node))
Esempio n. 7
0
def flexml_wbiol_handler(model):
    rewriters = []
    for node in model.graph.nodes:
        if node.op == 'call_module':
            module = get_module(model, node.target)
            if isinstance(module, tuple(QUANT_WBIOL_MAP.keys())):
                output_quant_handler(model,
                                     node,
                                     rewriters,
                                     is_sign_preserving=False)
                rewriter = ModuleToModuleByInstance(
                    module,
                    QUANT_WBIOL_MAP[type(module)],
                    weight_quant=Int8WeightPerTensorFixedPoint,
                    bias_quant=Int16Bias,
                    return_quant_tensor=True)
                rewriters.append(rewriter)
    for rewriter in rewriters:
        model = rewriter.apply(model)
    return model
Esempio n. 8
0
def _extract_regions(graph_model: GraphModule):
    regions = set()
    for node in graph_model.graph.nodes:
        if node.op == 'call_module':
            module = get_module(graph_model, node.target)
            if isinstance(module, _supported_layers):
                srcs, sinks = {node.target}, set()
                walk_region(graph_model,
                            node,
                            set(),
                            srcs,
                            sinks,
                            walk_forward=True)
                if sinks:
                    # each region should appear only once, so to make it hashable
                    # we convert srcs and sinks to ordered lists first, and then to tuples
                    regions.add((tuple(sorted(srcs)), tuple(sorted(sinks))))
    # for clarity, sort by the of the first source
    regions = sorted(regions, key=lambda region: region[0][0])
    return regions
Esempio n. 9
0
def output_quant_handler(model, node, rewriters, is_sign_preserving):
    quant_module = None
    quant_module_name = None
    for user in node.users:
        output_quant = True
        if user.op == 'call_module':
            user_module = get_module(model, user.target)
            if isinstance(
                    user_module,
                (qnn.QuantReLU, qnn.QuantIdentity, FlexMLQuantLeakyReLU)):
                output_quant = False
        if output_quant:
            if quant_module_name is None and quant_module is None:
                if is_sign_preserving and are_inputs_unsigned(model, node, []):
                    quant_module = qnn.QuantIdentity(
                        act_quant=Uint8ActPerTensorFixedPoint,
                        return_quant_tensor=True)
                else:
                    quant_module = qnn.QuantIdentity(
                        act_quant=Int8ActPerTensorFixedPoint,
                        return_quant_tensor=True)
                quant_module_name = node.name + '_output_quant'
                model.add_module(quant_module_name, quant_module)
            rewriters.append(InsertModuleCallAfter(quant_module_name, node))
Esempio n. 10
0
def _is_scale_invariant_module(graph_model, node):
    return node.op == 'call_module' and isinstance(
        get_module(graph_model, node.target), _scale_invariant_layers)
Esempio n. 11
0
def _is_supported_module(graph_model, node):
    return node.op == 'call_module' and isinstance(
        get_module(graph_model, node.target), _supported_layers)