def flexml_act_handler(model): rewriters = [] for node in model.graph.nodes: if node.op == 'call_module': module = get_module(model, node.target) if isinstance(module, nn.ReLU): rewriter = ModuleToModuleByInstance( module, qnn.QuantReLU, act_quant=Uint8ActPerTensorFixedPoint, return_quant_tensor=True) rewriters.append(rewriter) elif isinstance(module, nn.ReLU6): rewriter = ModuleToModuleByInstance( module, qnn.QuantReLU, act_quant=Uint8ActPerTensorFixedPointMaxInit, max_val=6., return_quant_tensor=True) rewriters.append(rewriter) elif isinstance(module, nn.LeakyReLU): rewriter = ModuleToModuleByInstance(module, FlexMLQuantLeakyReLU) rewriters.append(rewriter) for rewriter in rewriters: model = rewriter.apply(model) return model
def are_inputs_unsigned(model, node, is_unsigned_list): for inp_node in node.all_input_nodes: if inp_node.op == 'call_module': inp_module = get_module(model, inp_node.target) if isinstance(inp_module, (nn.ReLU, nn.ReLU6)): is_unsigned_list.append(True) elif isinstance(inp_module, tuple(SIGN_PRESERVING_MODULES)): are_inputs_unsigned(model, inp_node, is_unsigned_list) elif isinstance( inp_module, (qnn.QuantReLU, qnn.QuantIdentity, qnn.QuantHardTanh)): is_unsigned_list.append(not inp_module.is_quant_act_signed) else: is_unsigned_list.append(False) elif inp_node.op == 'call_function': if inp_node.target in [ torch.reshape, torch.flatten, torch.transpose, torch.cat ] + ADD_FNS: are_inputs_unsigned(model, inp_node, is_unsigned_list) else: is_unsigned_list.append(False) elif inp_node.op == 'call_method': if inp_node.target in [ 'view', 'reshape', 'flatten', 't', 'permute' ] + ADD_METHODS: are_inputs_unsigned(model, inp_node, is_unsigned_list) else: is_unsigned_list.append(False) return all(is_unsigned_list)
def add_input_handler(model, node, quant_identity_name, quant_identity, rewriters): for inp_node in node.all_input_nodes: if inp_node.op == 'call_module': module = get_module(model, inp_node.target) if isinstance(module, tuple(SIGN_PRESERVING_MODULES)): add_input_handler(model, inp_node, quant_identity_name, quant_identity, rewriters) elif isinstance(module, qnn.QuantReLU): rewriter = ModuleToModuleByInstance( module, qnn.QuantReLU, act_quant=Uint8ActPerTensorFixedPoint, scaling_impl=quant_identity.act_quant. fused_activation_quant_proxy.tensor_quant.scaling_impl, int_scaling_impl=quant_identity.act_quant. fused_activation_quant_proxy.tensor_quant.int_scaling_impl, return_quant_tensor=True) rewriters.append(rewriter) elif isinstance(module, qnn.QuantIdentity): if module.is_quant_act_signed == quant_identity.is_quant_act_signed: rewriters.append( ModuleInstanceToModuleInstance(module, quant_identity)) else: assert not module.is_quant_act_signed and quant_identity.is_quant_act_signed rewriter = ModuleToModuleByInstance( module, qnn.QuantIdentity, act_quant=Uint8ActPerTensorFixedPoint, scaling_impl=quant_identity.act_quant. fused_activation_quant_proxy.tensor_quant.scaling_impl, int_scaling_impl=quant_identity.act_quant. fused_activation_quant_proxy.tensor_quant. int_scaling_impl, return_quant_tensor=True) rewriters.append(rewriter) elif isinstance(module, FlexMLQuantLeakyReLU): rewriter = ModuleToModuleByInstance( module, FlexMLQuantLeakyReLU, output_quant=quant_identity) rewriters.append(rewriter) else: rewriters.append( InsertModuleCallAfter(quant_identity_name, inp_node)) elif inp_node.op == 'call_function' and inp_node.target in [ torch.flatten, torch.reshape, torch.transpose ]: add_input_handler(model, inp_node, quant_identity_name, quant_identity, rewriters) elif inp_node.op == 'call_function' and inp_node.target is torch.cat: cat_input_handler(model, inp_node, quant_identity_name, quant_identity, rewriters) elif inp_node.op == 'call_method' and inp_node.target in [ 'view', 'reshape', 'flatten', 'transpose' ]: add_input_handler(model, inp_node, quant_identity_name, quant_identity, rewriters) else: rewriters.append( InsertModuleCallAfter(quant_identity_name, inp_node))
def extract_groups(self, graph_model: GraphModule): groups = [] for node in graph_model.graph.nodes: if (_is_supported_module(graph_model, node) and node.next.op == 'call_module' and isinstance(get_module(graph_model, node.next.target), _batch_norm)): node_next = node.next.next while _is_scale_invariant_module( graph_model, node_next) or _is_reshaping_op(node_next): node_next = node_next.next if _is_supported_module(graph_model, node_next): group = (get_module(graph_model, node.target), get_module(graph_model, node.next.target), (node_next.target, get_module(graph_model, node_next.target))) groups.append(group) return groups
def are_inputs_quantized(model, node, quantized_modules_list, same_sign): for inp_node in node.all_input_nodes: if inp_node.op == 'call_module': inp_module = get_module(model, inp_node.target) if isinstance(inp_module, (nn.ReLU, nn.ReLU6)): quantized_modules_list.append(None) elif isinstance(inp_module, tuple(SIGN_PRESERVING_MODULES)): are_inputs_quantized(model, inp_node, quantized_modules_list, same_sign) elif isinstance( inp_module, (qnn.QuantReLU, qnn.QuantIdentity, qnn.QuantHardTanh)): tq = inp_module.act_quant.fused_activation_quant_proxy.tensor_quant if _tensor_quant_in_list(tq, quantized_modules_list, same_sign): continue quantized_modules_list.append(tq) elif isinstance(inp_module, (FlexMLQuantLeakyReLU)): tq = inp_module.output_quant.act_quant.fused_activation_quant_proxy.tensor_quant if _tensor_quant_in_list(tq, quantized_modules_list, same_sign): continue quantized_modules_list.append(tq) else: quantized_modules_list.append(None) elif inp_node.op == 'call_function': if inp_node.target in [ torch.reshape, torch.flatten, torch.transpose ]: are_inputs_quantized(model, inp_node, quantized_modules_list, same_sign) elif inp_node.target is torch.cat: are_inputs_quantized(model, inp_node, quantized_modules_list, True) elif inp_node.target in ADD_FNS: are_inputs_quantized(model, inp_node, quantized_modules_list, False) else: quantized_modules_list.append(None) elif inp_node.op == 'call_method': if inp_node.target in [ 'view', 'reshape', 'flatten', 't', 'permute' ]: are_inputs_quantized(model, inp_node, quantized_modules_list, same_sign) elif inp_node.target in ADD_METHODS: are_inputs_quantized(model, inp_node, quantized_modules_list, False) else: quantized_modules_list.append(None) if None in quantized_modules_list: return False elif len(quantized_modules_list) > 1: return False else: return True
def cat_input_handler(model, node, quant_identity_name, quant_identity, rewriters): for inp_node in node.all_input_nodes: if inp_node.op == 'call_module': module = get_module(model, inp_node.target) if isinstance(module, tuple(SIGN_PRESERVING_MODULES)): cat_input_handler(model, inp_node, quant_identity_name, quant_identity, rewriters) elif isinstance(module, qnn.QuantReLU): rewriter = ModuleToModuleByInstance( module, qnn.QuantReLU, # WORKAROUND # TODO act_quant=quant_identity.act_quant is currently broken # because it overrides act_impl even though it shouldn't signed=quant_identity.act_quant.is_signed, narrow_range=quant_identity.act_quant.is_narrow_range, tensor_quant=quant_identity.act_quant. fused_activation_quant_proxy.tensor_quant) rewriters.append(rewriter) elif isinstance(module, qnn.QuantIdentity): rewriter = ModuleInstanceToModuleInstance( module, quant_identity) rewriters.append(rewriter) elif isinstance(module, FlexMLQuantLeakyReLU): rewriter = ModuleToModuleByInstance( module, FlexMLQuantLeakyReLU, output_quant=quant_identity) rewriters.append(rewriter) else: rewriters.append( InsertModuleCallAfter(quant_identity_name, inp_node)) elif inp_node.op == 'call_function' and inp_node.target in [ torch.flatten, torch.reshape, torch.transpose ]: cat_input_handler(model, inp_node, quant_identity_name, quant_identity, rewriters) elif inp_node.op == 'call_function' and inp_node.target is torch.cat: cat_input_handler(model, inp_node, quant_identity_name, quant_identity, rewriters) elif inp_node.op == 'call_method' and inp_node.target in [ 'view', 'reshape', 'flatten', 'transpose' ]: cat_input_handler(model, inp_node, quant_identity_name, quant_identity, rewriters) else: rewriters.append( InsertModuleCallAfter(quant_identity_name, inp_node))
def flexml_wbiol_handler(model): rewriters = [] for node in model.graph.nodes: if node.op == 'call_module': module = get_module(model, node.target) if isinstance(module, tuple(QUANT_WBIOL_MAP.keys())): output_quant_handler(model, node, rewriters, is_sign_preserving=False) rewriter = ModuleToModuleByInstance( module, QUANT_WBIOL_MAP[type(module)], weight_quant=Int8WeightPerTensorFixedPoint, bias_quant=Int16Bias, return_quant_tensor=True) rewriters.append(rewriter) for rewriter in rewriters: model = rewriter.apply(model) return model
def _extract_regions(graph_model: GraphModule): regions = set() for node in graph_model.graph.nodes: if node.op == 'call_module': module = get_module(graph_model, node.target) if isinstance(module, _supported_layers): srcs, sinks = {node.target}, set() walk_region(graph_model, node, set(), srcs, sinks, walk_forward=True) if sinks: # each region should appear only once, so to make it hashable # we convert srcs and sinks to ordered lists first, and then to tuples regions.add((tuple(sorted(srcs)), tuple(sorted(sinks)))) # for clarity, sort by the of the first source regions = sorted(regions, key=lambda region: region[0][0]) return regions
def output_quant_handler(model, node, rewriters, is_sign_preserving): quant_module = None quant_module_name = None for user in node.users: output_quant = True if user.op == 'call_module': user_module = get_module(model, user.target) if isinstance( user_module, (qnn.QuantReLU, qnn.QuantIdentity, FlexMLQuantLeakyReLU)): output_quant = False if output_quant: if quant_module_name is None and quant_module is None: if is_sign_preserving and are_inputs_unsigned(model, node, []): quant_module = qnn.QuantIdentity( act_quant=Uint8ActPerTensorFixedPoint, return_quant_tensor=True) else: quant_module = qnn.QuantIdentity( act_quant=Int8ActPerTensorFixedPoint, return_quant_tensor=True) quant_module_name = node.name + '_output_quant' model.add_module(quant_module_name, quant_module) rewriters.append(InsertModuleCallAfter(quant_module_name, node))
def _is_scale_invariant_module(graph_model, node): return node.op == 'call_module' and isinstance( get_module(graph_model, node.target), _scale_invariant_layers)
def _is_supported_module(graph_model, node): return node.op == 'call_module' and isinstance( get_module(graph_model, node.target), _supported_layers)