Esempio n. 1
0
File: archive.py Progetto: yilab/dbt
    def compile(self):
        compiler = Compiler(self.project, self.args)
        compiler.initialize()
        compiled = compiler.compile()

        count_compiled_archives = compiled['archives']
        logger.info("Compiled {} archives".format(count_compiled_archives))
Esempio n. 2
0
    def run(self):
        compiler = Compiler(self.project, self.args)
        compiler.initialize()
        results = compiler.compile()

        stat_line = ", ".join(
            ["{} {}".format(results[k], k) for k in CompilableEntities])
        logger.info("Compiled {}".format(stat_line))
Esempio n. 3
0
    def compile(self):
        compiler = Compiler(self.project, BaseCreateTemplate)
        compiler.initialize()

        created_models, created_tests, created_analyses = compiler.compile()
        print("Compiled {} models, {} tests, and {} analyses".format(
            created_models, created_tests, created_analyses))

        return compiler
Esempio n. 4
0
    def compile(self):
        compiler = Compiler(self.project, BaseCreateTemplate, self.args)
        compiler.initialize()
        results = compiler.compile(limit_to=['tests'])

        stat_line = ", ".join(
            ["{} {}".format(results[k], k) for k in CompilableEntities])
        logger.info("Compiled {}".format(stat_line))

        return compiler
Esempio n. 5
0
    def compile(self):
        compiler = Compiler(self.project, BaseCreateTemplate)
        compiler.initialize()
        results = compiler.compile()

        stat_line = ", ".join(
            ["{} {}".format(results[k], k) for k in CompilableEntities])
        print("Compiled {}".format(stat_line))

        return compiler
Esempio n. 6
0
File: run.py Progetto: menetrier/dbt
    def compile(self):
        create_template = DryCreateTemplate if self.args.dry else BaseCreateTemplate
        compiler = Compiler(self.project, create_template)
        compiler.initialize()
        created_models, created_tests, created_analyses = compiler.compile(
            self.args.dry)
        print("Compiled {} models, {} tests, and {} analyses".format(
            created_models, created_tests, created_analyses))

        return create_template.label
Esempio n. 7
0
    def compile(self):
        create_template = DryCreateTemplate if self.args.dry else BaseCreateTemplate
        compiler = Compiler(self.project, create_template)
        compiler.initialize()
        results = compiler.compile(self.args.dry)

        stat_line = ", ".join(
            ["{} {}".format(results[k], k) for k in CompilableEntities])
        print("Compiled {}".format(stat_line))

        return create_template.label
Esempio n. 8
0
    def compile(self):
        create_template = DryCreateTemplate if self.args.dry \
                          else BaseCreateTemplate
        compiler = Compiler(self.project, create_template, self.args)
        compiler.initialize()
        results = compiler.compile(limit_to=['models'])

        stat_line = ", ".join(
            ["{} {}".format(results[k], k) for k in CompilableEntities])
        logger.info("Compiled {}".format(stat_line))

        return create_template.label
Esempio n. 9
0
    def run(self):
        if self.args.dry:
            create_template = DryCreateTemplate
        else:
            create_template = BaseCreateTemplate

        compiler = Compiler(self.project, create_template)
        compiler.initialize()
        created_models, created_tests, created_analyses = compiler.compile(
            dry=self.args.dry)

        print("Compiled {} models, {} tests and {} analyses".format(
            created_models, created_tests, created_analyses))
Esempio n. 10
0
    def run(self):
        if self.args.dry:
            create_template = DryCreateTemplate
        else:
            create_template = BaseCreateTemplate

        compiler = Compiler(self.project, create_template, self.args)
        compiler.initialize()
        results = compiler.compile(limit_to=CompilableEntities)

        stat_line = ", ".join(
            ["{} {}".format(results[k], k) for k in CompilableEntities])
        logger.info("Compiled {}".format(stat_line))
Esempio n. 11
0
 def get_compiler(self):
     from dbt.compilation import Compiler
     return Compiler(self.config)
Esempio n. 12
0
def new_base_compiler():
    project = new_project()
    create_template = BaseCreateTemplate
    compiler = Compiler(project, create_template)
    return compiler
Esempio n. 13
0
    def dbt_compiler(self):
        """Loads the dbt compiler."""
        from dbt.compilation import Compiler as DbtCompiler

        self.dbt_compiler = DbtCompiler(self.dbt_config)
        return self.dbt_compiler
Esempio n. 14
0
 def compile(self):
     compiler = Compiler(self.project, self.create_template)
     compiler.initialize()
     compiled = compiler.compile_archives()
     print("Compiled {} archives".format(len(compiled)))
