Esempio n. 1
0
def test_upsert_qualitycheck(conn: Connector):
    from sqlalchemy import Column, DateTime, text, UniqueConstraint
    from sqlalchemy.dialects.postgresql import TEXT, INTEGER, BIGINT

    class A(DQBase):
        id = Column(BIGINT, primary_key=True)
        name = Column(TEXT, nullable=False)
        price = Column(INTEGER)
        created_at = Column(
            DateTime(timezone=True),
            server_default=text("NOW()"),
            nullable=False,
            index=True,
        )

        __tablename__ = "my_table"
        __table_args__ = (UniqueConstraint("name", name=f"unique_constraint_test",),)

    conn.ensure_table(A.__table__)
    instance = A(name="hello", price=13)

    conn.upsert(
        objs=[instance,]
    )

    # check if inserted
    s = conn.make_session()
    row = s.query(A.__table__).all()
    s.expunge_all()
    s.commit()
    assert len(row) == 1
    assert row[0].price == 13

    # change data and insert again - should upsert
    instance.price = 42
    conn.upsert(
        objs=[instance,]
    )

    row = s.query(A.__table__).all()
    s.expunge_all()
    s.commit()
    assert len(row) == 1
    assert row[0].price == 42

    s.close()
Esempio n. 2
0
class ConsistencyChecker:
    """
    Checks consistency of the sync between two tables.
    """

    model_cls = ConsistencyCheck

    COUNT = "count"
    DIFF = "difference"

    def __init__(self, left_conn_uri_or_engine, right_conn_uri_or_engine=None):
        self.left_conn_uri_or_engine = left_conn_uri_or_engine
        self.left_conn = Connector(left_conn_uri_or_engine)
        if right_conn_uri_or_engine is None:
            self.right_conn_uri_or_engine = self.left_conn_uri_or_engine
            self.right_conn = self.left_conn
        else:
            self.right_conn_uri_or_engine = right_conn_uri_or_engine
            self.right_conn = Connector(right_conn_uri_or_engine)

    def run(
        self,
        method: str,
        left_check_table: Dict,
        right_check_table: Dict,
        result_table: Optional[Dict] = None,
        columns: Optional[List[str]] = None,
        time_filter: Optional[Union[str, List[Dict], TimeFilter]] = None,
        left_custom_sql: str = None,
        right_custom_sql: str = None,
        context: Optional[Dict] = None,
        example_selector: ExampleSelector = default_example_selector,
    ) -> Union[CheckResult, ConsistencyCheck]:
        if left_custom_sql and right_custom_sql:
            if columns or time_filter:
                raise ValueError(
                    "When using custom sqls you cannot change 'columns' or 'time_filter' attribute"
                )

        time_filter = parse_time_filter(time_filter)

        left_check_table = Table(**left_check_table)
        right_check_table = Table(**right_check_table)
        context = self.get_context(left_check_table, right_check_table,
                                   context)

        result = self.do_consistency_check(
            method,
            columns,
            time_filter,
            left_check_table,
            right_check_table,
            left_custom_sql,
            right_custom_sql,
            context,
            example_selector,
        )

        if result_table:
            result_table = ResultTable(**result_table,
                                       model_cls=self.model_cls)
            quality_check_class = create_default_check_class(result_table)
            self.right_conn.ensure_table(quality_check_class.__table__)
            self.upsert(quality_check_class, result)
            return result

        obj = CheckResult()
        obj.init_row_consistency(**result)
        return obj

    @staticmethod
    def get_context(
        left_check_table: Table,
        right_check_table: Table,
        context: Optional[Dict] = None,
    ) -> Dict:
        """
        Construct context to pass to executors. User context overrides defaults.
        """
        ctx_defaults = {
            "left_table_fullname": left_check_table.fullname,
            "right_table_fullname": right_check_table.fullname,
            "task_ts": datetime.now(),
        }
        if context:
            ctx_defaults.update(context)
        return ctx_defaults

    def do_consistency_check(
        self,
        method: str,
        columns: Optional[List[str]],
        time_filter: Optional[TimeFilter],
        left_check_table: Table,
        right_check_table: Table,
        left_sql: str = None,
        right_sql: str = None,
        context: Dict = None,
        example_selector: ExampleSelector = default_example_selector,
    ):
        """
        Run quality check for all rules. Use `qc_cls` to construct objects that will be inserted
        afterwards.
        """
        if not left_sql or not right_sql:
            if method == self.COUNT:
                if columns:
                    column = f"count({', '.join(columns)})"
                else:
                    column = "count(*)"
            elif method == self.DIFF:
                if columns:
                    column = ", ".join(columns)
                else:
                    # List the columns explicitly in case column order of compared tables is not the same.
                    column = ", ".join(
                        sorted(
                            self.right_conn.get_column_names(
                                right_check_table.fullname)))
            else:
                raise NotImplementedError(f"Method {method} not implemented")

        if not left_sql:
            left_sql = self.construct_default_query(left_check_table.fullname,
                                                    column, time_filter,
                                                    context)
        left_result = self.run_query(self.left_conn, left_sql, context)
        if not right_sql:
            right_sql = self.construct_default_query(
                right_check_table.fullname, column, time_filter, context)
        right_result = self.run_query(self.right_conn, right_sql, context)

        results = self.compare_results(left_result, right_result, method,
                                       example_selector)

        return {
            "check": {
                "type": method,
                "description": "",
                "name": "consistency",
            },
            "results": results,
            "left_table_name": left_check_table.fullname,
            "right_table_name": right_check_table.fullname,
            "time_filter": time_filter,
            "context": context,
        }

    def compare_results(self, left_result, right_result, method,
                        example_selector):
        if method == self.COUNT:
            left_count = left_result[0][0]
            right_count = right_result[0][0]
            passed = min(left_count, right_count)
            failed = (left_count - passed) - (right_count - passed)
            return AggregatedResult(
                total_records=max(left_count, right_count),
                failed=failed,
                passed=passed,
            )

        elif method == self.DIFF:
            left_set = set(left_result)
            right_set = set(right_result)
            common = left_set.intersection(right_set)
            passed = len(common)
            failed = (len(left_set) - len(common)) + (len(right_set) -
                                                      len(common))
            failed_examples = example_selector.select_examples(
                left_set.symmetric_difference(right_set))
            return AggregatedResult(
                total_records=failed + passed,
                failed=failed,
                passed=passed,
                failed_example=list(failed_examples),
            )

        else:
            raise NotImplementedError(f"Method {method} not implemented")

    def construct_default_query(
        self,
        table_name: str,
        column: str,
        time_filter: Optional[TimeFilter],
        context: Dict,
    ):
        if time_filter:
            if context.get("task_ts"):
                time_filter.now = context["task_ts"]
            time_filter = time_filter.sql
        query = f"""
            SELECT {column}
            FROM {table_name}
            {f'WHERE {time_filter}' if time_filter else ''}
        """
        return query

    def render_sql(self, sql, context):
        """
        Replace some parameters in query.
        :return str, formatted sql
        """
        rendered = render_jinja_sql(sql, context)
        return rendered

    def run_query(self, conn: Connector, query: str, context):
        query = self.render_sql(query, context)
        logging.debug(query)
        result = [tuple(r.values()) for r in conn.get_records(query)]
        return result

    def upsert(self, dc_cls, result):
        obj = dc_cls()
        obj.init_row(**result)
        self.right_conn.upsert([
            obj,
        ])

    def construct_automatic_time_filter(
        self,
        left_check_table: Dict,
        created_at_column=None,
        updated_at_column=None,
    ) -> TimeFilter:
        left_check_table = Table(**left_check_table)

        if created_at_column is None and updated_at_column is None:
            raise ValueError(
                "Automatic time filter need at least one time column")

        since_column = updated_at_column or created_at_column
        since_sql = f"SELECT min({since_column}) FROM {left_check_table.fullname}"
        logging.info(since_sql)
        since = self.left_conn.get_records(since_sql).scalar()

        return TimeFilter(
            columns=[
                TimeFilterColumn(since_column, since=since),
            ],
            conjunction=TimeFilterConjunction.AND,
        )
