Exemplo n.º 1
0
def get_deploy_graph_list(quant_model, nndct_graph):
    g_optmizer = DevGraphOptimizer(nndct_graph)
    # sync model data with dev graph
    connect_module_with_graph(quant_model,
                              g_optmizer.dev_graph,
                              recover_param=False)
    update_nndct_blob_data(quant_model, g_optmizer.dev_graph)
    g_optmizer.strip_redundant_ops()
    g_optmizer.update_op_attrs()
    g_optmizer.constant_folding()
    g_optmizer.layout_tranform()
    connect_module_with_graph(quant_model,
                              g_optmizer.dev_graph,
                              recover_param=False)
    update_nndct_blob_data(quant_model, g_optmizer.dev_graph)
    connect_module_with_graph(quant_model, nndct_graph, recover_param=False)
    # for node in g_optmizer._dev_graph.nodes:
    #   print(f"{node.name}, {node.op.type}, {node.out_tensors[0].layout}")

    if NndctOption.nndct_parse_debug.value >= 3:
        NndctDebugLogger.write(f"\nfrozen dev graph:\n{g_optmizer.dev_graph}")

    deploy_graphs = g_optmizer.partition_by_quant_part()

    return deploy_graphs
Exemplo n.º 2
0
    def __call__(self, graph_name, module, input_args):
        torch_graph_handler = TorchGraphHandler()
        raw_graph = torch_graph_handler.build_torch_graph(
            graph_name, module, input_args)
        self._nndct_graph = Graph(graph_name=raw_graph.name)
        node_convertor = NodeConvertor()
        op_creator = OpCreator()
        for raw_node in raw_graph.nodes:
            nndct_node = node_convertor(self, raw_graph, raw_node)
            if nndct_node:
                self._nndct_graph.add_node(nndct_node)
                nndct_node.op = op_creator(self, raw_graph, raw_node)

        for ret_value_name in raw_graph.ret_values().keys():
            end_tensor = self._nndct_graph.tensor(
                get_full_name(self._nndct_graph.name, ret_value_name))
            self._nndct_graph.add_end_tensor(end_tensor)

        self._convert_blob_tensor_type()
        self._nndct_graph.connect_nodes()
        self._load_data(module)
        if NndctOption.nndct_parse_debug.value >= 2:
            NndctDebugLogger.write(f"nndct raw graph:\n{self._nndct_graph}")
        # print(f"nndct graph:{self._nndct_graph}")
        return self._nndct_graph
Exemplo n.º 3
0
 def get_frozen_graph(self):
     self._infer_tensor_layout()
     self._strip_redundant_ops()
     self._constant_folding()
     if NndctOption.nndct_parse_debug.value >= 3:
         NndctDebugLogger.write(f"\nfrozen dev graph:\n{self._dev_graph}")
     return self._dev_graph
Exemplo n.º 4
0
    def build_torch_graph(self, graph_name, module, input_args, train=False):
        self._module = module
        NndctScreenLogger().info("Start to trace model...")
        fw_graph, params = self._trace_graph_from_model(input_args, train)
        NndctScreenLogger().info("Finish tracing.")

        self._node_kinds = {
            node.kind().split(":")[-1]
            for node in fw_graph.nodes()
        }
        if NndctOption.nndct_parse_debug.value >= 1:
            NndctDebugLogger.write(f"jit graph:\n{fw_graph}")
            NndctDebugLogger.write(
                f"\nparsing nodes types:\n{self._node_kinds}\n")

        raw_graph, raw_params = self._build_raw_graph(graph_name, fw_graph,
                                                      params)
        if NndctOption.nndct_parse_debug.value >= 2:
            NndctDebugLogger.write(f"\ntorch raw graph:\n{raw_graph}")
        opt_graph = self._opt_raw_graph(raw_graph)
        if NndctOption.nndct_parse_debug.value >= 2:
            NndctDebugLogger.write(f"\ntorch opt graph:\n{raw_graph}")

        if NndctOption.nndct_parse_debug.value >= 3:
            self._check_stub_topology(opt_graph)

        return opt_graph, raw_params