Esempio n. 15
0
class DbtTemplater(JinjaTemplater):
    """A templater using dbt."""

    name = "dbt"
    sequential_fail_limit = 3

    def __init__(self, **kwargs):
        self.sqlfluff_config = None
        self.formatter = None
        self.project_dir = None
        self.profiles_dir = None
        self.working_dir = os.getcwd()
        self._sequential_fails = 0
        self.connection_acquired = False
        super().__init__(**kwargs)

    def config_pairs(self):  # pragma: no cover TODO?
        """Returns info about the given templater for output by the cli."""
        return [("templater", self.name), ("dbt", self.dbt_version)]

    @property
    def dbt_version(self):
        """Gets the dbt version."""
        return DBT_VERSION_STRING

    @property
    def dbt_version_tuple(self):
        """Gets the dbt version as a tuple on (major, minor)."""
        return DBT_VERSION_TUPLE

    @cached_property
    def dbt_config(self):
        """Loads the dbt config."""
        if self.dbt_version_tuple >= (1, 0):
            # Here, we read flags.PROFILE_DIR directly, prior to calling
            # set_from_args(). Apparently, set_from_args() sets PROFILES_DIR
            # to a lowercase version of the value, and the profile wouldn't be
            # found if the directory name contained uppercase letters. This fix
            # was suggested and described here:
            # https://github.com/sqlfluff/sqlfluff/issues/2253#issuecomment-1018722979
            user_config = read_user_config(flags.PROFILES_DIR)
            flags.set_from_args(
                DbtConfigArgs(
                    project_dir=self.project_dir,
                    profiles_dir=self.profiles_dir,
                    profile=self._get_profile(),
                    vars=self._get_cli_vars(),
                ),
                user_config,
            )
        self.dbt_config = DbtRuntimeConfig.from_args(
            DbtConfigArgs(
                project_dir=self.project_dir,
                profiles_dir=self.profiles_dir,
                profile=self._get_profile(),
                target=self._get_target(),
                vars=self._get_cli_vars(),
            )
        )
        register_adapter(self.dbt_config)
        return self.dbt_config

    @cached_property
    def dbt_compiler(self):
        """Loads the dbt compiler."""
        self.dbt_compiler = DbtCompiler(self.dbt_config)
        return self.dbt_compiler

    @cached_property
    def dbt_manifest(self):
        """Loads the dbt manifest."""
        # Identity function used for macro hooks
        def identity(x):
            return x

        # Set dbt not to run tracking. We don't load
        # a full project and so some tracking routines
        # may fail.
        from dbt.tracking import do_not_track

        do_not_track()

        # dbt 0.20.* and onward
        from dbt.parser.manifest import ManifestLoader

        projects = self.dbt_config.load_dependencies()
        loader = ManifestLoader(self.dbt_config, projects, macro_hook=identity)
        self.dbt_manifest = loader.load()

        return self.dbt_manifest

    @cached_property
    def dbt_selector_method(self):
        """Loads the dbt selector method."""
        if self.formatter:  # pragma: no cover TODO?
            self.formatter.dispatch_compilation_header(
                "dbt templater", "Compiling dbt project..."
            )

        from dbt.graph.selector_methods import (
            MethodManager as DbtSelectorMethodManager,
            MethodName as DbtMethodName,
        )

        selector_methods_manager = DbtSelectorMethodManager(
            self.dbt_manifest, previous_state=None
        )
        self.dbt_selector_method = selector_methods_manager.get_method(
            DbtMethodName.Path, method_arguments=[]
        )

        if self.formatter:  # pragma: no cover TODO?
            self.formatter.dispatch_compilation_header(
                "dbt templater", "Project Compiled."
            )

        return self.dbt_selector_method

    def _get_profiles_dir(self):
        """Get the dbt profiles directory from the configuration.

        The default is `~/.dbt` in 0.17 but we use the
        PROFILES_DIR variable from the dbt library to
        support a change of default in the future, as well
        as to support the same overwriting mechanism as
        dbt (currently an environment variable).
        """
        dbt_profiles_dir = os.path.abspath(
            os.path.expanduser(
                self.sqlfluff_config.get_section(
                    (self.templater_selector, self.name, "profiles_dir")
                )
                or PROFILES_DIR
            )
        )

        if not os.path.exists(dbt_profiles_dir):
            templater_logger.error(
                f"dbt_profiles_dir: {dbt_profiles_dir} could not be accessed. "
                "Check it exists."
            )

        return dbt_profiles_dir

    def _get_project_dir(self):
        """Get the dbt project directory from the configuration.

        Defaults to the working directory.
        """
        dbt_project_dir = os.path.abspath(
            os.path.expanduser(
                self.sqlfluff_config.get_section(
                    (self.templater_selector, self.name, "project_dir")
                )
                or os.getcwd()
            )
        )
        if not os.path.exists(dbt_project_dir):
            templater_logger.error(
                f"dbt_project_dir: {dbt_project_dir} could not be accessed. "
                "Check it exists."
            )

        return dbt_project_dir

    def _get_profile(self):
        """Get a dbt profile name from the configuration."""
        return self.sqlfluff_config.get_section(
            (self.templater_selector, self.name, "profile")
        )

    def _get_target(self):
        """Get a dbt target name from the configuration."""
        return self.sqlfluff_config.get_section(
            (self.templater_selector, self.name, "target")
        )

    def _get_cli_vars(self) -> str:
        cli_vars = self.sqlfluff_config.get_section(
            (self.templater_selector, self.name, "context")
        )

        return str(cli_vars) if cli_vars else "{}"

    def sequence_files(
        self, fnames: List[str], config=None, formatter=None
    ) -> Iterator[str]:
        """Reorder fnames to process dependent files first.

        This avoids errors when an ephemeral model is processed before use.
        """
        if formatter:  # pragma: no cover
            formatter.dispatch_compilation_header("dbt templater", "Sorting Nodes...")

        # Initialise config if not already done
        self.sqlfluff_config = config
        if not self.project_dir:
            self.project_dir = self._get_project_dir()
        if not self.profiles_dir:
            self.profiles_dir = self._get_profiles_dir()

        # Populate full paths for selected files
        full_paths: Dict[str, str] = {}
        selected_files = set()
        for fname in fnames:
            fpath = os.path.join(self.working_dir, fname)
            full_paths[fpath] = fname
            selected_files.add(fpath)

        ephemeral_nodes: Dict[str, Tuple[str, Any]] = {}

        # Extract the ephemeral models
        for key, node in self.dbt_manifest.nodes.items():
            if node.config.materialized == "ephemeral":
                # The key is the full filepath.
                # The value tuple, with the filepath and a list of dependent keys
                ephemeral_nodes[key] = (
                    os.path.join(self.project_dir, node.original_file_path),
                    node.depends_on.nodes,
                )

        # Yield ephemeral nodes first. We use a deque for efficient re-queuing.
        # We iterate through the deque, yielding any nodes without dependents,
        # or where those dependents have already yielded, first. The original
        # mapping is still used to hold the metadata on each key.
        already_yielded = set()
        ephemeral_buffer: Deque[str] = deque(ephemeral_nodes.keys())
        while ephemeral_buffer:
            key = ephemeral_buffer.popleft()
            fpath, dependents = ephemeral_nodes[key]

            # If it's not in our selection, skip it
            if fpath not in selected_files:
                templater_logger.debug("- Purging unselected ephemeral: %r", fpath)
            # If there are dependent nodes in the set, don't process it yet.
            elif any(
                dependent in ephemeral_buffer for dependent in dependents
            ):  # pragma: no cover
                templater_logger.debug(
                    "- Requeuing ephemeral with dependents: %r", fpath
                )
                # Requeue it for later
                ephemeral_buffer.append(key)
            # Otherwise yield it.
            else:
                templater_logger.debug("- Yielding Ephemeral: %r", fpath)
                yield full_paths[fpath]
                already_yielded.add(full_paths[fpath])

        for fname in fnames:
            if fname not in already_yielded:
                yield fname
                # Dedupe here so we don't yield twice
                already_yielded.add(fname)
            else:
                templater_logger.debug(
                    "- Skipping yield of previously sequenced file: %r", fname
                )

    @large_file_check
    def process(self, *, fname, in_str=None, config=None, formatter=None):
        """Compile a dbt model and return the compiled SQL.

        Args:
            fname (:obj:`str`): Path to dbt model(s)
            in_str (:obj:`str`, optional): This is ignored for dbt
            config (:obj:`FluffConfig`, optional): A specific config to use for this
                templating operation. Only necessary for some templaters.
            formatter (:obj:`CallbackFormatter`): Optional object for output.
        """
        # Stash the formatter if provided to use in cached methods.
        self.formatter = formatter
        self.sqlfluff_config = config
        self.project_dir = self._get_project_dir()
        self.profiles_dir = self._get_profiles_dir()
        fname_absolute_path = os.path.abspath(fname)

        try:
            os.chdir(self.project_dir)
            processed_result = self._unsafe_process(fname_absolute_path, in_str, config)
            # Reset the fail counter
            self._sequential_fails = 0
            return processed_result
        except DbtCompilationException as e:
            # Increment the counter
            self._sequential_fails += 1
            if e.node:
                return None, [
                    SQLTemplaterError(
                        f"dbt compilation error on file '{e.node.original_file_path}', "
                        f"{e.msg}",
                        # It's fatal if we're over the limit
                        fatal=self._sequential_fails > self.sequential_fail_limit,
                    )
                ]
            else:
                raise  # pragma: no cover
        except DbtFailedToConnectException as e:
            return None, [
                SQLTemplaterError(
                    "dbt tried to connect to the database and failed: you could use "
                    "'execute' to skip the database calls. See"
                    "https://docs.getdbt.com/reference/dbt-jinja-functions/execute/ "
                    f"Error: {e.msg}",
                    fatal=True,
                )
            ]
        # If a SQLFluff error is raised, just pass it through
        except SQLTemplaterError as e:  # pragma: no cover
            return None, [e]
        finally:
            os.chdir(self.working_dir)

    def _find_node(self, fname, config=None):
        if not config:  # pragma: no cover
            raise ValueError(
                "For the dbt templater, the `process()` method "
                "requires a config object."
            )
        if not fname:  # pragma: no cover
            raise ValueError(
                "For the dbt templater, the `process()` method requires a file name"
            )
        elif fname == "stdin":  # pragma: no cover
            raise ValueError(
                "The dbt templater does not support stdin input, provide a path instead"
            )
        selected = self.dbt_selector_method.search(
            included_nodes=self.dbt_manifest.nodes,
            # Selector needs to be a relative path
            selector=os.path.relpath(fname, start=os.getcwd()),
        )
        results = [self.dbt_manifest.expect(uid) for uid in selected]

        if not results:
            skip_reason = self._find_skip_reason(fname)
            if skip_reason:
                raise SQLFluffSkipFile(
                    f"Skipped file {fname} because it is {skip_reason}"
                )
            raise SQLFluffSkipFile(
                "File %s was not found in dbt project" % fname
            )  # pragma: no cover
        return results[0]

    def _find_skip_reason(self, fname) -> Optional[str]:
        """Return string reason if model okay to skip, otherwise None."""
        # Scan macros.
        abspath = os.path.abspath(fname)
        for macro in self.dbt_manifest.macros.values():
            if os.path.abspath(macro.original_file_path) == abspath:
                return "a macro"

        if DBT_VERSION_TUPLE >= (1, 0):
            # Scan disabled nodes.
            for nodes in self.dbt_manifest.disabled.values():
                for node in nodes:
                    if os.path.abspath(node.original_file_path) == abspath:
                        return "disabled"
        else:
            model_name = os.path.splitext(os.path.basename(fname))[0]
            if self.dbt_manifest.find_disabled_by_name(name=model_name):
                return "disabled"
        return None

    def _unsafe_process(self, fname, in_str=None, config=None):
        original_file_path = os.path.relpath(fname, start=os.getcwd())

        # Below, we monkeypatch Environment.from_string() to intercept when dbt
        # compiles (i.e. runs Jinja) to expand the "node" corresponding to fname.
        # We do this to capture the Jinja context at the time of compilation, i.e.:
        # - Jinja Environment object
        # - Jinja "globals" dictionary
        #
        # This info is captured by the "make_template()" function, which in
        # turn is used by our parent class' (JinjaTemplater) slice_file()
        # function.
        old_from_string = Environment.from_string
        make_template = None

        def from_string(*args, **kwargs):
            """Replaces (via monkeypatch) the jinja2.Environment function."""
            nonlocal make_template
            # Is it processing the node corresponding to fname?
            globals = kwargs.get("globals")
            if globals:
                model = globals.get("model")
                if model:
                    if model.get("original_file_path") == original_file_path:
                        # Yes. Capture the important arguments and create
                        # a make_template() function.
                        env = args[0]
                        globals = args[2] if len(args) >= 3 else kwargs["globals"]

                        def make_template(in_str):
                            env.add_extension(SnapshotExtension)
                            return env.from_string(in_str, globals=globals)

            return old_from_string(*args, **kwargs)

        node = self._find_node(fname, config)
        templater_logger.debug(
            "_find_node for path %r returned object of type %s.", fname, type(node)
        )

        save_ephemeral_nodes = dict(
            (k, v)
            for k, v in self.dbt_manifest.nodes.items()
            if v.config.materialized == "ephemeral"
            and not getattr(v, "compiled", False)
        )
        with self.connection():
            # Apply the monkeypatch.
            Environment.from_string = from_string
            try:
                node = self.dbt_compiler.compile_node(
                    node=node,
                    manifest=self.dbt_manifest,
                )
            except Exception as err:
                templater_logger.exception(
                    "Fatal dbt compilation error on %s. This occurs most often "
                    "during incorrect sorting of ephemeral models before linting. "
                    "Please report this error on github at "
                    "https://github.com/sqlfluff/sqlfluff/issues, including "
                    "both the raw and compiled sql for the model affected.",
                    fname,
                )
                # Additional error logging in case we get a fatal dbt error.
                raise SQLFluffSkipFile(  # pragma: no cover
                    f"Skipped file {fname} because dbt raised a fatal "
                    f"exception during compilation: {err!s}"
                ) from err
            finally:
                # Undo the monkeypatch.
                Environment.from_string = old_from_string

            if hasattr(node, "injected_sql"):
                # If injected SQL is present, it contains a better picture
                # of what will actually hit the database (e.g. with tests).
                # However it's not always present.
                compiled_sql = node.injected_sql
            else:
                compiled_sql = getattr(node, COMPILED_SQL_ATTRIBUTE)

            raw_sql = getattr(node, RAW_SQL_ATTRIBUTE)

            if not compiled_sql:  # pragma: no cover
                raise SQLTemplaterError(
                    "dbt templater compilation failed silently, check your "
                    "configuration by running `dbt compile` directly."
                )

            with open(fname) as source_dbt_model:
                source_dbt_sql = source_dbt_model.read()

            if not source_dbt_sql.rstrip().endswith("-%}"):
                n_trailing_newlines = len(source_dbt_sql) - len(
                    source_dbt_sql.rstrip("\n")
                )
            else:
                # Source file ends with right whitespace stripping, so there's
                # no need to preserve/restore trailing newlines, as they would
                # have been removed regardless of dbt's
                # keep_trailing_newlines=False behavior.
                n_trailing_newlines = 0

            templater_logger.debug(
                "    Trailing newline count in source dbt model: %r",
                n_trailing_newlines,
            )
            templater_logger.debug("    Raw SQL before compile: %r", source_dbt_sql)
            templater_logger.debug("    Node raw SQL: %r", raw_sql)
            templater_logger.debug("    Node compiled SQL: %r", compiled_sql)

            # When using dbt-templater, trailing newlines are ALWAYS REMOVED during
            # compiling. Unless fixed (like below), this will cause:
            #    1. Assertion errors in TemplatedFile, when it sanity checks the
            #       contents of the sliced_file array.
            #    2. L009 linting errors when running "sqlfluff lint foo_bar.sql"
            #       since the linter will use the compiled code with the newlines
            #       removed.
            #    3. "No newline at end of file" warnings in Git/GitHub since
            #       sqlfluff uses the compiled SQL to write fixes back to the
            #       source SQL in the dbt model.
            #
            # The solution is (note that both the raw and compiled files have
            # had trailing newline(s) removed by the dbt-templater.
            #    1. Check for trailing newlines before compiling by looking at the
            #       raw SQL in the source dbt file. Remember the count of trailing
            #       newlines.
            #    2. Set node.raw_sql/node.raw_code to the original source file contents.
            #    3. Append the count from #1 above to compiled_sql. (In
            #       production, slice_file() does not usually use this string,
            #       but some test scenarios do.
            setattr(node, RAW_SQL_ATTRIBUTE, source_dbt_sql)
            compiled_sql = compiled_sql + "\n" * n_trailing_newlines

            # TRICKY: dbt configures Jinja2 with keep_trailing_newline=False.
            # As documented (https://jinja.palletsprojects.com/en/3.0.x/api/),
            # this flag's behavior is: "Preserve the trailing newline when
            # rendering templates. The default is False, which causes a single
            # newline, if present, to be stripped from the end of the template."
            #
            # Below, we use "append_to_templated" to effectively "undo" this.
            raw_sliced, sliced_file, templated_sql = self.slice_file(
                source_dbt_sql,
                compiled_sql,
                config=config,
                make_template=make_template,
                append_to_templated="\n" if n_trailing_newlines else "",
            )
        # :HACK: If calling compile_node() compiled any ephemeral nodes,
        # restore them to their earlier state. This prevents a runtime error
        # in the dbt "_inject_ctes_into_sql()" function that occurs with
        # 2nd-level ephemeral model dependencies (e.g. A -> B -> C, where
        # both B and C are ephemeral). Perhaps there is a better way to do
        # this, but this seems good enough for now.
        for k, v in save_ephemeral_nodes.items():
            if getattr(self.dbt_manifest.nodes[k], "compiled", False):
                self.dbt_manifest.nodes[k] = v
        return (
            TemplatedFile(
                source_str=source_dbt_sql,
                templated_str=templated_sql,
                fname=fname,
                sliced_file=sliced_file,
                raw_sliced=raw_sliced,
            ),
            # No violations returned in this way.
            [],
        )

    @contextmanager
    def connection(self):
        """Context manager that manages a dbt connection, if needed."""
        # We have to register the connection in dbt >= 1.0.0 ourselves
        # In previous versions, we relied on the functionality removed in
        # https://github.com/dbt-labs/dbt-core/pull/4062.
        if DBT_VERSION_TUPLE >= (1, 0):
            if not self.connection_acquired:
                adapter = get_adapter(self.dbt_config)
                adapter.acquire_connection("master")
                adapter.set_relations_cache(self.dbt_manifest)
                self.connection_acquired = True
            yield
            # :TRICKY: Once connected, we never disconnect. Making multiple
            # connections during linting has proven to cause major performance
            # issues.
        else:
            yield
