예제 #1
0
def test__cli__formatters__violation(tmpdir):
    """Test formatting violations.

    NB Position is 1 + start_pos.
    """
    s = RawSegment(
        "foobarbar",
        PositionMarker(
            slice(10, 19),
            slice(10, 19),
            TemplatedFile.from_string("      \n\n  foobarbar"),
        ),
    )
    r = RuleGhost("A", "DESC")
    v = SQLLintError(segment=s, rule=r)
    formatter = OutputStreamFormatter(
        FileOutput(FluffConfig(require_dialect=False), str(tmpdir / "out.txt")), False
    )
    f = formatter.format_violation(v)
    # Position is 3, 3 becase foobarbar is on the third
    # line (i.e. it has two newlines preceding it) and
    # it's at the third position in that line (i.e. there
    # are two characters between it and the preceding
    # newline).
    assert escape_ansi(f) == "L:   3 | P:   3 |    A | DESC"
예제 #2
0
    def _unsafe_process(self, fname, in_str=None, config=None):
        if not config:
            raise ValueError(
                "For the dbt templater, the `process()` method requires a config object."
            )
        if not fname:
            raise ValueError(
                "For the dbt templater, the `process()` method requires a file name"
            )
        elif fname == "stdin":
            raise ValueError(
                "The dbt templater does not support stdin input, provide a path instead"
            )
        self.sqlfluff_config = config

        selected = self.dbt_selector_method.search(
            included_nodes=self.dbt_manifest.nodes,
            # Selector needs to be a relative path
            selector=os.path.relpath(fname, start=os.getcwd()),
        )
        results = [self.dbt_manifest.expect(uid) for uid in selected]

        if not results:
            raise RuntimeError("File %s was not found in dbt project" % fname)

        node = self.dbt_compiler.compile_node(
            node=results[0],
            manifest=self.dbt_manifest,
        )

        if hasattr(node, "injected_sql"):
            # If injected SQL is present, it contains a better picture
            # of what will actually hit the database (e.g. with tests).
            # However it's not always present.
            compiled_sql = node.injected_sql
        else:
            compiled_sql = node.compiled_sql

        if not compiled_sql:
            raise SQLTemplaterError(
                "dbt templater compilation failed silently, check your configuration "
                "by running `dbt compile` directly.")

        raw_sliced, sliced_file, templated_sql = self.slice_file(node.raw_sql,
                                                                 compiled_sql,
                                                                 config=config)
        return (
            TemplatedFile(
                source_str=node.raw_sql,
                templated_str=templated_sql,
                fname=fname,
                sliced_file=sliced_file,
                raw_sliced=raw_sliced,
            ),
            # No violations returned in this way.
            [],
        )
예제 #3
0
 def _raw_slices_from_templated_slices(
     templated_file: TemplatedFile,
     templated_slices: List[slice],
     file_end_slice: Optional[RawFileSlice] = None,
 ) -> Set[RawFileSlice]:
     raw_slices: Set[RawFileSlice] = set()
     for templated_slice in templated_slices:
         try:
             raw_slices.update(
                 templated_file.raw_slices_spanning_source_slice(
                     templated_file.templated_slice_to_source_slice(
                         templated_slice)))
         except (IndexError, ValueError):
             # These errors will happen with "create_before" at the beginning
             # of the file or "create_after" at the end of the file. By
             # default, we ignore this situation. If the caller passed
             # "file_end_slice", add that to the result. In effect,
             # file_end_slice serves as a placeholder or sentinel value.
             if file_end_slice is not None:
                 raw_slices.add(file_end_slice)
     return raw_slices
예제 #4
0
    def process(self,
                *,
                in_str: str,
                fname: Optional[str] = None,
                config=None,
                formatter=None) -> Tuple[Optional[TemplatedFile], list]:
        """Process a string and return a TemplatedFile.

        Note that the arguments are enforced as keywords
        because Templaters can have differences in their
        `process` method signature.
        A Templater that only supports reading from a file
        would need the following signature:
            process(*, fname, in_str=None, config=None)
        (arguments are swapped)

        Args:
            in_str (:obj:`str`): The input string.
            fname (:obj:`str`, optional): The filename of this string. This is
                mostly for loading config files at runtime.
            config (:obj:`FluffConfig`): A specific config to use for this
                templating operation. Only necessary for some templaters.
            formatter (:obj:`CallbackFormatter`): Optional object for output.

        """
        live_context = self.get_context(fname=fname, config=config)
        try:
            new_str = in_str.format(**live_context)
        except KeyError as err:
            # TODO: Add a url here so people can get more help.
            raise SQLTemplaterError(
                "Failure in Python templating: {0}. Have you configured your variables?"
                .format(err))
        raw_sliced, sliced_file, new_str = self.slice_file(in_str,
                                                           new_str,
                                                           config=config)
        return (
            TemplatedFile(
                source_str=in_str,
                templated_str=new_str,
                fname=fname,
                sliced_file=sliced_file,
                raw_sliced=raw_sliced,
            ),
            [],
        )
