def JobBuildAndInferCtx_IsTensorList(job_name, lbn): job_name = str(job_name) lbn = str(lbn) ret, error = oneflow_api.JobBuildAndInferCtx_IsTensorList(job_name, lbn) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return ret
def JobBuildAndInferCtx_DisableBoxing(job_name, lbn): job_name = str(job_name) lbn = str(lbn) ret, error = oneflow_api.JobBuildAndInferCtx_DisableBoxing(job_name, lbn) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return ret
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) add_and_infer = oneflow_api.CurJobBuildAndInferCtx_AddAndInferMirroredOp op_attribute_str, error = add_and_infer(serialized_op_conf) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
def CheckAndCompleteUserOpConf(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) new_op_conf, error = oneflow_api.CheckAndCompleteUserOpConf( serialized_op_conf) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return text_format.Parse(new_op_conf, op_conf_util.OperatorConf())
def JobBuildAndInferCtx_GetDataType(job_name, lbn): job_name = str(job_name) lbn = str(lbn) dtype, error = oneflow_api.JobBuildAndInferCtx_GetDataType(job_name, lbn) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return int(dtype)
def CurJobBuildAndInferCtx_SetTrainConf(train_config_proto): serialized_train_conf = str( text_format.MessageToString(train_config_proto)) error = oneflow_api.CurJobBuildAndInferCtx_SetTrainConf( serialized_train_conf) if error.has_error_type(): raise JobBuildAndInferCfgError(error)
def JobBuildAndInferCtx_MirroredBlobIsDynamic(job_name, lbn): job_name = str(job_name) lbn = str(lbn) ret, error = oneflow_api.JobBuildAndInferCtx_MirroredBlobIsDynamic( job_name, lbn) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return ret
def IsInterfaceOpConf(op_conf): op_type_field = op_conf.WhichOneof("op_type") field_number = op_conf_util.OperatorConf.DESCRIPTOR.fields_by_name[ op_type_field].number res, error = oneflow_api.IsInterfaceOpTypeCase(field_number) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return res
def GetMachine2DeviceIdListOFRecordFromParallelConf(parallel_conf): serialized_parallel_conf = str(parallel_conf) (ofrecord, error) = oneflow_api.GetMachine2DeviceIdListOFRecordFromParallelConf( serialized_parallel_conf) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return text_format.Parse(ofrecord, record_util.OFRecord())
def JobBuildAndInferCtx_MirroredBlobGetSubLbi(job_name, lbn, index): job_name = str(job_name) lbn = str(lbn) ( ret, error, ) = oneflow_api.JobBuildAndInferCtx_MirroredBlobGetSerializedSubLbi( job_name, lbn, index) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return text_format.Parse(ret, logical_blob_id_util.LogicalBlobId())
def JobBuildAndInferCtx_GetBatchAxis(job_name, lbn): job_name = str(job_name) lbn = str(lbn) batch_axis_str, error = oneflow_api.JobBuildAndInferCtx_GetBatchAxis( job_name, lbn) batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64()) if error.has_error_type(): raise JobBuildAndInferCfgError(error) if batch_axis.HasField("value"): return batch_axis.value return None
def InferOpConf(op_conf_proto, upstream_signature): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) serialized_upstream_sig = str( text_format.MessageToString(upstream_signature)) op_attribute_str, error = oneflow_api.InferOpConf( serialized_op_conf, serialized_upstream_sig, ) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
def JobBuildAndInferCtx_MirroredBlobGetStaticShape(job_name, lbn): job_name = str(job_name) lbn = str(lbn) get_shape = ( oneflow_api. JobBuildAndInferCtx_MirroredBlobGetSerializedIdListAsStaticShape) axis_str, error = get_shape(job_name, lbn) if error.has_error_type(): raise JobBuildAndInferCfgError(error) int_list = text_format.Parse(axis_str, record_util.Int64List()) return tuple(map(int, int_list.value))
def JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn): job_name = str(job_name) lbn = str(lbn) ( split_axis_str, error, ) = oneflow_api.JobBuildAndInferCtx_GetSplitAxisFromProducerView( job_name, lbn) split_axis = text_format.Parse(split_axis_str, dtype_util.OptInt64()) if error.has_error_type(): raise JobBuildAndInferCfgError(error) if split_axis.HasField("value"): return split_axis.value return None
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn): job_name = str(job_name) lbn = str(lbn) GetParallelConf = ( oneflow_api. JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView) parallel_conf, error = GetParallelConf(job_name, lbn) if error.has_error_type(): raise JobBuildAndInferCfgError(error) parallel_conf = text_format.Parse(parallel_conf, placement_pb.ParallelConf()) # TODO(oyy) change temporary transformation after python code migrated into cpp code parallel_conf_cfg = placement_cfg.ParallelConf() parallel_conf_cfg.set_device_tag(parallel_conf.device_tag) for device_name in parallel_conf.device_name: parallel_conf_cfg.add_device_name(device_name) return parallel_conf_cfg
def GetJobSet(): job_set, error = oneflow_api.GetSerializedJobSet() if error.has_error_type(): raise JobBuildAndInferCfgError(error) return text_format.Parse(job_set, job_set_pb.JobSet())
def GetOpAttributes(): op_attributes, error = oneflow_api.GetSerializedOpAttributes() if error.has_error_type(): raise JobBuildAndInferCfgError(error) return text_format.Parse(op_attributes, op_attribute_pb.OpAttributeList())
def IsOpTypeNameCpuSupportOnly(op_type_name): ret, error = oneflow_api.IsOpTypeNameCpuSupportOnly(op_type_name) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return ret
def NewPhysicalSymbolId(): object_id, error = oneflow_api.NewPhysicalSymbolId() if error.has_error_type(): raise JobBuildAndInferCfgError(error) return object_id
def NewLogicalObjectId(): object_id, error = oneflow_api.NewLogicalObjectId() if error.has_error_type(): raise JobBuildAndInferCfgError(error) return object_id
def EnvResource(): resource, error = oneflow_api.EnvResource() if error.has_error_type(): raise JobBuildAndInferCfgError(error) return text_format.Parse(resource, resource_util.Resource())
def GetScopeConfigDef(): scope_config_def, error = oneflow_api.GetScopeConfigDef() if error.has_error_type(): raise JobBuildAndInferCfgError(error) return text_format.Parse(scope_config_def, ConfigDef())
def IsEnvInited(): is_env_inited, error = oneflow_api.IsEnvInited() if error.has_error_type(): raise JobBuildAndInferCfgError(error) return is_env_inited
def GetStructureGraph(): structure_graph, error = oneflow_api.GetSerializedStructureGraph() if error.has_error_type(): raise JobBuildAndInferCfgError(error) return structure_graph
def LoadLibraryNow(lib_path): error = oneflow_api.LoadLibraryNow(lib_path) if error.has_error_type(): raise JobBuildAndInferCfgError(error)
def RunPhysicalInstruction(vm_instruction_list, eager_symbol_list): symbols = str(text_format.MessageToString(eager_symbol_list)) error = oneflow_api.vm.RunPhysicalInstruction(vm_instruction_list, symbols) if error.has_error_type(): raise JobBuildAndInferCfgError(error)
def EnableEagerEnvironment(enable_eager_execution): error = oneflow_api.EnableEagerEnvironment(enable_eager_execution) if error.has_error_type(): raise JobBuildAndInferCfgError(error)
def CurrentMachineId(): machine_id, error = oneflow_api.CurrentMachineId() if error.has_error_type(): raise JobBuildAndInferCfgError(error) return machine_id
def InitEnv(env_proto): assert type(env_proto) is env_pb2.EnvProto env_proto_str = text_format.MessageToString(env_proto) error = oneflow_api.InitEnv(env_proto_str) if error.has_error_type(): raise JobBuildAndInferCfgError(error)
def GetFunctionConfigDef(): func_config_def, error = oneflow_api.GetFunctionConfigDef() if error.has_error_type(): raise JobBuildAndInferCfgError(error) return text_format.Parse(func_config_def, ConfigDef())