Esempio n. 16
0
 def dbt_compiler(self):
     """Loads the dbt compiler."""
     self.dbt_compiler = DbtCompiler(self.dbt_config)
     return self.dbt_compiler
Esempio n. 17
0
 def compile(self):
     compiler = Compiler(self.project, self.create_template, self.args)
     compiler.initialize()
     compiled = compiler.compile_archives()
     logger.info("Compiled {} archives".format(len(compiled)))
Esempio n. 18
0
class DbtTemplater(JinjaTemplater):
    """A templater using dbt."""

    name = "dbt"

    def __init__(self, **kwargs):
        self.sqlfluff_config = None
        super().__init__(**kwargs)

    @cached_property
    def dbt_version(self):
        """Gets the dbt version."""
        from dbt.version import get_installed_version

        self.dbt_version = get_installed_version().to_version_string()
        return self.dbt_version

    @cached_property
    def dbt_config(self):
        """Loads the dbt config."""
        from dbt.config.runtime import RuntimeConfig as DbtRuntimeConfig
        from dbt.adapters.factory import register_adapter

        self.dbt_config = DbtRuntimeConfig.from_args(
            DbtConfigArgs(
                project_dir=self._get_project_dir(),
                profiles_dir=self._get_profiles_dir(),
                profile=self._get_profile(),
            ))
        register_adapter(self.dbt_config)
        return self.dbt_config

    @cached_property
    def dbt_compiler(self):
        """Loads the dbt compiler."""
        from dbt.compilation import Compiler as DbtCompiler

        self.dbt_compiler = DbtCompiler(self.dbt_config)
        return self.dbt_compiler

    @cached_property
    def dbt_manifest(self):
        """Loads the dbt manifest."""

        # Identity function used for macro hooks
        def identity(x):
            return x

        # Set dbt not to run tracking. We don't load
        # a dull project and so some tracking routines
        # may fail.
        from dbt.tracking import do_not_track

        do_not_track()

        if "0.17" in self.dbt_version:
            from dbt.parser.manifest import (
                load_internal_manifest as load_macro_manifest,
                load_manifest,
            )
        else:
            from dbt.parser.manifest import (
                load_macro_manifest,
                load_manifest,
            )

            load_macro_manifest = partial(load_macro_manifest,
                                          macro_hook=identity)

        dbt_macros_manifest = load_macro_manifest(self.dbt_config)
        self.dbt_manifest = load_manifest(self.dbt_config,
                                          dbt_macros_manifest,
                                          macro_hook=identity)
        return self.dbt_manifest

    @cached_property
    def dbt_selector_method(self):
        """Loads the dbt selector method."""
        if "0.17" in self.dbt_version:
            from dbt.graph.selector import PathSelector

            self.dbt_selector_method = PathSelector(self.dbt_manifest)
        else:
            from dbt.graph.selector_methods import (
                MethodManager as DbtSelectorMethodManager,
                MethodName as DbtMethodName,
            )

            selector_methods_manager = DbtSelectorMethodManager(
                self.dbt_manifest, previous_state=None)
            self.dbt_selector_method = selector_methods_manager.get_method(
                DbtMethodName.Path, method_arguments=[])

        return self.dbt_selector_method

    def _get_profiles_dir(self):
        """Get the dbt profiles directory from the configuration.

        The default is `~/.dbt` in 0.17 but we use the
        PROFILES_DIR variable from the dbt library to
        support a change of default in the future, as well
        as to support the same overwriting mechanism as
        dbt (currently an environment variable).
        """
        from dbt.config.profile import PROFILES_DIR

        return os.path.expanduser(
            self.sqlfluff_config.get_section(
                (self.templater_selector, self.name, "profiles_dir"))
            or PROFILES_DIR)

    def _get_project_dir(self):
        """Get the dbt project directory from the configuration.

        Defaults to the working directory.
        """
        return os.path.expanduser(
            self.sqlfluff_config.get_section(
                (self.templater_selector, self.name, "project_dir"))
            or os.getcwd())

    def _get_profile(self):
        """Get a dbt profile name from the configuration."""
        return self.sqlfluff_config.get_section(
            (self.templater_selector, self.name, "profile"))

    @staticmethod
    def _check_dbt_installed():
        try:
            import dbt  # noqa: F401
        except ModuleNotFoundError as e:
            raise ModuleNotFoundError(
                "Module dbt was not found while trying to use dbt templating, "
                "please install dbt dependencies through `pip install sqlfluff[dbt]`"
            ) from e

    def process(self, *, fname, in_str=None, config=None):
        """Compile a dbt model and return the compiled SQL.

        Args:
            fname (:obj:`str`): Path to dbt model(s)
            in_str (:obj:`str`, optional): This is ignored for dbt
            config (:obj:`FluffConfig`, optional): A specific config to use for this
                templating operation. Only necessary for some templaters.
        """
        self._check_dbt_installed()
        from dbt.exceptions import (
            CompilationException as DbtCompilationException,
            FailedToConnectException as DbtFailedToConnectException,
        )

        try:
            return self._unsafe_process(fname, in_str, config)
        except DbtCompilationException as e:
            return None, [
                SQLTemplaterError(
                    f"dbt compilation error on file '{e.node.original_file_path}', {e.msg}"
                )
            ]
        except DbtFailedToConnectException as e:
            return None, [
                SQLTemplaterError(
                    "dbt tried to connect to the database and failed: "
                    "you could use 'execute' https://docs.getdbt.com/reference/dbt-jinja-functions/execute/ "
                    f"to skip the database calls. Error: {e.msg}")
            ]
        # If a SQLFluff error is raised, just pass it through
        except SQLTemplaterError as e:
            return None, [e]

    def _unsafe_process(self, fname, in_str=None, config=None):
        if not config:
            raise ValueError(
                "For the dbt templater, the `process()` method requires a config object."
            )
        if not fname:
            raise ValueError(
                "For the dbt templater, the `process()` method requires a file name"
            )
        elif fname == "stdin":
            raise ValueError(
                "The dbt templater does not support stdin input, provide a path instead"
            )
        self.sqlfluff_config = config

        selected = self.dbt_selector_method.search(
            included_nodes=self.dbt_manifest.nodes,
            # Selector needs to be a relative path
            selector=os.path.relpath(fname, start=os.getcwd()),
        )
        results = [self.dbt_manifest.expect(uid) for uid in selected]

        if not results:
            raise RuntimeError("File %s was not found in dbt project" % fname)

        node = self.dbt_compiler.compile_node(
            node=results[0],
            manifest=self.dbt_manifest,
        )

        if hasattr(node, "injected_sql"):
            # If injected SQL is present, it contains a better picture
            # of what will actually hit the database (e.g. with tests).
            # However it's not always present.
            compiled_sql = node.injected_sql
        else:
            compiled_sql = node.compiled_sql

        if not compiled_sql:
            raise SQLTemplaterError(
                "dbt templater compilation failed silently, check your configuration "
                "by running `dbt compile` directly.")

        raw_sliced, sliced_file, templated_sql = self.slice_file(node.raw_sql,
                                                                 compiled_sql,
                                                                 config=config)
        return (
            TemplatedFile(
                source_str=node.raw_sql,
                templated_str=templated_sql,
                fname=fname,
                sliced_file=sliced_file,
                raw_sliced=raw_sliced,
            ),
            # No violations returned in this way.
            [],
        )
