Beispiel #1
0
 def __init__(self):
     self.id_ = oneflow_api.NewSessionId()
     self.job_name2function_desc_ = {}
     self.status_ = SessionStatus.OPEN
     self.cond_var_ = threading.Condition()
     self.running_job_cnt_ = 0
     self.inter_user_job_info_ = None
     self.uuid2watch_handler_ = {}
     self.config_proto_ = None
     self.resource_ = None
     self.is_mirrored_strategy_enabled_stack_ = []
     self.job_name2var_name2var_blob_ = {}
     self.job_name2module_name2module_ = {}
     self.existed_module_names_ = set()
     self.var_name2var_blob_ = {}
     # parallel desc symbol id in op attribute does not always correct
     # for lazy ops as parallel conf may be updated in some passes
     # (like optimizer_placement_optimization_pass)
     self.interface_op_name2op_attr_ = {}
     self.interface_op_name2job_name_ = {}
     self.lazy_interface_op_name2parallel_conf_ = {}
     self.op_name2lazy_blob_cache_ = {}
     self.job_name2name_scope_stack_ = {}
     self.eager_global_function_desc_stack_ = []
     self.function_flag_name2default_val_ = {}
     self._UpdateFunctionFlagName2DefaultVal()
     self.scope_attr_name2default_val_ = {}
     self._UpdateScopeAttrName2DefaultVal()
     self.instruction_list_ = instr_cfg.InstructionListProto()
     self.eager_symbol_list_ = eager_symbol_cfg.EagerSymbolList()
     self.backward_blob_register_ = oneflow_api.BlobRegister(
         blob_cache_util.TryDisableBlobCache)
     self.snapshot_mgr_ = SnapshotManager()
     self.eager_config_proto_ctx_ = None
Beispiel #2
0
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import

import oneflow.python.eager.blob_cache as blob_cache_util
import oneflow_api
from contextlib import contextmanager


def GetDefaultBlobRegister():
    return default_blob_register_


@contextmanager
def BnInOp2BlobObjectScope(blob_register, op_attribute):
    bn_in_op2blob_object = {}
    for ibn in op_attribute.input_bns:
        lbi = op_attribute.arg_signature.bn_in_op2lbi[ibn]
        bn_in_op2blob_object[ibn] = blob_register.GetObject4BlobName(
            "%s/%s" % (lbi.op_name, lbi.blob_name))
    yield bn_in_op2blob_object
    for obn in op_attribute.output_bns:
        lbi = op_attribute.arg_signature.bn_in_op2lbi[obn]
        blob_register.SetObject4BlobName(
            "%s/%s" % (lbi.op_name, lbi.blob_name), bn_in_op2blob_object[obn])


default_blob_register_ = oneflow_api.BlobRegister(
    blob_cache_util.TryDisableBlobCache)