예제 #1
0
def has_description(paths: Sequence[str], manifest: Dict[str, Any]) -> int:
    status_code = 0
    sqls = get_filenames(paths, [".sql"])
    ymls = get_filenames(paths, [".yml", ".yaml"])
    filenames = set(sqls.keys())

    # get manifest nodes that pre-commit found as changed
    models = get_models(manifest, filenames)
    # if user added schema but did not rerun the model
    schemas = get_model_schemas(list(ymls.values()), filenames)
    # convert to sets
    in_models = {
        model.filename
        for model in models if model.node.get("description")
    }
    in_schemas = {
        schema.model_name
        for schema in schemas if schema.schema.get("description")
    }
    missing = filenames.difference(in_models, in_schemas)

    for model in missing:
        status_code = 1
        print(
            f"{sqls.get(model)}: "
            f"does not have defined description or properties file is missing.",
        )
    return status_code
def has_meta_key(paths: Sequence[str], manifest: Dict[str, Any],
                 meta_keys: Sequence[str]) -> int:
    status_code = 0
    sqls = get_filenames(paths, [".sql"])
    ymls = get_filenames(paths, [".yml", ".yaml"])
    filenames = set(sqls.keys())
    # get manifest nodes that pre-commit found as changed
    models = get_models(manifest, filenames)
    # if user added schema but did not rerun the model
    schemas = get_model_schemas(list(ymls.values()), filenames)
    # convert to sets
    in_models = {
        model.filename
        for model in models
        if set(model.node.get("meta", {}).keys()) == set(meta_keys)
    }
    in_schemas = {
        schema.model_name
        for schema in schemas
        if set(schema.schema.get("meta", {}).keys()) == set(meta_keys)
    }
    missing = filenames.difference(in_models, in_schemas)

    for model in missing:
        status_code = 1
        result = "\n- ".join(list(meta_keys))  # pragma: no mutate
        print(
            f"{sqls.get(model)}: "
            f"does not have some of the meta keys defined:\n- {result}", )
    return status_code
def check_column_desc(paths: Sequence[str],
                      manifest: Dict[str, Any]) -> Tuple[int, Dict[str, Any]]:
    status_code = 0
    sqls = get_filenames(paths, [".sql"])
    ymls = get_filenames(paths, [".yml", ".yaml"])
    filenames = set(sqls.keys())

    # get manifest nodes that pre-commit found as changed
    models = get_models(manifest, filenames)
    # if user added schema but did not rerun the model
    schemas = get_model_schemas(list(ymls.values()), filenames)
    missing: Dict[str, Set[str]] = {}

    for item in itertools.chain(models, schemas):
        missing_cols = set()  # pragma: no mutate
        if isinstance(item, ModelSchema):
            model_name = item.model_name
            missing_cols = {
                key.get("name")
                for key in item.schema.get("columns", [])
                if not key.get("description")
            }
        # Model
        elif isinstance(item, Model):
            model_name = item.filename
            missing_cols = {
                key
                for key, value in item.node.get("columns", {}).items()
                if not value.get("description")
            }
        else:
            continue  # pragma: no cover, no mutate
        seen = missing.get(model_name)
        if seen:
            if not missing_cols:
                missing[model_name] = set()  # pragma: no mutate
            else:
                missing[model_name] = seen.union(missing_cols)
        elif missing_cols:
            missing[model_name] = missing_cols

    for model, columns in missing.items():
        if columns:
            status_code = 1
            result = "\n- ".join(list(columns))  # pragma: no mutate
            print(
                f"{sqls.get(model)}: "
                f"following columns are missing description:\n- {result}", )
    return status_code, missing
예제 #4
0
def generate_properties_file(
    paths: Sequence[str],
    manifest: Dict[str, Any],
    catalog: Dict[str, Any],
    properties_file: str,
) -> int:
    status_code = 0
    sqls = get_filenames(paths, [".sql"])
    filenames = set(sqls.keys())

    # get manifest nodes that pre-commit found as changed
    models = get_models(manifest, filenames)
    catalog_nodes = catalog.get("nodes", {})

    for model in models:
        model_prop = get_model_properties(model, catalog_nodes)
        template = {
            "database": model.node.get("database"),
            "schema": model.node.get("schema"),
            "alias": model.node.get("alias"),
            "name": model.node.get("name"),
        }
        path_template = {k: v for k, v in template.items() if v}
        if model_prop:
            status_code = 1
            write_model_properties(properties_file, model_prop, path_template)
    return status_code
def check_test_cnt(paths: Sequence[str], manifest: Dict[str, Any],
                   required_tests: Dict[str, int]) -> int:
    status_code = 0
    sqls = get_filenames(paths, [".sql"])
    filenames = set(sqls.keys())

    # get manifest nodes that pre-commit found as changed
    models = get_models(manifest, filenames)

    for model in models:
        tests = list(get_tests(manifest=manifest, obj=model))

        grouped = groupby(sorted(tests, key=lambda x: x.test_name),
                          lambda x: x.test_name)
        test_dict = {key: list(value) for key, value in grouped}
        for required_test, required_cnt in required_tests.items():
            test = test_dict.get(required_test, [])
            test_cnt = len(test)
            if not test or required_cnt > test_cnt:
                status_code = 1
                print(
                    f"{model.model_name}: "
                    f"has only {test_cnt} {required_test} tests, but "
                    f"{required_cnt} are required.", )
    return status_code
