Beispiel #1
0
def read_proto(file_name, proto_format="MINDIR"):
    """
    Read protobuf file.

    Args:
        file_name (str): File name.
        proto_format (str): Proto format.

    Returns:
        Object, proto object.
    """

    if proto_format == "MINDIR":
        model = mindir_model()
    elif model_format == "ANF":
        model = anf_model()
    else:
        raise ValueError("Unsupported proto format.")

    try:
        with open(file_name, "rb") as f:
            pb_content = f.read()
            model.ParseFromString(pb_content)
    except BaseException as e:
        logger.error(
            "Failed to read the file `%s`, please check the correct of the file.",
            file_name)
        raise ValueError(e.__str__())
    return model
Beispiel #2
0
def read_proto(file_name, proto_format="MINDIR", display_data=False):
    """
    Read protobuf file.

    Args:
        file_name (str): File name.
        proto_format (str): Proto format {MINDIR, ANF, CKPT}.  Default: MINDIR.
        display_data (bool): Whether display data. Default: False.

    Returns:
        Object, proto object.
    """

    if proto_format == "MINDIR":
        model = mindir_model()
    elif proto_format == "ANF":
        model = anf_model()
    elif proto_format == "CKPT":
        model = Checkpoint()
    else:
        raise ValueError("Unsupported proto format.")

    try:
        with open(file_name, "rb") as f:
            pb_content = f.read()
            model.ParseFromString(pb_content)
    except BaseException as e:
        logger.error(
            "Failed to read the file `%s`, please check the correct of the file.",
            file_name)
        raise ValueError(e.__str__())

    if proto_format == "MINDIR" and not display_data:
        for param_proto in model.graph.parameter:
            param_proto.raw_data = b'\0'

    if proto_format == "CKPT" and not display_data:
        for element in model.value:
            element.tensor.tensor_content = b'\0'

    return model