Exemplo n.º 5
0
 def __call__(self, graph_name, module, input_args):
     graph_handler = TorchGraphHandler()
     # graph_handler = create_graph_handler(module)
     raw_graph = graph_handler.build_torch_graph(graph_name, module,
                                                 input_args)
     GLOBAL_MAP.set_map(NNDCT_KEYS.DEVICE,
                        self._get_device_info(module, input_args))
     NndctScreenLogger().info("Processing ops...")
     nndct_graph = self._convert_graph(raw_graph)
     unknown_op_type_check(nndct_graph)
     self._convert_blob_tensor_type(nndct_graph)
     self._load_data(nndct_graph, module)
     if NndctOption.nndct_parse_debug.value >= 2:
         NndctDebugLogger.write(f"nndct raw graph:\n{nndct_graph}")
     return nndct_graph
Exemplo n.º 6
0
def _init_torch_nndct_module():
    _pop_out_argv = []
    argv = _copy.copy(_sys.argv)
    for cmd_pos, option in enumerate(_sys.argv[1:], 1):
        _pop_out_argv.extend(
            option_util.add_valid_nndct_option(argv, option, cmd_pos, 'torch'))

    for item in _pop_out_argv:
        _sys.argv.remove(item)

    if NndctOption.nndct_option_list.value:
        print("Usage: python file [option]")
        print("Nndct options:")
        for option in option_util.get_all_options():
            print(option)
        _sys.exit()

    if NndctOption.nndct_help.value:
        # TODO: register api info
        pass
        # print("Nndct API Description:")
        # if "__all__" in api.__dict__:
        #   for key in api.__dict__["__all__"]:
        #     item = api.__dict__[key]
        #     if inspect.isclass(item):
        #       print(f"\nclass {key}:\n{item.__doc__}")
        #       for method_name, method in item.__dict__.items():
        #         if (not (method_name.startswith("_") or
        #                  method_name.startswith("__")) and
        #             inspect.isfunction(method) and method.__doc__ is not None):
        #           print(
        #               f"\n def {method_name}{inspect.signature(method)}:\n {method.__doc__}"
        #           )

        #     elif inspect.isfunction(item):
        #       print(f"\ndef {key}{inspect.signature(item)}:\n{item.__doc__}")
        _sys.exit()

    if NndctOption.nndct_parse_debug.value:
        option_util.set_option_value("nndct_logging_level", 1)

    if NndctOption.nndct_logging_level.value > 0:
        NndctDebugLogger("nndct_debug.log")

    if NndctOption.nndct_quant_off.value:
        option_util.set_option_value("nndct_quant_opt", 0)
        option_util.set_option_value("nndct_param_corr", False)
        option_util.set_option_value("nndct_equalization", False)
        # option_util.set_option_value("nndct_wes", False)
        option_util.set_option_value("nndct_wes_in_cle", False)

    #if NndctOption.nndct_quant_opt.value > 2:
    if (not hasattr(NndctOption.nndct_param_corr,
                    '_value')) and (NndctOption.nndct_quant_opt.value <= 0):
        option_util.set_option_value("nndct_param_corr", False)

    if (not hasattr(NndctOption.nndct_equalization,
                    '_value')) and (NndctOption.nndct_quant_opt.value <= 0):
        option_util.set_option_value("nndct_equalization", False)
