예제 #1
0
    def valid_snapshot_target(self, relation: BaseRelation) -> None:
        """Ensure that the target relation is valid, by making sure it has the
        expected columns.

        :param Relation relation: The relation to check
        :raises CompilationException: If the columns are
            incorrect.
        """
        if not isinstance(relation, self.Relation):
            invalid_type_error(method_name='valid_snapshot_target',
                               arg_name='relation',
                               got_value=relation,
                               expected_type=self.Relation)

        columns = self.get_columns_in_relation(relation)
        names = set(c.name.lower() for c in columns)
        expanded_keys = ('scd_id', 'valid_from', 'valid_to')
        extra = []
        missing = []
        for legacy in expanded_keys:
            desired = 'dbt_' + legacy
            if desired not in names:
                missing.append(desired)
                if legacy in names:
                    extra.append(legacy)

        if missing:
            if extra:
                msg = ('Snapshot target has ("{}") but not ("{}") - is it an '
                       'unmigrated previous version archive?'.format(
                           '", "'.join(extra), '", "'.join(missing)))
            else:
                msg = ('Snapshot target is not a snapshot table (missing "{}")'
                       .format('", "'.join(missing)))
            raise_compiler_error(msg)
예제 #2
0
 def __call__(self, *args) -> str:
     if len(args) != 2:
         raise_compiler_error(
             f"source() takes exactly two arguments ({len(args)} given)",
             self.model)
     self.model.sources.append(list(args))
     return ''
예제 #3
0
    def __init__(
        self,
        test: Dict[str, Any],
        target: Target,
        package_name: str,
        render_ctx: Dict[str, Any],
        column_name: str = None,
    ) -> None:
        test_name, test_args = self.extract_test_args(test, column_name)
        self.args: Dict[str, Any] = test_args
        self.package_name: str = package_name
        self.target: Target = target

        match = self.TEST_NAME_PATTERN.match(test_name)
        if match is None:
            raise_compiler_error(
                'Test name string did not match expected pattern: {}'.format(
                    test_name))

        groups = match.groupdict()
        self.name: str = groups['test_name']
        self.namespace: str = groups['test_namespace']
        self.modifiers: Dict[str, Any] = {}
        for key, default in self.MODIFIER_ARGS.items():
            value = self.args.pop(key, default)
            if isinstance(value, str):
                value = get_rendered(value, render_ctx)
            self.modifiers[key] = value

        if self.namespace is not None:
            self.package_name = self.namespace

        compiled_name, fqn_name = self.get_test_name()
        self.compiled_name: str = compiled_name
        self.fqn_name: str = fqn_name
예제 #4
0
 def get_missing_var(self, var_name):
     dct = {k: self.merged[k] for k in self.merged}
     pretty_vars = json.dumps(dct, sort_keys=True, indent=4)
     msg = self.UndefinedVarError.format(
         var_name, self.node_name, pretty_vars
     )
     raise_compiler_error(msg, self.node)
예제 #5
0
 def create_ephemeral_relation(self, target_model: NonSourceNode,
                               name: str) -> RelationProxy:
     # In operations, we can't ref() ephemeral nodes, because ParsedMacros
     # do not support set_cte
     raise_compiler_error(
         'Operations can not ref() ephemeral nodes, but {} is ephemeral'.
         format(target_model.name), self.model)
예제 #6
0
 def __call__(self, *args: str) -> RelationProxy:
     if len(args) != 2:
         raise_compiler_error(
             f"source() takes exactly two arguments ({len(args)} given)",
             self.model)
     self.validate_args(args[0], args[1])
     return self.resolve(args[0], args[1])
예제 #7
0
 def write(self, payload: str) -> str:
     # macros/source defs aren't 'writeable'.
     if isinstance(self.model, (ParsedMacro, ParsedSourceDefinition)):
         raise_compiler_error('cannot "write" macros or sources')
     self.model.build_path = self.model.write_node(self.config.target_path,
                                                   'run', payload)
     return ''
예제 #8
0
 def exception_handler(self) -> Iterator[None]:
     try:
         yield
     except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e:
         raise_compiler_error(str(e), self.macro)
     except CompilationException as e:
         e.stack.append(self.macro)
         raise e
