Ejemplo n.º 1
0
def JobBuildAndInferCtx_GetBatchAxis(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    batch_axis_str = oneflow_api.JobBuildAndInferCtx_GetBatchAxis(job_name, lbn)
    batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64())
    if batch_axis.HasField("value"):
        return batch_axis.value
    return None
Ejemplo n.º 2
0
 def to_split_axis(dist):
     split_axis = data_type_util.OptInt64()
     if type(dist) is distribute_util.SplitDistribute:
         split_axis.value = dist.axis
     elif type(dist) is distribute_util.BroadcastDistribute:
         split_axis.ClearField("value")
     else:
         raise NotImplementedError
     return split_axis
Ejemplo n.º 3
0
def JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    split_axis_str = oneflow._oneflow_internal.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
        job_name, lbn)
    split_axis = text_format.Parse(split_axis_str, dtype_util.OptInt64())
    if split_axis.HasField("value"):
        return split_axis.value
    return None
Ejemplo n.º 4
0
def JobBuildAndInferCtx_GetBatchAxis(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    batch_axis_str, error_str = oneflow_internal.JobBuildAndInferCtx_GetBatchAxis(
        job_name, lbn)
    batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64())
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    if batch_axis.HasField("value"):
        return batch_axis.value
    return None
Ejemplo n.º 5
0
def JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    (
        split_axis_str,
        error_str,
    ) = oneflow_internal.JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn)
    split_axis = text_format.Parse(split_axis_str, dtype_util.OptInt64())
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    if split_axis.HasField("value"):
        return split_axis.value
    return None
Ejemplo n.º 6
0
def JobBuildAndInferCtx_MirroredBlobGetBatchAxis(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    (
        batch_axis_str,
        error,
    ) = oneflow_api.JobBuildAndInferCtx_MirroredBlobGetBatchAxis(
        job_name, lbn)
    batch_axis = text_format.Parse(batch_axis_str, dtype_util.OptInt64())
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    if batch_axis.HasField("value"):
        return batch_axis.value
    return None
Ejemplo n.º 7
0
def JobBuildAndInferCtx_MirroredBlobGetSplitAxisFromProducerView(
        job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    (
        split_axis_str,
        error,
    ) = oneflow_api.JobBuildAndInferCtx_MirroredBlobGetSplitAxisFromProducerView(
        job_name, lbn)
    split_axis = text_format.Parse(split_axis_str, dtype_util.OptInt64())
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    if split_axis.HasField("value"):
        return split_axis.value
    return None