def _get_op_blob_info(self, job_name, op_name, blob_name):
        self._check_status(self.SessionStatus.OPEN, self.SessionStatus.RUNNING)
        if op_name in self.inferface_name2info_:
            return self.inferface_name2info_[op_name]

        job_name = job_name or self.cur_job_name_
        if job_name is None:
            raise ValueError("please specify job_name")

        lbn = oneflow_api.JobBuildAndInferCtx_GetOpBlobLbn(job_name, op_name, blob_name)
        shape = c_api_util.JobBuildAndInferCtx_GetStaticShape(job_name, lbn)
        dtype = c_api_util.JobBuildAndInferCtx_GetDataType(job_name, lbn)
        dtype = dtype_util.convert_proto_dtype_to_oneflow_dtype(dtype)
        # TODO: other info
        info = dict(shape=shape, dtype=dtype)
        self.inferface_name2info_[op_name] = info
        return info
Beispiel #2
0
    def input_info(self, input_name):
        self._check_status(self.SessionStatus.RUNNING)
        if input_name in self.inferface_name2info_:
            return self.inferface_name2info_[input_name]

        input_lbn = "{}/out".format(input_name)
        if input_name not in self.inferface_name2job_name_:
            raise ValueError('can not find input with name "{}"'.format(input_name))

        job_name = self.inferface_name2job_name_[input_name]
        input_shape = c_api_util.JobBuildAndInferCtx_GetStaticShape(job_name, input_lbn)
        input_dtype = c_api_util.JobBuildAndInferCtx_GetDataType(job_name, input_lbn)
        input_dtype = dtype_util.convert_proto_dtype_to_oneflow_dtype(input_dtype)
        # input_dtype = dtype_util.convert_oneflow_dtype_to_numpy_dtype(input_dtype)
        # TODO: other info
        input_info = dict(shape=input_shape, dtype=input_dtype)
        self.inferface_name2info_[input_name] = input_info
        return input_info
Beispiel #3
0
    def __init__(
        self,
        var_dir: str,
        dtype: Optional[dtype_util.dtype] = None,
        shape: Optional[Sequence[int]] = None,
    ):
        data_path = os.path.join(var_dir, DATA_FILENAME)
        assert os.path.isfile(data_path)
        self.var_dir_ = var_dir
        meta_info_path = os.path.join(self.var_dir_, META_INFO_FILENAME)
        if os.path.exists(meta_info_path):
            meta_info = variable_meta_info_pb.VariableMetaInfo()
            with open(meta_info_path) as f:
                text_format.Parse(f.read(), meta_info)
            self.has_meta_info_ = True
        else:
            self.has_meta_info_ = False

        if self.has_meta_info_:
            assert dtype is None and shape is None
            self.shape_ = tuple(meta_info.shape.dim)
            self.dtype_ = dtype_util.convert_proto_dtype_to_oneflow_dtype(
                meta_info.data_type)
        else:
            if shape is not None and dtype is not None:
                self.shape_ = shape
                self.dtype_ = dtype
                self.has_meta_info_ = True
            elif shape is not None or dtype is not None:
                raise RuntimeError(
                    "both or neither of shape and dtype should be None")
            else:
                pass

        if self.has_meta_info_:
            itemsize = np.dtype(
                dtype_util.convert_oneflow_dtype_to_numpy_dtype(
                    self.dtype_)).itemsize
            assert os.path.getsize(data_path) == np.prod(
                self.shape).item() * itemsize
Beispiel #4
0
 def dtype(self):
     return convert_proto_dtype_to_oneflow_dtype(
         self.blob_desc_.body.data_type)
Beispiel #5
0
def dtype(self):
    return convert_proto_dtype_to_oneflow_dtype(self.get_dtype())
Beispiel #6
0
 def dtype(self):
     return convert_proto_dtype_to_oneflow_dtype(
         oneflow_api.Ofblob_GetDataType(self.of_blob_ptr_))
Beispiel #7
0
 def dtype(self):
     return convert_proto_dtype_to_oneflow_dtype(
         c_api_util.JobBuildAndInferCtx_MirroredBlobGetDataType(
             self.job_name_, self.lbn_
         )
     )
Beispiel #8
0
def dtype(self):
    ret = convert_proto_dtype_to_oneflow_dtype(self.get_dtype())
    assert issubclass(ret, dtype_util.dtype)
    return ret
Beispiel #9
0
def dtype(self):
    ret = convert_proto_dtype_to_oneflow_dtype(self.get_dtype())
    assert isinstance(ret, oneflow.dtype)
    return ret
Beispiel #10
0
 def dtype(self):
     return convert_proto_dtype_to_oneflow_dtype(
         c_api_util.JobBuildAndInferCtx_GetDataType(self.job_name_,
                                                    self.logical_blob_name))