def valid_snapshot_target(self, relation: BaseRelation) -> None: """Ensure that the target relation is valid, by making sure it has the expected columns. :param Relation relation: The relation to check :raises CompilationException: If the columns are incorrect. """ if not isinstance(relation, self.Relation): invalid_type_error(method_name='valid_snapshot_target', arg_name='relation', got_value=relation, expected_type=self.Relation) columns = self.get_columns_in_relation(relation) names = set(c.name.lower() for c in columns) expanded_keys = ('scd_id', 'valid_from', 'valid_to') extra = [] missing = [] for legacy in expanded_keys: desired = 'dbt_' + legacy if desired not in names: missing.append(desired) if legacy in names: extra.append(legacy) if missing: if extra: msg = ('Snapshot target has ("{}") but not ("{}") - is it an ' 'unmigrated previous version archive?'.format( '", "'.join(extra), '", "'.join(missing))) else: msg = ('Snapshot target is not a snapshot table (missing "{}")' .format('", "'.join(missing))) raise_compiler_error(msg)
def __call__(self, *args) -> str: if len(args) != 2: raise_compiler_error( f"source() takes exactly two arguments ({len(args)} given)", self.model) self.model.sources.append(list(args)) return ''
def __init__( self, test: Dict[str, Any], target: Target, package_name: str, render_ctx: Dict[str, Any], column_name: str = None, ) -> None: test_name, test_args = self.extract_test_args(test, column_name) self.args: Dict[str, Any] = test_args self.package_name: str = package_name self.target: Target = target match = self.TEST_NAME_PATTERN.match(test_name) if match is None: raise_compiler_error( 'Test name string did not match expected pattern: {}'.format( test_name)) groups = match.groupdict() self.name: str = groups['test_name'] self.namespace: str = groups['test_namespace'] self.modifiers: Dict[str, Any] = {} for key, default in self.MODIFIER_ARGS.items(): value = self.args.pop(key, default) if isinstance(value, str): value = get_rendered(value, render_ctx) self.modifiers[key] = value if self.namespace is not None: self.package_name = self.namespace compiled_name, fqn_name = self.get_test_name() self.compiled_name: str = compiled_name self.fqn_name: str = fqn_name
def get_missing_var(self, var_name): dct = {k: self.merged[k] for k in self.merged} pretty_vars = json.dumps(dct, sort_keys=True, indent=4) msg = self.UndefinedVarError.format( var_name, self.node_name, pretty_vars ) raise_compiler_error(msg, self.node)
def create_ephemeral_relation(self, target_model: NonSourceNode, name: str) -> RelationProxy: # In operations, we can't ref() ephemeral nodes, because ParsedMacros # do not support set_cte raise_compiler_error( 'Operations can not ref() ephemeral nodes, but {} is ephemeral'. format(target_model.name), self.model)
def __call__(self, *args: str) -> RelationProxy: if len(args) != 2: raise_compiler_error( f"source() takes exactly two arguments ({len(args)} given)", self.model) self.validate_args(args[0], args[1]) return self.resolve(args[0], args[1])
def write(self, payload: str) -> str: # macros/source defs aren't 'writeable'. if isinstance(self.model, (ParsedMacro, ParsedSourceDefinition)): raise_compiler_error('cannot "write" macros or sources') self.model.build_path = self.model.write_node(self.config.target_path, 'run', payload) return ''
def exception_handler(self) -> Iterator[None]: try: yield except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e: raise_compiler_error(str(e), self.macro) except CompilationException as e: e.stack.append(self.macro) raise e
def persist_column_docs(self) -> bool: persist_docs = self.get('persist_docs', default={}) if not isinstance(persist_docs, dict): raise_compiler_error( f"Invalid value provided for 'persist_docs'. Expected dict " f"but received {type(persist_docs)}") return persist_docs.get('columns', False)
def try_or_compiler_error( self, message_if_exception: str, func: Callable, *args, **kwargs ) -> Any: try: return func(*args, **kwargs) except Exception: raise_compiler_error( message_if_exception, self.model )
def cache_added(self, relation: Optional[BaseRelation]) -> str: """Cache a new relation in dbt. It will show up in `list relations`.""" if relation is None: name = self.nice_connection_name() raise_compiler_error( 'Attempted to cache a null relation for {}'.format(name)) if dbt.flags.USE_CACHE: self.cache.add(relation) # so jinja doesn't render things return ''
def create_relation(self, target_model: ManifestNode, name: str) -> RelationProxy: if target_model.is_ephemeral_model: # In operations, we can't ref() ephemeral nodes, because # ParsedMacros do not support set_cte raise_compiler_error( 'Operations can not ref() ephemeral nodes, but {} is ephemeral' .format(target_model.name), self.model) else: return super().create_relation(target_model, name)
def _transform_config(self, config): for oldkey in ('pre_hook', 'post_hook'): if oldkey in config: newkey = oldkey.replace('_', '-') if newkey in config: raise_compiler_error( 'Invalid config, has conflicting keys "{}" and "{}"'. format(oldkey, newkey), self.model) config[newkey] = config.pop(oldkey) return config
def get_from_package(self, package_name: Optional[str], name: str) -> Optional[MacroGenerator]: pkg: FlatNamespace if package_name is None: return self.get(name) elif package_name == GLOBAL_PROJECT_NAME: return self.global_project_namespace.get(name) elif package_name in self.packages: return self.packages[package_name].get(name) else: raise_compiler_error(f"Could not find package '{package_name}'")
def packages_for_node(self) -> Iterable[Project]: dependencies = self.config.load_dependencies() package_name = self.node.package_name if package_name != self.config.project_name: if package_name not in dependencies: # I don't think this is actually reachable raise_compiler_error( f'Node package named {package_name} not found!', self.node) yield dependencies[package_name] yield self.config
def cache_dropped(self, relation: Optional[BaseRelation]) -> str: """Drop a relation in dbt. It will no longer show up in `list relations`, and any bound views will be dropped from the cache """ if relation is None: name = self.nice_connection_name() raise_compiler_error( 'Attempted to drop a null relation for {}'.format(name)) if dbt.flags.USE_CACHE: self.cache.drop(relation) return ''
def __eq__(self, other: object) -> bool: if not isinstance(other, MaterializationCandidate): return NotImplemented equal = (self.specificity == other.specificity and self.locality == other.locality) if equal: raise_compiler_error( 'Found two materializations with the name {} (packages {} and ' '{}). dbt cannot resolve this ambiguity'.format( self.macro.name, self.macro.package_name, other.macro.package_name)) return equal
def tags(self) -> List[str]: tags = self.modifiers.get('tags', []) if isinstance(tags, str): tags = [tags] if not isinstance(tags, list): raise_compiler_error( f'got {tags} ({type(tags)}) for tags, expected a list of ' f'strings') for tag in tags: if not isinstance(tag, str): raise_compiler_error( f'got {tag} ({type(tag)}) for tag, expected a str') return tags[:]
def get_from_package(self, package_name: Optional[str], name: str) -> Optional[MacroGenerator]: macro = None if package_name is None: macro = self.macro_resolver.macros_by_name.get(name) elif package_name == GLOBAL_PROJECT_NAME: macro = self.macro_resolver.internal_packages_namespace.get(name) elif package_name in self.resolver.packages: macro = self.macro_resolver.packages[package_name].get(name) else: raise_compiler_error(f"Could not find package '{package_name}'") macro_func = MacroGenerator(macro, self.ctx, self.node, self.thread_ctx) return macro_func
def load_agate_table(self) -> agate.Table: if not isinstance(self.model, (ParsedSeedNode, CompiledSeedNode)): raise_compiler_error( 'can only load_agate_table for seeds (got a {})'.format( self.model.resource_type)) path = os.path.join(self.model.root_path, self.model.original_file_path) column_types = self.model.config.column_types try: table = agate_helper.from_csv(path, text_columns=column_types) except ValueError as e: raise_compiler_error(str(e)) table.original_abspath = os.path.abspath(path) return table
def extract_test_args(test, name=None) -> Tuple[str, Dict[str, Any]]: if not isinstance(test, dict): raise_compiler_error( 'test must be dict or str, got {} (value {})'.format( type(test), test)) test = list(test.items()) if len(test) != 1: raise_compiler_error( 'test definition dictionary must have exactly one key, got' ' {} instead ({} keys)'.format(test, len(test))) test_name, test_args = test[0] if not isinstance(test_args, dict): raise_compiler_error( 'test arguments must be dict, got {} (value {})'.format( type(test_args), test_args)) if not isinstance(test_name, str): raise_compiler_error( 'test name must be a str, got {} (value {})'.format( type(test_name), test_name)) test_args = deepcopy(test_args) if name is not None: test_args['column_name'] = name return test_name, test_args
def load_dependencies(self) -> Mapping[str, 'RuntimeConfig']: if self.dependencies is None: all_projects = {self.project_name: self} project_paths = itertools.chain(map(Path, PACKAGES.values()), self._get_project_directories()) for project_name, project in self.load_projects(project_paths): if project_name in all_projects: raise_compiler_error( f'dbt found more than one package with the name ' f'"{project_name}" included in this project. Package ' f'names must be unique in a project. Please rename ' f'one of these packages.') all_projects[project_name] = project self.dependencies = all_projects return self.dependencies
def execute_schema_test(self, test: CompiledSchemaTestNode): res, table = self.adapter.execute( test.compiled_sql, auto_begin=True, fetch=True, ) num_rows = len(table.rows) if num_rows != 1: num_cols = len(table.columns) raise_compiler_error( f"Bad test {test.test_metadata.name}: " f"Returned {num_rows} rows and {num_cols} cols, but expected " f"1 row and 1 column") return table[0][0]
def parse_cli_vars(var_string: str) -> Dict[str, Any]: try: cli_vars = yaml_helper.load_yaml_text(var_string) var_type = type(cli_vars) if var_type is dict: return cli_vars else: type_name = var_type.__name__ raise_compiler_error( "The --vars argument must be a YAML dictionary, but was " "of type '{}'".format(type_name)) except ValidationException: logger.error( "The YAML provided in the --vars argument is not valid.\n" ) raise
def quote_seed_column(self, column: str, quote_config: Optional[bool]) -> str: # this is the default for now quote_columns: bool = False if isinstance(quote_config, bool): quote_columns = quote_config elif quote_config is None: deprecations.warn('column-quoting-unset') else: raise_compiler_error( f'The seed configuration value of "quote_columns" has an ' f'invalid type {type(quote_config)}') if quote_columns: return self.quote(column) else: return column
def __call__(self, *args, **kwargs): if len(args) == 1 and len(kwargs) == 0: opts = args[0] elif len(args) == 0 and len(kwargs) > 0: opts = kwargs else: raise_compiler_error("Invalid inline model config", self.model) opts = self._transform_config(opts) # it's ok to have a parse context with no context config, but you must # not call it! if self.context_config is None: raise RuntimeException( 'At parse time, did not receive a context config') self.context_config.update_in_model_config(opts) return ''
def cache_renamed( self, from_relation: Optional[BaseRelation], to_relation: Optional[BaseRelation], ) -> str: """Rename a relation in dbt. It will show up with a new name in `list_relations`, but bound views will remain bound. """ if from_relation is None or to_relation is None: name = self.nice_connection_name() src_name = _relation_name(from_relation) dst_name = _relation_name(to_relation) raise_compiler_error('Attempted to rename {} to {} for {}'.format( src_name, dst_name, name)) if dbt.flags.USE_CACHE: self.cache.rename(from_relation, to_relation) return ''
def calculate_freshness( self, source: BaseRelation, loaded_at_field: str, filter: Optional[str], manifest: Optional[Manifest] = None ) -> Dict[str, Any]: """Calculate the freshness of sources in dbt, and return it""" kwargs: Dict[str, Any] = { 'source': source, 'loaded_at_field': loaded_at_field, 'filter': filter, } # run the macro table = self.execute_macro( FRESHNESS_MACRO_NAME, kwargs=kwargs, release=True, manifest=manifest ) # now we have a 1-row table of the maximum `loaded_at_field` value and # the current time according to the db. if len(table) != 1 or len(table[0]) != 2: raise_compiler_error( 'Got an invalid result from "{}" macro: {}'.format( FRESHNESS_MACRO_NAME, [tuple(r) for r in table] ) ) if table[0][0] is None: # no records in the table, so really the max_loaded_at was # infinitely long ago. Just call it 0:00 January 1 year UTC max_loaded_at = datetime(1, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) else: max_loaded_at = _utc(table[0][0], source, loaded_at_field) snapshotted_at = _utc(table[0][1], source, loaded_at_field) age = (snapshotted_at - max_loaded_at).total_seconds() return { 'max_loaded_at': max_loaded_at, 'snapshotted_at': snapshotted_at, 'age': age, }
def __init__( self, test: Dict[str, Any], target: Testable, package_name: str, render_ctx: Dict[str, Any], column_name: str = None, ) -> None: test_name, test_args = self.extract_test_args(test, column_name) self.args: Dict[str, Any] = test_args if 'model' in self.args: raise_compiler_error( 'Test arguments include "model", which is a reserved argument', ) self.package_name: str = package_name self.target: Testable = target self.args['model'] = self.build_model_str() match = self.TEST_NAME_PATTERN.match(test_name) if match is None: raise_compiler_error( 'Test name string did not match expected pattern: {}'.format( test_name)) groups = match.groupdict() self.name: str = groups['test_name'] self.namespace: str = groups['test_namespace'] self.modifiers: Dict[str, Any] = {} for key in self.MODIFIER_ARGS: value = self.args.pop(key, None) if isinstance(value, str): value = get_rendered(value, render_ctx, native=True) if value is not None: self.modifiers[key] = value if self.namespace is not None: self.package_name = self.namespace compiled_name, fqn_name = self.get_test_name() self.compiled_name: str = compiled_name self.fqn_name: str = fqn_name
def flatten(self): new = self.__class__() # make sure we don't have duplicates seen = {r.database.lower() for r in self if r.database} if len(seen) > 1: raise_compiler_error(str(seen)) for information_schema_name, schema in self.search(): path = { 'database': information_schema_name.database, 'schema': schema } new.add(information_schema_name.incorporate( path=path, quote_policy={'database': False}, include_policy={'database': False}, )) return new