Beispiel #1
0
def _exec_save(ckpt_file_name, data_list):
    """Execute the process of saving checkpoint into file."""

    try:
        with _ckpt_mutex:
            if os.path.exists(ckpt_file_name):
                os.remove(ckpt_file_name)
            with open(ckpt_file_name, "ab") as f:
                for name, value in data_list.items():
                    data_size = value[2].nbytes
                    if data_size > SLICE_SIZE:
                        slice_count = math.ceil(data_size / SLICE_SIZE)
                        param_slice_list = np.array_split(
                            value[2], slice_count)
                    else:
                        param_slice_list = [value[2]]

                    for param_slice in param_slice_list:
                        checkpoint_list = Checkpoint()
                        param_value = checkpoint_list.value.add()
                        param_value.tag = name
                        param_tensor = param_value.tensor
                        param_tensor.dims.extend(value[0])
                        param_tensor.tensor_type = value[1]
                        param_tensor.tensor_content = param_slice.tobytes()

                        f.write(checkpoint_list.SerializeToString())

        os.chmod(ckpt_file_name, stat.S_IRUSR)

    except BaseException as e:
        logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
        raise e
Beispiel #2
0
def _load_single_param(ckpt_file_name, param_name):
    """Load a parameter from checkpoint."""
    logger.info("Execute the process of loading checkpoint files.")
    checkpoint_list = Checkpoint()

    try:
        with open(ckpt_file_name, "rb") as f:
            pb_content = f.read()
        checkpoint_list.ParseFromString(pb_content)
    except BaseException as e:
        logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
        raise ValueError(e.__str__())

    parameter = None
    try:
        param_data_list = []
        for element_id, element in enumerate(checkpoint_list.value):
            if element.tag != param_name:
                continue
            data = element.tensor.tensor_content
            data_type = element.tensor.tensor_type
            np_type = tensor_to_np_type[data_type]
            ms_type = tensor_to_ms_type[data_type]
            element_data = np.frombuffer(data, np_type)
            param_data_list.append(element_data)
            if (element_id == len(checkpoint_list.value) - 1) or \
                    (element.tag != checkpoint_list.value[element_id + 1].tag):
                param_data = np.concatenate((param_data_list), axis=0)
                param_data_list.clear()
                dims = element.tensor.dims
                if dims == [0]:
                    if 'Float' in data_type:
                        param_data = float(param_data[0])
                    elif 'Int' in data_type:
                        param_data = int(param_data[0])
                    parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
                elif dims == [1]:
                    parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
                else:
                    param_dim = []
                    for dim in dims:
                        param_dim.append(dim)
                    param_value = param_data.reshape(param_dim)
                    parameter = Parameter(Tensor(param_value, ms_type), name=element.tag)
            break
        logger.info("Loading checkpoint files process is finished.")

    except BaseException as e:
        logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
        raise RuntimeError(e.__str__())

    if parameter is None:
        raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, "
                         f"please check parameter name or checkpoint file.")
    return parameter
Beispiel #3
0
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
    """
    Saves checkpoint info to a specified file.

    Args:
        parameter_list (list): Parameters list, each element is a dict
                               like {"name":xx, "type":xx, "shape":xx, "data":xx}.
        ckpt_file_name (str): Checkpoint file name.
        model_type (str): The name of model type. Default: "normal".

    Raises:
        RuntimeError: Failed to save the Checkpoint file.
    """
    logger.info("Execute save checkpoint process.")
    checkpoint_list = Checkpoint()
    checkpoint_list.model_type = model_type

    try:
        for param in parameter_list:
            param_value = checkpoint_list.value.add()
            param_value.tag = param["name"]
            param_tensor = param_value.tensor
            if isinstance(param["data"], Parameter):
                param["data"].init_data()
            param_data = param["data"].asnumpy().reshape(-1)
            param_tensor.tensor_content = param_data.tostring()
            param_tensor.tensor_type = str(param["data"].dtype)

            if param['data'].shape == ():
                param_tensor.dims.append(0)
            else:
                for dim in param['data'].shape:
                    param_tensor.dims.append(dim)

        with open(ckpt_file_name, "wb") as f:
            f.write(checkpoint_list.SerializeToString())
        os.chmod(ckpt_file_name, stat.S_IRUSR)

    except BaseException as e:
        logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
        raise RuntimeError(e.__str__())
    logger.info("Save checkpoint process finish.")