Exemplo n.º 7
0
def get_node_schema(node):
    schema_op = node_type(node)
    schemas = torch._C._jit_get_schemas_for_operator(schema_op)
    for schema in schemas:
        if is_schema_matching(node, schema):
            if NndctOption.nndct_parse_debug.value >= 1:
                NndctDebugLogger.write(
                    f"%{get_node_outputs_name(node)[0]} signature: {parse_node_signature(node)}\n"
                )
                schema_handler = SchemaHelper(schema)
                NndctDebugLogger.write(
                    f"matched schema: {schema_handler.toString()}\n")
            return schema
    if schema_op.split("::")[0] == "aten":
        #assert False
        NndctScreenLogger().warning(
            f"Can't find schema for {node}.If you can get quantizable model successfully, please ignore it.\n"
        )
Exemplo n.º 8
0
def parse_module(module: torch.nn.Module,
                 input_args: Union[torch.Tensor, Sequence[Any]],
                 enable_opt: bool = True,
                 graph_name: Optional[str] = None) -> Graph:

    if NndctOption.nndct_equalization.value:
        if NndctOption.nndct_relu6_replace.value == 'reluk':
            replace_relu6_with_reluk(module)
        elif NndctOption.nndct_relu6_replace.value == 'relu':
            replace_relu6_with_relu(module)
    parser = TorchParser()
    graph = parser(module._get_name() if graph_name is None else graph_name,
                   module, input_args)
    if enable_opt:
        optimizer = QuantOptimizer()
        graph = optimizer(graph)
    if NndctOption.nndct_parse_debug.value >= 3:
        NndctDebugLogger.write(f"nndct quant graph:\n{graph}")
    return graph
Exemplo n.º 9
0
def get_deploy_graph_list(quant_model, nndct_graph):
    g_optmizer = DevGraphOptimizer(nndct_graph)
    g_optmizer.infer_tensor_layout()
    g_optmizer.strip_redundant_ops()

    # sync model data with dev graph
    connect_module_with_graph(quant_model,
                              g_optmizer.frozen_graph,
                              recover_param=False)
    update_nndct_blob_data(quant_model, g_optmizer.frozen_graph)
    connect_module_with_graph(quant_model, nndct_graph, recover_param=False)

    g_optmizer.constant_folding()
    if NndctOption.nndct_parse_debug.value >= 3:
        NndctDebugLogger.write(
            f"\nfrozen dev graph:\n{g_optmizer.frozen_graph}")

    deploy_graphs = g_optmizer.partition_by_quant_part()

    return deploy_graphs
Exemplo n.º 10
0
    def __call__(self, graph_name, module, input_args):
        # torch_graph_handler = TorchGraphHandler()
        graph_handler = create_graph_handler(module)
        raw_graph, raw_params = graph_handler.build_torch_graph(
            graph_name, module, input_args)

        self._convert_params(raw_params, raw_graph.name)
        NndctScreenLogger().info("Processing ops...")
        nndct_graph = self._convert_graph(
            raw_graph, self._get_device_info(module, input_args))
        unknown_op_type_check(nndct_graph)
        graphs = [nndct_graph]
        #graphs.extend(list(nndct_graph.block_subgraphs()))
        collect_all_blocks(nndct_graph, graphs)
        reorder_multi_subgraph_nodes(graphs)
        self._convert_blob_tensor_type(nndct_graph)
        self._load_data(nndct_graph, module)
        if NndctOption.nndct_parse_debug.value >= 2:
            NndctDebugLogger.write(f"nndct raw graph:\n{nndct_graph}")
        # print(f"nndct graph:{self._nndct_graph}")
        return nndct_graph