Esempio n. 3
0
class ConsistencyChecker:
    """
    Checks consistency of the sync between two tables.
    """

    model_cls = ConsistencyCheck

    COUNT = "count"
    DIFF = "difference"

    def __init__(self, left_conn_uri_or_engine, right_conn_uri_or_engine=None):
        self.left_conn_uri_or_engine = left_conn_uri_or_engine
        self.left_conn = Connector(left_conn_uri_or_engine)
        if right_conn_uri_or_engine is None:
            self.right_conn_uri_or_engine = self.left_conn_uri_or_engine
            self.right_conn = self.left_conn
        else:
            self.right_conn_uri_or_engine = right_conn_uri_or_engine
            self.right_conn = Connector(right_conn_uri_or_engine)

    def run(
        self,
        method: str,
        left_check_table: Dict,
        right_check_table: Dict,
        result_table: Optional[Dict] = None,
        columns: Optional[List[str]] = None,
        time_filter: str = TIME_FILTER_DEFAULT,
        left_custom_sql: str = None,
        right_custom_sql: str = None,
        context: Optional[Dict] = None,
    ) -> Union[CheckResult, ConsistencyCheck]:
        if left_custom_sql and right_custom_sql:
            if columns or time_filter:
                raise ValueError(
                    "When using custom sqls you cannot change 'columns' or 'time_filter' attribute"
                )

        left_check_table = Table(**left_check_table)
        right_check_table = Table(**right_check_table)
        context = self.get_context(left_check_table, right_check_table,
                                   context)

        result = self.do_consistency_check(
            method,
            columns,
            time_filter,
            left_check_table,
            right_check_table,
            left_custom_sql,
            right_custom_sql,
            context,
        )

        if result_table:
            result_table = ResultTable(**result_table,
                                       model_cls=self.model_cls)
            quality_check_class = create_default_check_class(result_table)
            self.right_conn.ensure_table(quality_check_class.__table__)
            self.upsert(quality_check_class, result)
            return result

        obj = CheckResult()
        obj.init_row_consistency(**result)
        return obj

    @staticmethod
    def get_context(
        left_check_table: Table,
        right_check_table: Table,
        context: Optional[Dict] = None,
    ) -> Dict:
        """
        Construct context to pass to executors. User context overrides defaults.
        """
        ctx_defaults = {
            "left_table_fullname": left_check_table.fullname,
            "right_table_fullname": right_check_table.fullname,
            "task_ts": datetime.now(),
        }
        if context:
            ctx_defaults.update(context)
        return ctx_defaults

    def do_consistency_check(
        self,
        method: str,
        columns: Optional[List[str]],
        time_filter: str,
        left_check_table: Table,
        right_check_table: Table,
        left_sql: str = None,
        right_sql: str = None,
        context: Dict = None,
    ):
        """
        Run quality check for all rules. Use `qc_cls` to construct objects that will be inserted
        afterwards.
        """
        if not left_sql or not right_sql:
            if method == self.COUNT:
                if columns:
                    column = f"count({', '.join(columns)})"
                else:
                    column = "count(*)"
            elif method == self.DIFF:
                if columns:
                    column = ", ".join(columns)
                else:
                    # List the columns explicitly in case column order of compared tables is not the same.
                    column = ", ".join(
                        sorted(
                            self.right_conn.get_column_names(
                                right_check_table.fullname)))
            else:
                raise NotImplementedError(f"Method {method} not implemented")

        if not left_sql:
            left_sql = self.construct_default_query(left_check_table.fullname,
                                                    column, time_filter,
                                                    context)
        left_result = self.run_query(self.left_conn, left_sql, context)
        if not right_sql:
            right_sql = self.construct_default_query(
                right_check_table.fullname, column, time_filter, context)
        right_result = self.run_query(self.right_conn, right_sql, context)

        valid, passed, failed = self.compare_results(left_result, right_result,
                                                     method)

        return {
            "check": {
                "type": method,
                "description": "",
                "name": "consistency",
                "passed": passed,
                "failed": failed,
            },
            "status": "valid" if valid else "invalid",
            "left_table_name": left_check_table.fullname,
            "right_table_name": right_check_table.fullname,
            "time_filter": time_filter,
            "context": context,
        }

    def compare_results(self, left_result, right_result, method):
        if method == self.COUNT:
            left_count = left_result[0][0]
            right_count = right_result[0][0]
            passed = min(left_count, right_count)
            failed = (left_count - passed) - (right_count - passed)
            return failed == 0, passed, failed
        elif method == self.DIFF:
            left_set = set(left_result)
            right_set = set(right_result)
            common = left_set.intersection(right_set)
            passed = len(common)
            failed = (len(left_set) - len(common)) + (len(right_set) -
                                                      len(common))
            return failed == 0, passed, failed
        else:
            raise NotImplementedError(f"Method {method} not implemented")

    def construct_default_query(self, table_name, column, time_filter,
                                context):
        time_filter = compose_where_time_filter(time_filter,
                                                context["task_ts"])
        query = f"""
            SELECT {column}
            FROM {table_name}
            {f'WHERE {time_filter}' if time_filter else ''}
        """
        return query

    def render_sql(self, sql, context):
        """
        Replace some parameters in query.
        :return str, formatted sql
        """
        t = jinja2.Template(sql)
        rendered = t.render(**context)
        return rendered

    def run_query(self, conn: Connector, query: str, context):
        query = self.render_sql(query, context)
        logging.info(query)
        result = [r._row for r in conn.get_records(query)]
        return result

    def upsert(self, dc_cls, result):
        obj = dc_cls()
        obj.init_row(**result)
        self.right_conn.upsert([
            obj,
        ])
