def _get_access_token(self) -> str: if self.authenticator != 'oauth': raise InternalException('Can only get access tokens for oauth') missing = any(x is None for x in (self.oauth_client_id, self.oauth_client_secret, self.token)) if missing: raise InternalException( 'need a client ID a client secret, and a refresh token to get ' 'an access token') # should the full url be a config item? token_url = _TOKEN_REQUEST_URL.format(self.account) # I think this is only used to redirect on success, which we ignore # (it does not have to match the integration's settings in snowflake) redirect_uri = 'http://localhost:9999' data = { 'grant_type': 'refresh_token', 'refresh_token': self.token, 'redirect_uri': redirect_uri } auth = base64.b64encode( f'{self.oauth_client_id}:{self.oauth_client_secret}'.encode( 'ascii')).decode('ascii') headers = { 'Authorization': f'Basic {auth}', 'Content-type': 'application/x-www-form-urlencoded;charset=utf-8' } result = requests.post(token_url, headers=headers, data=data) result_json = result.json() if 'access_token' not in result_json: raise DatabaseException(f'Did not get a token: {result_json}') return result_json['access_token']
def process_freshness_result( result: FreshnessNodeResult) -> FreshnessNodeOutput: unique_id = result.node.unique_id if result.status == FreshnessStatus.RuntimeErr: return SourceFreshnessRuntimeError( unique_id=unique_id, error=result.message, status=FreshnessErrorEnum.runtime_error, ) # we know that this must be a SourceFreshnessResult if not isinstance(result, SourceFreshnessResult): raise InternalException( 'Got {} instead of a SourceFreshnessResult for a ' 'non-error result in freshness execution!'.format(type(result))) # if we're here, we must have a non-None freshness threshold criteria = result.node.freshness if criteria is None: raise InternalException( 'Somehow evaluated a freshness result for a source ' 'that has no freshness criteria!') return SourceFreshnessOutput(unique_id=unique_id, max_loaded_at=result.max_loaded_at, snapshotted_at=result.snapshotted_at, max_loaded_at_time_ago_in_s=result.age, status=result.status, criteria=criteria, adapter_response=result.adapter_response)
def _runtime_initialize(self): super()._runtime_initialize() if self.manifest is None or self.graph is None: raise InternalException( '_runtime_initialize never loaded the manifest and graph!' ) self.job_queue = self.get_graph_queue() # we use this a couple times. order does not matter. self._flattened_nodes = [] for uid in self.job_queue.get_selected_nodes(): if uid in self.manifest.nodes: self._flattened_nodes.append(self.manifest.nodes[uid]) elif uid in self.manifest.sources: self._flattened_nodes.append(self.manifest.sources[uid]) else: raise InternalException( f'Node selection returned {uid}, expected a node or a ' f'source' ) self.num_nodes = len([ n for n in self._flattened_nodes if not n.is_ephemeral_model ])
def _get_compiled_model( self, manifest: Manifest, cte_id: str, extra_context: Dict[str, Any], ) -> NonSourceCompiledNode: if cte_id not in manifest.nodes: raise InternalException( f'During compilation, found a cte reference that could not be ' f'resolved: {cte_id}') cte_model = manifest.nodes[cte_id] if getattr(cte_model, 'compiled', False): assert isinstance(cte_model, tuple(COMPILED_TYPES.values())) return cast(NonSourceCompiledNode, cte_model) elif cte_model.is_ephemeral_model: # this must be some kind of parsed node that we can compile. # we know it's not a parsed source definition assert isinstance(cte_model, tuple(COMPILED_TYPES)) # update the node so node = self.compile_node(cte_model, manifest, extra_context) manifest.sync_update_node(node) return node else: raise InternalException( f'During compilation, found an uncompiled cte that ' f'was not an ephemeral model: {cte_id}')
def _expect_row_value(key: str, row: agate.Row): if key not in row.keys(): raise InternalException( 'Got a row without "{}" column, columns: {}' .format(key, row.keys()) ) return row[key]
def _get_list_dicts(dct: Dict[str, Any], key: str) -> List[RawDefinition]: result: List[RawDefinition] = [] if key not in dct: raise InternalException( f'Expected to find key {key} in dict, only found {list(dct)}') values = dct[key] if not isinstance(values, list): raise ValidationException( f'Invalid value type {type(values)} in key "{key}" ' f'(value "{values}")') for value in values: if isinstance(value, dict): for value_key in value: if not isinstance(value_key, str): raise ValidationException( f'Expected all keys to "{key}" dict to be strings, ' f'but "{value_key}" is a "{type(value_key)}"') result.append(value) elif isinstance(value, str): result.append(value) else: raise ValidationException( f'Invalid value type {type(value)} in key "{key}", expected ' f'dict or str (value: {value}).') return result
def compile_manifest(self): if self.manifest is None: raise InternalException( 'compile_manifest called before manifest was loaded') adapter = get_adapter(self.config) compiler = adapter.get_compiler() self.graph = compiler.compile(self.manifest)
def execute(self, model, manifest): context = dbt.context.runtime.generate(model, self.config, manifest) materialization_macro = manifest.get_materialization_macro( self.config.project_name, model.get_materialization(), self.adapter.type()) if materialization_macro is None: missing_materialization(model, self.adapter.type()) if 'config' not in context: raise InternalException( 'Invalid materialization context generated, missing config: {}' .format(context)) context_config = context['config'] hook_ctx = self.adapter.pre_model_hook(context_config) try: result = materialization_macro.generator(context)() finally: self.adapter.post_model_hook(context_config, hook_ctx) for relation in self._materialization_relations(result, model): self.adapter.cache_added(relation.incorporate(dbt_created=True)) return self._build_run_model_result(model, context)
def track_model_run(index, num_nodes, run_model_result): if tracking.active_user is None: raise InternalException('cannot track model run with no active user') invocation_id = tracking.active_user.invocation_id tracking.track_model_run({ "invocation_id": invocation_id, "index": index, "total": num_nodes, "execution_time": run_model_result.execution_time, "run_status": run_model_result.status, "run_skipped": run_model_result.skip, "run_error": None, "model_materialization": run_model_result.node.get_materialization(), "model_id": utils.get_hash(run_model_result.node), "hashed_contents": utils.get_hashed_contents(run_model_result.node), "timing": [t.to_dict() for t in run_model_result.timing], })
def get_result(self) -> RemoteResult: if self.process is None: raise InternalException('get_result() called before handle()') flags = self.task.get_flags() # If we blocked the manifest tasks, we need to un-set them on exit. # threaded mode handles this on its own. with get_results_context(flags, self.manager, lambda: self.logs): try: with list_handler(self.logs): try: result = self._wait_for_results() finally: if not self._single_threaded: self.process.join() except RPCException as exc: # RPC Exceptions come already preserialized for the jsonrpc # framework exc.logs = [log.to_dict() for log in self.logs] exc.tags = self.tags raise # results get real logs result.logs = self.logs[:] return result
def _wait_for_results(self) -> RemoteResult: """Wait for results off the queue. If there is an exception raised, raise an appropriate RPC exception. This does not handle joining, but does terminate the process if it timed out. """ if (self.subscriber is None or self.started is None or self.process is None): raise InternalException( '_wait_for_results() called before handle()') try: msg = self.subscriber.dispatch_until_exit( started=self.started, timeout=self.timeout, ) except dbt.exceptions.Exception as exc: raise dbt_error(exc) except Exception as exc: raise server_error(exc) if isinstance(msg, QueueErrorMessage): raise RPCException.from_error(msg.error) elif isinstance(msg, QueueTimeoutMessage): if not self._single_threaded: self.process.terminate() raise timeout_error(self.timeout) elif isinstance(msg, QueueResultMessage): return msg.result else: raise dbt.exceptions.InternalException( f'Invalid message type {msg.message_type} ({msg})')
def execute(self, test: CompiledTestNode, manifest: Manifest): if isinstance(test, CompiledDataTestNode): failed_rows = self.execute_data_test(test) elif isinstance(test, CompiledSchemaTestNode): failed_rows = self.execute_schema_test(test) else: raise InternalException( f'Expected compiled schema test or compiled data test, got ' f'{type(test)}') severity = test.config.severity.upper() thread_id = threading.current_thread().name status = None if failed_rows == 0: status = TestStatus.Pass elif severity == 'ERROR' or flags.WARN_ERROR: status = TestStatus.Fail else: status = TestStatus.Warn return RunResult(node=test, status=status, timing=[], thread_id=thread_id, execution_time=0, message=int(failed_rows), adapter_response={})
def _get_exec_node(self): if self.manifest is None: raise InternalException('manifest not set in _get_exec_node') results = ParseResult.rpc() macro_overrides = {} macros = self.args.macros sql, macros = self._extract_request_data(self.args.sql) if macros: macro_parser = RPCMacroParser(results, self.config) for node in macro_parser.parse_remote(macros): macro_overrides[node.unique_id] = node self.manifest.macros.update(macro_overrides) rpc_parser = RPCCallParser( results=results, project=self.config, root_project=self.config, macro_manifest=self.manifest, ) rpc_node = rpc_parser.parse_remote(sql, self.args.name) add_new_refs(manifest=self.manifest, config=self.config, node=rpc_node, macros=macro_overrides) # don't write our new, weird manifest! self.linker = compile_manifest(self.config, self.manifest, write=False) self._compile_ancestors(rpc_node.unique_id) return rpc_node
def get_adapter_plugins(self, name: Optional[str]) -> List[AdapterPlugin]: """Iterate over the known adapter plugins. If a name is provided, iterate in dependency order over the named plugin and its dependencies. """ if name is None: return list(self.plugins.values()) plugins: List[AdapterPlugin] = [] seen: Set[str] = set() plugin_names: List[str] = [name] while plugin_names: plugin_name = plugin_names[0] plugin_names = plugin_names[1:] try: plugin = self.plugins[plugin_name] except KeyError: raise InternalException( f'No plugin found for {plugin_name}') from None plugins.append(plugin) seen.add(plugin_name) if plugin.dependencies is None: continue for dep in plugin.dependencies: if dep not in seen: plugin_names.append(dep) return plugins
def run(self) -> CatalogResults: compile_results = None if self.args.compile: compile_results = CompileTask.run(self) if any(r.error is not None for r in compile_results): dbt.ui.printer.print_timestamped_line( 'compile failed, cannot generate docs') return CatalogResults(nodes={}, sources={}, generated_at=datetime.utcnow(), errors=None, _compile_results=compile_results) else: self.manifest = get_full_manifest(self.config) shutil.copyfile(DOCS_INDEX_FILE_PATH, os.path.join(self.config.target_path, 'index.html')) if self.manifest is None: raise InternalException('self.manifest was None in run!') adapter = get_adapter(self.config) with adapter.connection_named('generate_catalog'): dbt.ui.printer.print_timestamped_line("Building catalog") catalog_table, exceptions = adapter.get_catalog(self.manifest) catalog_data: List[PrimitiveDict] = [ dict(zip(catalog_table.column_names, map(_coerce_decimal, row))) for row in catalog_table ] catalog = Catalog(catalog_data) errors: Optional[List[str]] = None if exceptions: errors = [str(e) for e in exceptions] nodes, sources = catalog.make_unique_id_map(self.manifest) results = self.get_catalog_results( nodes=nodes, sources=sources, generated_at=datetime.utcnow(), compile_results=compile_results, errors=errors, ) path = os.path.join(self.config.target_path, CATALOG_FILENAME) results.write(path) if self.args.compile: write_manifest(self.config, self.manifest) if exceptions: logger.error( 'dbt encountered {} failure{} while writing the catalog'. format(len(exceptions), (len(exceptions) != 1) * 's')) dbt.ui.printer.print_timestamped_line('Catalog written to {}'.format( os.path.abspath(path))) return results
def _get_exec_node(self): if self.manifest is None: raise InternalException('manifest not set in _get_exec_node') results = ParseResult.rpc() macro_overrides = {} macros = self.args.macros sql, macros = self._extract_request_data(self.args.sql) if macros: macro_parser = RPCMacroParser(results, self.config) for node in macro_parser.parse_remote(macros): macro_overrides[node.unique_id] = node self.manifest.macros.update(macro_overrides) rpc_parser = RPCCallParser( results=results, project=self.config, root_project=self.config, macro_manifest=self.manifest, ) rpc_node = rpc_parser.parse_remote(sql, self.args.name) add_new_refs(manifest=self.manifest, config=self.config, node=rpc_node, macros=macro_overrides) # don't write our new, weird manifest! adapter = get_adapter(self.config) compiler = adapter.get_compiler() self.graph = compiler.compile(self.manifest, write=False) # previously, this compiled the ancestors, but they are compiled at # runtime now. return rpc_node
def execute(self, compiled_node, manifest): # we should only be here if we compiled_node.has_freshness, and # therefore loaded_at_field should be a str. If this invariant is # broken, raise! if compiled_node.loaded_at_field is None: raise InternalException( 'Got to execute for source freshness of a source that has no ' 'loaded_at_field!') relation = self.adapter.Relation.create_from_source(compiled_node) # given a Source, calculate its fresnhess. with self.adapter.connection_for(compiled_node): self.adapter.clear_transaction() freshness = self.adapter.calculate_freshness( relation, compiled_node.loaded_at_field, compiled_node.freshness.filter, manifest=manifest) status = compiled_node.freshness.status(freshness['age']) return SourceFreshnessResult(node=compiled_node, status=status, thread_id=threading.current_thread().name, timing=[], execution_time=0, message=None, adapter_response={}, **freshness)
def on_skip(self): schema_name = self.node.schema node_name = self.node.name error = None if not self.node.is_ephemeral_model: # if this model was skipped due to an upstream ephemeral model # failure, print a special 'error skip' message. if self._skip_caused_by_ephemeral_failure(): print_skip_caused_by_error(self.node, schema_name, node_name, self.node_index, self.num_nodes, self.skip_cause) if self.skip_cause is None: # mypy appeasement raise InternalException( 'Skip cause not set but skip was somehow caused by ' 'an ephemeral failure') # set an error so dbt will exit with an error code error = ( 'Compilation Error in {}, caused by compilation error ' 'in referenced ephemeral model {}'.format( self.node.unique_id, self.skip_cause.node.unique_id)) else: print_skip_line(self.node, schema_name, node_name, self.node_index, self.num_nodes) node_result = RunModelResult(self.node, skip=True, error=error) return node_result
def _handle_result(self, result): """Mark the result as completed, insert the `CompileResultNode` into the manifest, and mark any descendants (potentially with a 'cause' if the result was an ephemeral model) as skipped. """ is_ephemeral = result.node.is_ephemeral_model if not is_ephemeral: self.node_results.append(result) node = result.node if self.manifest is None: raise InternalException('manifest was None in _handle_result') if isinstance(node, ParsedSourceDefinition): self.manifest.update_source(node) else: self.manifest.update_node(node) if result.status == NodeStatus.Error: if is_ephemeral: cause = result else: cause = None self._mark_dependent_errors(node.unique_id, result, cause)
def run(self): """ Run dbt for the query, based on the graph. """ self._runtime_initialize() if self._flattened_nodes is None: raise InternalException( 'after _runtime_initialize, _flattened_nodes was still None' ) if len(self._flattened_nodes) == 0: logger.warning("WARNING: Nothing to do. Try checking your model " "configs and model specification args") return self.get_result( results=[], generated_at=datetime.utcnow(), elapsed_time=0.0, ) else: with TextOnly(): logger.info("") selected_uids = frozenset(n.unique_id for n in self._flattened_nodes) result = self.execute_with_hooks(selected_uids) if flags.WRITE_JSON: self.write_manifest() self.write_result(result) self.task_end_messages(result.results) return result
def _runtime_initialize(): with open(graph_path, "rb") as f: task.graph = Graph(graph=pickle.load(f)) with open(manifest_path) as f: loaded_manifest = json.load(f) # If I'm taking something from this experience, it's this Mashumaro # package. I spent a long time trying to build a manifest, when I only # had to call from_dict. Amazing stuff. Manifest.from_dict(loaded_manifest) task.manifest = Manifest.from_dict(loaded_manifest) # What follows is the remaining _runtime_initialize method of # GraphRunnableTask. task.job_queue = task.get_graph_queue() task._flattened_nodes = [] for uid in task.job_queue.get_selected_nodes(): if uid in task.manifest.nodes: task._flattened_nodes.append(task.manifest.nodes[uid]) elif uid in task.manifest.sources: task._flattened_nodes.append(task.manifest.sources[uid]) else: raise InternalException( f"Node selection returned {uid}, expected a node or a " f"source" ) task.num_nodes = len( [n for n in task._flattened_nodes if not n.is_ephemeral_model] )
def parse_source( self, target: UnpatchedSourceDefinition) -> ParsedSourceDefinition: source = target.source table = target.table refs = ParserRef.from_target(table) unique_id = target.unique_id description = table.description or '' meta = table.meta or {} source_description = source.description or '' loaded_at_field = table.loaded_at_field or source.loaded_at_field freshness = merge_freshness(source.freshness, table.freshness) quoting = source.quoting.merged(table.quoting) # path = block.path.original_file_path source_meta = source.meta or {} # make sure we don't do duplicate tags from source + table tags = sorted(set(itertools.chain(source.tags, table.tags))) config = self.config_generator.calculate_node_config( config_calls=[], fqn=target.fqn, resource_type=NodeType.Source, project_name=self.project.project_name, base=False, ) if not isinstance(config, SourceConfig): raise InternalException( f'Calculated a {type(config)} for a source, but expected ' f'a SourceConfig') default_database = self.root_project.credentials.database return ParsedSourceDefinition( package_name=target.package_name, database=(source.database or default_database), schema=(source.schema or source.name), identifier=(table.identifier or table.name), root_path=target.root_path, path=target.path, original_file_path=target.original_file_path, columns=refs.column_info, unique_id=unique_id, name=table.name, description=description, external=table.external, source_name=source.name, source_description=source_description, source_meta=source_meta, meta=meta, loader=source.loader, loaded_at_field=loaded_at_field, freshness=freshness, quoting=quoting, resource_type=NodeType.Source, fqn=target.fqn, tags=tags, config=config, )
def get_compiled_path(cls, block: FileBlock): # we do it this way to make mypy happy if not isinstance(block, RPCBlock): raise InternalException( 'While parsing RPC calls, got an actual file block instead of ' 'an RPC block: {}'.format(block)) return os.path.join('rpc', block.name)
def handle_request(self) -> Result: if self.real_task is None: raise InternalException( 'CLI task is in a bad state: handle_request called with no ' 'real_task set!' ) # we parsed args from the cli, so we're set on that front return self.real_task.handle_request()
def get_method(self, method: MethodName, method_arguments: List[str]) -> SelectorMethod: if method not in self.SELECTOR_METHODS: raise InternalException( f'Method name "{method}" is a valid node selection ' f'method name, but it is not handled') cls: Type[SelectorMethod] = self.SELECTOR_METHODS[method] return cls(self.manifest, self.previous_state, method_arguments)
def get_node_selector(self): if self.manifest is None or self.graph is None: raise InternalException( 'manifest and graph must be set to get perform node selection') return FreshnessSelector( graph=self.graph, manifest=self.manifest, previous_state=self.previous_state, )
def create_from( cls: Type[Self], config: HasQuoting, node: Union[CompiledNode, ParsedNode, ParsedSourceDefinition], **kwargs: Any, ) -> Self: if node.resource_type == NodeType.Source: if not isinstance(node, ParsedSourceDefinition): raise InternalException( 'type mismatch, expected ParsedSourceDefinition but got {}' .format(type(node))) return cls.create_from_source(node, **kwargs) else: if not isinstance(node, (ParsedNode, CompiledNode)): raise InternalException( 'type mismatch, expected ParsedNode or CompiledNode but ' 'got {}'.format(type(node))) return cls.create_from_node(config, node, **kwargs)
def get_node_project(self, project_name: str): if project_name == self._active_project.project_name: return self._active_project dependencies = self._active_project.load_dependencies() if project_name not in dependencies: raise InternalException( f'Project name {project_name} not found in dependencies ' f'(found {list(dependencies)})') return dependencies[project_name]
def get_parameters(cls) -> Type[Parameters]: argspec = inspect.getfullargspec(cls.set_args) annotations = argspec.annotations if 'params' not in annotations: raise InternalException( 'set_args must have parameter named params with a valid ' 'RPCParameters type definition (no params annotation found)') params_type = annotations['params'] if not issubclass(params_type, RPCParameters): raise InternalException( 'set_args must have parameter named params with a valid ' 'RPCParameters type definition (got {}, expected ' 'RPCParameters subclass)'.format(params_type)) if params_type is RPCParameters: raise InternalException( 'set_args must have parameter named params with a valid ' 'RPCParameters type definition (got RPCParameters itself!)') return params_type
def descendants(self, node, max_depth: Optional[int] = None) -> Set[str]: """Returns all nodes reachable from `node` in `graph`""" if not self.graph.has_node(node): raise InternalException(f'Node {node} not found in the graph!') des = nx.single_source_shortest_path_length(G=self.graph, source=node, cutoff=max_depth)\ .keys() return des - {node}