예제 #5
0
def test__cli__formatters__violation():
    """Test formatting violations.

    NB Position is 1 + start_pos.
    """
    s = RawSegment(
        "foobarbar",
        PositionMarker(
            slice(10, 19),
            slice(10, 19),
            TemplatedFile.from_string("      \n\n  foobarbar"),
        ),
    )
    r = RuleGhost("A", "DESC")
    v = SQLLintError(segment=s, rule=r)
    f = format_violation(v)
    # Position is 3, 3 becase foobarbar is on the third
    # line (i.e. it has two newlines preceding it) and
    # it's at the third position in that line (i.e. there
    # are two characters between it and the preceeding
    # newline).
    assert escape_ansi(f) == "L:   3 | P:   3 |    A | DESC"
예제 #6
0
def test__templater_jinja_slice_file(raw_file, override_context, result, caplog):
    """Test slice_file."""
    templater = JinjaTemplater(override_context=override_context)
    env, live_context, make_template = templater.template_builder(
        config=FluffConfig.from_path(
            "test/fixtures/templater/jinja_slice_template_macros"
        )
    )

    templated_file = make_template(raw_file).render()
    with caplog.at_level(logging.DEBUG, logger="sqlfluff.templater"):
        raw_sliced, sliced_file, templated_str = templater.slice_file(
            raw_file, templated_file, make_template=make_template
        )
    # Create a TemplatedFile from the results. This runs some useful sanity
    # checks.
    _ = TemplatedFile(raw_file, "<<DUMMY>>", templated_str, sliced_file, raw_sliced)
    # Check contiguous on the TEMPLATED VERSION
    print(sliced_file)
    prev_slice = None
    for elem in sliced_file:
        print(elem)
        if prev_slice:
            assert elem[2].start == prev_slice.stop
        prev_slice = elem[2]
    # Check that all literal segments have a raw slice
    for elem in sliced_file:
        if elem[0] == "literal":
            assert elem[1] is not None
    # check result
    actual = [
        (
            templated_file_slice.slice_type,
            templated_file_slice.source_slice,
            templated_file_slice.templated_slice,
        )
        for templated_file_slice in sliced_file
    ]
    assert actual == result
예제 #7
0
        )

    e.match("Rule classes must be named in the format of")


def test_rule_set_return_informative_error_when_rule_not_registered():
    """Assert that a rule that throws an exception returns it as a validation."""
    cfg = FluffConfig(overrides={"dialect": "ansi"})
    with pytest.raises(ValueError) as e:
        get_rule_from_set("L000", config=cfg)

    e.match("'L000' not in")


seg = WhitespaceSegment(pos_marker=PositionMarker(
    slice(0, 1), slice(0, 1), TemplatedFile(" ", fname="<str>")))


