コード例 #1
0
def JobBuildAndInferCtx_IsTensorList(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    ret, error = oneflow_api.JobBuildAndInferCtx_IsTensorList(job_name, lbn)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return ret
コード例 #2
0
def JobBuildAndInferCtx_DisableBoxing(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    ret, error = oneflow_api.JobBuildAndInferCtx_DisableBoxing(job_name, lbn)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return ret
コード例 #3
0
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    add_and_infer = oneflow_api.CurJobBuildAndInferCtx_AddAndInferMirroredOp
    op_attribute_str, error = add_and_infer(serialized_op_conf)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
コード例 #4
0
def CheckAndCompleteUserOpConf(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    new_op_conf, error = oneflow_api.CheckAndCompleteUserOpConf(
        serialized_op_conf)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return text_format.Parse(new_op_conf, op_conf_util.OperatorConf())
コード例 #5
0
def JobBuildAndInferCtx_GetDataType(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    dtype, error = oneflow_api.JobBuildAndInferCtx_GetDataType(job_name, lbn)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return int(dtype)
コード例 #6
0
def CurJobBuildAndInferCtx_SetTrainConf(train_config_proto):
    serialized_train_conf = str(
        text_format.MessageToString(train_config_proto))
    error = oneflow_api.CurJobBuildAndInferCtx_SetTrainConf(
        serialized_train_conf)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
コード例 #7
0
def JobBuildAndInferCtx_MirroredBlobIsDynamic(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    ret, error = oneflow_api.JobBuildAndInferCtx_MirroredBlobIsDynamic(
        job_name, lbn)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return ret
コード例 #8
0
def IsInterfaceOpConf(op_conf):
    op_type_field = op_conf.WhichOneof("op_type")
    field_number = op_conf_util.OperatorConf.DESCRIPTOR.fields_by_name[
        op_type_field].number
    res, error = oneflow_api.IsInterfaceOpTypeCase(field_number)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return res
コード例 #9
0
def GetMachine2DeviceIdListOFRecordFromParallelConf(parallel_conf):
    serialized_parallel_conf = str(parallel_conf)
    (ofrecord,
     error) = oneflow_api.GetMachine2DeviceIdListOFRecordFromParallelConf(
         serialized_parallel_conf)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return text_format.Parse(ofrecord, record_util.OFRecord())
コード例 #10
0
def JobBuildAndInferCtx_MirroredBlobGetSubLbi(job_name, lbn, index):
    job_name = str(job_name)
    lbn = str(lbn)
    (
        ret,
        error,
    ) = oneflow_api.JobBuildAndInferCtx_MirroredBlobGetSerializedSubLbi(
        job_name, lbn, index)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return text_format.Parse(ret, logical_blob_id_util.LogicalBlobId())
コード例 #11
0
def JobBuildAndInferCtx_GetBatchAxis(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    batch_axis_str, error = oneflow_api.JobBuildAndInferCtx_GetBatchAxis(
        job_name, lbn)
    batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64())
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    if batch_axis.HasField("value"):
        return batch_axis.value
    return None
コード例 #12
0
def InferOpConf(op_conf_proto, upstream_signature):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    serialized_upstream_sig = str(
        text_format.MessageToString(upstream_signature))
    op_attribute_str, error = oneflow_api.InferOpConf(
        serialized_op_conf,
        serialized_upstream_sig,
    )
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
コード例 #13
0
def JobBuildAndInferCtx_MirroredBlobGetStaticShape(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    get_shape = (
        oneflow_api.
        JobBuildAndInferCtx_MirroredBlobGetSerializedIdListAsStaticShape)
    axis_str, error = get_shape(job_name, lbn)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    int_list = text_format.Parse(axis_str, record_util.Int64List())
    return tuple(map(int, int_list.value))
コード例 #14
0
def JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    (
        split_axis_str,
        error,
    ) = oneflow_api.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
        job_name, lbn)
    split_axis = text_format.Parse(split_axis_str, dtype_util.OptInt64())
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    if split_axis.HasField("value"):
        return split_axis.value
    return None
コード例 #15
0
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    GetParallelConf = (
        oneflow_api.
        JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView)
    parallel_conf, error = GetParallelConf(job_name, lbn)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    parallel_conf = text_format.Parse(parallel_conf,
                                      placement_pb.ParallelConf())
    # TODO(oyy) change temporary transformation after python code migrated into cpp code
    parallel_conf_cfg = placement_cfg.ParallelConf()
    parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
    for device_name in parallel_conf.device_name:
        parallel_conf_cfg.add_device_name(device_name)

    return parallel_conf_cfg
コード例 #16
0
def GetJobSet():
    job_set, error = oneflow_api.GetSerializedJobSet()
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return text_format.Parse(job_set, job_set_pb.JobSet())
コード例 #17
0
def GetOpAttributes():
    op_attributes, error = oneflow_api.GetSerializedOpAttributes()
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return text_format.Parse(op_attributes, op_attribute_pb.OpAttributeList())
コード例 #18
0
def IsOpTypeNameCpuSupportOnly(op_type_name):
    ret, error = oneflow_api.IsOpTypeNameCpuSupportOnly(op_type_name)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return ret
コード例 #19
0
def NewPhysicalSymbolId():
    object_id, error = oneflow_api.NewPhysicalSymbolId()
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return object_id
コード例 #20
0
def NewLogicalObjectId():
    object_id, error = oneflow_api.NewLogicalObjectId()
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return object_id
コード例 #21
0
def EnvResource():
    resource, error = oneflow_api.EnvResource()
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return text_format.Parse(resource, resource_util.Resource())
コード例 #22
0
def GetScopeConfigDef():
    scope_config_def, error = oneflow_api.GetScopeConfigDef()
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return text_format.Parse(scope_config_def, ConfigDef())
コード例 #23
0
def IsEnvInited():
    is_env_inited, error = oneflow_api.IsEnvInited()
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return is_env_inited
コード例 #24
0
def GetStructureGraph():
    structure_graph, error = oneflow_api.GetSerializedStructureGraph()
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return structure_graph
コード例 #25
0
def LoadLibraryNow(lib_path):
    error = oneflow_api.LoadLibraryNow(lib_path)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
コード例 #26
0
def RunPhysicalInstruction(vm_instruction_list, eager_symbol_list):
    symbols = str(text_format.MessageToString(eager_symbol_list))
    error = oneflow_api.vm.RunPhysicalInstruction(vm_instruction_list, symbols)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
コード例 #27
0
def EnableEagerEnvironment(enable_eager_execution):
    error = oneflow_api.EnableEagerEnvironment(enable_eager_execution)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
コード例 #28
0
def CurrentMachineId():
    machine_id, error = oneflow_api.CurrentMachineId()
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return machine_id
コード例 #29
0
def InitEnv(env_proto):
    assert type(env_proto) is env_pb2.EnvProto
    env_proto_str = text_format.MessageToString(env_proto)
    error = oneflow_api.InitEnv(env_proto_str)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
コード例 #30
0
def GetFunctionConfigDef():
    func_config_def, error = oneflow_api.GetFunctionConfigDef()
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return text_format.Parse(func_config_def, ConfigDef())