def get_deploy_graph_list(quant_model, nndct_graph): g_optmizer = DevGraphOptimizer(nndct_graph) # sync model data with dev graph connect_module_with_graph(quant_model, g_optmizer.dev_graph, recover_param=False) update_nndct_blob_data(quant_model, g_optmizer.dev_graph) g_optmizer.strip_redundant_ops() g_optmizer.update_op_attrs() g_optmizer.constant_folding() g_optmizer.layout_tranform() connect_module_with_graph(quant_model, g_optmizer.dev_graph, recover_param=False) update_nndct_blob_data(quant_model, g_optmizer.dev_graph) connect_module_with_graph(quant_model, nndct_graph, recover_param=False) # for node in g_optmizer._dev_graph.nodes: # print(f"{node.name}, {node.op.type}, {node.out_tensors[0].layout}") if NndctOption.nndct_parse_debug.value >= 3: NndctDebugLogger.write(f"\nfrozen dev graph:\n{g_optmizer.dev_graph}") deploy_graphs = g_optmizer.partition_by_quant_part() return deploy_graphs
def __call__(self, graph_name, module, input_args): torch_graph_handler = TorchGraphHandler() raw_graph = torch_graph_handler.build_torch_graph( graph_name, module, input_args) self._nndct_graph = Graph(graph_name=raw_graph.name) node_convertor = NodeConvertor() op_creator = OpCreator() for raw_node in raw_graph.nodes: nndct_node = node_convertor(self, raw_graph, raw_node) if nndct_node: self._nndct_graph.add_node(nndct_node) nndct_node.op = op_creator(self, raw_graph, raw_node) for ret_value_name in raw_graph.ret_values().keys(): end_tensor = self._nndct_graph.tensor( get_full_name(self._nndct_graph.name, ret_value_name)) self._nndct_graph.add_end_tensor(end_tensor) self._convert_blob_tensor_type() self._nndct_graph.connect_nodes() self._load_data(module) if NndctOption.nndct_parse_debug.value >= 2: NndctDebugLogger.write(f"nndct raw graph:\n{self._nndct_graph}") # print(f"nndct graph:{self._nndct_graph}") return self._nndct_graph
def get_frozen_graph(self): self._infer_tensor_layout() self._strip_redundant_ops() self._constant_folding() if NndctOption.nndct_parse_debug.value >= 3: NndctDebugLogger.write(f"\nfrozen dev graph:\n{self._dev_graph}") return self._dev_graph
def build_torch_graph(self, graph_name, module, input_args, train=False): self._module = module NndctScreenLogger().info("Start to trace model...") fw_graph, params = self._trace_graph_from_model(input_args, train) NndctScreenLogger().info("Finish tracing.") self._node_kinds = { node.kind().split(":")[-1] for node in fw_graph.nodes() } if NndctOption.nndct_parse_debug.value >= 1: NndctDebugLogger.write(f"jit graph:\n{fw_graph}") NndctDebugLogger.write( f"\nparsing nodes types:\n{self._node_kinds}\n") raw_graph, raw_params = self._build_raw_graph(graph_name, fw_graph, params) if NndctOption.nndct_parse_debug.value >= 2: NndctDebugLogger.write(f"\ntorch raw graph:\n{raw_graph}") opt_graph = self._opt_raw_graph(raw_graph) if NndctOption.nndct_parse_debug.value >= 2: NndctDebugLogger.write(f"\ntorch opt graph:\n{raw_graph}") if NndctOption.nndct_parse_debug.value >= 3: self._check_stub_topology(opt_graph) return opt_graph, raw_params
def __call__(self, graph_name, module, input_args): graph_handler = TorchGraphHandler() # graph_handler = create_graph_handler(module) raw_graph = graph_handler.build_torch_graph(graph_name, module, input_args) GLOBAL_MAP.set_map(NNDCT_KEYS.DEVICE, self._get_device_info(module, input_args)) NndctScreenLogger().info("Processing ops...") nndct_graph = self._convert_graph(raw_graph) unknown_op_type_check(nndct_graph) self._convert_blob_tensor_type(nndct_graph) self._load_data(nndct_graph, module) if NndctOption.nndct_parse_debug.value >= 2: NndctDebugLogger.write(f"nndct raw graph:\n{nndct_graph}") return nndct_graph
def _init_torch_nndct_module(): _pop_out_argv = [] argv = _copy.copy(_sys.argv) for cmd_pos, option in enumerate(_sys.argv[1:], 1): _pop_out_argv.extend( option_util.add_valid_nndct_option(argv, option, cmd_pos, 'torch')) for item in _pop_out_argv: _sys.argv.remove(item) if NndctOption.nndct_option_list.value: print("Usage: python file [option]") print("Nndct options:") for option in option_util.get_all_options(): print(option) _sys.exit() if NndctOption.nndct_help.value: # TODO: register api info pass # print("Nndct API Description:") # if "__all__" in api.__dict__: # for key in api.__dict__["__all__"]: # item = api.__dict__[key] # if inspect.isclass(item): # print(f"\nclass {key}:\n{item.__doc__}") # for method_name, method in item.__dict__.items(): # if (not (method_name.startswith("_") or # method_name.startswith("__")) and # inspect.isfunction(method) and method.__doc__ is not None): # print( # f"\n def {method_name}{inspect.signature(method)}:\n {method.__doc__}" # ) # elif inspect.isfunction(item): # print(f"\ndef {key}{inspect.signature(item)}:\n{item.__doc__}") _sys.exit() if NndctOption.nndct_parse_debug.value: option_util.set_option_value("nndct_logging_level", 1) if NndctOption.nndct_logging_level.value > 0: NndctDebugLogger("nndct_debug.log") if NndctOption.nndct_quant_off.value: option_util.set_option_value("nndct_quant_opt", 0) option_util.set_option_value("nndct_param_corr", False) option_util.set_option_value("nndct_equalization", False) # option_util.set_option_value("nndct_wes", False) option_util.set_option_value("nndct_wes_in_cle", False) #if NndctOption.nndct_quant_opt.value > 2: if (not hasattr(NndctOption.nndct_param_corr, '_value')) and (NndctOption.nndct_quant_opt.value <= 0): option_util.set_option_value("nndct_param_corr", False) if (not hasattr(NndctOption.nndct_equalization, '_value')) and (NndctOption.nndct_quant_opt.value <= 0): option_util.set_option_value("nndct_equalization", False)
def get_node_schema(node): schema_op = node_type(node) schemas = torch._C._jit_get_schemas_for_operator(schema_op) for schema in schemas: if is_schema_matching(node, schema): if NndctOption.nndct_parse_debug.value >= 1: NndctDebugLogger.write( f"%{get_node_outputs_name(node)[0]} signature: {parse_node_signature(node)}\n" ) schema_handler = SchemaHelper(schema) NndctDebugLogger.write( f"matched schema: {schema_handler.toString()}\n") return schema if schema_op.split("::")[0] == "aten": #assert False NndctScreenLogger().warning( f"Can't find schema for {node}.If you can get quantizable model successfully, please ignore it.\n" )
def parse_module(module: torch.nn.Module, input_args: Union[torch.Tensor, Sequence[Any]], enable_opt: bool = True, graph_name: Optional[str] = None) -> Graph: if NndctOption.nndct_equalization.value: if NndctOption.nndct_relu6_replace.value == 'reluk': replace_relu6_with_reluk(module) elif NndctOption.nndct_relu6_replace.value == 'relu': replace_relu6_with_relu(module) parser = TorchParser() graph = parser(module._get_name() if graph_name is None else graph_name, module, input_args) if enable_opt: optimizer = QuantOptimizer() graph = optimizer(graph) if NndctOption.nndct_parse_debug.value >= 3: NndctDebugLogger.write(f"nndct quant graph:\n{graph}") return graph
def get_deploy_graph_list(quant_model, nndct_graph): g_optmizer = DevGraphOptimizer(nndct_graph) g_optmizer.infer_tensor_layout() g_optmizer.strip_redundant_ops() # sync model data with dev graph connect_module_with_graph(quant_model, g_optmizer.frozen_graph, recover_param=False) update_nndct_blob_data(quant_model, g_optmizer.frozen_graph) connect_module_with_graph(quant_model, nndct_graph, recover_param=False) g_optmizer.constant_folding() if NndctOption.nndct_parse_debug.value >= 3: NndctDebugLogger.write( f"\nfrozen dev graph:\n{g_optmizer.frozen_graph}") deploy_graphs = g_optmizer.partition_by_quant_part() return deploy_graphs
def __call__(self, graph_name, module, input_args): # torch_graph_handler = TorchGraphHandler() graph_handler = create_graph_handler(module) raw_graph, raw_params = graph_handler.build_torch_graph( graph_name, module, input_args) self._convert_params(raw_params, raw_graph.name) NndctScreenLogger().info("Processing ops...") nndct_graph = self._convert_graph( raw_graph, self._get_device_info(module, input_args)) unknown_op_type_check(nndct_graph) graphs = [nndct_graph] #graphs.extend(list(nndct_graph.block_subgraphs())) collect_all_blocks(nndct_graph, graphs) reorder_multi_subgraph_nodes(graphs) self._convert_blob_tensor_type(nndct_graph) self._load_data(nndct_graph, module) if NndctOption.nndct_parse_debug.value >= 2: NndctDebugLogger.write(f"nndct raw graph:\n{nndct_graph}") # print(f"nndct graph:{self._nndct_graph}") return nndct_graph
def build_torch_graph(self, graph_name, module, input_args, train=False): self._module = module if get_torch_version() >= 190 and isinstance(self._module, torch.jit._trace.TracedModule): fw_graph = self._get_graph_from_trace_script(input_args) elif get_torch_version() >= 190 and isinstance(self._module, torch.jit.ScriptModule): fw_graph = self._get_graph_from_script(input_args) elif isinstance(self._module, torch.nn.Module): if NndctOption.nndct_jit_script.value and get_torch_version() >= 190: NndctScreenLogger().info("Start to convert model to jit script...") try: self._module = torch.jit.script(self._module.eval()) except Exception as e: NndctScreenLogger().error(str(e)) sys.exit("Failed to convert nn.module to jit script.") fw_graph = self._get_graph_from_script(input_args) NndctScreenLogger().info("Finish converting.") elif NndctOption.nndct_jit_trace.value and get_torch_version() >= 190: NndctScreenLogger().info("Start to trace model...") try: self._module = torch.jit.trace(self._module.eval(), input_args) except Exception as e: NndctScreenLogger().error(str(e)) sys.exit("Failed to trace nn.module to jit script.") fw_graph = self._get_graph_from_trace_script(input_args) NndctScreenLogger().info("Finish tracing.") else: NndctScreenLogger().info("Start to trace model...") fw_graph = self._trace_graph_from_model(input_args, train) NndctScreenLogger().info("Finish tracing.") self._node_kinds = {node.kind().split(":")[-1] for node in fw_graph.nodes()} if NndctOption.nndct_parse_debug.value >= 1: NndctDebugLogger.write(f"jit graph:\n{fw_graph}") NndctDebugLogger.write(f"\nparsing nodes types:\n{self._node_kinds}\n") raw_graph = self._create_raw_graph(graph_name, fw_graph) if NndctOption.nndct_parse_debug.value >= 2: NndctDebugLogger.write(f"\ntorch raw graph:\n{raw_graph}") self._opt_raw_graph(raw_graph) if NndctOption.nndct_parse_debug.value >= 2: NndctDebugLogger.write(f"\ntorch opt graph:\n{raw_graph}") if NndctOption.nndct_parse_debug.value >= 3: self._check_stub_topology(raw_graph) return raw_graph
def parse_module(module: Union[torch.nn.Module, torch.jit.ScriptModule], input_args: Union[torch.Tensor, Sequence[Any]], enable_opt: bool = True, graph_name: Optional[str] = None) -> Graph: if NndctOption.nndct_equalization.value: if NndctOption.nndct_relu6_replace.value == 'reluk': replace_relu6_with_reluk(module) elif NndctOption.nndct_relu6_replace.value == 'relu': replace_relu6_with_relu(module) # if NndctOption.nndct_wes.value: # insert_scale_after_conv2d(module) parser = TorchParser() graph = parser( _get_module_name(module) if graph_name is None else graph_name, module, input_args) if enable_opt: graph = quant_optimize(graph) if NndctOption.nndct_parse_debug.value >= 3: NndctDebugLogger.write(f"nndct quant graph:\n{graph}") return graph
def build_torch_graph(self, graph_name, module, input_args, train=False): self._module = module fw_graph, params = self._trace_graph_from_model(input_args, train) self._node_kinds = { node.kind().split(":")[-1] for node in fw_graph.nodes() } if NndctOption.nndct_parse_debug.value >= 1: NndctDebugLogger.write(f"jit graph:\n{fw_graph}") NndctDebugLogger.write( f"\nparsing nodes types:\n{self._node_kinds}") raw_graph = self._build_raw_graph(graph_name, fw_graph, params) if NndctOption.nndct_parse_debug.value >= 2: NndctDebugLogger.write(f"\ntorch raw graph:\n{raw_graph}") opt_graph = self._opt_raw_graph(raw_graph) if NndctOption.nndct_parse_debug.value >= 2: NndctDebugLogger.write(f"\ntorch opt graph:\n{raw_graph}") return opt_graph