コード例 #1
0
def InferOpConf(op_conf_proto, upstream_signature):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    serialized_upstream_sig = str(
        text_format.MessageToString(upstream_signature))
    op_attribute_str = oneflow._oneflow_internal.InferOpConf(
        serialized_op_conf, serialized_upstream_sig)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
コード例 #2
0
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    add_and_infer = oneflow_api.CurJobBuildAndInferCtx_AddAndInferMirroredOp
    op_attribute_str, error = add_and_infer(serialized_op_conf)
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
コード例 #3
0
ファイル: c_api_util.py プロジェクト: ncnnnnn/oneflow
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    add_and_infer = oneflow_internal.CurJobBuildAndInferCtx_AddAndInferMirroredOp
    op_attribute_str, error_str = add_and_infer(serialized_op_conf)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
コード例 #4
0
def InferOpConf(op_conf_proto, upstream_signature):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    serialized_upstream_sig = str(text_format.MessageToString(upstream_signature))
    op_attribute_str, error_str = oneflow_internal.InferOpConf(
        serialized_op_conf, serialized_upstream_sig,
    )
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
コード例 #5
0
def MirroredCast(op_attribute_str, parallel_conf):
    op_attribute = text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
    blob_register = oneflow_api.GetDefaultBlobRegister()
    is_cast_to_mirrored = op_attribute.op_conf.HasField("cast_to_mirrored_conf")
    is_cast_from_mirrored = op_attribute.op_conf.HasField("cast_from_mirrored_conf")
    assert is_cast_to_mirrored or is_cast_from_mirrored
    _MirroredCastAndAddOutputBlobReleaser(op_attribute, blob_register)
    bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    gradient_util.TrySetBackwardUsedBlobObject(
        op_attribute, blob_register, bw_blob_register
    )
コード例 #6
0
def InferOpConf(op_conf_proto, upstream_signature):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    serialized_upstream_sig = str(
        text_format.MessageToString(upstream_signature))
    op_attribute_str, error = oneflow_api.InferOpConf(
        serialized_op_conf,
        serialized_upstream_sig,
    )
    if error.has_error_type():
        raise JobBuildAndInferCfgError(error)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
コード例 #7
0
def InterpretCompletedOp(op_attribute_str, parallel_conf):
    op_attribute = text_format.Parse(op_attribute_str,
                                     op_attribute_pb.OpAttribute())
    blob_register = gradient_util.GetDefaultBackwardBlobRegister()
    _InterpretCompletedOp(op_attribute, parallel_conf, blob_register)
    gradient_util.ReleaseUnusedBlobObject(op_attribute, blob_register)
コード例 #8
0
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    add_and_infer = (
        oneflow._oneflow_internal.CurJobBuildAndInferCtx_AddAndInferMirroredOp)
    op_attribute_str = add_and_infer(serialized_op_conf)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
コード例 #9
0
ファイル: c_api_util.py プロジェクト: liudyboy/oneflow
def CurJobBuildAndInferCtx_AddAndInferConsistentOp(op_conf_proto):
    serialized_op_conf = str(text_format.MessageToString(op_conf_proto))
    add_and_infer = oneflow_api.CurJobBuildAndInferCtx_AddAndInferConsistentOp
    op_attribute_str = add_and_infer(serialized_op_conf)
    return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())