def JobBuildAndInferCtx_GetBatchAxis(job_name, lbn): job_name = str(job_name) lbn = str(lbn) batch_axis_str = oneflow_api.JobBuildAndInferCtx_GetBatchAxis(job_name, lbn) batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64()) if batch_axis.HasField("value"): return batch_axis.value return None
def to_split_axis(dist): split_axis = data_type_util.OptInt64() if type(dist) is distribute_util.SplitDistribute: split_axis.value = dist.axis elif type(dist) is distribute_util.BroadcastDistribute: split_axis.ClearField("value") else: raise NotImplementedError return split_axis
def JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn): job_name = str(job_name) lbn = str(lbn) split_axis_str = oneflow._oneflow_internal.JobBuildAndInferCtx_GetSplitAxisFromProducerView( job_name, lbn) split_axis = text_format.Parse(split_axis_str, dtype_util.OptInt64()) if split_axis.HasField("value"): return split_axis.value return None
def JobBuildAndInferCtx_GetBatchAxis(job_name, lbn): job_name = str(job_name) lbn = str(lbn) batch_axis_str, error_str = oneflow_internal.JobBuildAndInferCtx_GetBatchAxis( job_name, lbn) batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64()) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) if batch_axis.HasField("value"): return batch_axis.value return None
def JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn): job_name = str(job_name) lbn = str(lbn) ( split_axis_str, error_str, ) = oneflow_internal.JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn) split_axis = text_format.Parse(split_axis_str, dtype_util.OptInt64()) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) if split_axis.HasField("value"): return split_axis.value return None
def JobBuildAndInferCtx_MirroredBlobGetBatchAxis(job_name, lbn): job_name = str(job_name) lbn = str(lbn) ( batch_axis_str, error, ) = oneflow_api.JobBuildAndInferCtx_MirroredBlobGetBatchAxis( job_name, lbn) batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64()) if error.has_error_type(): raise JobBuildAndInferCfgError(error) if batch_axis.HasField("value"): return batch_axis.value return None
def JobBuildAndInferCtx_MirroredBlobGetSplitAxisFromProducerView( job_name, lbn): job_name = str(job_name) lbn = str(lbn) ( split_axis_str, error, ) = oneflow_api.JobBuildAndInferCtx_MirroredBlobGetSplitAxisFromProducerView( job_name, lbn) split_axis = text_format.Parse(split_axis_str, dtype_util.OptInt64()) if error.has_error_type(): raise JobBuildAndInferCfgError(error) if split_axis.HasField("value"): return split_axis.value return None