def EagerOpKernelForward(add_and_infer, op_conf, opkernel_object): op_attribute = add_and_infer(op_conf, opkernel_object.scope_symbol) op_executor.OpKernelCall(opkernel_object, op_attribute, blob_register) bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister() gradient_util.TrySetBackwardUsedBlobObject(op_attribute, blob_register, bw_blob_register) return op_attribute
def EagerForward(add_and_infer, op_conf, scope_symbol=None): op_attribute = add_and_infer(op_conf, scope_symbol) parallel_conf = scope_symbol.device_parallel_desc_symbol.parallel_conf op_executor.Interpret(op_attribute, parallel_conf, blob_register) bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister() gradient_util.TrySetBackwardUsedBlobObject(op_attribute, blob_register, bw_blob_register) return op_attribute
def get_eager_variable( name, shape=None, dtype=None, initializer=None, regularizer=None, trainable=None, model_name=None, random_seed=None, distribute=oneflow_api.distribute.broadcast(), reuse=True, ): assert isinstance(name, str) assert isinstance( shape, (list, tuple) ), "param shape should be a list or tuple of dimension" job_name = oneflow_api.JobBuildAndInferCtx_GetCurrentJobName() name = name_scope.GetJobNameScopePrefix(job_name) + name sess = session_ctx.GetDefaultSession() var_blob, job_var_blob = sess.TryGetVariableBlobOfJobFromStash(job_name, name) if reuse is False: assert job_var_blob is None, ( "varaible '{}' already exists, " "getting the same variable is not allowed " "when reuse is False".format(name) ) if job_var_blob is None: op_conf = GenerateVariableOpConf( name=name, shape=shape, dtype=dtype, initializer=initializer, regularizer=regularizer, trainable=trainable, model_name=model_name, random_seed=random_seed, distribute=distribute, ) op_attribute = compile_context.CurJobAddConsistentOp(op_conf) if var_blob is None: var_blob = CreateEagerVariableBlob(op_attribute) op_executor.EagerInitVariableBlob(sess, op_conf, var_blob) assert isinstance(var_blob, remote_blob_util.EagerConsistentBlob) sess.StashVariableBlob4Job(job_name, op_conf.name, var_blob) else: assert isinstance(job_var_blob, remote_blob_util.EagerConsistentBlob) assert isinstance(var_blob, remote_blob_util.EagerConsistentBlob) assert var_blob.IdenticalTo(job_var_blob) bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister() bw_blob_register.TrySetObject4BlobName( var_blob.logical_blob_name, var_blob.blob_object ) return var_blob
def MirroredCast(op_attribute_str, parallel_conf): op_attribute = text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute()) blob_register = oneflow_api.GetDefaultBlobRegister() is_cast_to_mirrored = op_attribute.op_conf.HasField("cast_to_mirrored_conf") is_cast_from_mirrored = op_attribute.op_conf.HasField("cast_from_mirrored_conf") assert is_cast_to_mirrored or is_cast_from_mirrored _MirroredCastAndAddOutputBlobReleaser(op_attribute, blob_register) bw_blob_register = gradient_util.GetDefaultBackwardBlobRegister() gradient_util.TrySetBackwardUsedBlobObject( op_attribute, blob_register, bw_blob_register )
def RemoteBlobList(self): remote_blob_list = [] for k in self.op_conf_.user_conf.output: if k not in self.output_arg_key_list_: raise ValueError( "output_arg_name {} of {} op is not set in python op builder" .format(k, self.op_conf_.name)) for output_arg_name in self.output_arg_key_list_: assert output_arg_name in self.op_conf_.user_conf.output for i in range( len(self.op_conf_.user_conf.output[output_arg_name].s)): lbi = logical_blob_id_util.LogicalBlobId() lbi.op_name = self.op_conf_.name lbi.blob_name = "{}_{}".format(output_arg_name, i) remote_blob_obj = self.MakeRemoteBlob(lbi) remote_blob_list.append(remote_blob_obj) if flow.eager_execution_enabled(): gradient_util.GetDefaultBackwardBlobRegister( ).TrySetObject4BlobName(remote_blob_obj.logical_blob_name, remote_blob_obj.blob_object) return tuple(remote_blob_list)
def InterpretCompletedOp(op_attribute_str, parallel_conf): op_attribute = text_format.Parse(op_attribute_str, op_attribute_pb.OpAttribute()) blob_register = gradient_util.GetDefaultBackwardBlobRegister() _InterpretCompletedOp(op_attribute, parallel_conf, blob_register) gradient_util.ReleaseUnusedBlobObject(op_attribute, blob_register)
def eager_add_loss(loss): c_api_util.CurJobBuildAndInferCtx_AddLossLogicalBlobName(loss.unique_name) gradient_util.GetDefaultBackwardBlobRegister().TrySetObject4BlobName( loss.logical_blob_name, loss.blob_object)