예제 #9
0
    def persist_column_docs(self) -> bool:
        persist_docs = self.get('persist_docs', default={})
        if not isinstance(persist_docs, dict):
            raise_compiler_error(
                f"Invalid value provided for 'persist_docs'. Expected dict "
                f"but received {type(persist_docs)}")

        return persist_docs.get('columns', False)
예제 #10
0
 def try_or_compiler_error(
     self, message_if_exception: str, func: Callable, *args, **kwargs
 ) -> Any:
     try:
         return func(*args, **kwargs)
     except Exception:
         raise_compiler_error(
             message_if_exception, self.model
         )
예제 #11
0
 def cache_added(self, relation: Optional[BaseRelation]) -> str:
     """Cache a new relation in dbt. It will show up in `list relations`."""
     if relation is None:
         name = self.nice_connection_name()
         raise_compiler_error(
             'Attempted to cache a null relation for {}'.format(name))
     if dbt.flags.USE_CACHE:
         self.cache.add(relation)
     # so jinja doesn't render things
     return ''
예제 #12
0
 def create_relation(self, target_model: ManifestNode,
                     name: str) -> RelationProxy:
     if target_model.is_ephemeral_model:
         # In operations, we can't ref() ephemeral nodes, because
         # ParsedMacros do not support set_cte
         raise_compiler_error(
             'Operations can not ref() ephemeral nodes, but {} is ephemeral'
             .format(target_model.name), self.model)
     else:
         return super().create_relation(target_model, name)
예제 #13
0
 def _transform_config(self, config):
     for oldkey in ('pre_hook', 'post_hook'):
         if oldkey in config:
             newkey = oldkey.replace('_', '-')
             if newkey in config:
                 raise_compiler_error(
                     'Invalid config, has conflicting keys "{}" and "{}"'.
                     format(oldkey, newkey), self.model)
             config[newkey] = config.pop(oldkey)
     return config
예제 #14
0
 def get_from_package(self, package_name: Optional[str],
                      name: str) -> Optional[MacroGenerator]:
     pkg: FlatNamespace
     if package_name is None:
         return self.get(name)
     elif package_name == GLOBAL_PROJECT_NAME:
         return self.global_project_namespace.get(name)
     elif package_name in self.packages:
         return self.packages[package_name].get(name)
     else:
         raise_compiler_error(f"Could not find package '{package_name}'")
예제 #15
0
    def packages_for_node(self) -> Iterable[Project]:
        dependencies = self.config.load_dependencies()
        package_name = self.node.package_name

        if package_name != self.config.project_name:
            if package_name not in dependencies:
                # I don't think this is actually reachable
                raise_compiler_error(
                    f'Node package named {package_name} not found!', self.node)
            yield dependencies[package_name]
        yield self.config
예제 #16
0
 def cache_dropped(self, relation: Optional[BaseRelation]) -> str:
     """Drop a relation in dbt. It will no longer show up in
     `list relations`, and any bound views will be dropped from the cache
     """
     if relation is None:
         name = self.nice_connection_name()
         raise_compiler_error(
             'Attempted to drop a null relation for {}'.format(name))
     if dbt.flags.USE_CACHE:
         self.cache.drop(relation)
     return ''
예제 #17
0
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, MaterializationCandidate):
            return NotImplemented
        equal = (self.specificity == other.specificity
                 and self.locality == other.locality)
        if equal:
            raise_compiler_error(
                'Found two materializations with the name {} (packages {} and '
                '{}). dbt cannot resolve this ambiguity'.format(
                    self.macro.name, self.macro.package_name,
                    other.macro.package_name))

        return equal
예제 #18
0
 def tags(self) -> List[str]:
     tags = self.modifiers.get('tags', [])
     if isinstance(tags, str):
         tags = [tags]
     if not isinstance(tags, list):
         raise_compiler_error(
             f'got {tags} ({type(tags)}) for tags, expected a list of '
             f'strings')
     for tag in tags:
         if not isinstance(tag, str):
             raise_compiler_error(
                 f'got {tag} ({type(tag)}) for tag, expected a str')
     return tags[:]
예제 #19
0
 def get_from_package(self, package_name: Optional[str],
                      name: str) -> Optional[MacroGenerator]:
     macro = None
     if package_name is None:
         macro = self.macro_resolver.macros_by_name.get(name)
     elif package_name == GLOBAL_PROJECT_NAME:
         macro = self.macro_resolver.internal_packages_namespace.get(name)
     elif package_name in self.resolver.packages:
         macro = self.macro_resolver.packages[package_name].get(name)
     else:
         raise_compiler_error(f"Could not find package '{package_name}'")
     macro_func = MacroGenerator(macro, self.ctx, self.node,
                                 self.thread_ctx)
     return macro_func
