def _exec_save(ckpt_file_name, data_list): """Execute the process of saving checkpoint into file.""" try: with _ckpt_mutex: if os.path.exists(ckpt_file_name): os.remove(ckpt_file_name) with open(ckpt_file_name, "ab") as f: for name, value in data_list.items(): data_size = value[2].nbytes if data_size > SLICE_SIZE: slice_count = math.ceil(data_size / SLICE_SIZE) param_slice_list = np.array_split( value[2], slice_count) else: param_slice_list = [value[2]] for param_slice in param_slice_list: checkpoint_list = Checkpoint() param_value = checkpoint_list.value.add() param_value.tag = name param_tensor = param_value.tensor param_tensor.dims.extend(value[0]) param_tensor.tensor_type = value[1] param_tensor.tensor_content = param_slice.tobytes() f.write(checkpoint_list.SerializeToString()) os.chmod(ckpt_file_name, stat.S_IRUSR) except BaseException as e: logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) raise e
def _load_single_param(ckpt_file_name, param_name): """Load a parameter from checkpoint.""" logger.info("Execute the process of loading checkpoint files.") checkpoint_list = Checkpoint() try: with open(ckpt_file_name, "rb") as f: pb_content = f.read() checkpoint_list.ParseFromString(pb_content) except BaseException as e: logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) raise ValueError(e.__str__()) parameter = None try: param_data_list = [] for element_id, element in enumerate(checkpoint_list.value): if element.tag != param_name: continue data = element.tensor.tensor_content data_type = element.tensor.tensor_type np_type = tensor_to_np_type[data_type] ms_type = tensor_to_ms_type[data_type] element_data = np.frombuffer(data, np_type) param_data_list.append(element_data) if (element_id == len(checkpoint_list.value) - 1) or \ (element.tag != checkpoint_list.value[element_id + 1].tag): param_data = np.concatenate((param_data_list), axis=0) param_data_list.clear() dims = element.tensor.dims if dims == [0]: if 'Float' in data_type: param_data = float(param_data[0]) elif 'Int' in data_type: param_data = int(param_data[0]) parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) elif dims == [1]: parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) else: param_dim = [] for dim in dims: param_dim.append(dim) param_value = param_data.reshape(param_dim) parameter = Parameter(Tensor(param_value, ms_type), name=element.tag) break logger.info("Loading checkpoint files process is finished.") except BaseException as e: logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) raise RuntimeError(e.__str__()) if parameter is None: raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, " f"please check parameter name or checkpoint file.") return parameter
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"): """ Saves checkpoint info to a specified file. Args: parameter_list (list): Parameters list, each element is a dict like {"name":xx, "type":xx, "shape":xx, "data":xx}. ckpt_file_name (str): Checkpoint file name. model_type (str): The name of model type. Default: "normal". Raises: RuntimeError: Failed to save the Checkpoint file. """ logger.info("Execute save checkpoint process.") checkpoint_list = Checkpoint() checkpoint_list.model_type = model_type try: for param in parameter_list: param_value = checkpoint_list.value.add() param_value.tag = param["name"] param_tensor = param_value.tensor if isinstance(param["data"], Parameter): param["data"].init_data() param_data = param["data"].asnumpy().reshape(-1) param_tensor.tensor_content = param_data.tostring() param_tensor.tensor_type = str(param["data"].dtype) if param['data'].shape == (): param_tensor.dims.append(0) else: for dim in param['data'].shape: param_tensor.dims.append(dim) with open(ckpt_file_name, "wb") as f: f.write(checkpoint_list.SerializeToString()) os.chmod(ckpt_file_name, stat.S_IRUSR) except BaseException as e: logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) raise RuntimeError(e.__str__()) logger.info("Save checkpoint process finish.")
def read_proto(file_name, proto_format="MINDIR", display_data=False): """ Read protobuf file. Args: file_name (str): File name. proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR. display_data (bool): Whether display data. Default: False. Returns: Object, proto object. """ if proto_format == "MINDIR": model = mindir_model() elif proto_format == "ANF": model = anf_model() elif proto_format == "CKPT": model = Checkpoint() else: raise ValueError("Unsupported proto format.") try: with open(file_name, "rb") as f: pb_content = f.read() model.ParseFromString(pb_content) except BaseException as e: logger.error( "Failed to read the file `%s`, please check the correct of the file.", file_name) raise ValueError(e.__str__()) if proto_format == "MINDIR" and not display_data: for param_proto in model.graph.parameter: param_proto.raw_data = b'\0' if proto_format == "CKPT" and not display_data: for element in model.value: element.tensor.tensor_content = b'\0' return model
def _exec_save(ckpt_file_name, data_list): """Execute save checkpoint into file process.""" checkpoint_list = Checkpoint() try: with _ckpt_mutex: for name, value in data_list.items(): param_value = checkpoint_list.value.add() param_value.tag = name param_tensor = param_value.tensor param_tensor.dims.extend(value[0]) param_tensor.tensor_type = value[1] param_tensor.tensor_content = value[2].tostring() with open(ckpt_file_name, "wb") as f: f.write(checkpoint_list.SerializeToString()) os.chmod(ckpt_file_name, stat.S_IRUSR) except BaseException as e: logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) raise RuntimeError(e.__str__())
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None): """ Loads checkpoint info from a specified file. Args: ckpt_file_name (str): Checkpoint file name. net (Cell): Cell network. Default: None strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter in the param_dict into net with the same suffix. Default: False filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix will not be loaded. Default: None. Returns: Dict, key is parameter name, value is a Parameter. Raises: ValueError: Checkpoint file is incorrect. Examples: >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") """ if not isinstance(ckpt_file_name, str): raise ValueError("The ckpt_file_name must be string.") if not os.path.exists(ckpt_file_name): raise ValueError("The checkpoint file is not exist.") if ckpt_file_name[-5:] != ".ckpt": raise ValueError("Please input the correct checkpoint file name.") if os.path.getsize(ckpt_file_name) == 0: raise ValueError( "The checkpoint file may be empty, please make sure enter the correct file name." ) if filter_prefix is not None: if not isinstance(filter_prefix, (str, list, tuple)): raise TypeError( f"The type of filter_prefix must be str, list[str] or tuple[str] " f"when filter_prefix is not None, but got {str(type(filter_prefix))}." ) if isinstance(filter_prefix, str): filter_prefix = (filter_prefix, ) if not filter_prefix: raise ValueError( "The filter_prefix can't be empty when filter_prefix is list or tuple." ) for index, prefix in enumerate(filter_prefix): if not isinstance(prefix, str): raise TypeError( f"The type of filter_prefix must be str, list[str] or tuple[str], " f"but got {str(type(prefix))} at index {index}.") logger.info("Execute the process of loading checkpoint files.") checkpoint_list = Checkpoint() try: with open(ckpt_file_name, "rb") as f: pb_content = f.read() checkpoint_list.ParseFromString(pb_content) except BaseException as e: logger.error( "Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) raise ValueError(e.__str__()) parameter_dict = {} try: param_data_list = [] for element_id, element in enumerate(checkpoint_list.value): if filter_prefix is not None and _check_param_prefix( filter_prefix, element.tag): continue data = element.tensor.tensor_content data_type = element.tensor.tensor_type np_type = tensor_to_np_type[data_type] ms_type = tensor_to_ms_type[data_type] element_data = np.frombuffer(data, np_type) param_data_list.append(element_data) if (element_id == len(checkpoint_list.value) - 1) or \ (element.tag != checkpoint_list.value[element_id + 1].tag): param_data = np.concatenate((param_data_list), axis=0) param_data_list.clear() dims = element.tensor.dims if dims == [0]: if 'Float' in data_type: param_data = float(param_data[0]) elif 'Int' in data_type: param_data = int(param_data[0]) parameter_dict[element.tag] = Parameter(Tensor( param_data, ms_type), name=element.tag) elif dims == [1]: parameter_dict[element.tag] = Parameter(Tensor( param_data, ms_type), name=element.tag) else: param_dim = [] for dim in dims: param_dim.append(dim) param_value = param_data.reshape(param_dim) parameter_dict[element.tag] = Parameter(Tensor( param_value, ms_type), name=element.tag) logger.info("Loading checkpoint files process is finished.") except BaseException as e: logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) raise RuntimeError(e.__str__()) if not parameter_dict: raise ValueError( f"The loaded parameter dict is empty after filtering, please check filter_prefix." ) if net is not None: load_param_into_net(net, parameter_dict, strict_load) return parameter_dict
def load_checkpoint(ckpoint_file_name, net=None): """ Loads checkpoint info from a specified file. Args: ckpoint_file_name (str): Checkpoint file name. net (Cell): Cell network. Default: None Returns: Dict, key is parameter name, value is a Parameter. Raises: ValueError: Checkpoint file is incorrect. """ if not isinstance(ckpoint_file_name, str): raise ValueError("The ckpoint_file_name must be String.") if not os.path.exists( ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt": raise ValueError("Please input the correct checkpoint file name.") if os.path.getsize(ckpoint_file_name) == 0: raise ValueError( "The checkpoint file may be empty, please make sure enter the correct file name." ) logger.info("Execute load checkpoint process.") checkpoint_list = Checkpoint() try: with open(ckpoint_file_name, "rb") as f: pb_content = f.read() checkpoint_list.ParseFromString(pb_content) except BaseException as e: logger.error( "Failed to read the checkpoint file %s, please check the correct of the file.", ckpoint_file_name) raise ValueError(e.__str__()) parameter_dict = {} try: for element in checkpoint_list.value: data = element.tensor.tensor_content data_type = element.tensor.tensor_type np_type = tensor_to_np_type[data_type] ms_type = tensor_to_ms_type[data_type] param_data = np.fromstring(data, np_type) dims = element.tensor.dims if dims in [[0], [1]]: parameter_dict[element.tag] = Parameter(param_data[0], name=element.tag) else: param_dim = [] for dim in dims: param_dim.append(dim) param_value = param_data.reshape(param_dim) parameter_dict[element.tag] = Parameter(Tensor( param_value, ms_type), name=element.tag) logger.info("Load checkpoint process finish.") except BaseException as e: logger.error("Failed to load the checkpoint file %s.", ckpoint_file_name) raise RuntimeError(e.__str__()) if net: load_param_into_net(net, parameter_dict) return parameter_dict
def load_checkpoint(ckpt_file_name, model_type="normal", net=None): """ Loads checkpoint info from a specified file. Args: ckpt_file_name (str): Checkpoint file name. model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal". net (Cell): Cell network. Default: None Returns: Dict, key is parameter name, value is a Parameter. Raises: ValueError: Checkpoint file is incorrect. """ if not isinstance(ckpt_file_name, str): raise ValueError("The ckpt_file_name must be string.") if model_type not in ModelType: raise ValueError(f"The model_type is not in {ModelType}.") if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt": raise ValueError("Please input the correct checkpoint file name.") if os.path.getsize(ckpt_file_name) == 0: raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.") logger.info("Execute load checkpoint process.") checkpoint_list = Checkpoint() try: with open(ckpt_file_name, "rb") as f: pb_content = f.read() checkpoint_list.ParseFromString(pb_content) except BaseException as e: logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) raise ValueError(e.__str__()) parameter_dict = {} if checkpoint_list.model_type: if model_type != checkpoint_list.model_type: raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( checkpoint_list.model_type, model_type)) try: for element in checkpoint_list.value: data = element.tensor.tensor_content data_type = element.tensor.tensor_type np_type = tensor_to_np_type[data_type] ms_type = tensor_to_ms_type[data_type] param_data = np.fromstring(data, np_type) dims = element.tensor.dims if dims == [0]: if 'Float' in data_type: param_data = float(param_data[0]) elif 'Int' in data_type: param_data = int(param_data[0]) parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) elif dims == [1]: parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) else: param_dim = [] for dim in dims: param_dim.append(dim) param_value = param_data.reshape(param_dim) parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) logger.info("Load checkpoint process finish.") except BaseException as e: logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) raise RuntimeError(e.__str__()) if net: load_param_into_net(net, parameter_dict) return parameter_dict