コード例 #1
0
ファイル: c_api_util.py プロジェクト: zhouyuegit/oneflow
def CurJobBuildAndInferCtx_SetTrainConf(train_config_proto):
    serialized_train_conf = str(
        text_format.MessageToString(train_config_proto))
    error_str = oneflow_internal.CurJobBuildAndInferCtx_SetTrainConf(
        serialized_train_conf)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #2
0
ファイル: c_api_util.py プロジェクト: zhouyuegit/oneflow
def GetOpParallelSymbolId(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    symbol_id, error_str = oneflow_internal.GetOpParallelSymbolId(
        serialized_op_conf)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return symbol_id
コード例 #3
0
ファイル: c_api_util.py プロジェクト: qianrenjian/oneflow
def CheckAndCompleteUserOpConf(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    new_op_conf, error_str = oneflow_internal.CheckAndCompleteUserOpConf(
        serialized_op_conf)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(new_op_conf, op_conf_util.OperatorConf())
コード例 #4
0
ファイル: c_api_util.py プロジェクト: qianrenjian/oneflow
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    add_and_infer = oneflow_internal.CurJobBuildAndInferCtx_AddAndInferMirroredOp
    op_attribute_str, error_str = add_and_infer(serialized_op_conf)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
コード例 #5
0
def CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair(lbi_and_uuid):
    serialized = str(text_format.MessageToString(lbi_and_uuid))
    error_str = oneflow_internal.CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair(
        serialized
    )
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #6
0
def JobBuildAndInferCtx_IsTensorList(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    ret, error_str = oneflow_internal.JobBuildAndInferCtx_IsTensorList(job_name, lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return ret
コード例 #7
0
def JobBuildAndInferCtx_GetDataType(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    dtype, erro_str = oneflow_internal.JobBuildAndInferCtx_GetDataType(job_name, lbn)
    error = text_format.Parse(erro_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return int(dtype)
コード例 #8
0
def InferOpConf(op_conf_proto, upstream_signature):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    serialized_upstream_sig = str(text_format.MessageToString(upstream_signature))
    op_attribute_str, error_str = oneflow_internal.InferOpConf(
        serialized_op_conf, serialized_upstream_sig,
    )
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
コード例 #9
0
def JobBuildAndInferCtx_MirroredBlobDisableBoxing(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    # TODO(hanbinbin): this api dose not exist and will be del after confirmed
    ret, error_str = oneflow_internal.JobBuildAndInferCtx_MirroredBlobDisableBoxing(
        job_name, lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return ret
コード例 #10
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def GetMachine2DeviceIdListOFRecordFromParallelConf(parallel_conf):
    serialized_parallel_conf = str(parallel_conf)
    (
        ofrecord,
        error_str,
    ) = oneflow_internal.GetMachine2DeviceIdListOFRecordFromParallelConf(
        serialized_parallel_conf)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(ofrecord, record_util.OFRecord())
コード例 #11
0
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    GetParallelConf = (
        oneflow_internal.JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView
    )
    parallel_conf, error_str = GetParallelConf(job_name, lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(parallel_conf, placement_pb.ParallelConf())
コード例 #12
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def JobBuildAndInferCtx_GetBatchAxis(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    batch_axis_str, error_str = oneflow_internal.JobBuildAndInferCtx_GetBatchAxis(
        job_name, lbn)
    batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64())
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    if batch_axis.HasField("value"):
        return batch_axis.value
    return None
コード例 #13
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def JobBuildAndInferCtx_MirroredBlobGetStaticShape(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    get_shape = (
        oneflow_internal.
        JobBuildAndInferCtx_MirroredBlobGetSerializedIdListAsStaticShape)
    axis_str, error_str = get_shape(job_name, lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    int_list = text_format.Parse(axis_str, record_util.Int64List())
    return tuple(map(int, int_list.value))
コード例 #14
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def JobBuildAndInferCtx_MirroredBlobGetSubLbi(job_name, lbn, index):
    job_name = str(job_name)
    lbn = str(lbn)
    (
        ret,
        error_str,
    ) = oneflow_internal.JobBuildAndInferCtx_MirroredBlobGetSerializedSubLbi(
        job_name, lbn, index)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(ret, logical_blob_id_util.LogicalBlobId())
コード例 #15
0
def JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    (
        split_axis_str,
        error_str,
    ) = oneflow_internal.JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn)
    split_axis = text_format.Parse(split_axis_str, dtype_util.OptInt64())
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    if split_axis.HasField("value"):
        return split_axis.value
    return None
コード例 #16
0
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    GetParallelConf = (
        oneflow_internal.JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView
    )
    parallel_conf, error_str = GetParallelConf(job_name, lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    parallel_conf = text_format.Parse(parallel_conf, placement_pb.ParallelConf())
    # TODO(oyy) change temporary transformation after python code migrated into cpp code
    parallel_conf_cfg = placement_cfg.ParallelConf()
    parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
    for device_name in parallel_conf.device_name:
        parallel_conf_cfg.add_device_name(device_name)

    return parallel_conf_cfg
コード例 #17
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def IsOpTypeNameCpuSupportOnly(op_type_name):
    ret, error_str = oneflow_internal.IsOpTypeNameCpuSupportOnly(op_type_name)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return ret
コード例 #18
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def RegisterWatcherOnlyOnce(watcher):
    error_str = oneflow_internal.RegisterWatcherOnlyOnce(watcher)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #19
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def EnvResource():
    resource, error_str = oneflow_internal.EnvResource()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(resource, resource_util.Resource())
コード例 #20
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def GetJobSet():
    job_set, error_str = oneflow_internal.GetSerializedJobSet()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(job_set, job_set_pb.JobSet())
コード例 #21
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def NewPhysicalSymbolId():
    object_id, error_str = oneflow_internal.NewPhysicalSymbolId()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return object_id
コード例 #22
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def GetScopeConfigDef():
    scope_config_def, error_str = oneflow_internal.GetScopeConfigDef()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(scope_config_def, ConfigDef())
コード例 #23
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def CurrentMachineId():
    machine_id, error_str = oneflow_internal.CurrentMachineId()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return machine_id
コード例 #24
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def StopLazyGlobalSession():
    error_str = oneflow_internal.StopLazyGlobalSession()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #25
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def GetOpAttributes():
    op_attributes, error_str = oneflow_internal.GetSerializedOpAttributes()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(op_attributes, op_attribute_pb.OpAttributeList())
コード例 #26
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def LaunchJob(job_instance):
    error_str = oneflow_internal.LaunchJob(job_instance)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #27
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def GetStructureGraph():
    structure_graph, error_str = oneflow_internal.GetSerializedStructureGraph()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return structure_graph
コード例 #28
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def JobBuildAndInferCtx_Open(job_name):
    job_name = str(job_name)
    error_str = oneflow_internal.JobBuildAndInferCtx_Open(job_name)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #29
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def DestroyEnv():
    error_str = oneflow_internal.DestroyEnv()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #30
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def CurJobBuildAndInferCtx_CheckJob():
    error_str = oneflow_internal.CurJobBuildAndInferCtx_CheckJob()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)