예제 #6
0
def check_parents_schema(
    paths: Sequence[str],
    manifest: Dict[str, Any],
    blacklist: Optional[Sequence[str]],
    whitelist: Optional[Sequence[str]],
) -> int:
    status_code = 0
    sqls = get_filenames(paths, [".sql"])
    filenames = set(sqls.keys())
    blacklist = blacklist or []
    whitelist = whitelist or []

    # get manifest nodes that pre-commit found as changed
    models = get_models(manifest, filenames)

    for model in models:
        parents = list(
            get_parent_childs(
                manifest=manifest,
                obj=model,
                manifest_node="parent_map",
                node_types=["model", "source"],
            ))
        for parent in parents:
            db = parent.node.get("schema")
            if (whitelist and db not in whitelist) or db in blacklist:
                status_code = 1
                print(
                    f"{model.model_name}: "
                    f"has parent {parent.node.get('name')} with invalid schema "
                    f"{db}.", )
    return status_code
def check_refs_sources(
    paths: Sequence[str], manifest: Dict[str, Any]
) -> Tuple[int, Set[str], Dict[FrozenSet[str], Dict[str, str]]]:
    status_code = 0
    sqls = get_filenames(paths, [".sql"])

    models = set()
    sources = {}
    for _, file in sqls.items():
        full_script = file.read_text(encoding="utf-8")
        src_refs = re.findall(r"\{\{\s*(source|ref)\s*\((.*)\)\s*\}\}",
                              full_script)
        for src_ref in src_refs:
            src_ref_value = src_ref[1].replace("'", "").replace('"',
                                                                "").strip()
            if src_ref[0] == "ref":
                models.add(src_ref_value)
            if src_ref[0] == "source":
                src_split = src_ref_value.split(",")
                source_name = src_split[0].strip()
                table_name = src_split[1].strip()
                src_key = frozenset([source_name, table_name])
                sources[src_key] = {
                    "source_name": source_name,
                    "table_name": table_name,
                }

    if models:
        nodes = manifest.get("nodes", {})
        for _, value in nodes.items():
            model_name = value.get("name")
            if model_name in models:
                models.remove(model_name)

    if sources:
        srcs = manifest.get("sources", {})
        for _, value in srcs.items():
            source_set = frozenset(
                [value.get("source_name"),
                 value.get("name")])
            if source_set in sources.keys():
                sources.pop(source_set)

    for _, src in sources.items():
        status_code = 1
        source_name = src.get("source_name")  # pragma: no mutate
        table_name = src.get("table_name")  # pragma: no mutate
        print(f"Missing source `{source_name}.{table_name}`")

    for missing_ref in models:
        status_code = 1
        print(f"Missing model (ref) {missing_ref}")

    return status_code, models, sources
예제 #8
0
def check_test_cnt(paths: Sequence[str], manifest: Dict[str, Any],
                   test_cnt: int) -> int:
    status_code = 0
    sqls = get_filenames(paths, [".sql"])
    filenames = set(sqls.keys())

    # get manifest nodes that pre-commit found as changed
    models = get_models(manifest, filenames)

    for model in models:
        tests = list(get_tests(manifest=manifest, obj=model))
        model_test_cnt = len(tests)
        if model_test_cnt < test_cnt:
            status_code = 1
            print(
                f"{model.model_name}: "
                f"has only {model_test_cnt} tests, but {test_cnt} are required.",
            )
    return status_code
예제 #9
0
def validate_tags(paths: Sequence[str], manifest: Dict[str, Any],
                  tags: Sequence[str]) -> int:
    status_code = 0
    sqls = get_filenames(paths, [".sql"])
    filenames = set(sqls.keys())

    # get manifest nodes that pre-commit found as changed
    models = get_models(manifest, filenames)
    for model in models:
        # tags can be specified only from manifest
        model_tags = set(model.node.get("tags", []))
        valid_tags = set(tags)
        if not model_tags.issubset(valid_tags):
            status_code = 1
            list_diff = list(model_tags.difference(valid_tags))
            result = "\n- ".join(list_diff)  # pragma: no mutate
            print(
                f"{model.node.get('original_file_path', model.filename)}: "
                f"has invalid tags:\n- {result}", )
    return status_code
def has_properties_file(
    paths: Sequence[str], manifest: Dict[str, Any]
) -> Tuple[int, Set[str]]:
    status_code = 0
    sqls = get_filenames(paths, [".sql"])
    filenames = set(sqls.keys())

    # get manifest nodes that pre-commit found as changed
    models = get_models(manifest, filenames)
    # convert to sets
    in_models = {model.filename for model in models if model.node.get("patch_path")}
    missing = filenames.difference(in_models)

    for model in missing:
        status_code = 1
        print(
            f"{sqls.get(model)}: "
            f"does not have model properties defined in any .yml file.",
        )
    return status_code, missing
예제 #11
0
def test_get_filenames():
    result = get_filenames(["aa/bb/cc.sql", "bb/ee.sql"])
    assert result == {"cc": Path("aa/bb/cc.sql"), "ee": Path("bb/ee.sql")}