Exemplo n.º 1
0
def GetInterfaceBlobConf(job_name, lbn, blob_conf=None):
    assert isinstance(job_name, str)
    assert isinstance(lbn, str)
    if blob_conf is None:
        blob_conf = interface_blob_conf_pb.InterfaceBlobConf()
    else:
        assert isinstance(blob_conf, interface_blob_conf_pb.InterfaceBlobConf)

    shape = c_api_util.JobBuildAndInferCtx_GetStaticShape(job_name, lbn)
    dtype = c_api_util.JobBuildAndInferCtx_GetDataType(job_name, lbn)
    split_axis = c_api_util.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
        job_name, lbn)
    batch_axis = c_api_util.JobBuildAndInferCtx_GetBatchAxis(job_name, lbn)
    is_dynamic = c_api_util.JobBuildAndInferCtx_IsDynamic(job_name, lbn)
    is_tensor_list = c_api_util.JobBuildAndInferCtx_IsTensorList(job_name, lbn)

    blob_conf.shape.dim.extend(shape)
    blob_conf.data_type = dtype
    if split_axis is not None:
        blob_conf.split_axis.value = split_axis
    if batch_axis is not None:
        blob_conf.batch_axis.value = batch_axis
    blob_conf.is_dynamic = is_dynamic
    blob_conf.is_tensor_list = is_tensor_list
    return blob_conf
Exemplo n.º 2
0
 def shape(self):
     if oneflow.scope.mirrored_view_enabled():
         print(
             "WARNING:",
             "You access a consistent blob shape in mirrored view, there may be problems,",
             "you should add 'x = flow.cast_to_current_logical_view(x)'.",
             file=sys.stderr,
         )
         print(traceback.format_stack()[-2])
     return c_api_util.JobBuildAndInferCtx_GetStaticShape(self.job_name_, self.lbn_)
Exemplo n.º 3
0
    def _get_op_blob_info(self, job_name, op_name, blob_name):
        self._check_status(self.SessionStatus.OPEN, self.SessionStatus.RUNNING)
        if op_name in self.inferface_name2info_:
            return self.inferface_name2info_[op_name]

        job_name = job_name or self.cur_job_name_
        if job_name is None:
            raise ValueError("please specify job_name")

        lbn = oneflow_api.JobBuildAndInferCtx_GetOpBlobLbn(job_name, op_name, blob_name)
        shape = c_api_util.JobBuildAndInferCtx_GetStaticShape(job_name, lbn)
        dtype = c_api_util.JobBuildAndInferCtx_GetDataType(job_name, lbn)
        dtype = dtype_util.convert_proto_dtype_to_oneflow_dtype(dtype)
        # TODO: other info
        info = dict(shape=shape, dtype=dtype)
        self.inferface_name2info_[op_name] = info
        return info
Exemplo n.º 4
0
    def input_info(self, input_name):
        self._check_status(self.SessionStatus.RUNNING)
        if input_name in self.inferface_name2info_:
            return self.inferface_name2info_[input_name]

        input_lbn = "{}/out".format(input_name)
        if input_name not in self.inferface_name2job_name_:
            raise ValueError('can not find input with name "{}"'.format(input_name))

        job_name = self.inferface_name2job_name_[input_name]
        input_shape = c_api_util.JobBuildAndInferCtx_GetStaticShape(job_name, input_lbn)
        input_dtype = c_api_util.JobBuildAndInferCtx_GetDataType(job_name, input_lbn)
        input_dtype = dtype_util.convert_proto_dtype_to_oneflow_dtype(input_dtype)
        # input_dtype = dtype_util.convert_oneflow_dtype_to_numpy_dtype(input_dtype)
        # TODO: other info
        input_info = dict(shape=input_shape, dtype=input_dtype)
        self.inferface_name2info_[input_name] = input_info
        return input_info
Exemplo n.º 5
0
def GetInterfaceBlobConf(job_name, lbn, blob_conf=None):
    assert isinstance(job_name, str)
    assert isinstance(lbn, str)
    if blob_conf is None:
        blob_conf = interface_blob_conf_pb.InterfaceBlobConf()
    else:
        assert isinstance(blob_conf, interface_blob_conf_pb.InterfaceBlobConf)

    shape = c_api_util.JobBuildAndInferCtx_GetStaticShape(job_name, lbn)
    dtype = c_api_util.JobBuildAndInferCtx_GetDataType(job_name, lbn)
    split_axis = c_api_util.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
        job_name, lbn)
    is_dynamic = c_api_util.JobBuildAndInferCtx_IsDynamic(job_name, lbn)

    blob_conf.shape.dim.extend(shape)
    blob_conf.data_type = dtype
    if split_axis is not None:
        sbp_parallel = sbp_parallel_pb.SbpParallel()
        sbp_parallel.split_parallel.axis = split_axis
        blob_conf.parallel_distribution.sbp_parallel.extend([sbp_parallel])

    blob_conf.is_dynamic = is_dynamic
    return blob_conf