Esempio n. 4
0
class ContessaRunner:
    model_cls = QualityCheck

    def __init__(self, conn_uri_or_engine, special_qc_map=None):
        self.conn_uri_or_engine = conn_uri_or_engine
        self.conn = Connector(conn_uri_or_engine)

        # todo - allow cfg
        self.special_qc_map = special_qc_map or {}

    def run(
        self,
        raw_rules: List[Dict[str, str]],
        check_table: Dict,
        result_table: Optional[
            Dict] = None,  # todo - docs for quality name, maybe defaults..
        context: Optional[Dict] = None,
    ) -> Union[CheckResult, QualityCheck]:
        check_table = Table(**check_table)
        context = self.get_context(check_table, context)

        normalized_rules = self.normalize_rules(raw_rules)
        refresh_executors(check_table, self.conn, context)

        if result_table:
            result_table = ResultTable(**result_table,
                                       model_cls=self.model_cls)
            quality_check_class = self.get_quality_check_class(result_table)
            self.conn.ensure_table(quality_check_class.__table__)
        else:
            quality_check_class = CheckResult

        rules = self.build_rules(normalized_rules)
        objs = self.do_quality_checks(quality_check_class, rules, context)

        if result_table:
            self.conn.upsert(objs)
        return objs

    @staticmethod
    def get_context(check_table: Table,
                    context: Optional[Dict] = None) -> Dict:
        """
        Construct context to pass to executors. User context overrides defaults.
        """
        ctx_defaults = {
            "table_fullname": check_table.fullname,
            "task_ts": datetime.now(),  # todo - is now() ok ?
        }
        if context:
            ctx_defaults.update(context)
        return ctx_defaults

    def normalize_rules(self, raw_rules):
        return RuleNormalizer.normalize(raw_rules)

    def do_quality_checks(self,
                          dq_cls,
                          rules: List[Rule],
                          context: Dict = None):
        """
        Run quality check for all rules. Use `qc_cls` to construct objects that will be inserted
        afterwards.
        """
        ret = []
        for rule in rules:
            obj = self.apply_rule(context, dq_cls, rule)
            ret.append(obj)
        return ret

    def apply_rule(self, context, dq_cls, rule):
        e = get_executor(rule)
        logging.info(f"Executing rule `{rule}`.")
        results = e.execute(rule)
        obj = dq_cls()
        obj.init_row(rule, results, self.conn, context)
        return obj

    @staticmethod
    def build_rules(normalized_rules):
        """
        Construct rules classes from user definition that are dicts.
        Raises if there are bad arguments for a certain rule.
        :return: list of Rule objects
        """
        ret = []
        for rule_def in normalized_rules:
            rule_cls = ContessaRunner.pick_rule_cls(rule_def)
            try:
                r = rule_cls(**rule_def)
            except Exception as e:
                logging.error(f"For rule `{rule_cls.__name__}`. {e.args[0]}")
                raise
            else:
                ret.append(r)
        return ret

    @staticmethod
    def pick_rule_cls(rule_def):
        """
        Get rule class based on its type that was input by user.
        :param rule_def: dict
        :return: Rule class
        """
        return get_rule_cls(rule_def["type"])

    def get_quality_check_class(self, result_table: ResultTable):
        """
        QualityCheck can be different, e.g. `special_table` has specific quality_check.
        Or kind of generic one that computes number of passed/failed objects etc.
        So determine if is special or not and return the class.
        :return: QualityCheck cls
        """
        special_checks = self.special_qc_map.keys()
        if result_table.fullname in special_checks:
            quality_check_class = self.special_qc_map[result_table.fullname]
            logging.info(
                f"Using {quality_check_class.__name__} as quality check class."
            )
        else:
            quality_check_class = create_default_check_class(result_table)
            logging.info("Using default QualityCheck class.")
        return quality_check_class