def read_proto(file_name, proto_format="MINDIR"): """ Read protobuf file. Args: file_name (str): File name. proto_format (str): Proto format. Returns: Object, proto object. """ if proto_format == "MINDIR": model = mindir_model() elif model_format == "ANF": model = anf_model() else: raise ValueError("Unsupported proto format.") try: with open(file_name, "rb") as f: pb_content = f.read() model.ParseFromString(pb_content) except BaseException as e: logger.error( "Failed to read the file `%s`, please check the correct of the file.", file_name) raise ValueError(e.__str__()) return model
def read_proto(file_name, proto_format="MINDIR", display_data=False): """ Read protobuf file. Args: file_name (str): File name. proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR. display_data (bool): Whether display data. Default: False. Returns: Object, proto object. """ if proto_format == "MINDIR": model = mindir_model() elif proto_format == "ANF": model = anf_model() elif proto_format == "CKPT": model = Checkpoint() else: raise ValueError("Unsupported proto format.") try: with open(file_name, "rb") as f: pb_content = f.read() model.ParseFromString(pb_content) except BaseException as e: logger.error( "Failed to read the file `%s`, please check the correct of the file.", file_name) raise ValueError(e.__str__()) if proto_format == "MINDIR" and not display_data: for param_proto in model.graph.parameter: param_proto.raw_data = b'\0' if proto_format == "CKPT" and not display_data: for element in model.value: element.tensor.tensor_content = b'\0' return model