def InferOpConf(op_conf_proto, upstream_signature): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) serialized_upstream_sig = str( text_format.MessageToString(upstream_signature)) op_attribute_str = oneflow._oneflow_internal.InferOpConf( serialized_op_conf, serialized_upstream_sig) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) add_and_infer = oneflow_api.CurJobBuildAndInferCtx_AddAndInferMirroredOp op_attribute_str, error = add_and_infer(serialized_op_conf) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) add_and_infer = oneflow_internal.CurJobBuildAndInferCtx_AddAndInferMirroredOp op_attribute_str, error_str = add_and_infer(serialized_op_conf) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
def InferOpConf(op_conf_proto, upstream_signature): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) serialized_upstream_sig = str(text_format.MessageToString(upstream_signature)) op_attribute_str, error_str = oneflow_internal.InferOpConf( serialized_op_conf, serialized_upstream_sig, ) error = text_format.Parse(error_str, error_util.ErrorProto()) if error.HasField("error_type"): raise JobBuildAndInferError(error) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
def MirroredCast(op_attribute_str, parallel_conf): op_attribute = text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute()) blob_register = oneflow_api.GetDefaultBlobRegister() is_cast_to_mirrored = op_attribute.op_conf.HasField("cast_to_mirrored_conf") is_cast_from_mirrored = op_attribute.op_conf.HasField("cast_from_mirrored_conf") assert is_cast_to_mirrored or is_cast_from_mirrored _MirroredCastAndAddOutputBlobReleaser(op_attribute, blob_register) bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister() gradient_util.TrySetBackwardUsedBlobObject( op_attribute, blob_register, bw_blob_register )
def InferOpConf(op_conf_proto, upstream_signature): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) serialized_upstream_sig = str( text_format.MessageToString(upstream_signature)) op_attribute_str, error = oneflow_api.InferOpConf( serialized_op_conf, serialized_upstream_sig, ) if error.has_error_type(): raise JobBuildAndInferCfgError(error) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
def InterpretCompletedOp(op_attribute_str, parallel_conf): op_attribute = text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute()) blob_register = gradient_util.GetDefaultBackwardBlobRegister() _InterpretCompletedOp(op_attribute, parallel_conf, blob_register) gradient_util.ReleaseUnusedBlobObject(op_attribute, blob_register)
def CurJobBuildAndInferCtx_AddAndInferMirroredOp(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) add_and_infer = ( oneflow._oneflow_internal.CurJobBuildAndInferCtx_AddAndInferMirroredOp) op_attribute_str = add_and_infer(serialized_op_conf) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())
def CurJobBuildAndInferCtx_AddAndInferConsistentOp(op_conf_proto): serialized_op_conf = str(text_format.MessageToString(op_conf_proto)) add_and_infer = oneflow_api.CurJobBuildAndInferCtx_AddAndInferConsistentOp op_attribute_str = add_and_infer(serialized_op_conf) return text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute())