예제 #1
0
def _exec_save(ckpt_file_name, data_list):
    """Execute the process of saving checkpoint into file."""

    try:
        with _ckpt_mutex:
            if os.path.exists(ckpt_file_name):
                os.remove(ckpt_file_name)
            with open(ckpt_file_name, "ab") as f:
                for name, value in data_list.items():
                    data_size = value[2].nbytes
                    if data_size > SLICE_SIZE:
                        slice_count = math.ceil(data_size / SLICE_SIZE)
                        param_slice_list = np.array_split(
                            value[2], slice_count)
                    else:
                        param_slice_list = [value[2]]

                    for param_slice in param_slice_list:
                        checkpoint_list = Checkpoint()
                        param_value = checkpoint_list.value.add()
                        param_value.tag = name
                        param_tensor = param_value.tensor
                        param_tensor.dims.extend(value[0])
                        param_tensor.tensor_type = value[1]
                        param_tensor.tensor_content = param_slice.tobytes()

                        f.write(checkpoint_list.SerializeToString())

        os.chmod(ckpt_file_name, stat.S_IRUSR)

    except BaseException as e:
        logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
        raise e
예제 #2
0
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
    """
    Saves checkpoint info to a specified file.

    Args:
        parameter_list (list): Parameters list, each element is a dict
                               like {"name":xx, "type":xx, "shape":xx, "data":xx}.
        ckpt_file_name (str): Checkpoint file name.
        model_type (str): The name of model type. Default: "normal".

    Raises:
        RuntimeError: Failed to save the Checkpoint file.
    """
    logger.info("Execute save checkpoint process.")
    checkpoint_list = Checkpoint()
    checkpoint_list.model_type = model_type

    try:
        for param in parameter_list:
            param_value = checkpoint_list.value.add()
            param_value.tag = param["name"]
            param_tensor = param_value.tensor
            if isinstance(param["data"], Parameter):
                param["data"].init_data()
            param_data = param["data"].asnumpy().reshape(-1)
            param_tensor.tensor_content = param_data.tostring()
            param_tensor.tensor_type = str(param["data"].dtype)

            if param['data'].shape == ():
                param_tensor.dims.append(0)
            else:
                for dim in param['data'].shape:
                    param_tensor.dims.append(dim)

        with open(ckpt_file_name, "wb") as f:
            f.write(checkpoint_list.SerializeToString())
        os.chmod(ckpt_file_name, stat.S_IRUSR)

    except BaseException as e:
        logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
        raise RuntimeError(e.__str__())
    logger.info("Save checkpoint process finish.")
예제 #3
0
def _exec_save(ckpt_file_name, data_list):
    """Execute save checkpoint into file process."""
    checkpoint_list = Checkpoint()

    try:
        with _ckpt_mutex:
            for name, value in data_list.items():
                param_value = checkpoint_list.value.add()
                param_value.tag = name
                param_tensor = param_value.tensor
                param_tensor.dims.extend(value[0])
                param_tensor.tensor_type = value[1]
                param_tensor.tensor_content = value[2].tostring()

            with open(ckpt_file_name, "wb") as f:
                f.write(checkpoint_list.SerializeToString())
                os.chmod(ckpt_file_name, stat.S_IRUSR)

    except BaseException as e:
        logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
        raise RuntimeError(e.__str__())