@pytest.mark.parametrize(
    "lint_result, expected",
    [
        (LintResult(), "LintResult(<empty>)"),
        (LintResult(seg),
         "LintResult(<WhitespaceSegment: ([L:  1, P:  1]) ' '>)"),
        (
            LintResult(seg, description="foo"),
            "LintResult(foo: <WhitespaceSegment: ([L:  1, P:  1]) ' '>)",
        ),
        (
            LintResult(
                seg,
예제 #8
0
    def _unsafe_process(self, fname, in_str=None, config=None):
        if not config:
            raise ValueError(
                "For the dbt templater, the `process()` method requires a config object."
            )
        if not fname:
            raise ValueError(
                "For the dbt templater, the `process()` method requires a file name"
            )
        elif fname == "stdin":
            raise ValueError(
                "The dbt templater does not support stdin input, provide a path instead"
            )
        selected = self.dbt_selector_method.search(
            included_nodes=self.dbt_manifest.nodes,
            # Selector needs to be a relative path
            selector=os.path.relpath(fname, start=os.getcwd()),
        )
        results = [self.dbt_manifest.expect(uid) for uid in selected]

        if not results:
            model_name = os.path.splitext(os.path.basename(fname))[0]
            disabled_model = self.dbt_manifest.find_disabled_by_name(name=model_name)
            if disabled_model and os.path.abspath(
                disabled_model.original_file_path
            ) == os.path.abspath(fname):
                raise SQLTemplaterSkipFile(
                    f"Skipped file {fname} because the model was disabled"
                )
            raise RuntimeError("File %s was not found in dbt project" % fname)

        node = self.dbt_compiler.compile_node(
            node=results[0],
            manifest=self.dbt_manifest,
        )

        if hasattr(node, "injected_sql"):
            # If injected SQL is present, it contains a better picture
            # of what will actually hit the database (e.g. with tests).
            # However it's not always present.
            compiled_sql = node.injected_sql
        else:
            compiled_sql = node.compiled_sql

        if not compiled_sql:
            raise SQLTemplaterError(
                "dbt templater compilation failed silently, check your configuration "
                "by running `dbt compile` directly."
            )

        with open(fname) as source_dbt_model:
            source_dbt_sql = source_dbt_model.read()

        n_trailing_newlines = len(source_dbt_sql) - len(source_dbt_sql.rstrip("\n"))

        templater_logger.debug(
            "    Trailing newline count in source dbt model: %r", n_trailing_newlines
        )
        templater_logger.debug("    Raw SQL before compile: %r", source_dbt_sql)
        templater_logger.debug("    Node raw SQL: %r", node.raw_sql)
        templater_logger.debug("    Node compiled SQL: %r", compiled_sql)

        # When using dbt-templater, trailing newlines are ALWAYS REMOVED during
        # compiling. Unless fixed (like below), this will cause:
        #    1. L009 linting errors when running "sqlfluff lint foo_bar.sql"
        #       since the linter will use the compiled code with the newlines
        #       removed.
        #    2. "No newline at end of file" warnings in Git/GitHub since
        #       sqlfluff uses the compiled SQL to write fixes back to the
        #       source SQL in the dbt model.
        # The solution is:
        #    1. Check for trailing newlines before compiling by looking at the
        #       raw SQL in the source dbt file, store the count of trailing newlines.
        #    2. Append the count from #1 above to the node.raw_sql and
        #       compiled_sql objects, both of which have had the trailing
        #       newlines removed by the dbt-templater.
        node.raw_sql = node.raw_sql + "\n" * n_trailing_newlines
        compiled_sql = compiled_sql + "\n" * n_trailing_newlines

        raw_sliced, sliced_file, templated_sql = self.slice_file(
            node.raw_sql,
            compiled_sql,
            config=config,
        )

        return (
            TemplatedFile(
                source_str=node.raw_sql,
                templated_str=templated_sql,
                fname=fname,
                sliced_file=sliced_file,
                raw_sliced=raw_sliced,
            ),
            # No violations returned in this way.
            [],
        )
예제 #9
0
파일: jinja.py 프로젝트: sti0/sqlfluff
    def process(self,
                *,
                in_str: str,
                fname: str,
                config=None,
                formatter=None) -> Tuple[Optional[TemplatedFile], list]:
        """Process a string and return the new string.

        Note that the arguments are enforced as keywords
        because Templaters can have differences in their
        `process` method signature.
        A Templater that only supports reading from a file
        would need the following signature:
            process(*, fname, in_str=None, config=None)
        (arguments are swapped)

        Args:
            in_str (:obj:`str`): The input string.
            fname (:obj:`str`, optional): The filename of this string. This is
                mostly for loading config files at runtime.
            config (:obj:`FluffConfig`): A specific config to use for this
                templating operation. Only necessary for some templaters.
            formatter (:obj:`CallbackFormatter`): Optional object for output.

        """
        if not config:  # pragma: no cover
            raise ValueError(
                "For the jinja templater, the `process()` method requires a config "
                "object.")

        env, live_context, make_template = self.template_builder(fname=fname,
                                                                 config=config)

        # Load the template, passing the global context.
        try:
            template = make_template(in_str)
        except TemplateSyntaxError as err:
            # Something in the template didn't parse, return the original
            # and a violation around what happened.
            return (
                TemplatedFile(source_str=in_str, fname=fname),
                [
                    SQLTemplaterError(
                        f"Failure to parse jinja template: {err}.",
                        line_no=err.lineno,
                    )
                ],
            )

        violations = []

        # Attempt to identify any undeclared variables. The majority
        # will be found during the _crawl_tree step rather than this
        # first Exception which serves only to catch catastrophic errors.
        try:
            syntax_tree = env.parse(in_str)
            undefined_variables = meta.find_undeclared_variables(syntax_tree)
        except Exception as err:  # pragma: no cover
            # TODO: Add a url here so people can get more help.
            raise SQLTemplaterError(
                f"Failure in identifying Jinja variables: {err}.")

        # Get rid of any that *are* actually defined.
        for val in live_context:
            if val in undefined_variables:
                undefined_variables.remove(val)

        if undefined_variables:
            # Lets go through and find out where they are:
            for val in self._crawl_tree(syntax_tree, undefined_variables,
                                        in_str):
                violations.append(val)

        try:
            # NB: Passing no context. Everything is loaded when the template is loaded.
            out_str = template.render()
            # Slice the file once rendered.
            raw_sliced, sliced_file, out_str = self.slice_file(
                in_str,
                out_str,
                config=config,
                make_template=make_template,
            )
            return (
                TemplatedFile(
                    source_str=in_str,
                    templated_str=out_str,
                    fname=fname,
                    sliced_file=sliced_file,
                    raw_sliced=raw_sliced,
                ),
                violations,
            )
        except (TemplateError, TypeError) as err:
            templater_logger.info("Unrecoverable Jinja Error: %s", err)
            violations.append(
                SQLTemplaterError(
                    ("Unrecoverable failure in Jinja templating: {}. Have you "
                     "configured your variables? "
                     "https://docs.sqlfluff.com/en/latest/configuration.html"
                     ).format(err),
                    # We don't have actual line number information, but specify
                    # line 1 so users can ignore with "noqa" if they want. (The
                    # default is line 0, which can't be ignored because it's not
                    # a valid line number.)
                    line_no=1,
                    line_pos=1,
                ))
            return None, violations
예제 #10
0
    def process(
        self, *, in_str: str, fname: Optional[str] = None, config=None
    ) -> Tuple[Optional[TemplatedFile], list]:
        """Process a string and return the new string.

        Note that the arguments are enforced as keywords
        because Templaters can have differences in their
        `process` method signature.
        A Templater that only supports reading from a file
        would need the following signature:
            process(*, fname, in_str=None, config=None)
        (arguments are swapped)

        Args:
            in_str (:obj:`str`): The input string.
            fname (:obj:`str`, optional): The filename of this string. This is
                mostly for loading config files at runtime.
            config (:obj:`FluffConfig`): A specific config to use for this
                templating operation. Only necessary for some templaters.

        """
        if not config:
            raise ValueError(
                "For the jinja templater, the `process()` method requires a config object."
            )

        # Load the context
        live_context = self.get_context(fname=fname, config=config)
        # Apply dbt builtin functions if we're allowed.
        apply_dbt_builtins = config.get_section(
            (self.templater_selector, self.name, "apply_dbt_builtins")
        )
        if apply_dbt_builtins:
            # This feels a bit wrong defining these here, they should probably
            # be configurable somewhere sensible. But for now they're not.
            # TODO: Come up with a better solution.
            dbt_builtins = self._generate_dbt_builtins()
            for name in dbt_builtins:
                # Only apply if it hasn't already been set at this stage.
                if name not in live_context:
                    live_context[name] = dbt_builtins[name]

        env = self._get_jinja_env()

        # Load macros from path (if applicable)
        macros_path = config.get_section(
            (self.templater_selector, self.name, "load_macros_from_path")
        )
        if macros_path:
            live_context.update(
                self._extract_macros_from_path(macros_path, env=env, ctx=live_context)
            )

        # Load config macros, these will take precedence over macros from the path
        live_context.update(
            self._extract_macros_from_config(config=config, env=env, ctx=live_context)
        )

        live_context.update(self._extract_libraries_from_config(config=config))

        # Load the template, passing the global context.
        try:
            template = env.from_string(in_str, globals=live_context)
        except TemplateSyntaxError as err:
            # Something in the template didn't parse, return the original
            # and a violation around what happened.
            (len(line) for line in in_str.split("\n")[: err.lineno])
            return (
                TemplatedFile(source_str=in_str, fname=fname),
                [
                    SQLTemplaterError(
                        "Failure to parse jinja template: {0}.".format(err),
                        pos=FilePositionMarker(
                            None,
                            err.lineno,
                            None,
                            # Calculate the charpos for sorting.
                            sum(
                                len(line)
                                for line in in_str.split("\n")[: err.lineno - 1]
                            ),
                        ),
                    )
                ],
            )

        violations = []

        # Attempt to identify any undeclared variables. The majority
        # will be found during the _crawl_tree step rather than this
        # first Exception which serves only to catch catastrophic errors.
        try:
            syntax_tree = env.parse(in_str)
            undefined_variables = meta.find_undeclared_variables(syntax_tree)
        except Exception as err:
            # TODO: Add a url here so people can get more help.
            raise SQLTemplaterError(
                "Failure in identifying Jinja variables: {0}.".format(err)
            )

        # Get rid of any that *are* actually defined.
        for val in live_context:
            if val in undefined_variables:
                undefined_variables.remove(val)

        if undefined_variables:
            # Lets go through and find out where they are:
            for val in self._crawl_tree(syntax_tree, undefined_variables, in_str):
                violations.append(val)

        try:
            # NB: Passing no context. Everything is loaded when the template is loaded.
            out_str = template.render()
            # Slice the file once rendered.
            raw_sliced, sliced_file, out_str = self.slice_file(
                in_str, out_str, config=config
            )
            return (
                TemplatedFile(
                    source_str=in_str,
                    templated_str=out_str,
                    fname=fname,
                    sliced_file=sliced_file,
                    raw_sliced=raw_sliced,
                ),
                violations,
            )
        except (TemplateError, TypeError) as err:
            templater_logger.info("Unrecoverable Jinja Error: %s", err)
            violations.append(
                SQLTemplaterError(
                    (
                        "Unrecoverable failure in Jinja templating: {0}. Have you configured "
                        "your variables? https://docs.sqlfluff.com/en/latest/configuration.html"
                    ).format(err)
                )
            )
            return None, violations
예제 #11
0
    def process(self,
                *,
                in_str: str,
                fname: str,
                config=None,
                formatter=None) -> Tuple[Optional[TemplatedFile], list]:
        """Process a string and return a TemplatedFile.

        Note that the arguments are enforced as keywords
        because Templaters can have differences in their
        `process` method signature.
        A Templater that only supports reading from a file
        would need the following signature:
            process(*, fname, in_str=None, config=None)
        (arguments are swapped)

        Args:
            in_str (:obj:`str`): The input string.
            fname (:obj:`str`, optional): The filename of this string. This is
                mostly for loading config files at runtime.
            config (:obj:`FluffConfig`): A specific config to use for this
                templating operation. Only necessary for some templaters.
            formatter (:obj:`CallbackFormatter`): Optional object for output.

        """
        context = self.get_context(config)
        template_slices = []
        raw_slices = []
        last_pos_raw, last_pos_templated = 0, 0
        out_str = ""

        regex = context["__bind_param_regex"]
        # when the param has no name, use a 1-based index
        param_counter = 1
        for found_param in regex.finditer(in_str):
            span = found_param.span()
            if "param_name" not in found_param.groupdict():
                param_name = str(param_counter)
                param_counter += 1
            else:
                param_name = found_param["param_name"]
            last_literal_length = span[0] - last_pos_raw
            try:
                replacement = str(context[param_name])
            except KeyError as err:
                # TODO: Add a url here so people can get more help.
                raise SQLTemplaterError(
                    "Failure in placeholder templating: {}. Have you configured your "
                    "variables?".format(err))
            # add the literal to the slices
            template_slices.append(
                TemplatedFileSlice(
                    slice_type="literal",
                    source_slice=slice(last_pos_raw, span[0], None),
                    templated_slice=slice(
                        last_pos_templated,
                        last_pos_templated + last_literal_length,
                        None,
                    ),
                ))
            raw_slices.append(
                RawFileSlice(
                    raw=in_str[last_pos_raw:span[0]],
                    slice_type="literal",
                    source_idx=last_pos_raw,
                ))
            out_str += in_str[last_pos_raw:span[0]]
            # add the current replaced element
            start_template_pos = last_pos_templated + last_literal_length
            template_slices.append(
                TemplatedFileSlice(
                    slice_type="templated",
                    source_slice=slice(span[0], span[1], None),
                    templated_slice=slice(
                        start_template_pos,
                        start_template_pos + len(replacement), None),
                ))
            raw_slices.append(
                RawFileSlice(
                    raw=in_str[span[0]:span[1]],
                    slice_type="templated",
                    source_idx=span[0],
                ))
            out_str += replacement
            # update the indexes
            last_pos_raw = span[1]
            last_pos_templated = start_template_pos + len(replacement)
        # add the last literal, if any
        if len(in_str) > last_pos_raw:
            template_slices.append(
                TemplatedFileSlice(
                    slice_type="literal",
                    source_slice=slice(last_pos_raw, len(in_str), None),
                    templated_slice=slice(
                        last_pos_templated,
                        last_pos_templated + (len(in_str) - last_pos_raw),
                        None,
                    ),
                ))
            raw_slices.append(
                RawFileSlice(
                    raw=in_str[last_pos_raw:],
                    slice_type="literal",
                    source_idx=last_pos_raw,
                ))
            out_str += in_str[last_pos_raw:]
        return (
            TemplatedFile(
                # original string
                source_str=in_str,
                # string after all replacements
                templated_str=out_str,
                # filename
                fname=fname,
                # list of TemplatedFileSlice
                sliced_file=template_slices,
                # list of RawFileSlice, same size
                raw_sliced=raw_slices,
            ),
            [],  # violations, always empty
        )
예제 #12
0
    def _unsafe_process(self, fname, in_str=None, config=None):
        original_file_path = os.path.relpath(fname, start=os.getcwd())

        # Below, we monkeypatch Environment.from_string() to intercept when dbt
        # compiles (i.e. runs Jinja) to expand the "node" corresponding to fname.
        # We do this to capture the Jinja context at the time of compilation, i.e.:
        # - Jinja Environment object
        # - Jinja "globals" dictionary
        #
        # This info is captured by the "make_template()" function, which in
        # turn is used by our parent class' (JinjaTemplater) slice_file()
        # function.
        old_from_string = Environment.from_string
        make_template = None

        def from_string(*args, **kwargs):
            """Replaces (via monkeypatch) the jinja2.Environment function."""
            nonlocal make_template
            # Is it processing the node corresponding to fname?
            globals = kwargs.get("globals")
            if globals:
                model = globals.get("model")
                if model:
                    if model.get("original_file_path") == original_file_path:
                        # Yes. Capture the important arguments and create
                        # a make_template() function.
                        env = args[0]
                        globals = args[2] if len(args) >= 3 else kwargs["globals"]

                        def make_template(in_str):
                            env.add_extension(SnapshotExtension)
                            return env.from_string(in_str, globals=globals)

            return old_from_string(*args, **kwargs)

        node = self._find_node(fname, config)
        templater_logger.debug(
            "_find_node for path %r returned object of type %s.", fname, type(node)
        )

        save_ephemeral_nodes = dict(
            (k, v)
            for k, v in self.dbt_manifest.nodes.items()
            if v.config.materialized == "ephemeral"
            and not getattr(v, "compiled", False)
        )
        with self.connection():
            # Apply the monkeypatch.
            Environment.from_string = from_string
            try:
                node = self.dbt_compiler.compile_node(
                    node=node,
                    manifest=self.dbt_manifest,
                )
            except Exception as err:
                templater_logger.exception(
                    "Fatal dbt compilation error on %s. This occurs most often "
                    "during incorrect sorting of ephemeral models before linting. "
                    "Please report this error on github at "
                    "https://github.com/sqlfluff/sqlfluff/issues, including "
                    "both the raw and compiled sql for the model affected.",
                    fname,
                )
                # Additional error logging in case we get a fatal dbt error.
                raise SQLFluffSkipFile(  # pragma: no cover
                    f"Skipped file {fname} because dbt raised a fatal "
                    f"exception during compilation: {err!s}"
                ) from err
            finally:
                # Undo the monkeypatch.
                Environment.from_string = old_from_string

            if hasattr(node, "injected_sql"):
                # If injected SQL is present, it contains a better picture
                # of what will actually hit the database (e.g. with tests).
                # However it's not always present.
                compiled_sql = node.injected_sql
            else:
                compiled_sql = getattr(node, COMPILED_SQL_ATTRIBUTE)

            raw_sql = getattr(node, RAW_SQL_ATTRIBUTE)

            if not compiled_sql:  # pragma: no cover
                raise SQLTemplaterError(
                    "dbt templater compilation failed silently, check your "
                    "configuration by running `dbt compile` directly."
                )

            with open(fname) as source_dbt_model:
                source_dbt_sql = source_dbt_model.read()

            if not source_dbt_sql.rstrip().endswith("-%}"):
                n_trailing_newlines = len(source_dbt_sql) - len(
                    source_dbt_sql.rstrip("\n")
                )
            else:
                # Source file ends with right whitespace stripping, so there's
                # no need to preserve/restore trailing newlines, as they would
                # have been removed regardless of dbt's
                # keep_trailing_newlines=False behavior.
                n_trailing_newlines = 0

            templater_logger.debug(
                "    Trailing newline count in source dbt model: %r",
                n_trailing_newlines,
            )
            templater_logger.debug("    Raw SQL before compile: %r", source_dbt_sql)
            templater_logger.debug("    Node raw SQL: %r", raw_sql)
            templater_logger.debug("    Node compiled SQL: %r", compiled_sql)

            # When using dbt-templater, trailing newlines are ALWAYS REMOVED during
            # compiling. Unless fixed (like below), this will cause:
            #    1. Assertion errors in TemplatedFile, when it sanity checks the
            #       contents of the sliced_file array.
            #    2. L009 linting errors when running "sqlfluff lint foo_bar.sql"
            #       since the linter will use the compiled code with the newlines
            #       removed.
            #    3. "No newline at end of file" warnings in Git/GitHub since
            #       sqlfluff uses the compiled SQL to write fixes back to the
            #       source SQL in the dbt model.
            #
            # The solution is (note that both the raw and compiled files have
            # had trailing newline(s) removed by the dbt-templater.
            #    1. Check for trailing newlines before compiling by looking at the
            #       raw SQL in the source dbt file. Remember the count of trailing
            #       newlines.
            #    2. Set node.raw_sql/node.raw_code to the original source file contents.
            #    3. Append the count from #1 above to compiled_sql. (In
            #       production, slice_file() does not usually use this string,
            #       but some test scenarios do.
            setattr(node, RAW_SQL_ATTRIBUTE, source_dbt_sql)
            compiled_sql = compiled_sql + "\n" * n_trailing_newlines

            # TRICKY: dbt configures Jinja2 with keep_trailing_newline=False.
            # As documented (https://jinja.palletsprojects.com/en/3.0.x/api/),
            # this flag's behavior is: "Preserve the trailing newline when
            # rendering templates. The default is False, which causes a single
            # newline, if present, to be stripped from the end of the template."
            #
            # Below, we use "append_to_templated" to effectively "undo" this.
            raw_sliced, sliced_file, templated_sql = self.slice_file(
                source_dbt_sql,
                compiled_sql,
                config=config,
                make_template=make_template,
                append_to_templated="\n" if n_trailing_newlines else "",
            )
        # :HACK: If calling compile_node() compiled any ephemeral nodes,
        # restore them to their earlier state. This prevents a runtime error
        # in the dbt "_inject_ctes_into_sql()" function that occurs with
        # 2nd-level ephemeral model dependencies (e.g. A -> B -> C, where
        # both B and C are ephemeral). Perhaps there is a better way to do
        # this, but this seems good enough for now.
        for k, v in save_ephemeral_nodes.items():
            if getattr(self.dbt_manifest.nodes[k], "compiled", False):
                self.dbt_manifest.nodes[k] = v
        return (
            TemplatedFile(
                source_str=source_dbt_sql,
                templated_str=templated_sql,
                fname=fname,
                sliced_file=sliced_file,
                raw_sliced=raw_sliced,
            ),
            # No violations returned in this way.
            [],
        )
예제 #13
0
    def _unsafe_process(self, fname, in_str=None, config=None):
        original_file_path = os.path.relpath(fname, start=os.getcwd())

        # Below, we monkeypatch Environment.from_string() to intercept when dbt
        # compiles (i.e. runs Jinja) to expand the "node" corresponding to fname.
        # We do this to capture the Jinja context at the time of compilation, i.e.:
        # - Jinja Environment object
        # - Jinja "globals" dictionary
        #
        # This info is captured by the "make_template()" function, which in
        # turn is used by our parent class' (JinjaTemplater) slice_file()
        # function.
        old_from_string = Environment.from_string
        try:
            make_template = None

            def from_string(*args, **kwargs):
                """Replaces (via monkeypatch) the jinja2.Environment function."""
                nonlocal make_template
                # Is it processing the node corresponding to fname?
                globals = kwargs.get("globals")
                if globals:
                    model = globals.get("model")
                    if model:
                        if model.get(
                                "original_file_path") == original_file_path:
                            # Yes. Capture the important arguments and create
                            # a make_template() function.
                            env = args[0]
                            globals = args[2] if len(
                                args) >= 3 else kwargs["globals"]

                            def make_template(in_str):
                                env.add_extension(SnapshotExtension)
                                return env.from_string(in_str, globals=globals)

                return old_from_string(*args, **kwargs)

        finally:
            # Undo the monkeypatch.
            Environment.from_string = from_string

        node = self._find_node(fname, config)

        save_ephemeral_nodes = dict(
            (k, v) for k, v in self.dbt_manifest.nodes.items()
            if v.config.materialized == "ephemeral"
            and not getattr(v, "compiled", False))
        with self.connection():
            node = self.dbt_compiler.compile_node(
                node=node,
                manifest=self.dbt_manifest,
            )

            Environment.from_string = old_from_string

            if hasattr(node, "injected_sql"):
                # If injected SQL is present, it contains a better picture
                # of what will actually hit the database (e.g. with tests).
                # However it's not always present.
                compiled_sql = node.injected_sql
            else:
                compiled_sql = node.compiled_sql

            if not compiled_sql:  # pragma: no cover
                raise SQLTemplaterError(
                    "dbt templater compilation failed silently, check your "
                    "configuration by running `dbt compile` directly.")

            with open(fname) as source_dbt_model:
                source_dbt_sql = source_dbt_model.read()

            n_trailing_newlines = len(source_dbt_sql) - len(
                source_dbt_sql.rstrip("\n"))

            templater_logger.debug(
                "    Trailing newline count in source dbt model: %r",
                n_trailing_newlines,
            )
            templater_logger.debug("    Raw SQL before compile: %r",
                                   source_dbt_sql)
            templater_logger.debug("    Node raw SQL: %r", node.raw_sql)
            templater_logger.debug("    Node compiled SQL: %r", compiled_sql)

            # When using dbt-templater, trailing newlines are ALWAYS REMOVED during
            # compiling. Unless fixed (like below), this will cause:
            #    1. L009 linting errors when running "sqlfluff lint foo_bar.sql"
            #       since the linter will use the compiled code with the newlines
            #       removed.
            #    2. "No newline at end of file" warnings in Git/GitHub since
            #       sqlfluff uses the compiled SQL to write fixes back to the
            #       source SQL in the dbt model.
            # The solution is:
            #    1. Check for trailing newlines before compiling by looking at the
            #       raw SQL in the source dbt file, store the count of trailing
            #       newlines.
            #    2. Append the count from #1 above to the node.raw_sql and
            #       compiled_sql objects, both of which have had the trailing
            #       newlines removed by the dbt-templater.
            node.raw_sql = node.raw_sql + "\n" * n_trailing_newlines
            compiled_sql = compiled_sql + "\n" * n_trailing_newlines

            raw_sliced, sliced_file, templated_sql = self.slice_file(
                source_dbt_sql,
                compiled_sql,
                config=config,
                make_template=make_template,
            )
        # :HACK: If calling compile_node() compiled any ephemeral nodes,
        # restore them to their earlier state. This prevents a runtime error
        # in the dbt "_inject_ctes_into_sql()" function that occurs with
        # 2nd-level ephemeral model dependencies (e.g. A -> B -> C, where
        # both B and C are ephemeral). Perhaps there is a better way to do
        # this, but this seems good enough for now.
        for k, v in save_ephemeral_nodes.items():
            if getattr(self.dbt_manifest.nodes[k], "compiled", False):
                self.dbt_manifest.nodes[k] = v

        if make_template and n_trailing_newlines:
            # Update templated_sql as we updated the other strings above. Update
            # sliced_file to reflect the mapping of the added character(s) back
            # to the raw SQL.
            templated_sql = templated_sql + "\n" * n_trailing_newlines
            sliced_file.append(
                TemplatedFileSlice(
                    slice_type="literal",
                    source_slice=slice(
                        len(source_dbt_sql) - n_trailing_newlines,
                        len(source_dbt_sql)),
                    templated_slice=slice(
                        len(templated_sql) - n_trailing_newlines,
                        len(templated_sql)),
                ))
        return (
            TemplatedFile(
                source_str=source_dbt_sql,
                templated_str=templated_sql,
                fname=fname,
                sliced_file=sliced_file,
                raw_sliced=raw_sliced,
            ),
            # No violations returned in this way.
            [],
        )