示例#1
0
def phase_pb_file(file_path: str) -> Union[MSGraph, None]:
    """
    Parse pb file to graph

    Args:
        file_path (str): The file path of pb file.

    Returns:
        MSGraph, if load pb file and build graph success, will return the graph, else return None.
    """
    if not CONFIG.VERBOSE:
        logger.setLevel(logging.ERROR)
    logger.info("Start to load graph from pb file, file path: %s.", file_path)
    model_proto = anf_ir_pb2.ModelProto()
    try:
        model_proto.ParseFromString(FileHandler(file_path).read())
    except ParseError:
        logger.warning("The given file is not a valid pb file, file path: %s.",
                       file_path)
        return None

    graph = MSGraph()

    try:
        graph.build_graph(model_proto.graph)
    except Exception as ex:
        logger.error("Build graph failed, file path: %s.", file_path)
        logger.exception(ex)
        raise UnknownError(str(ex))

    logger.info("Build graph success, file path: %s.", file_path)
    return graph
示例#2
0
    def _parse_pb_file(self, filename):
        """
        Parse pb file and write content to `EventsData`.

        Args:
            filename (str): The file path of pb file.
        """
        file_path = FileHandler.join(self._summary_dir, filename)
        logger.info("Start to load graph from pb file, file path: %s.",
                    file_path)
        filehandler = FileHandler(file_path)
        model_proto = anf_ir_pb2.ModelProto()
        try:
            model_proto.ParseFromString(filehandler.read())
        except ParseError:
            logger.warning(
                "The given file is not a valid pb file, file path: %s.",
                file_path)
            return

        graph = MSGraph()
        graph.build_graph(model_proto.graph)
        tensor_event = TensorEvent(wall_time=FileHandler.file_stat(file_path),
                                   step=0,
                                   tag=filename,
                                   plugin_name=PluginNameEnum.GRAPH.value,
                                   value=graph)
        self._events_data.add_tensor_event(tensor_event)
def create_graph_pb_file(output_dir='./', filename='ms_output.pb'):
    """Create graph pb file, and return file path."""
    graph_base = os.path.join(os.path.dirname(__file__), "graph_base.json")
    with open(graph_base, 'r') as fp:
        data = json.load(fp)
    model_def = dict(graph=data)
    model_proto = json_format.Parse(json.dumps(model_def),
                                    anf_ir_pb2.ModelProto())
    msg = model_proto.SerializeToString()
    output_path = os.path.realpath(os.path.join(output_dir, filename))
    with open(output_path, 'wb') as fp:
        fp.write(msg)

    return output_path
示例#4
0
    def _parse_pb_file(summary_dir, filename):
        """
        Parse pb file and write content to `EventsData`.

        Args:
            filename (str): The file path of pb file.

        Returns:
            TensorEvent, if load pb file and build graph success, will return tensor event, else return None.
        """
        file_path = FileHandler.join(summary_dir, filename)
        logger.info("Start to load graph from pb file, file path: %s.",
                    file_path)
        filehandler = FileHandler(file_path)
        model_proto = anf_ir_pb2.ModelProto()
        try:
            model_proto.ParseFromString(filehandler.read())
        except ParseError:
            logger.warning(
                "The given file is not a valid pb file, file path: %s.",
                file_path)
            return None

        graph = MSGraph()

        try:
            graph.build_graph(model_proto.graph)
        except Exception as ex:
            # Normally, there are no exceptions, and it is only possible for users on the MindSpore side
            # to dump other non-default graphs.
            logger.error("Build graph failed, file path: %s.", file_path)
            logger.exception(ex)
            raise UnknownError(str(ex))

        tensor_event = TensorEvent(
            wall_time=FileHandler.file_stat(file_path).mtime,
            step=0,
            tag=filename,
            plugin_name=PluginNameEnum.GRAPH.value,
            value=graph,
            filename=filename)

        logger.info("Build graph success, file path: %s.", file_path)
        return tensor_event