コード例 #1
0
ファイル: check_point_v2.py プロジェクト: zzk0/oneflow
def _SaveVarDict(
    path: str,
    var_dict: Optional[Dict[str, Union[FileBackendVariableBlob,
                                       EagerBlobTrait]]] = None,
) -> None:
    if var_dict is None:
        var_dict = GetAllVariables()

    def IsFileOrNonEmptyDir(path):
        if os.path.isfile(path):
            return True
        if os.path.isdir(path) and len(os.listdir(path)) != 0:
            return True
        return False

    assert not IsFileOrNonEmptyDir(
        path
    ), "{} is a file or non-empty directory! Note that flow.save is different from torch.save. It saves each weight as a separated file so that a directory instead of a file should be given.".format(
        path)
    os.makedirs(path, exist_ok=True)
    for (name, var) in var_dict.items():
        meta_info = variable_meta_info_pb.VariableMetaInfo()
        meta_info.shape.dim[:] = var.shape
        meta_info.data_type = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
            var.dtype)
        var_dir = os.path.join(path, name)
        param_path = os.path.join(var_dir, DATA_FILENAME)
        os.makedirs(os.path.dirname(param_path))
        with open(param_path, "wb") as f:
            for (_, _, slice) in _ReadSlice(var):
                f.write(slice.tobytes())
        with open(os.path.join(var_dir, META_INFO_FILENAME), "w") as f:
            f.write(text_format.MessageToString(meta_info))
    with open(os.path.join(path, "snapshot_done"), "w"):
        pass
コード例 #2
0
def _save_tensor_to_disk(tensor: "oneflow.Tensor",
                         dir_name: Union[str, Path]) -> None:
    os.makedirs(dir_name, exist_ok=True)
    meta_info = variable_meta_info_pb.VariableMetaInfo()
    meta_info.shape.dim[:] = tensor.shape
    meta_info.data_type = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
        tensor.dtype)
    data_path = os.path.join(dir_name, DATA_FILENAME)
    with open(data_path, "wb") as f:
        f.write(tensor.numpy().tobytes())

    with open(os.path.join(dir_name, META_INFO_FILENAME), "w") as f:
        f.write(text_format.MessageToString(meta_info))
コード例 #3
0
ファイル: check_point_v2.py プロジェクト: liudyboy/oneflow
    def __init__(
        self,
        var_dir: str,
        dtype: Optional[dtype_util.dtype] = None,
        shape: Optional[Sequence[int]] = None,
    ):
        data_path = os.path.join(var_dir, DATA_FILENAME)
        assert os.path.isfile(data_path)
        self.var_dir_ = var_dir
        meta_info_path = os.path.join(self.var_dir_, META_INFO_FILENAME)
        if os.path.exists(meta_info_path):
            meta_info = variable_meta_info_pb.VariableMetaInfo()
            with open(meta_info_path) as f:
                text_format.Parse(f.read(), meta_info)
            self.has_meta_info_ = True
        else:
            self.has_meta_info_ = False

        if self.has_meta_info_:
            assert dtype is None and shape is None
            self.shape_ = tuple(meta_info.shape.dim)
            self.dtype_ = dtype_util.convert_proto_dtype_to_oneflow_dtype(
                meta_info.data_type)
        else:
            if shape is not None and dtype is not None:
                self.shape_ = shape
                self.dtype_ = dtype
                self.has_meta_info_ = True
            elif shape is not None or dtype is not None:
                raise RuntimeError(
                    "both or neither of shape and dtype should be None")
            else:
                pass

        if self.has_meta_info_:
            itemsize = np.dtype(
                dtype_util.convert_oneflow_dtype_to_numpy_dtype(
                    self.dtype_)).itemsize
            assert os.path.getsize(data_path) == np.prod(
                self.shape).item() * itemsize
コード例 #4
0
def SaveVarDict(
    path: str,
    var_dict: Optional[Dict[str, Union[FileBackendVariableBlob,
                                       EagerBlobTrait]]] = None,
) -> None:
    """
    Save `var_dict` to `path`
    """
    oneflow.sync_default_session()

    if var_dict is None:
        var_dict = GetAllVariables()

    def IsFileOrNonEmptyDir(path):
        if os.path.isfile(path):
            return True
        if os.path.isdir(path) and len(os.listdir(path)) != 0:
            return True
        return False

    assert not IsFileOrNonEmptyDir(
        path), "Non-empty directory {} already exists!".format(path)
    os.makedirs(path, exist_ok=True)
    for name, var in var_dict.items():
        meta_info = variable_meta_info_pb.VariableMetaInfo()
        meta_info.shape.dim[:] = var.shape
        meta_info.data_type = oneflow_api.deprecated.GetProtoDtype4OfDtype(
            var.dtype)
        var_dir = os.path.join(path, name)
        param_path = os.path.join(var_dir, DATA_FILENAME)
        os.makedirs(os.path.dirname(param_path))
        with open(param_path, "wb") as f:
            for _, _, slice in _ReadSlice(var):
                f.write(slice.tobytes())
        with open(os.path.join(var_dir, META_INFO_FILENAME), "w") as f:
            f.write(text_format.MessageToString(meta_info))
    # write a empty file 'snapshot_done', indicating that
    # the save process finishes normally
    with open(os.path.join(path, "snapshot_done"), "w"):
        pass