Esempio n. 19
0
class DbtTemplater(JinjaTemplater):
    """A templater using dbt."""

    name = "dbt"
    sequential_fail_limit = 3

    def __init__(self, **kwargs):
        self.sqlfluff_config = None
        self.formatter = None
        self.project_dir = None
        self.profiles_dir = None
        self.working_dir = os.getcwd()
        self._sequential_fails = 0
        super().__init__(**kwargs)

    def config_pairs(self):
        """Returns info about the given templater for output by the cli."""
        return [("templater", self.name), ("dbt", self.dbt_version)]

    @cached_property
    def dbt_version(self):
        """Gets the dbt version."""
        from dbt.version import get_installed_version

        self.dbt_version = get_installed_version().to_version_string()
        return self.dbt_version

    @cached_property
    def dbt_config(self):
        """Loads the dbt config."""
        from dbt.config.runtime import RuntimeConfig as DbtRuntimeConfig
        from dbt.adapters.factory import register_adapter

        self.dbt_config = DbtRuntimeConfig.from_args(
            DbtConfigArgs(
                project_dir=self.project_dir,
                profiles_dir=self.profiles_dir,
                profile=self._get_profile(),
            )
        )
        register_adapter(self.dbt_config)
        return self.dbt_config

    @cached_property
    def dbt_compiler(self):
        """Loads the dbt compiler."""
        from dbt.compilation import Compiler as DbtCompiler

        self.dbt_compiler = DbtCompiler(self.dbt_config)
        return self.dbt_compiler

    @cached_property
    def dbt_manifest(self):
        """Loads the dbt manifest."""
        # Identity function used for macro hooks
        def identity(x):
            return x

        # Set dbt not to run tracking. We don't load
        # a dull project and so some tracking routines
        # may fail.
        from dbt.tracking import do_not_track

        do_not_track()

        if "0.17" in self.dbt_version:
            from dbt.parser.manifest import (
                load_internal_manifest as load_macro_manifest,
                load_manifest,
            )
        else:
            from dbt.parser.manifest import (
                load_macro_manifest,
                load_manifest,
            )

            load_macro_manifest = partial(load_macro_manifest, macro_hook=identity)

        dbt_macros_manifest = load_macro_manifest(self.dbt_config)
        self.dbt_manifest = load_manifest(
            self.dbt_config, dbt_macros_manifest, macro_hook=identity
        )
        return self.dbt_manifest

    @cached_property
    def dbt_selector_method(self):
        """Loads the dbt selector method."""
        if self.formatter:
            self.formatter.dispatch_compilation_header(
                "dbt templater", "Compiling dbt project..."
            )

        if "0.17" in self.dbt_version:
            from dbt.graph.selector import PathSelector

            self.dbt_selector_method = PathSelector(self.dbt_manifest)
        else:
            from dbt.graph.selector_methods import (
                MethodManager as DbtSelectorMethodManager,
                MethodName as DbtMethodName,
            )

            selector_methods_manager = DbtSelectorMethodManager(
                self.dbt_manifest, previous_state=None
            )
            self.dbt_selector_method = selector_methods_manager.get_method(
                DbtMethodName.Path, method_arguments=[]
            )

        if self.formatter:
            self.formatter.dispatch_compilation_header(
                "dbt templater", "Project Compiled."
            )

        return self.dbt_selector_method

    def _get_profiles_dir(self):
        """Get the dbt profiles directory from the configuration.

        The default is `~/.dbt` in 0.17 but we use the
        PROFILES_DIR variable from the dbt library to
        support a change of default in the future, as well
        as to support the same overwriting mechanism as
        dbt (currently an environment variable).
        """
        from dbt.config.profile import PROFILES_DIR

        dbt_profiles_dir = os.path.abspath(
            os.path.expanduser(
                self.sqlfluff_config.get_section(
                    (self.templater_selector, self.name, "profiles_dir")
                )
                or PROFILES_DIR
            )
        )

        if not os.path.exists(dbt_profiles_dir):
            templater_logger.error(
                f"dbt_profiles_dir: {dbt_profiles_dir} could not be accessed. Check it exists."
            )

        return dbt_profiles_dir

    def _get_project_dir(self):
        """Get the dbt project directory from the configuration.

        Defaults to the working directory.
        """
        dbt_project_dir = os.path.abspath(
            os.path.expanduser(
                self.sqlfluff_config.get_section(
                    (self.templater_selector, self.name, "project_dir")
                )
                or os.getcwd()
            )
        )
        if not os.path.exists(dbt_project_dir):
            templater_logger.error(
                f"dbt_project_dir: {dbt_project_dir} could not be accessed. Check it exists."
            )

        return dbt_project_dir

    def _get_profile(self):
        """Get a dbt profile name from the configuration."""
        return self.sqlfluff_config.get_section(
            (self.templater_selector, self.name, "profile")
        )

    @staticmethod
    def _check_dbt_installed():
        try:
            import dbt  # noqa: F401
        except ModuleNotFoundError as e:
            raise ModuleNotFoundError(
                "Module dbt was not found while trying to use dbt templating, "
                "please install dbt dependencies through `pip install sqlfluff[dbt]`"
            ) from e

    def process(self, *, fname, in_str=None, config=None, formatter=None):
        """Compile a dbt model and return the compiled SQL.

        Args:
            fname (:obj:`str`): Path to dbt model(s)
            in_str (:obj:`str`, optional): This is ignored for dbt
            config (:obj:`FluffConfig`, optional): A specific config to use for this
                templating operation. Only necessary for some templaters.
            formatter (:obj:`CallbackFormatter`): Optional object for output.
        """
        # Stash the formatter if provided to use in cached methods.
        self.formatter = formatter

        self._check_dbt_installed()
        from dbt.exceptions import (
            CompilationException as DbtCompilationException,
            FailedToConnectException as DbtFailedToConnectException,
        )

        self.sqlfluff_config = config
        self.project_dir = self._get_project_dir()
        self.profiles_dir = self._get_profiles_dir()
        fname_absolute_path = os.path.abspath(fname)

        try:
            os.chdir(self.project_dir)
            processed_result = self._unsafe_process(fname_absolute_path, in_str, config)
            # Reset the fail counter
            self._sequential_fails = 0
            return processed_result
        except DbtCompilationException as e:
            # Increment the counter
            self._sequential_fails += 1
            return None, [
                SQLTemplaterError(
                    f"dbt compilation error on file '{e.node.original_file_path}', {e.msg}",
                    # It's fatal if we're over the limit
                    fatal=self._sequential_fails > self.sequential_fail_limit,
                )
            ]
        except DbtFailedToConnectException as e:
            return None, [
                SQLTemplaterError(
                    "dbt tried to connect to the database and failed: "
                    "you could use 'execute' https://docs.getdbt.com/reference/dbt-jinja-functions/execute/ "
                    f"to skip the database calls. Error: {e.msg}",
                    fatal=True,
                )
            ]
        # If a SQLFluff error is raised, just pass it through
        except SQLTemplaterError as e:
            return None, [e]
        finally:
            os.chdir(self.working_dir)

    def _unsafe_process(self, fname, in_str=None, config=None):
        if not config:
            raise ValueError(
                "For the dbt templater, the `process()` method requires a config object."
            )
        if not fname:
            raise ValueError(
                "For the dbt templater, the `process()` method requires a file name"
            )
        elif fname == "stdin":
            raise ValueError(
                "The dbt templater does not support stdin input, provide a path instead"
            )
        selected = self.dbt_selector_method.search(
            included_nodes=self.dbt_manifest.nodes,
            # Selector needs to be a relative path
            selector=os.path.relpath(fname, start=os.getcwd()),
        )
        results = [self.dbt_manifest.expect(uid) for uid in selected]

        if not results:
            model_name = os.path.splitext(os.path.basename(fname))[0]
            disabled_model = self.dbt_manifest.find_disabled_by_name(name=model_name)
            if disabled_model and os.path.abspath(
                disabled_model.original_file_path
            ) == os.path.abspath(fname):
                raise SQLTemplaterSkipFile(
                    f"Skipped file {fname} because the model was disabled"
                )
            raise RuntimeError("File %s was not found in dbt project" % fname)

        node = self.dbt_compiler.compile_node(
            node=results[0],
            manifest=self.dbt_manifest,
        )

        if hasattr(node, "injected_sql"):
            # If injected SQL is present, it contains a better picture
            # of what will actually hit the database (e.g. with tests).
            # However it's not always present.
            compiled_sql = node.injected_sql
        else:
            compiled_sql = node.compiled_sql

        if not compiled_sql:
            raise SQLTemplaterError(
                "dbt templater compilation failed silently, check your configuration "
                "by running `dbt compile` directly."
            )

        with open(fname) as source_dbt_model:
            source_dbt_sql = source_dbt_model.read()

        n_trailing_newlines = len(source_dbt_sql) - len(source_dbt_sql.rstrip("\n"))

        templater_logger.debug(
            "    Trailing newline count in source dbt model: %r", n_trailing_newlines
        )
        templater_logger.debug("    Raw SQL before compile: %r", source_dbt_sql)
        templater_logger.debug("    Node raw SQL: %r", node.raw_sql)
        templater_logger.debug("    Node compiled SQL: %r", compiled_sql)

        # When using dbt-templater, trailing newlines are ALWAYS REMOVED during
        # compiling. Unless fixed (like below), this will cause:
        #    1. L009 linting errors when running "sqlfluff lint foo_bar.sql"
        #       since the linter will use the compiled code with the newlines
        #       removed.
        #    2. "No newline at end of file" warnings in Git/GitHub since
        #       sqlfluff uses the compiled SQL to write fixes back to the
        #       source SQL in the dbt model.
        # The solution is:
        #    1. Check for trailing newlines before compiling by looking at the
        #       raw SQL in the source dbt file, store the count of trailing newlines.
        #    2. Append the count from #1 above to the node.raw_sql and
        #       compiled_sql objects, both of which have had the trailing
        #       newlines removed by the dbt-templater.
        node.raw_sql = node.raw_sql + "\n" * n_trailing_newlines
        compiled_sql = compiled_sql + "\n" * n_trailing_newlines

        raw_sliced, sliced_file, templated_sql = self.slice_file(
            node.raw_sql,
            compiled_sql,
            config=config,
        )

        return (
            TemplatedFile(
                source_str=node.raw_sql,
                templated_str=templated_sql,
                fname=fname,
                sliced_file=sliced_file,
                raw_sliced=raw_sliced,
            ),
            # No violations returned in this way.
            [],
        )
Esempio n. 20
0
    def run(self):
        compiler = Compiler(self.project, BaseCreateTemplate)
        compiler.initialize()
        created_models, created_analyses = compiler.compile()

        print("Compiled {} models and {} analyses".format(created_models, created_analyses))
Esempio n. 21
0
    def compile(self):
        compiler = Compiler(self.project, TestCreateTemplate)
        compiler.initialize()
        compiler.compile()

        return compiler
Esempio n. 22
0
def new_test_compiler():
    project = new_project()
    create_template = DryCreateTemplate
    compiler = Compiler(project, create_template)
    return compiler