def _exec_save(ckpt_file_name, data_list): """Execute the process of saving checkpoint into file.""" try: with _ckpt_mutex: if os.path.exists(ckpt_file_name): os.remove(ckpt_file_name) with open(ckpt_file_name, "ab") as f: for name, value in data_list.items(): data_size = value[2].nbytes if data_size > SLICE_SIZE: slice_count = math.ceil(data_size / SLICE_SIZE) param_slice_list = np.array_split( value[2], slice_count) else: param_slice_list = [value[2]] for param_slice in param_slice_list: checkpoint_list = Checkpoint() param_value = checkpoint_list.value.add() param_value.tag = name param_tensor = param_value.tensor param_tensor.dims.extend(value[0]) param_tensor.tensor_type = value[1] param_tensor.tensor_content = param_slice.tobytes() f.write(checkpoint_list.SerializeToString()) os.chmod(ckpt_file_name, stat.S_IRUSR) except BaseException as e: logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) raise e
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"): """ Saves checkpoint info to a specified file. Args: parameter_list (list): Parameters list, each element is a dict like {"name":xx, "type":xx, "shape":xx, "data":xx}. ckpt_file_name (str): Checkpoint file name. model_type (str): The name of model type. Default: "normal". Raises: RuntimeError: Failed to save the Checkpoint file. """ logger.info("Execute save checkpoint process.") checkpoint_list = Checkpoint() checkpoint_list.model_type = model_type try: for param in parameter_list: param_value = checkpoint_list.value.add() param_value.tag = param["name"] param_tensor = param_value.tensor if isinstance(param["data"], Parameter): param["data"].init_data() param_data = param["data"].asnumpy().reshape(-1) param_tensor.tensor_content = param_data.tostring() param_tensor.tensor_type = str(param["data"].dtype) if param['data'].shape == (): param_tensor.dims.append(0) else: for dim in param['data'].shape: param_tensor.dims.append(dim) with open(ckpt_file_name, "wb") as f: f.write(checkpoint_list.SerializeToString()) os.chmod(ckpt_file_name, stat.S_IRUSR) except BaseException as e: logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) raise RuntimeError(e.__str__()) logger.info("Save checkpoint process finish.")
def _exec_save(ckpt_file_name, data_list): """Execute save checkpoint into file process.""" checkpoint_list = Checkpoint() try: with _ckpt_mutex: for name, value in data_list.items(): param_value = checkpoint_list.value.add() param_value.tag = name param_tensor = param_value.tensor param_tensor.dims.extend(value[0]) param_tensor.tensor_type = value[1] param_tensor.tensor_content = value[2].tostring() with open(ckpt_file_name, "wb") as f: f.write(checkpoint_list.SerializeToString()) os.chmod(ckpt_file_name, stat.S_IRUSR) except BaseException as e: logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) raise RuntimeError(e.__str__())