def _load_single_param(ckpt_file_name, param_name): """Load a parameter from checkpoint.""" logger.info("Execute the process of loading checkpoint files.") checkpoint_list = Checkpoint() try: with open(ckpt_file_name, "rb") as f: pb_content = f.read() checkpoint_list.ParseFromString(pb_content) except BaseException as e: logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) raise ValueError(e.__str__()) parameter = None try: param_data_list = [] for element_id, element in enumerate(checkpoint_list.value): if element.tag != param_name: continue data = element.tensor.tensor_content data_type = element.tensor.tensor_type np_type = tensor_to_np_type[data_type] ms_type = tensor_to_ms_type[data_type] element_data = np.frombuffer(data, np_type) param_data_list.append(element_data) if (element_id == len(checkpoint_list.value) - 1) or \ (element.tag != checkpoint_list.value[element_id + 1].tag): param_data = np.concatenate((param_data_list), axis=0) param_data_list.clear() dims = element.tensor.dims if dims == [0]: if 'Float' in data_type: param_data = float(param_data[0]) elif 'Int' in data_type: param_data = int(param_data[0]) parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) elif dims == [1]: parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) else: param_dim = [] for dim in dims: param_dim.append(dim) param_value = param_data.reshape(param_dim) parameter = Parameter(Tensor(param_value, ms_type), name=element.tag) break logger.info("Loading checkpoint files process is finished.") except BaseException as e: logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) raise RuntimeError(e.__str__()) if parameter is None: raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, " f"please check parameter name or checkpoint file.") return parameter
def read_proto(file_name, proto_format="MINDIR", display_data=False): """ Read protobuf file. Args: file_name (str): File name. proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR. display_data (bool): Whether display data. Default: False. Returns: Object, proto object. """ if proto_format == "MINDIR": model = mindir_model() elif proto_format == "ANF": model = anf_model() elif proto_format == "CKPT": model = Checkpoint() else: raise ValueError("Unsupported proto format.") try: with open(file_name, "rb") as f: pb_content = f.read() model.ParseFromString(pb_content) except BaseException as e: logger.error( "Failed to read the file `%s`, please check the correct of the file.", file_name) raise ValueError(e.__str__()) if proto_format == "MINDIR" and not display_data: for param_proto in model.graph.parameter: param_proto.raw_data = b'\0' if proto_format == "CKPT" and not display_data: for element in model.value: element.tensor.tensor_content = b'\0' return model
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None): """ Loads checkpoint info from a specified file. Args: ckpt_file_name (str): Checkpoint file name. net (Cell): Cell network. Default: None strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter in the param_dict into net with the same suffix. Default: False filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix will not be loaded. Default: None. Returns: Dict, key is parameter name, value is a Parameter. Raises: ValueError: Checkpoint file is incorrect. Examples: >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") """ if not isinstance(ckpt_file_name, str): raise ValueError("The ckpt_file_name must be string.") if not os.path.exists(ckpt_file_name): raise ValueError("The checkpoint file is not exist.") if ckpt_file_name[-5:] != ".ckpt": raise ValueError("Please input the correct checkpoint file name.") if os.path.getsize(ckpt_file_name) == 0: raise ValueError( "The checkpoint file may be empty, please make sure enter the correct file name." ) if filter_prefix is not None: if not isinstance(filter_prefix, (str, list, tuple)): raise TypeError( f"The type of filter_prefix must be str, list[str] or tuple[str] " f"when filter_prefix is not None, but got {str(type(filter_prefix))}." ) if isinstance(filter_prefix, str): filter_prefix = (filter_prefix, ) if not filter_prefix: raise ValueError( "The filter_prefix can't be empty when filter_prefix is list or tuple." ) for index, prefix in enumerate(filter_prefix): if not isinstance(prefix, str): raise TypeError( f"The type of filter_prefix must be str, list[str] or tuple[str], " f"but got {str(type(prefix))} at index {index}.") logger.info("Execute the process of loading checkpoint files.") checkpoint_list = Checkpoint() try: with open(ckpt_file_name, "rb") as f: pb_content = f.read() checkpoint_list.ParseFromString(pb_content) except BaseException as e: logger.error( "Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) raise ValueError(e.__str__()) parameter_dict = {} try: param_data_list = [] for element_id, element in enumerate(checkpoint_list.value): if filter_prefix is not None and _check_param_prefix( filter_prefix, element.tag): continue data = element.tensor.tensor_content data_type = element.tensor.tensor_type np_type = tensor_to_np_type[data_type] ms_type = tensor_to_ms_type[data_type] element_data = np.frombuffer(data, np_type) param_data_list.append(element_data) if (element_id == len(checkpoint_list.value) - 1) or \ (element.tag != checkpoint_list.value[element_id + 1].tag): param_data = np.concatenate((param_data_list), axis=0) param_data_list.clear() dims = element.tensor.dims if dims == [0]: if 'Float' in data_type: param_data = float(param_data[0]) elif 'Int' in data_type: param_data = int(param_data[0]) parameter_dict[element.tag] = Parameter(Tensor( param_data, ms_type), name=element.tag) elif dims == [1]: parameter_dict[element.tag] = Parameter(Tensor( param_data, ms_type), name=element.tag) else: param_dim = [] for dim in dims: param_dim.append(dim) param_value = param_data.reshape(param_dim) parameter_dict[element.tag] = Parameter(Tensor( param_value, ms_type), name=element.tag) logger.info("Loading checkpoint files process is finished.") except BaseException as e: logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) raise RuntimeError(e.__str__()) if not parameter_dict: raise ValueError( f"The loaded parameter dict is empty after filtering, please check filter_prefix." ) if net is not None: load_param_into_net(net, parameter_dict, strict_load) return parameter_dict
def load_checkpoint(ckpoint_file_name, net=None): """ Loads checkpoint info from a specified file. Args: ckpoint_file_name (str): Checkpoint file name. net (Cell): Cell network. Default: None Returns: Dict, key is parameter name, value is a Parameter. Raises: ValueError: Checkpoint file is incorrect. """ if not isinstance(ckpoint_file_name, str): raise ValueError("The ckpoint_file_name must be String.") if not os.path.exists( ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt": raise ValueError("Please input the correct checkpoint file name.") if os.path.getsize(ckpoint_file_name) == 0: raise ValueError( "The checkpoint file may be empty, please make sure enter the correct file name." ) logger.info("Execute load checkpoint process.") checkpoint_list = Checkpoint() try: with open(ckpoint_file_name, "rb") as f: pb_content = f.read() checkpoint_list.ParseFromString(pb_content) except BaseException as e: logger.error( "Failed to read the checkpoint file %s, please check the correct of the file.", ckpoint_file_name) raise ValueError(e.__str__()) parameter_dict = {} try: for element in checkpoint_list.value: data = element.tensor.tensor_content data_type = element.tensor.tensor_type np_type = tensor_to_np_type[data_type] ms_type = tensor_to_ms_type[data_type] param_data = np.fromstring(data, np_type) dims = element.tensor.dims if dims in [[0], [1]]: parameter_dict[element.tag] = Parameter(param_data[0], name=element.tag) else: param_dim = [] for dim in dims: param_dim.append(dim) param_value = param_data.reshape(param_dim) parameter_dict[element.tag] = Parameter(Tensor( param_value, ms_type), name=element.tag) logger.info("Load checkpoint process finish.") except BaseException as e: logger.error("Failed to load the checkpoint file %s.", ckpoint_file_name) raise RuntimeError(e.__str__()) if net: load_param_into_net(net, parameter_dict) return parameter_dict
def load_checkpoint(ckpt_file_name, model_type="normal", net=None): """ Loads checkpoint info from a specified file. Args: ckpt_file_name (str): Checkpoint file name. model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal". net (Cell): Cell network. Default: None Returns: Dict, key is parameter name, value is a Parameter. Raises: ValueError: Checkpoint file is incorrect. """ if not isinstance(ckpt_file_name, str): raise ValueError("The ckpt_file_name must be string.") if model_type not in ModelType: raise ValueError(f"The model_type is not in {ModelType}.") if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt": raise ValueError("Please input the correct checkpoint file name.") if os.path.getsize(ckpt_file_name) == 0: raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.") logger.info("Execute load checkpoint process.") checkpoint_list = Checkpoint() try: with open(ckpt_file_name, "rb") as f: pb_content = f.read() checkpoint_list.ParseFromString(pb_content) except BaseException as e: logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) raise ValueError(e.__str__()) parameter_dict = {} if checkpoint_list.model_type: if model_type != checkpoint_list.model_type: raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( checkpoint_list.model_type, model_type)) try: for element in checkpoint_list.value: data = element.tensor.tensor_content data_type = element.tensor.tensor_type np_type = tensor_to_np_type[data_type] ms_type = tensor_to_ms_type[data_type] param_data = np.fromstring(data, np_type) dims = element.tensor.dims if dims == [0]: if 'Float' in data_type: param_data = float(param_data[0]) elif 'Int' in data_type: param_data = int(param_data[0]) parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) elif dims == [1]: parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) else: param_dim = [] for dim in dims: param_dim.append(dim) param_value = param_data.reshape(param_dim) parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) logger.info("Load checkpoint process finish.") except BaseException as e: logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) raise RuntimeError(e.__str__()) if net: load_param_into_net(net, parameter_dict) return parameter_dict