def __init__(self): self.id_ = oneflow_api.NewSessionId() self.job_name2function_desc_ = {} self.status_ = SessionStatus.OPEN self.cond_var_ = threading.Condition() self.running_job_cnt_ = 0 self.inter_user_job_info_ = None self.uuid2watch_handler_ = {} self.config_proto_ = None self.resource_ = None self.is_mirrored_strategy_enabled_stack_ = [] self.job_name2var_name2var_blob_ = {} self.job_name2module_name2module_ = {} self.existed_module_names_ = set() self.var_name2var_blob_ = {} # parallel desc symbol id in op attribute does not always correct # for lazy ops as parallel conf may be updated in some passes # (like optimizer_placement_optimization_pass) self.interface_op_name2op_attr_ = {} self.interface_op_name2job_name_ = {} self.lazy_interface_op_name2parallel_conf_ = {} self.op_name2lazy_blob_cache_ = {} self.job_name2name_scope_stack_ = {} self.eager_global_function_desc_stack_ = [] self.function_flag_name2default_val_ = {} self._UpdateFunctionFlagName2DefaultVal() self.scope_attr_name2default_val_ = {} self._UpdateScopeAttrName2DefaultVal() self.instruction_list_ = instr_cfg.InstructionListProto() self.eager_symbol_list_ = eager_symbol_cfg.EagerSymbolList() self.backward_blob_register_ = oneflow_api.BlobRegister( blob_cache_util.TryDisableBlobCache) self.snapshot_mgr_ = SnapshotManager() self.eager_config_proto_ctx_ = None
See the License for the specific language governing permissions and limitations under the License. """ from __future__ import absolute_import import oneflow.python.eager.blob_cache as blob_cache_util import oneflow_api from contextlib import contextmanager def GetDefaultBlobRegister(): return default_blob_register_ @contextmanager def BnInOp2BlobObjectScope(blob_register, op_attribute): bn_in_op2blob_object = {} for ibn in op_attribute.input_bns: lbi = op_attribute.arg_signature.bn_in_op2lbi[ibn] bn_in_op2blob_object[ibn] = blob_register.GetObject4BlobName( "%s/%s" % (lbi.op_name, lbi.blob_name)) yield bn_in_op2blob_object for obn in op_attribute.output_bns: lbi = op_attribute.arg_signature.bn_in_op2lbi[obn] blob_register.SetObject4BlobName( "%s/%s" % (lbi.op_name, lbi.blob_name), bn_in_op2blob_object[obn]) default_blob_register_ = oneflow_api.BlobRegister( blob_cache_util.TryDisableBlobCache)