コード例 #1
0
 def _get_op_blob_info(self, job_name, op_name, blob_name):
     self._check_status(self.SessionStatus.OPEN, self.SessionStatus.RUNNING)
     if op_name in self.inferface_name2info_:
         return self.inferface_name2info_[op_name]
     job_name = job_name or self.cur_job_name_
     if job_name is None:
         raise ValueError("please specify job_name")
     lbn = oneflow._oneflow_internal.JobBuildAndInferCtx_GetOpBlobLbn(
         job_name, op_name, blob_name)
     shape = c_api_util.JobBuildAndInferCtx_GetStaticShape(job_name, lbn)
     dtype = c_api_util.JobBuildAndInferCtx_GetDataType(job_name, lbn)
     dtype = dtype_util.convert_proto_dtype_to_oneflow_dtype(dtype)
     info = dict(shape=shape, dtype=dtype)
     self.inferface_name2info_[op_name] = info
     return info
コード例 #2
0
 def __init__(
     self,
     var_dir: str,
     dtype: Optional[oneflow.dtype] = None,
     shape: Optional[Sequence[int]] = None,
 ):
     data_path = os.path.join(var_dir, DATA_FILENAME)
     if not os.path.isfile(data_path):
         raise FileNotFoundError()
     self.var_dir_ = var_dir
     meta_info_path = os.path.join(self.var_dir_, META_INFO_FILENAME)
     if os.path.exists(meta_info_path):
         meta_info = variable_meta_info_pb.VariableMetaInfo()
         with open(meta_info_path) as f:
             text_format.Parse(f.read(), meta_info)
         self.has_meta_info_ = True
     else:
         self.has_meta_info_ = False
     if self.has_meta_info_:
         assert dtype is None and shape is None
         self.shape_ = tuple(meta_info.shape.dim)
         self.dtype_ = dtype_util.convert_proto_dtype_to_oneflow_dtype(
             meta_info.data_type)
     elif shape is not None and dtype is not None:
         self.shape_ = shape
         self.dtype_ = dtype
         self.has_meta_info_ = True
     elif shape is not None or dtype is not None:
         raise RuntimeError(
             "both or neither of shape and dtype should be None")
     else:
         pass
     if self.has_meta_info_:
         itemsize = np.dtype(
             dtype_util.convert_oneflow_dtype_to_numpy_dtype(
                 self.dtype_)).itemsize
         assert os.path.getsize(data_path) == np.prod(
             self.shape).item() * itemsize
コード例 #3
0
 def dtype(self):
     return convert_proto_dtype_to_oneflow_dtype(
         oneflow._oneflow_internal.Ofblob_GetDataType(self.of_blob_ptr_)
     )