def InitEnv(env_proto): assert type(env_proto) is env_pb2.EnvProto env_proto_str = text_format.MessageToString(env_proto) error_str = oneflow_internal.InitEnv(env_proto_str) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def CurJobBuildAndInferCtx_HasJobConf(): has_job_conf, error_str = oneflow_internal.CurJobBuildAndInferCtx_HasJobConf( ) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return has_job_conf
def CurJobBuildAndInferCtx_AddLossLogicalBlobName(lbn): lbn = str(lbn) error_str = oneflow_internal.CurJobBuildAndInferCtx_AddLossLogicalBlobName( lbn) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def CurJobBuildAndInferCtx_SetJobConf(job_config_proto): serialized_job_conf = str(text_format.MessageToString(job_config_proto)) error_str = oneflow_internal.CurJobBuildAndInferCtx_SetJobConf( serialized_job_conf) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def DeviceType4DeviceTag(device_tag): device_tag = str(device_tag) device_type, error_str = oneflow_internal.DeviceType4DeviceTag(device_tag) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return device_type
def GetUserOpAttrType(op_type_name, attr_name): attr_type, error_str = oneflow_internal.GetUserOpAttrType( op_type_name, attr_name) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return attr_type
def GetInterUserJobInfo(): inter_user_job_info, error_str = oneflow_internal.GetSerializedInterUserJobInfo( ) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(inter_user_job_info, InterUserJobInfo())
def JobBuildAndInferCtx_GetCurrentJobName(): job_name, error_str = oneflow_internal.JobBuildAndInferCtx_GetCurrentJobName( ) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return job_name
def CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair(lbi_and_uuid): serialized = str(text_format.MessageToString(lbi_and_uuid)) error_str = oneflow_internal.CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair( serialized) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def InitLazyGlobalSession(config_proto): assert type(config_proto) is job_set_pb.ConfigProto config_proto_str = text_format.MessageToString(config_proto) error_str = oneflow_internal.InitLazyGlobalSession(config_proto_str) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def RunLogicalInstruction(vm_instruction_list, eager_symbol_list): symbols = str(text_format.MessageToString(eager_symbol_list)) error_str = oneflow_api.vm.RunLogicalInstruction(vm_instruction_list, symbols) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def GetOpParallelSymbolId(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) symbol_id, error_str = oneflow_internal.GetOpParallelSymbolId(serialized_op_conf) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return symbol_id
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) add_and_infer = oneflow_internal.CurJobBuildAndInferCtx_AddAndInferMirroredOp op_attribute_str, error_str = add_and_infer(serialized_op_conf) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
def CheckAndCompleteUserOpConf(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) new_op_conf, error_str = oneflow_internal.CheckAndCompleteUserOpConf( serialized_op_conf) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(new_op_conf, op_conf_util.OperatorConf())
def JobBuildAndInferCtx_IsTensorList(job_name, lbn): job_name = str(job_name) lbn = str(lbn) ret, error_str = oneflow_internal.JobBuildAndInferCtx_IsTensorList(job_name, lbn) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return ret
def JobBuildAndInferCtx_GetDataType(job_name, lbn): job_name = str(job_name) lbn = str(lbn) dtype, erro_str = oneflow_internal.JobBuildAndInferCtx_GetDataType(job_name, lbn) error = text_format.Parse(erro_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return int(dtype)
def InferOpConf(op_conf_proto, upstream_signature): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) serialized_upstream_sig = str(text_format.MessageToString(upstream_signature)) op_attribute_str, error_str = oneflow_internal.InferOpConf( serialized_op_conf, serialized_upstream_sig, ) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
def JobBuildAndInferCtx_MirroredBlobDisableBoxing(job_name, lbn): job_name = str(job_name) lbn = str(lbn) # TODO(hanbinbin): this api dose not exist and will be del after confirmed ret, error_str = oneflow_internal.JobBuildAndInferCtx_MirroredBlobDisableBoxing( job_name, lbn) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return ret
def GetMachine2DeviceIdListOFRecordFromParallelConf(parallel_conf): serialized_parallel_conf = str(parallel_conf) ( ofrecord, error_str, ) = oneflow_internal.GetMachine2DeviceIdListOFRecordFromParallelConf( serialized_parallel_conf) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(ofrecord, record_util.OFRecord())
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn): job_name = str(job_name) lbn = str(lbn) GetParallelConf = ( oneflow_internal.JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView ) parallel_conf, error_str = GetParallelConf(job_name, lbn) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(parallel_conf, placement_pb.ParallelConf())
def JobBuildAndInferCtx_GetBatchAxis(job_name, lbn): job_name = str(job_name) lbn = str(lbn) batch_axis_str, error_str = oneflow_internal.JobBuildAndInferCtx_GetBatchAxis( job_name, lbn) batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64()) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) if batch_axis.HasField("value"): return batch_axis.value return None
def JobBuildAndInferCtx_MirroredBlobGetStaticShape(job_name, lbn): job_name = str(job_name) lbn = str(lbn) get_shape = ( oneflow_internal. JobBuildAndInferCtx_MirroredBlobGetSerializedIdListAsStaticShape) axis_str, error_str = get_shape(job_name, lbn) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) int_list = text_format.Parse(axis_str, record_util.Int64List()) return tuple(map(int, int_list.value))
def JobBuildAndInferCtx_MirroredBlobGetSubLbi(job_name, lbn, index): job_name = str(job_name) lbn = str(lbn) ( ret, error_str, ) = oneflow_internal.JobBuildAndInferCtx_MirroredBlobGetSerializedSubLbi( job_name, lbn, index) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(ret, logical_blob_id_util.LogicalBlobId())
def JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn): job_name = str(job_name) lbn = str(lbn) ( split_axis_str, error_str, ) = oneflow_internal.JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn) split_axis = text_format.Parse(split_axis_str, dtype_util.OptInt64()) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) if split_axis.HasField("value"): return split_axis.value return None
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn): job_name = str(job_name) lbn = str(lbn) GetParallelConf = ( oneflow_internal.JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView ) parallel_conf, error_str = GetParallelConf(job_name, lbn) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) parallel_conf = text_format.Parse(parallel_conf, placement_pb.ParallelConf()) # TODO(oyy) change temporary transformation after python code migrated into cpp code parallel_conf_cfg = placement_cfg.ParallelConf() parallel_conf_cfg.set_device_tag(parallel_conf.device_tag) for device_name in parallel_conf.device_name: parallel_conf_cfg.add_device_name(device_name) return parallel_conf_cfg
def DestroyEnv(): error_str = oneflow_internal.DestroyEnv() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error)
def EnvResource(): resource, error_str = oneflow_internal.EnvResource() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(resource, resource_util.Resource())
def GetStructureGraph(): structure_graph, error_str = oneflow_internal.GetSerializedStructureGraph() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return structure_graph
def GetJobSet(): job_set, error_str = oneflow_internal.GetSerializedJobSet() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(job_set, job_set_pb.JobSet())
def GetOpAttributes(): op_attributes, error_str = oneflow_internal.GetSerializedOpAttributes() error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(op_attributes, op_attribute_pb.OpAttributeList())