예제 #20
0
 def load_agate_table(self) -> agate.Table:
     if not isinstance(self.model, (ParsedSeedNode, CompiledSeedNode)):
         raise_compiler_error(
             'can only load_agate_table for seeds (got a {})'.format(
                 self.model.resource_type))
     path = os.path.join(self.model.root_path,
                         self.model.original_file_path)
     column_types = self.model.config.column_types
     try:
         table = agate_helper.from_csv(path, text_columns=column_types)
     except ValueError as e:
         raise_compiler_error(str(e))
     table.original_abspath = os.path.abspath(path)
     return table
예제 #21
0
    def extract_test_args(test, name=None) -> Tuple[str, Dict[str, Any]]:
        if not isinstance(test, dict):
            raise_compiler_error(
                'test must be dict or str, got {} (value {})'.format(
                    type(test), test))

        test = list(test.items())
        if len(test) != 1:
            raise_compiler_error(
                'test definition dictionary must have exactly one key, got'
                ' {} instead ({} keys)'.format(test, len(test)))
        test_name, test_args = test[0]

        if not isinstance(test_args, dict):
            raise_compiler_error(
                'test arguments must be dict, got {} (value {})'.format(
                    type(test_args), test_args))
        if not isinstance(test_name, str):
            raise_compiler_error(
                'test name must be a str, got {} (value {})'.format(
                    type(test_name), test_name))
        test_args = deepcopy(test_args)
        if name is not None:
            test_args['column_name'] = name
        return test_name, test_args
예제 #22
0
 def load_dependencies(self) -> Mapping[str, 'RuntimeConfig']:
     if self.dependencies is None:
         all_projects = {self.project_name: self}
         project_paths = itertools.chain(map(Path, PACKAGES.values()),
                                         self._get_project_directories())
         for project_name, project in self.load_projects(project_paths):
             if project_name in all_projects:
                 raise_compiler_error(
                     f'dbt found more than one package with the name '
                     f'"{project_name}" included in this project. Package '
                     f'names must be unique in a project. Please rename '
                     f'one of these packages.')
             all_projects[project_name] = project
         self.dependencies = all_projects
     return self.dependencies
예제 #23
0
    def execute_schema_test(self, test: CompiledSchemaTestNode):
        res, table = self.adapter.execute(
            test.compiled_sql,
            auto_begin=True,
            fetch=True,
        )

        num_rows = len(table.rows)
        if num_rows != 1:
            num_cols = len(table.columns)
            raise_compiler_error(
                f"Bad test {test.test_metadata.name}: "
                f"Returned {num_rows} rows and {num_cols} cols, but expected "
                f"1 row and 1 column")
        return table[0][0]
예제 #24
0
def parse_cli_vars(var_string: str) -> Dict[str, Any]:
    try:
        cli_vars = yaml_helper.load_yaml_text(var_string)
        var_type = type(cli_vars)
        if var_type is dict:
            return cli_vars
        else:
            type_name = var_type.__name__
            raise_compiler_error(
                "The --vars argument must be a YAML dictionary, but was "
                "of type '{}'".format(type_name))
    except ValidationException:
        logger.error(
            "The YAML provided in the --vars argument is not valid.\n"
        )
        raise
예제 #25
0
    def quote_seed_column(self, column: str,
                          quote_config: Optional[bool]) -> str:
        # this is the default for now
        quote_columns: bool = False
        if isinstance(quote_config, bool):
            quote_columns = quote_config
        elif quote_config is None:
            deprecations.warn('column-quoting-unset')
        else:
            raise_compiler_error(
                f'The seed configuration value of "quote_columns" has an '
                f'invalid type {type(quote_config)}')

        if quote_columns:
            return self.quote(column)
        else:
            return column
