コード例 #1
0
def GetInterfaceBlobConf(job_name, lbn, blob_conf=None):
    assert isinstance(job_name, str)
    assert isinstance(lbn, str)
    if blob_conf is None:
        blob_conf = interface_blob_conf_pb.InterfaceBlobConf()
    else:
        assert isinstance(blob_conf, interface_blob_conf_pb.InterfaceBlobConf)

    shape = c_api_util.JobBuildAndInferCtx_GetStaticShape(job_name, lbn)
    dtype = c_api_util.JobBuildAndInferCtx_GetDataType(job_name, lbn)
    split_axis = c_api_util.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
        job_name, lbn)
    batch_axis = c_api_util.JobBuildAndInferCtx_GetBatchAxis(job_name, lbn)
    is_dynamic = c_api_util.JobBuildAndInferCtx_IsDynamic(job_name, lbn)
    is_tensor_list = c_api_util.JobBuildAndInferCtx_IsTensorList(job_name, lbn)

    blob_conf.shape.dim.extend(shape)
    blob_conf.data_type = dtype
    if split_axis is not None:
        blob_conf.split_axis.value = split_axis
    if batch_axis is not None:
        blob_conf.batch_axis.value = batch_axis
    blob_conf.is_dynamic = is_dynamic
    blob_conf.is_tensor_list = is_tensor_list
    return blob_conf
コード例 #2
0
ファイル: inference_session.py プロジェクト: strint/oneflow
    def _make_pull_job_cb(self, output_name, user_job_name, future):
        output_lbn = oneflow_api.JobBuildAndInferCtx_GetOpBlobLbn(
            user_job_name, output_name, "out")
        split_axis = c_api_util.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
            user_job_name, output_lbn)

        def pull_fn(ofblob):
            ndarray = ofblob.CopyToNdarray()
            self.event_loop_.call_soon_threadsafe(future.set_result, ndarray)

        return pull_fn
コード例 #3
0
    def _make_pull_job_cb(self, output_name, future):
        output_lbn = self.inferface_name2lbn_[output_name]
        job_name = self.inferface_name2job_name_[output_name]
        split_axis = c_api_util.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
            job_name, output_lbn
        )
        loop = asyncio.get_event_loop()

        def pull_fn(ofblob):
            ndarray_lists = ofblob.CopyToNdarrayLists()
            assert len(ndarray_lists) == 1
            ndarray_list = ndarray_lists[0]
            if len(ndarray_list) == 1:
                loop.call_soon_threadsafe(future.set_result, ndarray_list[0])
            else:
                assert split_axis is not None
                pull_result = np.concatenate(ndarray_list, axis=split_axis)
                loop.call_soon_threadsafe(future.set_result, pull_result)

        return pull_fn
コード例 #4
0
    def _make_pull_job_cb(self, output_name, user_job_name, future):
        output_lbn = oneflow_api.JobBuildAndInferCtx_GetOpBlobLbn(
            user_job_name, output_name, "out")
        split_axis = c_api_util.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
            user_job_name, output_lbn)

        def pull_fn(ofblob):
            ndarray_lists = ofblob.CopyToNdarrayLists()
            assert len(ndarray_lists) == 1
            ndarray_list = ndarray_lists[0]
            if len(ndarray_list) == 1:
                self.event_loop_.call_soon_threadsafe(future.set_result,
                                                      ndarray_list[0])
            else:
                assert split_axis is not None
                pull_result = np.concatenate(ndarray_list, axis=split_axis)
                self.event_loop_.call_soon_threadsafe(future.set_result,
                                                      pull_result)

        return pull_fn
コード例 #5
0
def GetInterfaceBlobConf(job_name, lbn, blob_conf=None):
    assert isinstance(job_name, str)
    assert isinstance(lbn, str)
    if blob_conf is None:
        blob_conf = interface_blob_conf_pb.InterfaceBlobConf()
    else:
        assert isinstance(blob_conf, interface_blob_conf_pb.InterfaceBlobConf)

    shape = c_api_util.JobBuildAndInferCtx_GetStaticShape(job_name, lbn)
    dtype = c_api_util.JobBuildAndInferCtx_GetDataType(job_name, lbn)
    split_axis = c_api_util.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
        job_name, lbn)
    is_dynamic = c_api_util.JobBuildAndInferCtx_IsDynamic(job_name, lbn)

    blob_conf.shape.dim.extend(shape)
    blob_conf.data_type = dtype
    if split_axis is not None:
        sbp_parallel = sbp_parallel_pb.SbpParallel()
        sbp_parallel.split_parallel.axis = split_axis
        blob_conf.parallel_distribution.sbp_parallel.extend([sbp_parallel])

    blob_conf.is_dynamic = is_dynamic
    return blob_conf
コード例 #6
0
ファイル: remote_blob.py プロジェクト: Sodu-Qinming/Oneflow
 def split_axis(self):
     return c_api_util.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
         self.job_name_, self.lbn_
     )