Пример #1
0
def GetInterfaceBlobConf(job_name, lbn, blob_conf=None):
    assert isinstance(job_name, str)
    assert isinstance(lbn, str)
    if blob_conf is None:
        blob_conf = interface_blob_conf_pb.InterfaceBlobConf()
    else:
        assert isinstance(blob_conf, interface_blob_conf_pb.InterfaceBlobConf)

    shape = c_api_util.JobBuildAndInferCtx_GetStaticShape(job_name, lbn)
    dtype = c_api_util.JobBuildAndInferCtx_GetDataType(job_name, lbn)
    split_axis = c_api_util.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
        job_name, lbn)
    batch_axis = c_api_util.JobBuildAndInferCtx_GetBatchAxis(job_name, lbn)
    is_dynamic = c_api_util.JobBuildAndInferCtx_IsDynamic(job_name, lbn)
    is_tensor_list = c_api_util.JobBuildAndInferCtx_IsTensorList(job_name, lbn)

    blob_conf.shape.dim.extend(shape)
    blob_conf.data_type = dtype
    if split_axis is not None:
        blob_conf.split_axis.value = split_axis
    if batch_axis is not None:
        blob_conf.batch_axis.value = batch_axis
    blob_conf.is_dynamic = is_dynamic
    blob_conf.is_tensor_list = is_tensor_list
    return blob_conf
Пример #2
0
 def ToInterfaceBlobConf(self):
     interface_blob_conf = inter_face_blob_conf_util.InterfaceBlobConf()
     interface_blob_conf.shape.dim.extend(self.shape_)
     interface_blob_conf.data_type = self.dtype_.oneflow_proto_dtype
     interface_blob_conf.is_dynamic = self.is_dynamic
     interface_blob_conf.is_tensor_list = self.is_tensor_list
     self.SetBatchAxisAndSplitAxis(interface_blob_conf)
     return interface_blob_conf
Пример #3
0
 def ToInterfaceBlobConf(self):
     interface_blob_conf = inter_face_blob_conf_util.InterfaceBlobConf()
     interface_blob_conf.shape.dim.extend(self.shape_)
     interface_blob_conf.data_type = oneflow_api.deprecated.GetProtoDtype4OfDtype(
         self.dtype_)
     interface_blob_conf.is_dynamic = self.is_dynamic
     # NOTE(chengcheng): rm batch_axis, so set split_axis always = 0 for safe. will support
     #     set sbp in future, or will delete in multi-client
     interface_blob_conf.split_axis.value = 0
     return interface_blob_conf
Пример #4
0
 def ToInterfaceBlobConf(self):
     interface_blob_conf = inter_face_blob_conf_util.InterfaceBlobConf()
     interface_blob_conf.shape.dim.extend(self.shape_)
     interface_blob_conf.data_type = self.dtype_.oneflow_proto_dtype
     interface_blob_conf.is_dynamic = self.is_dynamic
     interface_blob_conf.is_tensor_list = self.is_tensor_list
     # NOTE(chengcheng): rm batch_axis, so set split_axis always = 0 for safe. will support
     #     set sbp in future, or will delete in multi-client
     interface_blob_conf.split_axis.value = 0
     return interface_blob_conf
Пример #5
0
 def ToInterfaceBlobConf(self):
     interface_blob_conf = inter_face_blob_conf_util.InterfaceBlobConf()
     interface_blob_conf.shape.dim.extend(self.shape_)
     interface_blob_conf.data_type = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
         self.dtype_
     )
     interface_blob_conf.is_dynamic = self.is_dynamic
     sbp_parallel = sbp_parallel_pb.SbpParallel()
     sbp_parallel.split_parallel.axis = 0
     interface_blob_conf.nd_sbp.sbp_parallel.extend([sbp_parallel])
     return interface_blob_conf
Пример #6
0
 def ToInterfaceBlobConf(self):
     interface_blob_conf = inter_face_blob_conf_util.InterfaceBlobConf()
     interface_blob_conf.shape.dim.extend(self.shape_)
     interface_blob_conf.data_type = oneflow_api.deprecated.GetProtoDtype4OfDtype(
         self.dtype_
     )
     interface_blob_conf.is_dynamic = self.is_dynamic
     # NOTE(chengcheng): rm batch_axis, so set split_axis always = 0 for safe. will support
     #     set sbp in future, or will delete in multi-client
     sbp_parallel = sbp_parallel_pb.SbpParallel()
     sbp_parallel.split_parallel.axis = 0
     interface_blob_conf.parallel_distribution.sbp_parallel.extend([sbp_parallel])
     return interface_blob_conf
Пример #7
0
def _GenModelIOPathInputOpConfAndRetLbi():
    op_conf = op_conf_util.OperatorConf()
    op_conf.name = "model_io_path_input"
    op_conf.device_tag = "cpu"
    op_conf.input_conf.out = "out"
    blob_conf = inter_face_blob_conf_util.InterfaceBlobConf()
    blob_conf.shape.dim.append(65536)
    blob_conf.data_type = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
        flow.int8)
    blob_conf.is_dynamic = True
    op_conf.input_conf.blob_conf.CopyFrom(blob_conf)
    lbi = logical_blob_id_util.LogicalBlobId()
    lbi.op_name = op_conf.name
    lbi.blob_name = op_conf.input_conf.out
    return (op_conf, lbi)
Пример #8
0
def _GenModelIOPathInputOpConfAndRetLbi():
    op_conf = op_conf_util.OperatorConf()
    op_conf.name = "model_io_path_input"
    op_conf.device_tag = "cpu"
    op_conf.input_conf.out = "out"

    blob_conf = inter_face_blob_conf_util.InterfaceBlobConf()
    blob_conf.shape.dim.append(65536)
    blob_conf.data_type = dtype_util.int8.oneflow_proto_dtype
    blob_conf.batch_axis.value = 0
    blob_conf.is_dynamic = True
    op_conf.input_conf.blob_conf.CopyFrom(blob_conf)

    lbi = logical_blob_id_util.LogicalBlobId()
    lbi.op_name = op_conf.name
    lbi.blob_name = op_conf.input_conf.out
    return op_conf, lbi
Пример #9
0
def GetInterfaceBlobConf(job_name, lbn, blob_conf=None):
    assert isinstance(job_name, str)
    assert isinstance(lbn, str)
    if blob_conf is None:
        blob_conf = interface_blob_conf_pb.InterfaceBlobConf()
    else:
        assert isinstance(blob_conf, interface_blob_conf_pb.InterfaceBlobConf)
    shape = c_api_util.JobBuildAndInferCtx_GetStaticShape(job_name, lbn)
    dtype = c_api_util.JobBuildAndInferCtx_GetDataType(job_name, lbn)
    split_axis = c_api_util.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
        job_name, lbn)
    is_dynamic = c_api_util.JobBuildAndInferCtx_IsDynamic(job_name, lbn)
    blob_conf.shape.dim.extend(shape)
    blob_conf.data_type = dtype
    if split_axis is not None:
        sbp_parallel = sbp_parallel_pb.SbpParallel()
        sbp_parallel.split_parallel.axis = split_axis
        blob_conf.nd_sbp.sbp_parallel.extend([sbp_parallel])
    blob_conf.is_dynamic = is_dynamic
    return blob_conf