Beispiel #4
0
def read_proto(file_name, proto_format="MINDIR", display_data=False):
    """
    Read protobuf file.

    Args:
        file_name (str): File name.
        proto_format (str): Proto format {MINDIR, ANF, CKPT}.  Default: MINDIR.
        display_data (bool): Whether display data. Default: False.

    Returns:
        Object, proto object.
    """

    if proto_format == "MINDIR":
        model = mindir_model()
    elif proto_format == "ANF":
        model = anf_model()
    elif proto_format == "CKPT":
        model = Checkpoint()
    else:
        raise ValueError("Unsupported proto format.")

    try:
        with open(file_name, "rb") as f:
            pb_content = f.read()
            model.ParseFromString(pb_content)
    except BaseException as e:
        logger.error(
            "Failed to read the file `%s`, please check the correct of the file.",
            file_name)
        raise ValueError(e.__str__())

    if proto_format == "MINDIR" and not display_data:
        for param_proto in model.graph.parameter:
            param_proto.raw_data = b'\0'

    if proto_format == "CKPT" and not display_data:
        for element in model.value:
            element.tensor.tensor_content = b'\0'

    return model
Beispiel #5
0
def _exec_save(ckpt_file_name, data_list):
    """Execute save checkpoint into file process."""
    checkpoint_list = Checkpoint()

    try:
        with _ckpt_mutex:
            for name, value in data_list.items():
                param_value = checkpoint_list.value.add()
                param_value.tag = name
                param_tensor = param_value.tensor
                param_tensor.dims.extend(value[0])
                param_tensor.tensor_type = value[1]
                param_tensor.tensor_content = value[2].tostring()

            with open(ckpt_file_name, "wb") as f:
                f.write(checkpoint_list.SerializeToString())
                os.chmod(ckpt_file_name, stat.S_IRUSR)

    except BaseException as e:
        logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
        raise RuntimeError(e.__str__())
Beispiel #6
0
def load_checkpoint(ckpt_file_name,
                    net=None,
                    strict_load=False,
                    filter_prefix=None):
    """
    Loads checkpoint info from a specified file.

    Args:
        ckpt_file_name (str): Checkpoint file name.
        net (Cell): Cell network. Default: None
        strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
                           in the param_dict into net with the same suffix. Default: False
        filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix
            will not be loaded. Default: None.

    Returns:
        Dict, key is parameter name, value is a Parameter.

    Raises:
        ValueError: Checkpoint file is incorrect.

    Examples:
        >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
        >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
    """
    if not isinstance(ckpt_file_name, str):
        raise ValueError("The ckpt_file_name must be string.")

    if not os.path.exists(ckpt_file_name):
        raise ValueError("The checkpoint file is not exist.")

    if ckpt_file_name[-5:] != ".ckpt":
        raise ValueError("Please input the correct checkpoint file name.")

    if os.path.getsize(ckpt_file_name) == 0:
        raise ValueError(
            "The checkpoint file may be empty, please make sure enter the correct file name."
        )

    if filter_prefix is not None:
        if not isinstance(filter_prefix, (str, list, tuple)):
            raise TypeError(
                f"The type of filter_prefix must be str, list[str] or tuple[str] "
                f"when filter_prefix is not None, but got {str(type(filter_prefix))}."
            )
        if isinstance(filter_prefix, str):
            filter_prefix = (filter_prefix, )
        if not filter_prefix:
            raise ValueError(
                "The filter_prefix can't be empty when filter_prefix is list or tuple."
            )
        for index, prefix in enumerate(filter_prefix):
            if not isinstance(prefix, str):
                raise TypeError(
                    f"The type of filter_prefix must be str, list[str] or tuple[str], "
                    f"but got {str(type(prefix))} at index {index}.")

    logger.info("Execute the process of loading checkpoint files.")
    checkpoint_list = Checkpoint()

    try:
        with open(ckpt_file_name, "rb") as f:
            pb_content = f.read()
        checkpoint_list.ParseFromString(pb_content)
    except BaseException as e:
        logger.error(
            "Failed to read the checkpoint file `%s`, please check the correct of the file.",
            ckpt_file_name)
        raise ValueError(e.__str__())

    parameter_dict = {}
    try:
        param_data_list = []
        for element_id, element in enumerate(checkpoint_list.value):
            if filter_prefix is not None and _check_param_prefix(
                    filter_prefix, element.tag):
                continue
            data = element.tensor.tensor_content
            data_type = element.tensor.tensor_type
            np_type = tensor_to_np_type[data_type]
            ms_type = tensor_to_ms_type[data_type]
            element_data = np.frombuffer(data, np_type)
            param_data_list.append(element_data)
            if (element_id == len(checkpoint_list.value) - 1) or \
                    (element.tag != checkpoint_list.value[element_id + 1].tag):
                param_data = np.concatenate((param_data_list), axis=0)
                param_data_list.clear()
                dims = element.tensor.dims

                if dims == [0]:
                    if 'Float' in data_type:
                        param_data = float(param_data[0])
                    elif 'Int' in data_type:
                        param_data = int(param_data[0])
                    parameter_dict[element.tag] = Parameter(Tensor(
                        param_data, ms_type),
                                                            name=element.tag)
                elif dims == [1]:
                    parameter_dict[element.tag] = Parameter(Tensor(
                        param_data, ms_type),
                                                            name=element.tag)
                else:
                    param_dim = []
                    for dim in dims:
                        param_dim.append(dim)
                    param_value = param_data.reshape(param_dim)
                    parameter_dict[element.tag] = Parameter(Tensor(
                        param_value, ms_type),
                                                            name=element.tag)

        logger.info("Loading checkpoint files process is finished.")

    except BaseException as e:
        logger.error("Failed to load the checkpoint file `%s`.",
                     ckpt_file_name)
        raise RuntimeError(e.__str__())

    if not parameter_dict:
        raise ValueError(
            f"The loaded parameter dict is empty after filtering, please check filter_prefix."
        )

    if net is not None:
        load_param_into_net(net, parameter_dict, strict_load)

    return parameter_dict