Exemplo n.º 11
0
 def build_torch_graph(self, graph_name, module, input_args, train=False):
   self._module = module
   if get_torch_version() >= 190 and isinstance(self._module, torch.jit._trace.TracedModule):
      fw_graph = self._get_graph_from_trace_script(input_args)
   elif get_torch_version() >= 190 and isinstance(self._module, torch.jit.ScriptModule):
     fw_graph = self._get_graph_from_script(input_args)
   elif isinstance(self._module, torch.nn.Module):
     if NndctOption.nndct_jit_script.value and get_torch_version() >= 190:
       NndctScreenLogger().info("Start to convert model to jit script...")
       try:
         self._module = torch.jit.script(self._module.eval())
       except Exception as e:
         NndctScreenLogger().error(str(e))
         sys.exit("Failed to convert nn.module to jit script.")
       fw_graph = self._get_graph_from_script(input_args)
       NndctScreenLogger().info("Finish converting.")
     elif NndctOption.nndct_jit_trace.value and get_torch_version() >= 190:
       NndctScreenLogger().info("Start to trace model...")
       try:
         self._module = torch.jit.trace(self._module.eval(), input_args)
       except Exception as e:
         NndctScreenLogger().error(str(e))
         sys.exit("Failed to trace nn.module to jit script.")
       fw_graph = self._get_graph_from_trace_script(input_args)
       NndctScreenLogger().info("Finish tracing.")
     else:
       NndctScreenLogger().info("Start to trace model...")
       fw_graph = self._trace_graph_from_model(input_args, train)
       NndctScreenLogger().info("Finish tracing.")
   
    
   
   self._node_kinds = {node.kind().split(":")[-1] for node in fw_graph.nodes()}
   if NndctOption.nndct_parse_debug.value >= 1:
     NndctDebugLogger.write(f"jit graph:\n{fw_graph}")
     NndctDebugLogger.write(f"\nparsing nodes types:\n{self._node_kinds}\n")
   
   raw_graph = self._create_raw_graph(graph_name, fw_graph)
   if NndctOption.nndct_parse_debug.value >= 2:
     NndctDebugLogger.write(f"\ntorch raw graph:\n{raw_graph}")   
   self._opt_raw_graph(raw_graph)
   if NndctOption.nndct_parse_debug.value >= 2:
     NndctDebugLogger.write(f"\ntorch opt graph:\n{raw_graph}")
   
   if NndctOption.nndct_parse_debug.value >= 3:
     self._check_stub_topology(raw_graph)
   
   return raw_graph
Exemplo n.º 12
0
def parse_module(module: Union[torch.nn.Module, torch.jit.ScriptModule],
                 input_args: Union[torch.Tensor, Sequence[Any]],
                 enable_opt: bool = True,
                 graph_name: Optional[str] = None) -> Graph:

    if NndctOption.nndct_equalization.value:
        if NndctOption.nndct_relu6_replace.value == 'reluk':
            replace_relu6_with_reluk(module)
        elif NndctOption.nndct_relu6_replace.value == 'relu':
            replace_relu6_with_relu(module)

    # if NndctOption.nndct_wes.value:
    #   insert_scale_after_conv2d(module)

    parser = TorchParser()
    graph = parser(
        _get_module_name(module) if graph_name is None else graph_name, module,
        input_args)
    if enable_opt:
        graph = quant_optimize(graph)

    if NndctOption.nndct_parse_debug.value >= 3:
        NndctDebugLogger.write(f"nndct quant graph:\n{graph}")
    return graph
Exemplo n.º 13
0
    def build_torch_graph(self, graph_name, module, input_args, train=False):
        self._module = module
        fw_graph, params = self._trace_graph_from_model(input_args, train)

        self._node_kinds = {
            node.kind().split(":")[-1]
            for node in fw_graph.nodes()
        }
        if NndctOption.nndct_parse_debug.value >= 1:
            NndctDebugLogger.write(f"jit graph:\n{fw_graph}")
            NndctDebugLogger.write(
                f"\nparsing nodes types:\n{self._node_kinds}")

        raw_graph = self._build_raw_graph(graph_name, fw_graph, params)
        if NndctOption.nndct_parse_debug.value >= 2:
            NndctDebugLogger.write(f"\ntorch raw graph:\n{raw_graph}")
        opt_graph = self._opt_raw_graph(raw_graph)
        if NndctOption.nndct_parse_debug.value >= 2:
            NndctDebugLogger.write(f"\ntorch opt graph:\n{raw_graph}")
        return opt_graph