コード例 #1
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def InitEnv(env_proto):
    assert type(env_proto) is env_pb2.EnvProto
    env_proto_str = text_format.MessageToString(env_proto)
    error_str = oneflow_internal.InitEnv(env_proto_str)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #2
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def CurJobBuildAndInferCtx_HasJobConf():
    has_job_conf, error_str = oneflow_internal.CurJobBuildAndInferCtx_HasJobConf(
    )
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return has_job_conf
コード例 #3
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def CurJobBuildAndInferCtx_AddLossLogicalBlobName(lbn):
    lbn = str(lbn)
    error_str = oneflow_internal.CurJobBuildAndInferCtx_AddLossLogicalBlobName(
        lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #4
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def CurJobBuildAndInferCtx_SetJobConf(job_config_proto):
    serialized_job_conf = str(text_format.MessageToString(job_config_proto))
    error_str = oneflow_internal.CurJobBuildAndInferCtx_SetJobConf(
        serialized_job_conf)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #5
0
def DeviceType4DeviceTag(device_tag):
    device_tag = str(device_tag)
    device_type, error_str = oneflow_internal.DeviceType4DeviceTag(device_tag)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return device_type
コード例 #6
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def GetUserOpAttrType(op_type_name, attr_name):
    attr_type, error_str = oneflow_internal.GetUserOpAttrType(
        op_type_name, attr_name)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return attr_type
コード例 #7
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def GetInterUserJobInfo():
    inter_user_job_info, error_str = oneflow_internal.GetSerializedInterUserJobInfo(
    )
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(inter_user_job_info, InterUserJobInfo())
コード例 #8
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def JobBuildAndInferCtx_GetCurrentJobName():
    job_name, error_str = oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName(
    )
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return job_name
コード例 #9
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair(lbi_and_uuid):
    serialized = str(text_format.MessageToString(lbi_and_uuid))
    error_str = oneflow_internal.CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair(
        serialized)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #10
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def InitLazyGlobalSession(config_proto):
    assert type(config_proto) is job_set_pb.ConfigProto
    config_proto_str = text_format.MessageToString(config_proto)
    error_str = oneflow_internal.InitLazyGlobalSession(config_proto_str)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #11
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def RunLogicalInstruction(vm_instruction_list, eager_symbol_list):
    symbols = str(text_format.MessageToString(eager_symbol_list))
    error_str = oneflow_api.vm.RunLogicalInstruction(vm_instruction_list,
                                                     symbols)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #12
0
def GetOpParallelSymbolId(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    symbol_id, error_str = oneflow_internal.GetOpParallelSymbolId(serialized_op_conf)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return symbol_id
コード例 #13
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    add_and_infer = oneflow_internal.CurJobBuildAndInferCtx_AddAndInferMirroredOp
    op_attribute_str, error_str = add_and_infer(serialized_op_conf)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
コード例 #14
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def CheckAndCompleteUserOpConf(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    new_op_conf, error_str = oneflow_internal.CheckAndCompleteUserOpConf(
        serialized_op_conf)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(new_op_conf, op_conf_util.OperatorConf())
コード例 #15
0
def JobBuildAndInferCtx_IsTensorList(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    ret, error_str = oneflow_internal.JobBuildAndInferCtx_IsTensorList(job_name, lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return ret
コード例 #16
0
def JobBuildAndInferCtx_GetDataType(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    dtype, erro_str = oneflow_internal.JobBuildAndInferCtx_GetDataType(job_name, lbn)
    error = text_format.Parse(erro_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return int(dtype)
コード例 #17
0
def InferOpConf(op_conf_proto, upstream_signature):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    serialized_upstream_sig = str(text_format.MessageToString(upstream_signature))
    op_attribute_str, error_str = oneflow_internal.InferOpConf(
        serialized_op_conf, serialized_upstream_sig,
    )
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
コード例 #18
0
def JobBuildAndInferCtx_MirroredBlobDisableBoxing(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    # TODO(hanbinbin): this api dose not exist and will be del after confirmed
    ret, error_str = oneflow_internal.JobBuildAndInferCtx_MirroredBlobDisableBoxing(
        job_name, lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return ret
コード例 #19
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def GetMachine2DeviceIdListOFRecordFromParallelConf(parallel_conf):
    serialized_parallel_conf = str(parallel_conf)
    (
        ofrecord,
        error_str,
    ) = oneflow_internal.GetMachine2DeviceIdListOFRecordFromParallelConf(
        serialized_parallel_conf)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(ofrecord, record_util.OFRecord())
コード例 #20
0
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    GetParallelConf = (
        oneflow_internal.JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView
    )
    parallel_conf, error_str = GetParallelConf(job_name, lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(parallel_conf, placement_pb.ParallelConf())
コード例 #21
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def JobBuildAndInferCtx_GetBatchAxis(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    batch_axis_str, error_str = oneflow_internal.JobBuildAndInferCtx_GetBatchAxis(
        job_name, lbn)
    batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64())
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    if batch_axis.HasField("value"):
        return batch_axis.value
    return None
コード例 #22
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def JobBuildAndInferCtx_MirroredBlobGetStaticShape(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    get_shape = (
        oneflow_internal.
        JobBuildAndInferCtx_MirroredBlobGetSerializedIdListAsStaticShape)
    axis_str, error_str = get_shape(job_name, lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    int_list = text_format.Parse(axis_str, record_util.Int64List())
    return tuple(map(int, int_list.value))
コード例 #23
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def JobBuildAndInferCtx_MirroredBlobGetSubLbi(job_name, lbn, index):
    job_name = str(job_name)
    lbn = str(lbn)
    (
        ret,
        error_str,
    ) = oneflow_internal.JobBuildAndInferCtx_MirroredBlobGetSerializedSubLbi(
        job_name, lbn, index)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(ret, logical_blob_id_util.LogicalBlobId())
コード例 #24
0
def JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    (
        split_axis_str,
        error_str,
    ) = oneflow_internal.JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn)
    split_axis = text_format.Parse(split_axis_str, dtype_util.OptInt64())
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    if split_axis.HasField("value"):
        return split_axis.value
    return None
コード例 #25
0
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    GetParallelConf = (
        oneflow_internal.JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView
    )
    parallel_conf, error_str = GetParallelConf(job_name, lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    parallel_conf = text_format.Parse(parallel_conf, placement_pb.ParallelConf())
    # TODO(oyy) change temporary transformation after python code migrated into cpp code
    parallel_conf_cfg = placement_cfg.ParallelConf()
    parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
    for device_name in parallel_conf.device_name:
        parallel_conf_cfg.add_device_name(device_name)

    return parallel_conf_cfg
コード例 #26
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def DestroyEnv():
    error_str = oneflow_internal.DestroyEnv()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
コード例 #27
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def EnvResource():
    resource, error_str = oneflow_internal.EnvResource()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(resource, resource_util.Resource())
コード例 #28
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def GetStructureGraph():
    structure_graph, error_str = oneflow_internal.GetSerializedStructureGraph()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return structure_graph
コード例 #29
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def GetJobSet():
    job_set, error_str = oneflow_internal.GetSerializedJobSet()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(job_set, job_set_pb.JobSet())
コード例 #30
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def GetOpAttributes():
    op_attributes, error_str = oneflow_internal.GetSerializedOpAttributes()
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(op_attributes, op_attribute_pb.OpAttributeList())