def CurJobBuildAndInferCtx_SetTrainConf(train_config_proto): serialized_train_conf = str( text_format.MessageToString(train_config_proto)) error_str = oneflow_internal.CurJobBuildAndInferCtx_SetTrainConf( serialized_train_conf) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def GetOpParallelSymbolId(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) symbol_id, error_str = oneflow_internal.GetOpParallelSymbolId( serialized_op_conf) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return symbol_id
def CheckAndCompleteUserOpConf(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) new_op_conf, error_str = oneflow_internal.CheckAndCompleteUserOpConf( serialized_op_conf) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(new_op_conf, op_conf_util.OperatorConf())
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) add_and_infer = oneflow_internal.CurJobBuildAndInferCtx_AddAndInferMirroredOp op_attribute_str, error_str = add_and_infer(serialized_op_conf) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
def CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair(lbi_and_uuid): serialized = str(text_format.MessageToString(lbi_and_uuid)) error_str = oneflow_internal.CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair( serialized ) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def JobBuildAndInferCtx_IsTensorList(job_name, lbn): job_name = str(job_name) lbn = str(lbn) ret, error_str = oneflow_internal.JobBuildAndInferCtx_IsTensorList(job_name, lbn) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return ret
def JobBuildAndInferCtx_GetDataType(job_name, lbn): job_name = str(job_name) lbn = str(lbn) dtype, erro_str = oneflow_internal.JobBuildAndInferCtx_GetDataType(job_name, lbn) error = text_format.Parse(erro_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return int(dtype)
def InferOpConf(op_conf_proto, upstream_signature): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) serialized_upstream_sig = str(text_format.MessageToString(upstream_signature)) op_attribute_str, error_str = oneflow_internal.InferOpConf( serialized_op_conf, serialized_upstream_sig, ) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
def JobBuildAndInferCtx_MirroredBlobDisableBoxing(job_name, lbn): job_name = str(job_name) lbn = str(lbn) # TODO(hanbinbin): this api dose not exist and will be del after confirmed ret, error_str = oneflow_internal.JobBuildAndInferCtx_MirroredBlobDisableBoxing( job_name, lbn) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return ret
def GetMachine2DeviceIdListOFRecordFromParallelConf(parallel_conf): serialized_parallel_conf = str(parallel_conf) ( ofrecord, error_str, ) = oneflow_internal.GetMachine2DeviceIdListOFRecordFromParallelConf( serialized_parallel_conf) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(ofrecord, record_util.OFRecord())
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn): job_name = str(job_name) lbn = str(lbn) GetParallelConf = ( oneflow_internal.JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView ) parallel_conf, error_str = GetParallelConf(job_name, lbn) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(parallel_conf, placement_pb.ParallelConf())
def JobBuildAndInferCtx_GetBatchAxis(job_name, lbn): job_name = str(job_name) lbn = str(lbn) batch_axis_str, error_str = oneflow_internal.JobBuildAndInferCtx_GetBatchAxis( job_name, lbn) batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64()) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) if batch_axis.HasField("value"): return batch_axis.value return None
def JobBuildAndInferCtx_MirroredBlobGetStaticShape(job_name, lbn): job_name = str(job_name) lbn = str(lbn) get_shape = ( oneflow_internal. JobBuildAndInferCtx_MirroredBlobGetSerializedIdListAsStaticShape) axis_str, error_str = get_shape(job_name, lbn) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) int_list = text_format.Parse(axis_str, record_util.Int64List()) return tuple(map(int, int_list.value))
def JobBuildAndInferCtx_MirroredBlobGetSubLbi(job_name, lbn, index): job_name = str(job_name) lbn = str(lbn) ( ret, error_str, ) = oneflow_internal.JobBuildAndInferCtx_MirroredBlobGetSerializedSubLbi( job_name, lbn, index) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(ret, logical_blob_id_util.LogicalBlobId())
def JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn): job_name = str(job_name) lbn = str(lbn) ( split_axis_str, error_str, ) = oneflow_internal.JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn) split_axis = text_format.Parse(split_axis_str, dtype_util.OptInt64()) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) if split_axis.HasField("value"): return split_axis.value return None
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn): job_name = str(job_name) lbn = str(lbn) GetParallelConf = ( oneflow_internal.JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView ) parallel_conf, error_str = GetParallelConf(job_name, lbn) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) parallel_conf = text_format.Parse(parallel_conf, placement_pb.ParallelConf()) # TODO(oyy) change temporary transformation after python code migrated into cpp code parallel_conf_cfg = placement_cfg.ParallelConf() parallel_conf_cfg.set_device_tag(parallel_conf.device_tag) for device_name in parallel_conf.device_name: parallel_conf_cfg.add_device_name(device_name) return parallel_conf_cfg
def IsOpTypeNameCpuSupportOnly(op_type_name): ret, error_str = oneflow_internal.IsOpTypeNameCpuSupportOnly(op_type_name) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return ret
def RegisterWatcherOnlyOnce(watcher): error_str = oneflow_internal.RegisterWatcherOnlyOnce(watcher) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def EnvResource(): resource, error_str = oneflow_internal.EnvResource() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(resource, resource_util.Resource())
def GetJobSet(): job_set, error_str = oneflow_internal.GetSerializedJobSet() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(job_set, job_set_pb.JobSet())
def NewPhysicalSymbolId(): object_id, error_str = oneflow_internal.NewPhysicalSymbolId() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return object_id
def GetScopeConfigDef(): scope_config_def, error_str = oneflow_internal.GetScopeConfigDef() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(scope_config_def, ConfigDef())
def CurrentMachineId(): machine_id, error_str = oneflow_internal.CurrentMachineId() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return machine_id
def StopLazyGlobalSession(): error_str = oneflow_internal.StopLazyGlobalSession() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def GetOpAttributes(): op_attributes, error_str = oneflow_internal.GetSerializedOpAttributes() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(op_attributes, op_attribute_pb.OpAttributeList())
def LaunchJob(job_instance): error_str = oneflow_internal.LaunchJob(job_instance) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def GetStructureGraph(): structure_graph, error_str = oneflow_internal.GetSerializedStructureGraph() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return structure_graph
def JobBuildAndInferCtx_Open(job_name): job_name = str(job_name) error_str = oneflow_internal.JobBuildAndInferCtx_Open(job_name) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def DestroyEnv(): error_str = oneflow_internal.DestroyEnv() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def CurJobBuildAndInferCtx_CheckJob(): error_str = oneflow_internal.CurJobBuildAndInferCtx_CheckJob() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)