Пример #1
0
def graph_based_converter_tf_to_ms(graph_path: str,
                                   input_nodes: dict,
                                   output_nodes: List[str],
                                   output_folder: str,
                                   report_folder: str = None):
    """
    Tensorflow to MindSpore based on Graph.

    Args:
        graph_path (str): Graph file path.
        input_nodes (dict): Input node(s) of the model.
        output_nodes (list[str]): Output node(s) of the model.
        output_folder (str): Output folder.
        report_folder (str): Report output folder path.
    """
    # Close unnecessary log.
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

    graph_obj = GraphFactory.init(graph_path,
                                  input_nodes=input_nodes,
                                  output_nodes=output_nodes)
    generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
    model_name = _extract_model_name(graph_path)
    code_fragments = generator_inst.generate()
    save_code_file_and_report(model_name, code_fragments, output_folder,
                              report_folder)
    # Release global context.
    GlobalContext.release()
Пример #2
0
def graph_based_converter_tf_to_ms(graph_path: str,
                                   input_nodes: dict, output_nodes: List[str],
                                   output_folder: str, report_folder: str = None,
                                   query_result_folder: str = None):
    """
    Tensorflow to MindSpore based on Graph.

    Args:
        graph_path (str): Graph file path.
        input_nodes (dict): Input node(s) of the model.
        output_nodes (list[str]): Output node(s) of the model.
        output_folder (str): Output folder.
        report_folder (str): Report output folder path.
        query_result_folder (str): Save the optimized graph and its topological order to disk.
    """
    # Close unnecessary log.
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

    graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes)
    if query_result_folder:
        save_intermediate_graph(graph_obj.dataloader, query_result_folder)
        GlobalContext.release()
        return
    graph_obj.build()
    generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
    model_name = _extract_model_name(graph_path)
    log_console.info("Code saving begins.")
    code_fragments = generator_inst.generate()
    save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
    log_console.info("Code saving is finished.")
    # Release global context.
    GlobalContext.release()
Пример #3
0
def graph_based_converter_pytorch_to_ms(graph_path: str,
                                        input_nodes: dict,
                                        output_nodes: List[str],
                                        output_folder: str,
                                        report_folder: str = None):
    """
    PyTorch to MindSpore based on Graph.

    Args:
        graph_path (str): Graph file path.
        input_nodes (dict): Input node(s) of the model.
        output_nodes (list[str]): Output node(s) of the model.
        output_folder (str): Output folder.
        report_folder (str): Report output folder path.
    """
    graph_obj = GraphFactory.init(graph_path,
                                  input_nodes=input_nodes,
                                  output_nodes=output_nodes)
    generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
    model_name = _extract_model_name(graph_path)
    code_fragments = generator_inst.generate()
    save_code_file_and_report(model_name, code_fragments, output_folder,
                              report_folder)
    # Release global context.
    GlobalContext.release()
Пример #4
0
def convert_according_to_user_selections(graph_obj, output_folder: str, report_folder: str = None,
                                         user_operations: Mapping[str, Dict] = None):
    """
    ONNX to MindSpore based on Graph.

    Args:
        graph_obj (OnnxGraph): Onnx graph object.
        output_folder (str): Output folder.
        report_folder (str): Report output folder path.
        user_operations (dict): Record user's operations.
    """
    graph_obj.generate_scope_name(user_operations)
    graph_obj.build()
    generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
    model_name = _extract_model_name(graph_obj.model_path)
    log_console.info("Code saving begins.")
    code_fragments = generator_inst.generate()
    save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
    log_console.info("Code saving is finished.")
    # Release global context.
    GlobalContext.release()