def __init__(self, args, config, task_types: TaskTypes) -> None: self.args = args self.config = config self.manifest: Optional[Manifest] = None self._task_types: TaskTypes = task_types self.active_tasks: TaskHandlerMap = {} self.gc = GarbageCollector(active_tasks=self.active_tasks) self.last_parse: LastParse = LastParse(state=ManifestStatus.Init) self._lock: dbt.flags.MP_CONTEXT.Lock = dbt.flags.MP_CONTEXT.Lock() self._reloader: Optional[ManifestReloader] = None self.reload_manifest()
class TaskManager: def __init__(self, args, config, task_types: TaskTypes) -> None: self.args = args self.config = config self.manifest: Optional[Manifest] = None self._task_types: TaskTypes = task_types self.active_tasks: TaskHandlerMap = {} self.gc = GarbageCollector(active_tasks=self.active_tasks) self.last_parse: LastParse = LastParse(state=ManifestStatus.Init) self._lock: dbt.flags.MP_CONTEXT.Lock = dbt.flags.MP_CONTEXT.Lock() self._reloader: Optional[ManifestReloader] = None self.reload_manifest() def single_threaded(self): return SINGLE_THREADED_WEBSERVER or self.args.single_threaded def _reload_task_manager_thread(self, reloader: ManifestReloader): """This function can only be running once at a time, as it runs in the signal handler we replace """ # compile in a thread that will fix up the tag manager when it's done reloader.start() # only assign to _reloader here, to avoid calling join() before start() self._reloader = reloader def _reload_task_manager_fg(self, reloader: ManifestReloader): """Override for single-threaded mode to run in the foreground""" # just reload directly reloader.reload_manifest() def reload_manifest(self) -> bool: """Reload the manifest using a manifest reloader. Returns False if the reload was not started because it was already running. """ if not self.set_parsing(): return False if self._reloader is not None: # join() the existing reloader self._reloader.join() # perform the reload reloader = ManifestReloader(self) if self.single_threaded(): self._reload_task_manager_fg(reloader) else: self._reload_task_manager_thread(reloader) return True def reload_config(self): config = self.config.from_args(self.args) self.config = config return config def add_request(self, request_handler: TaskHandlerProtocol): self.active_tasks[request_handler.task_id] = request_handler def get_request(self, task_id: TaskID) -> TaskHandlerProtocol: try: return self.active_tasks[task_id] except KeyError: # We don't recognize that ID. raise dbt.exceptions.UnknownAsyncIDException(task_id) from None def _get_manifest_callable( self, task: Type[RemoteManifestMethod] ) -> Union[UnconditionalError, RemoteManifestMethod]: state = self.last_parse.state if state == ManifestStatus.Compiling: return CurrentlyCompiling() elif state == ManifestStatus.Error: return ParseError(self.last_parse.error) else: if self.manifest is None: raise dbt.exceptions.InternalException( f'Manifest should not be None if the last parse state is ' f'{state}') return task(self.args, self.config, self.manifest) def rpc_task(self, method_name: str) -> Union[UnconditionalError, RemoteMethod]: with self._lock: task = self._task_types[method_name] if issubclass(task, RemoteBuiltinMethod): return task(self) elif issubclass(task, RemoteManifestMethod): return self._get_manifest_callable(task) elif issubclass(task, RemoteMethod): return task(self.args, self.config) else: raise dbt.exceptions.InternalException( f'Got a task with an invalid type! {task} with method ' f'name {method_name} has a type of {task.__class__}, ' f'should be a RemoteMethod') def ready(self) -> bool: with self._lock: return self.last_parse.state == ManifestStatus.Ready def set_parsing(self) -> bool: with self._lock: if self.last_parse.state == ManifestStatus.Compiling: return False self.last_parse = LastParse(state=ManifestStatus.Compiling) return True def parse_manifest(self) -> None: self.manifest = get_full_manifest(self.config) def set_compile_exception(self, exc, logs=List[LogMessage]) -> None: assert self.last_parse.state == ManifestStatus.Compiling, \ f'invalid state {self.last_parse.state}' self.last_parse = LastParse(error={'message': str(exc)}, state=ManifestStatus.Error, logs=logs) def set_ready(self, logs=List[LogMessage]) -> None: assert self.last_parse.state == ManifestStatus.Compiling, \ f'invalid state {self.last_parse.state}' self.last_parse = LastParse(state=ManifestStatus.Ready, logs=logs) def methods(self) -> Set[str]: with self._lock: return set(self._task_types) def currently_compiling(self, *args, **kwargs): """Raise an RPC exception to trigger the error handler.""" raise dbt_error(dbt.exceptions.RPCCompiling('compile in progress')) def compilation_error(self, *args, **kwargs): """Raise an RPC exception to trigger the error handler.""" raise dbt_error(dbt.exceptions.RPCLoadException(self.last_parse.error)) def get_handler( self, method, http_request, json_rpc_request) -> Optional[Union[WrappedHandler, RemoteMethod]]: # get_handler triggers a GC check. TODO: does this go somewhere else? self.gc_as_required() if method not in self._task_types: return None task = self.rpc_task(method) return task def task_table(self) -> List[TaskRow]: rows: List[TaskRow] = [] now = datetime.utcnow() with self._lock: for task in self.active_tasks.values(): rows.append(task.make_task_row(now)) return rows def gc_as_required(self) -> None: with self._lock: return self.gc.collect_as_required() def gc_safe( self, task_ids: Optional[List[uuid.UUID]] = None, before: Optional[datetime] = None, settings: Optional[GCSettings] = None, ) -> GCResult: with self._lock: return self.gc.collect_selected( task_ids=task_ids, before=before, settings=settings, )