예제 #26
0
    def __call__(self, *args, **kwargs):
        if len(args) == 1 and len(kwargs) == 0:
            opts = args[0]
        elif len(args) == 0 and len(kwargs) > 0:
            opts = kwargs
        else:
            raise_compiler_error("Invalid inline model config", self.model)

        opts = self._transform_config(opts)

        # it's ok to have a parse context with no context config, but you must
        # not call it!
        if self.context_config is None:
            raise RuntimeException(
                'At parse time, did not receive a context config')
        self.context_config.update_in_model_config(opts)
        return ''
예제 #27
0
    def cache_renamed(
        self,
        from_relation: Optional[BaseRelation],
        to_relation: Optional[BaseRelation],
    ) -> str:
        """Rename a relation in dbt. It will show up with a new name in
        `list_relations`, but bound views will remain bound.
        """
        if from_relation is None or to_relation is None:
            name = self.nice_connection_name()
            src_name = _relation_name(from_relation)
            dst_name = _relation_name(to_relation)
            raise_compiler_error('Attempted to rename {} to {} for {}'.format(
                src_name, dst_name, name))

        if dbt.flags.USE_CACHE:
            self.cache.rename(from_relation, to_relation)
        return ''
예제 #28
0
    def calculate_freshness(
        self,
        source: BaseRelation,
        loaded_at_field: str,
        filter: Optional[str],
        manifest: Optional[Manifest] = None
    ) -> Dict[str, Any]:
        """Calculate the freshness of sources in dbt, and return it"""
        kwargs: Dict[str, Any] = {
            'source': source,
            'loaded_at_field': loaded_at_field,
            'filter': filter,
        }

        # run the macro
        table = self.execute_macro(
            FRESHNESS_MACRO_NAME,
            kwargs=kwargs,
            release=True,
            manifest=manifest
        )
        # now we have a 1-row table of the maximum `loaded_at_field` value and
        # the current time according to the db.
        if len(table) != 1 or len(table[0]) != 2:
            raise_compiler_error(
                'Got an invalid result from "{}" macro: {}'.format(
                    FRESHNESS_MACRO_NAME, [tuple(r) for r in table]
                )
            )
        if table[0][0] is None:
            # no records in the table, so really the max_loaded_at was
            # infinitely long ago. Just call it 0:00 January 1 year UTC
            max_loaded_at = datetime(1, 1, 1, 0, 0, 0, tzinfo=pytz.UTC)
        else:
            max_loaded_at = _utc(table[0][0], source, loaded_at_field)

        snapshotted_at = _utc(table[0][1], source, loaded_at_field)
        age = (snapshotted_at - max_loaded_at).total_seconds()
        return {
            'max_loaded_at': max_loaded_at,
            'snapshotted_at': snapshotted_at,
            'age': age,
        }
예제 #29
0
    def __init__(
        self,
        test: Dict[str, Any],
        target: Testable,
        package_name: str,
        render_ctx: Dict[str, Any],
        column_name: str = None,
    ) -> None:
        test_name, test_args = self.extract_test_args(test, column_name)
        self.args: Dict[str, Any] = test_args
        if 'model' in self.args:
            raise_compiler_error(
                'Test arguments include "model", which is a reserved argument',
            )
        self.package_name: str = package_name
        self.target: Testable = target

        self.args['model'] = self.build_model_str()

        match = self.TEST_NAME_PATTERN.match(test_name)
        if match is None:
            raise_compiler_error(
                'Test name string did not match expected pattern: {}'.format(
                    test_name))

        groups = match.groupdict()
        self.name: str = groups['test_name']
        self.namespace: str = groups['test_namespace']
        self.modifiers: Dict[str, Any] = {}
        for key in self.MODIFIER_ARGS:
            value = self.args.pop(key, None)
            if isinstance(value, str):
                value = get_rendered(value, render_ctx, native=True)
            if value is not None:
                self.modifiers[key] = value

        if self.namespace is not None:
            self.package_name = self.namespace

        compiled_name, fqn_name = self.get_test_name()
        self.compiled_name: str = compiled_name
        self.fqn_name: str = fqn_name
예제 #30
0
    def flatten(self):
        new = self.__class__()

        # make sure we don't have duplicates
        seen = {r.database.lower() for r in self if r.database}
        if len(seen) > 1:
            raise_compiler_error(str(seen))

        for information_schema_name, schema in self.search():
            path = {
                'database': information_schema_name.database,
                'schema': schema
            }
            new.add(information_schema_name.incorporate(
                path=path,
                quote_policy={'database': False},
                include_policy={'database': False},
            ))

        return new