Beispiel #7
0
def load_checkpoint(ckpoint_file_name, net=None):
    """
    Loads checkpoint info from a specified file.

    Args:
        ckpoint_file_name (str): Checkpoint file name.
        net (Cell): Cell network. Default: None

    Returns:
        Dict, key is parameter name, value is a Parameter.

    Raises:
        ValueError: Checkpoint file is incorrect.
    """
    if not isinstance(ckpoint_file_name, str):
        raise ValueError("The ckpoint_file_name must be String.")

    if not os.path.exists(
            ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt":
        raise ValueError("Please input the correct checkpoint file name.")

    if os.path.getsize(ckpoint_file_name) == 0:
        raise ValueError(
            "The checkpoint file may be empty, please make sure enter the correct file name."
        )

    logger.info("Execute load checkpoint process.")
    checkpoint_list = Checkpoint()

    try:
        with open(ckpoint_file_name, "rb") as f:
            pb_content = f.read()
        checkpoint_list.ParseFromString(pb_content)
    except BaseException as e:
        logger.error(
            "Failed to read the checkpoint file %s, please check the correct of the file.",
            ckpoint_file_name)
        raise ValueError(e.__str__())

    parameter_dict = {}

    try:
        for element in checkpoint_list.value:
            data = element.tensor.tensor_content
            data_type = element.tensor.tensor_type
            np_type = tensor_to_np_type[data_type]
            ms_type = tensor_to_ms_type[data_type]
            param_data = np.fromstring(data, np_type)
            dims = element.tensor.dims

            if dims in [[0], [1]]:
                parameter_dict[element.tag] = Parameter(param_data[0],
                                                        name=element.tag)
            else:
                param_dim = []
                for dim in dims:
                    param_dim.append(dim)
                param_value = param_data.reshape(param_dim)
                parameter_dict[element.tag] = Parameter(Tensor(
                    param_value, ms_type),
                                                        name=element.tag)

        logger.info("Load checkpoint process finish.")

    except BaseException as e:
        logger.error("Failed to load the checkpoint file %s.",
                     ckpoint_file_name)
        raise RuntimeError(e.__str__())

    if net:
        load_param_into_net(net, parameter_dict)

    return parameter_dict
Beispiel #8
0
def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
    """
    Loads checkpoint info from a specified file.

    Args:
        ckpt_file_name (str): Checkpoint file name.
        model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
        net (Cell): Cell network. Default: None

    Returns:
        Dict, key is parameter name, value is a Parameter.

    Raises:
        ValueError: Checkpoint file is incorrect.
    """
    if not isinstance(ckpt_file_name, str):
        raise ValueError("The ckpt_file_name must be string.")

    if model_type not in ModelType:
        raise ValueError(f"The model_type is not in {ModelType}.")

    if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
        raise ValueError("Please input the correct checkpoint file name.")

    if os.path.getsize(ckpt_file_name) == 0:
        raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")

    logger.info("Execute load checkpoint process.")
    checkpoint_list = Checkpoint()

    try:
        with open(ckpt_file_name, "rb") as f:
            pb_content = f.read()
        checkpoint_list.ParseFromString(pb_content)
    except BaseException as e:
        logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
        raise ValueError(e.__str__())

    parameter_dict = {}
    if checkpoint_list.model_type:
        if model_type != checkpoint_list.model_type:
            raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
                checkpoint_list.model_type, model_type))
    try:
        for element in checkpoint_list.value:
            data = element.tensor.tensor_content
            data_type = element.tensor.tensor_type
            np_type = tensor_to_np_type[data_type]
            ms_type = tensor_to_ms_type[data_type]
            param_data = np.fromstring(data, np_type)
            dims = element.tensor.dims

            if dims == [0]:
                if 'Float' in data_type:
                    param_data = float(param_data[0])
                elif 'Int' in data_type:
                    param_data = int(param_data[0])
                parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
            elif dims == [1]:
                parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
            else:
                param_dim = []
                for dim in dims:
                    param_dim.append(dim)
                param_value = param_data.reshape(param_dim)
                parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)

        logger.info("Load checkpoint process finish.")

    except BaseException as e:
        logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
        raise RuntimeError(e.__str__())

    if net:
        load_param_into_net(net, parameter_dict)

    return parameter_dict