def _EagerRunModelSave(var_blobs, snapshot_path): path_input_op_conf, path_lbi = _GenModelIOPathInputOpConfAndRetLbi() path_input_blob_objects = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject( ) ( BuildModelIOPathInputInstruction, BuildFeedPathInstruction, ) = _MakeModelIOPathInputBuilds(path_input_op_conf, snapshot_path, path_input_blob_objects) model_save_op_conf = _GenModelSaveOpConf(var_blobs, path_lbi) model_save_blob_objects = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject( ) def BuildModelSaveInstruction(builder): path_blob_object = path_input_blob_objects["out"] model_save_blob_objects["path"] = path_blob_object for i, blob in enumerate(var_blobs): model_save_blob_objects["in_{}".format(i)] = blob.blob_object op_attribute = op_infer_util.Infer( model_save_op_conf, ibn2blob_object=model_save_blob_objects) parallel_conf = path_blob_object.parallel_desc_symbol.parallel_conf cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString( str(op_attribute)) builder.StatelessCall( cfg_op_attribute, parallel_conf, model_save_blob_objects, boxing_util.BoxingTo, ) sess = session_ctx.GetDefaultSession() with scope_util.ScopeContext(scope_util.MakeScope(_BuildNotMirroredScope)): oneflow._oneflow_internal.deprecated.LogicalRun( BuildModelIOPathInputInstruction) oneflow._oneflow_internal.deprecated.LogicalRun( BuildFeedPathInstruction) oneflow._oneflow_internal.deprecated.LogicalRun( BuildModelSaveInstruction)
def _EagerRunModelInit(var_op_conf): op_conf, _ = _GenModelInitOpConfAndRetLbi(var_op_conf) bn_in_op2blob_object = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject( ) def BuildModelInitInstruction(builder): upstream_signature = op_node_signature_pb.OpNodeSignature() op_conf.scope_symbol_id = oneflow.current_scope().symbol_id op_attribute = c_api_util.InferOpConf(op_conf, upstream_signature) parallel_conf = ( oneflow.current_scope().device_parallel_desc_symbol.parallel_conf) cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString( str(op_attribute)) builder.StatelessCall(cfg_op_attribute, parallel_conf, bn_in_op2blob_object, boxing_util.BoxingTo) sess = session_ctx.GetDefaultSession() with scope_util.ScopeContext(scope_util.MakeScope(_BuildNotMirroredScope)): oneflow._oneflow_internal.deprecated.LogicalRun( BuildModelInitInstruction) return bn_in_op2blob_object["out_0"]
def open(self, job_name, signature, batch_size=None): self._check_status(self.SessionStatus.OPEN) c_api_util.JobBuildAndInferCtx_Open(job_name) self.set_job_signature(job_name, signature) if isinstance(batch_size, int): self.set_job_batch_size(job_name, batch_size) job_conf = self._get_job_conf(job_name) c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf) tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds( self.config_proto_.resource ) scope = scope_util.MakeInitialScope( job_conf, *tag_and_dev_ids, self.is_mirrored_ ) with runtime_mode.ModeScope(runtime_mode.GLOBAL_MODE): with scope_util.ScopeContext(scope): yield self oneflow_api.JobBuildAndInferCtx_Close()
def InterpretScope(session, function_desc, config_proto): job_conf = function_desc.job_config_proto job_conf.job_name = function_desc.job_func.__name__ placement_scope = function_desc.function_attribute.default_placement_scope if placement_scope is None: tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds(session.resource) else: assert isinstance(placement_scope, placement_ctx.EmptyPlacementScope) tag_and_dev_ids = ( placement_scope.device_tag, placement_scope.machine_device_ids, ) distribute_strategy = function_desc.function_attribute.default_distribute_strategy if distribute_strategy is None: distribute_strategy = distribute_util.DistributeConsistentStrategy() is_mirrored = isinstance( distribute_strategy, distribute_util.DistributeMirroredStrategy ) scope = scope_util.MakeInitialScope(job_conf, *tag_and_dev_ids, is_mirrored) with _JobBuildAndInferCtx(job_conf.job_name), distribute_strategy: c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf) with runtime_mode.ModeScope(runtime_mode.GLOBAL_MODE): with scope_util.ScopeContext(scope): yield
def GetNormalModePlacementScope(device_tag, machine_device_ids): sess = session_ctx.GetDefaultSession() scope = scope_util.MakeScope( lambda old_scope, builder: old_scope.BuildWithNewParallelDesc( builder, device_tag, machine_device_ids)) return scope_util.ScopeContext(scope)