示例#1
0
    def setUp(self):
        """
        Init a temporary table with some data.
        """
        self.left_table_name = "raw_booking"
        self.right_table_name = "booking"
        self.ts_nodash = (FakedDatetime.now().isoformat().replace("-",
                                                                  "").replace(
                                                                      ":", ""))
        self.now = FakedDatetime.now()

        sql = [
            "DROP SCHEMA if exists tmp CASCADE;",
            "CREATE SCHEMA IF NOT EXISTS tmp;",
            "CREATE SCHEMA IF NOT EXISTS hello;",
            f"""
                    CREATE TABLE IF NOT EXISTS tmp.{self.left_table_name}(
                      id SERIAL PRIMARY KEY,
                      src text,
                      dst text,
                      price int,
                      turnover_after_refunds double precision,
                      initial_price double precision,
                      created_at timestamptz
                    )
                """,
            f"""
                    CREATE TABLE IF NOT EXISTS hello.{self.right_table_name}(
                      id SERIAL PRIMARY KEY,
                      src text,
                      dst text,
                      price int,
                      turnover_after_refunds double precision,
                      initial_price double precision,
                      created_at timestamptz
                    )
                """,
            f"""
                INSERT INTO tmp.{self.left_table_name}
                    (src, dst, price, turnover_after_refunds, initial_price, created_at)
                VALUES
                    ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'),
                    (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'),
                    ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00'),
                    ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00')
                """,
            f"""
                INSERT INTO hello.{self.right_table_name}
                    (src, dst, price, turnover_after_refunds, initial_price, created_at)
                VALUES
                    ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'),
                    (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'),
                    ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00')
            """,
        ]
        self.conn = Connector(TEST_DB_URI)
        for s in sql:
            self.conn.execute(s)

        self.consistency_checker = ConsistencyChecker(TEST_DB_URI)
示例#2
0
    def setUp(self):
        """
        Init a temporary table with some data.
        """

        sql = [
            f"DROP SCHEMA IF EXISTS {DATA_QUALITY_SCHEMA} CASCADE;"
            f"CREATE SCHEMA IF NOT EXISTS {DATA_QUALITY_SCHEMA};",
            get_quality_table_creation_script(DATA_QUALITY_SCHEMA,
                                              DATA_QUALITY_TABLE_1),
            get_quality_table_creation_script(DATA_QUALITY_SCHEMA,
                                              DATA_QUALITY_TABLE_2),
            f"""
                    create table {DATA_QUALITY_SCHEMA}.{MIGRATION_TABLE}
                        (
                            version_num varchar(32) not null
                                constraint alembic_version_pkc
                                    primary key
                        );
                        INSERT INTO {DATA_QUALITY_SCHEMA}.{MIGRATION_TABLE} (version_num) VALUES ('0.1.4-hash');
                    """,
        ]
        self.conn = Connector(TEST_DB_URI)
        for s in sql:
            self.conn.execute(s)
示例#3
0
    def test_execute_consistency_diff(self):
        """ Test scenario where table diff is done on tables with different column order."""
        self.inconsistent_colums_left = "user"
        self.inconsistent_colums_right = "user_inconsistent"
        self.conn = Connector(TEST_DB_URI)
        self.consistency_checker = ConsistencyChecker(TEST_DB_URI)

        sql = [
            f"""
                CREATE TABLE IF NOT EXISTS tmp.{self.inconsistent_colums_left}(
                  id SERIAL PRIMARY KEY,
                  name text
                )
            """,
            f"""
                CREATE TABLE IF NOT EXISTS tmp.{self.inconsistent_colums_right}(
                  name text,
                  id SERIAL PRIMARY KEY
                )
            """,
            f"""
                INSERT INTO tmp.{self.inconsistent_colums_left}
                VALUES
                    (1, 'John Doe')
                """,
            f"""
                INSERT INTO tmp.{self.inconsistent_colums_right}
                VALUES
                    ('John Doe', 1)
            """,
        ]

        for s in sql:
            self.conn.execute(s)

        self.consistency_checker.run(
            self.consistency_checker.DIFF,
            left_check_table={
                "schema_name": "tmp",
                "table_name": self.inconsistent_colums_left,
            },
            right_check_table={
                "schema_name": "tmp",
                "table_name": self.inconsistent_colums_right,
            },
            result_table={
                "schema_name": "data_quality",
                "table_name": self.result_table_name,
            },
            context={"task_ts": self.now},
        )

        rows = self.conn.get_records(
            f"""
                SELECT * from data_quality.consistency_check_{self.result_table_name}
                order by created_at
            """
        )

        self.assertEqual(rows.fetchone()["status"], "valid")
示例#4
0
 def __init__(self, left_conn_uri_or_engine, right_conn_uri_or_engine=None):
     self.left_conn_uri_or_engine = left_conn_uri_or_engine
     self.left_conn = Connector(left_conn_uri_or_engine)
     if right_conn_uri_or_engine is None:
         self.right_conn_uri_or_engine = self.left_conn_uri_or_engine
         self.right_conn = self.left_conn
     else:
         self.right_conn_uri_or_engine = right_conn_uri_or_engine
         self.right_conn = Connector(right_conn_uri_or_engine)
示例#5
0
 def __init__(self, migrations_map, package_version, url, schema):
     """
     :param migrations_map: map of package versions and their migrations.
     In form of dictionary {'0.1.4':'A', '0.1.5':'B'}
     :param package_version: the version of the package planned to be migrated
     :param url: the database url where the Alembic migration table is present or planned to be created
     :param schema: the database schema where the Alembic migration table is present or planned to be created
     """
     self.versions_migrations = migrations_map
     self.package_version = package_version
     self.url = url
     self.schema = schema
     self.conn = Connector(self.url)
示例#6
0
    def setUp(self):
        """
        Init a temporary table with some data.
        """

        sql = [
            f"DROP SCHEMA IF EXISTS {DATA_QUALITY_SCHEMA} CASCADE;"
            f"CREATE SCHEMA IF NOT EXISTS {DATA_QUALITY_SCHEMA};",
            get_quality_table_creation_script(DATA_QUALITY_SCHEMA,
                                              DATA_QUALITY_TABLE_1),
            get_quality_table_creation_script(DATA_QUALITY_SCHEMA,
                                              DATA_QUALITY_TABLE_2),
        ]
        self.conn = Connector(TEST_DB_URI)
        for s in sql:
            self.conn.execute(s)
示例#7
0
def test_set_medians(conn: Connector, monkeypatch):
    DQBase.metadata.clear()
    qc = create_default_quality_check_class(
        ResultTable(schema_name="data_quality", table_name="t"))
    qc.__table__.create(conn.engine)
    instance = qc()

    conn.execute("""
        insert into data_quality.quality_check_t(failed, passed, task_ts)
        values
          (10, 200, '2018-09-11T13:00:00'),
          (3, 22, '2018-09-10T13:00:00'),
          (11, 110, '2018-09-09T13:00:00'),
          (55, 476, '2018-09-08T13:00:00'),
          (77, 309, '2018-07-12T13:00:00') -- should not be taken
    """)

    monkeypatch.setattr("contessa.models.datetime", FakedDatetime)
    instance.set_medians(conn)

    assert instance.median_30_day_failed == 10.5
    assert instance.median_30_day_passed == 155
示例#8
0
def test_set_medians(conn: Connector, monkeypatch):
    DQBase.metadata.clear()
    qc = create_default_check_class(
        ResultTable(schema_name="data_quality",
                    table_name="t",
                    model_cls=QualityCheck))
    qc.__table__.create(conn.engine)
    instance = qc()

    conn.execute("""
        insert into data_quality.quality_check_t(attribute, rule_name, rule_type, failed, passed, task_ts, time_filter)
        values
          ('a', 'b', 'not_null', 10, 200, '2018-09-11T13:00:00', 'not_set'),
          ('a', 'b', 'not_null', 3, 22, '2018-09-10T13:00:00', 'not_set'),
          ('a', 'b', 'not_null', 11, 110, '2018-09-09T13:00:00', 'not_set'),
          ('a', 'b', 'not_null', 55, 476, '2018-09-08T13:00:00', 'not_set'),
          ('a', 'b', 'not_null', 77, 309, '2018-07-12T13:00:00', 'not_set') -- should not be taken
    """)

    monkeypatch.setattr("contessa.models.datetime", FakedDatetime)
    instance.set_medians(conn)

    assert instance.median_30_day_failed == 10.5
    assert instance.median_30_day_passed == 155
示例#9
0
    def apply(self, conn: Connector):
        """
        Execute a formatted sql. Check if it returns list of booleans that is needed
        to do a quality check. If yes, return pd.Series.
        :return: pd.Series
        """
        sql = self.sql_with_where
        results = [r for r in conn.get_records(sql)
                   ]  # returns generator, so get it to memory

        is_list_of_bool = all((len(r) == 1
                               and isinstance(r[0], (bool, type(None)))
                               for r in results))
        if not is_list_of_bool:
            raise ValueError(
                f"Your query for rule `{self.name}` does not return list of booleans or Nones."
            )
        return pd.Series([bool(r[0]) for r in results])
示例#10
0
def test_upsert_qualitycheck(conn: Connector):
    from sqlalchemy import Column, DateTime, text, UniqueConstraint
    from sqlalchemy.dialects.postgresql import TEXT, INTEGER, BIGINT

    class A(DQBase):
        id = Column(BIGINT, primary_key=True)
        name = Column(TEXT, nullable=False)
        price = Column(INTEGER)
        created_at = Column(
            DateTime(timezone=True),
            server_default=text("NOW()"),
            nullable=False,
            index=True,
        )

        __tablename__ = "my_table"
        __table_args__ = (UniqueConstraint("name", name=f"unique_constraint_test",),)

    conn.ensure_table(A.__table__)
    instance = A(name="hello", price=13)

    conn.upsert(
        objs=[instance,]
    )

    # check if inserted
    s = conn.make_session()
    row = s.query(A.__table__).all()
    s.expunge_all()
    s.commit()
    assert len(row) == 1
    assert row[0].price == 13

    # change data and insert again - should upsert
    instance.price = 42
    conn.upsert(
        objs=[instance,]
    )

    row = s.query(A.__table__).all()
    s.expunge_all()
    s.commit()
    assert len(row) == 1
    assert row[0].price == 42

    s.close()
示例#11
0
def conn():
    h = Connector(TEST_DB_URI)

    schemas = ["tmp", "temporary", "data_quality", "raw"]
    create_queries = [f"create schema if not exists {s}" for s in schemas]
    drop_queries = [f"drop schema if exists {s} cascade" for s in schemas]

    for c in create_queries:
        h.execute(c)

    yield h

    for d in drop_queries:
        h.execute(d)
示例#12
0
    def set_medians(self, conn: Connector, days=30):
        """
        Calculate median of passed/failed quality checks from last 30 days.
        """
        now = datetime.today().date()
        past = now - timedelta(days=days)
        cls = self.__class__

        session = conn.make_session()
        checks = (session.query(cls.failed, cls.passed).filter(
            and_(cls.task_ts <= str(now), cls.task_ts >= str(past))).all())
        session.expunge_all()
        session.commit()
        session.close()

        failed = [ch.failed for ch in checks]
        self.median_30_day_failed = median(failed) if failed else None

        passed = [ch.passed for ch in checks]
        self.median_30_day_passed = median(passed) if passed else None
示例#13
0
class TestReturnResults(unittest.TestCase):
    def setUp(self):
        """
        Init a temporary table with some data.
        """
        self.left_table_name = "raw_booking"
        self.right_table_name = "booking"
        self.ts_nodash = (FakedDatetime.now().isoformat().replace("-",
                                                                  "").replace(
                                                                      ":", ""))
        self.now = FakedDatetime.now()

        sql = [
            "DROP SCHEMA if exists tmp CASCADE;",
            "CREATE SCHEMA IF NOT EXISTS tmp;",
            "CREATE SCHEMA IF NOT EXISTS hello;",
            f"""
                    CREATE TABLE IF NOT EXISTS tmp.{self.left_table_name}(
                      id SERIAL PRIMARY KEY,
                      src text,
                      dst text,
                      price int,
                      turnover_after_refunds double precision,
                      initial_price double precision,
                      created_at timestamptz
                    )
                """,
            f"""
                    CREATE TABLE IF NOT EXISTS hello.{self.right_table_name}(
                      id SERIAL PRIMARY KEY,
                      src text,
                      dst text,
                      price int,
                      turnover_after_refunds double precision,
                      initial_price double precision,
                      created_at timestamptz
                    )
                """,
            f"""
                INSERT INTO tmp.{self.left_table_name}
                    (src, dst, price, turnover_after_refunds, initial_price, created_at)
                VALUES
                    ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'),
                    (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'),
                    ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00'),
                    ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00')
                """,
            f"""
                INSERT INTO hello.{self.right_table_name}
                    (src, dst, price, turnover_after_refunds, initial_price, created_at)
                VALUES
                    ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'),
                    (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'),
                    ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00')
            """,
        ]
        self.conn = Connector(TEST_DB_URI)
        for s in sql:
            self.conn.execute(s)

        self.consistency_checker = ConsistencyChecker(TEST_DB_URI)

    def tearDown(self):
        """
        Drop all created tables.
        """
        self.conn.execute(f"DROP schema tmp CASCADE;")
        self.conn.execute(f"DROP schema hello CASCADE;")
        DQBase.metadata.clear()

    @mock.patch("contessa.executor.datetime", FakedDatetime)
    def test_execute_consistency(self):
        result = self.consistency_checker.run(
            self.consistency_checker.COUNT,
            left_check_table={
                "schema_name": "tmp",
                "table_name": self.left_table_name
            },
            right_check_table={
                "schema_name": "hello",
                "table_name": self.right_table_name,
            },
            context={"task_ts": self.now},
        )

        self.assertEqual(result.status, "invalid")
        self.assertEqual(result.context["left_table_name"], "tmp.raw_booking")
        self.assertEqual(result.context["right_table_name"], "hello.booking")
示例#14
0
class TestMigrationsResolverInit(unittest.TestCase):
    def setUp(self):
        """
        Init a temporary table with some data.
        """

        sql = [
            f"DROP SCHEMA IF EXISTS {DATA_QUALITY_SCHEMA} CASCADE;"
            f"CREATE SCHEMA IF NOT EXISTS {DATA_QUALITY_SCHEMA};",
            get_quality_table_creation_script(DATA_QUALITY_SCHEMA,
                                              DATA_QUALITY_TABLE_1),
            get_quality_table_creation_script(DATA_QUALITY_SCHEMA,
                                              DATA_QUALITY_TABLE_2),
        ]
        self.conn = Connector(TEST_DB_URI)
        for s in sql:
            self.conn.execute(s)

    def tearDown(self):
        """
        Drop all created tables.
        """
        self.conn.execute(f"DROP schema {DATA_QUALITY_SCHEMA} CASCADE;")
        DQBase.metadata.clear()

    def test_migration_table_exists_init(self):
        versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"}

        m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        migration_table_exists = m.migrations_table_exists()

        assert migration_table_exists is False

    def test_get_current_migration_init(self):
        versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"}

        m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        current = m.get_applied_migration()

        assert current is None

    def test_is_on_head_init(self):
        versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"}

        m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        is_on_head = m.is_on_head()

        is_on_head is False

    def test_get_migrations_to_head__is_before_init(self):
        versions_migrations = {
            "0.1.2": "0.1.2-hash",
            "0.1.3": "0.1.3-hash",
            "0.1.4": "0.1.4-hash",
            "0.1.5": "0.1.5-hash",
            "0.1.6": "0.1.6-hash",
            "0.1.7": "0.1.7-hash",
        }

        m = MigrationsResolver(versions_migrations, "0.1.7", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        migrations = m.get_migration_to_head()
        assert migrations[0] is "upgrade"
        assert migrations[1] is "0.1.7-hash"
示例#15
0
class ContessaRunner:
    model_cls = QualityCheck

    def __init__(self, conn_uri_or_engine, special_qc_map=None):
        self.conn_uri_or_engine = conn_uri_or_engine
        self.conn = Connector(conn_uri_or_engine)

        # todo - allow cfg
        self.special_qc_map = special_qc_map or {}

    def run(
        self,
        raw_rules: List[Dict[str, str]],
        check_table: Dict,
        result_table: Optional[
            Dict] = None,  # todo - docs for quality name, maybe defaults..
        context: Optional[Dict] = None,
    ) -> Union[CheckResult, QualityCheck]:
        check_table = Table(**check_table)
        context = self.get_context(check_table, context)

        normalized_rules = self.normalize_rules(raw_rules)
        refresh_executors(check_table, self.conn, context)

        if result_table:
            result_table = ResultTable(**result_table,
                                       model_cls=self.model_cls)
            quality_check_class = self.get_quality_check_class(result_table)
            self.conn.ensure_table(quality_check_class.__table__)
        else:
            quality_check_class = CheckResult

        rules = self.build_rules(normalized_rules)
        objs = self.do_quality_checks(quality_check_class, rules, context)

        if result_table:
            self.conn.upsert(objs)
        return objs

    @staticmethod
    def get_context(check_table: Table,
                    context: Optional[Dict] = None) -> Dict:
        """
        Construct context to pass to executors. User context overrides defaults.
        """
        ctx_defaults = {
            "table_fullname": check_table.fullname,
            "task_ts": datetime.now(),  # todo - is now() ok ?
        }
        if context:
            ctx_defaults.update(context)
        return ctx_defaults

    def normalize_rules(self, raw_rules):
        return RuleNormalizer.normalize(raw_rules)

    def do_quality_checks(self,
                          dq_cls,
                          rules: List[Rule],
                          context: Dict = None):
        """
        Run quality check for all rules. Use `qc_cls` to construct objects that will be inserted
        afterwards.
        """
        ret = []
        for rule in rules:
            obj = self.apply_rule(context, dq_cls, rule)
            ret.append(obj)
        return ret

    def apply_rule(self, context, dq_cls, rule):
        e = get_executor(rule)
        logging.info(f"Executing rule `{rule}`.")
        results = e.execute(rule)
        obj = dq_cls()
        obj.init_row(rule, results, self.conn, context)
        return obj

    @staticmethod
    def build_rules(normalized_rules):
        """
        Construct rules classes from user definition that are dicts.
        Raises if there are bad arguments for a certain rule.
        :return: list of Rule objects
        """
        ret = []
        for rule_def in normalized_rules:
            rule_cls = ContessaRunner.pick_rule_cls(rule_def)
            try:
                r = rule_cls(**rule_def)
            except Exception as e:
                logging.error(f"For rule `{rule_cls.__name__}`. {e.args[0]}")
                raise
            else:
                ret.append(r)
        return ret

    @staticmethod
    def pick_rule_cls(rule_def):
        """
        Get rule class based on its type that was input by user.
        :param rule_def: dict
        :return: Rule class
        """
        return get_rule_cls(rule_def["type"])

    def get_quality_check_class(self, result_table: ResultTable):
        """
        QualityCheck can be different, e.g. `special_table` has specific quality_check.
        Or kind of generic one that computes number of passed/failed objects etc.
        So determine if is special or not and return the class.
        :return: QualityCheck cls
        """
        special_checks = self.special_qc_map.keys()
        if result_table.fullname in special_checks:
            quality_check_class = self.special_qc_map[result_table.fullname]
            logging.info(
                f"Using {quality_check_class.__name__} as quality check class."
            )
        else:
            quality_check_class = create_default_check_class(result_table)
            logging.info("Using default QualityCheck class.")
        return quality_check_class
示例#16
0
class TestMigrationsResolver(unittest.TestCase):
    def setUp(self):
        """
        Init a temporary table with some data.
        """

        sql = [
            f"DROP SCHEMA IF EXISTS {DATA_QUALITY_SCHEMA} CASCADE;"
            f"CREATE SCHEMA IF NOT EXISTS {DATA_QUALITY_SCHEMA};",
            get_quality_table_creation_script(DATA_QUALITY_SCHEMA,
                                              DATA_QUALITY_TABLE_1),
            get_quality_table_creation_script(DATA_QUALITY_SCHEMA,
                                              DATA_QUALITY_TABLE_2),
            f"""
                    create table {DATA_QUALITY_SCHEMA}.{MIGRATION_TABLE}
                        (
                            version_num varchar(32) not null
                                constraint alembic_version_pkc
                                    primary key
                        );
                        INSERT INTO {DATA_QUALITY_SCHEMA}.{MIGRATION_TABLE} (version_num) VALUES ('0.1.4-hash');
                    """,
        ]
        self.conn = Connector(TEST_DB_URI)
        for s in sql:
            self.conn.execute(s)

    def tearDown(self):
        """
        Drop all created tables.
        """
        # self.conn.execute(f"DROP schema data_quality_test CASCADE;")
        DQBase.metadata.clear()

    def test_schema_exists(self):
        versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"}

        m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        schema_exists = m.schema_exists()

        assert schema_exists

        m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL,
                               "not_existing_schema")
        schema_exists = m.schema_exists()

        assert schema_exists is False

    def test_migration_table_exists(self):
        versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"}

        m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        migration_table_exists = m.migrations_table_exists()

        assert migration_table_exists

    def test_get_current_migration(self):
        versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"}

        m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        current = m.get_applied_migration()

        assert current == "0.1.4-hash"

    def test_is_on_head(self):
        versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"}

        m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        is_on_head = m.is_on_head()

        assert is_on_head

    def test_is_on_head_no_on_head(self):
        versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"}

        m = MigrationsResolver(versions_migrations, "0.1.5", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        is_on_head = m.is_on_head()

        assert is_on_head is False

    def test_is_on_head_with_fallback(self):
        versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.6": "0.1.6-hash"}

        m = MigrationsResolver(versions_migrations, "0.1.5", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        is_on_head = m.is_on_head()

        assert is_on_head

    def test_get_migrations_to_head__already_on_head(self):
        versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"}

        m = MigrationsResolver(versions_migrations, "0.1.4", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        migrations = m.get_migration_to_head()
        assert migrations is None

    def test_get_migrations_to_head__package_greather_than_map_max(self):
        versions_migrations = {"0.1.4": "0.1.4-hash", "0.1.5": "0.1.5-hash"}

        m = MigrationsResolver(versions_migrations, "0.1.6", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        migrations = m.get_migration_to_head()
        assert migrations[0] is "upgrade"
        assert migrations[1] is "0.1.5-hash"

    def test_get_migrations_to_head__is_down_from_head(self):
        versions_migrations = {
            "0.1.2": "0.1.2-hash",
            "0.1.3": "0.1.3-hash",
            "0.1.4": "0.1.4-hash",
            "0.1.5": "0.1.5-hash",
            "0.1.6": "0.1.6-hash",
            "0.1.7": "0.1.7-hash",
        }

        m = MigrationsResolver(versions_migrations, "0.1.7", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        migrations = m.get_migration_to_head()
        assert migrations[0] is "upgrade"
        assert migrations[1] is "0.1.7-hash"

    def test_get_migrations_to_head__is_down_from_head_with_fallback(self):
        versions_migrations = {
            "0.1.2": "0.1.2-hash",
            "0.1.3": "0.1.3-hash",
            "0.1.4": "0.1.4-hash",
            "0.1.5": "0.1.5-hash",
            "0.1.8": "0.1.8-hash",
            "0.1.9": "0.1.9-hash",
        }

        m = MigrationsResolver(versions_migrations, "0.1.7", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        migrations = m.get_migration_to_head()
        assert migrations[0] is "upgrade"
        assert migrations[1] is "0.1.5-hash"

    def test_get_migrations_to_head__is_up_from_head(self):
        versions_migrations = {
            "0.1.2": "0.1.2-hash",
            "0.1.3": "0.1.3-hash",
            "0.1.4": "0.1.4-hash",
            "0.1.5": "0.1.5-hash",
            "0.1.6": "0.1.6-hash",
            "0.1.7": "0.1.7-hash",
        }

        m = MigrationsResolver(versions_migrations, "0.1.2", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        migrations = m.get_migration_to_head()
        assert migrations[0] is "downgrade"
        assert migrations[1] is "0.1.2-hash"

    def test_get_migrations_to_head__is_up_from_head_with_fallback(self):
        versions_migrations = {
            "0.1.1": "0.1.1-hash",
            "0.1.3": "0.1.3-hash",
            "0.1.4": "0.1.4-hash",
            "0.1.5": "0.1.5-hash",
            "0.1.6": "0.1.6-hash",
            "0.1.7": "0.1.7-hash",
        }

        m = MigrationsResolver(versions_migrations, "0.1.2", SQLALCHEMY_URL,
                               DATA_QUALITY_SCHEMA)
        migrations = m.get_migration_to_head()
        assert migrations[0] is "downgrade"
        assert migrations[1] is "0.1.1-hash"
示例#17
0
    def setUp(self):
        """
        Init a temporary table with some data.
        """
        self.table_name = "booking_all_v2"
        self.ts_nodash = (FakedDatetime.now().isoformat().replace("-",
                                                                  "").replace(
                                                                      ":", ""))
        self.tmp_table_name = f"{self.table_name}_{self.ts_nodash}"
        self.now = FakedDatetime.now()

        sql = [
            "DROP SCHEMA if exists tmp CASCADE;",
            "DROP SCHEMA if exists data_quality CASCADE;",
            "CREATE SCHEMA IF NOT EXISTS tmp;",
            "CREATE SCHEMA IF NOT EXISTS data_quality;",
            "CREATE SCHEMA IF NOT EXISTS hello;",
            f"""
                    CREATE TABLE IF NOT EXISTS tmp.{self.table_name}(
                      id SERIAL PRIMARY KEY,
                      src text,
                      dst text,
                      price int,
                      turnover_after_refunds double precision,
                      initial_price double precision,
                      created_at timestamptz
                    )
                """,
            f"""
                    CREATE TABLE IF NOT EXISTS tmp.{self.tmp_table_name}(
                      id SERIAL PRIMARY KEY,
                      src text,
                      dst text,
                      price int,
                      turnover_after_refunds double precision,
                      initial_price double precision,
                      created_at timestamptz
                    )
                """,
            f"""
                    INSERT INTO tmp.{self.tmp_table_name}
                        (src, dst, price, turnover_after_refunds, initial_price, created_at)
                    VALUES
                        ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'),
                        -- this is older than 30 days.
                        -- not in stats when time_filter = `created_at`
                        (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'),
                        ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00'),
                        ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00')
                """,
            f"""
                INSERT INTO tmp.{self.table_name}
                    (src, dst, price, turnover_after_refunds, initial_price, created_at)
                VALUES
                    ('BTS', NULL, 1, 100, 11, '2018-09-12T13:00:00'),
                    -- this is older than 30 days.
                    -- not in stats when time_filter = `created_at`
                    (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T13:00:00'),
                    ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T13:00:00'),
                    ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T13:00:00')
            """,
        ]
        self.conn = Connector(TEST_DB_URI)
        for s in sql:
            self.conn.execute(s)

        self.contessa_runner = ContessaRunner(TEST_DB_URI)
示例#18
0
class TestDataQualityOperator(unittest.TestCase):
    def setUp(self):
        """
        Init a temporary table with some data.
        """
        self.table_name = "booking_all_v2"
        self.ts_nodash = (FakedDatetime.now().isoformat().replace("-",
                                                                  "").replace(
                                                                      ":", ""))
        self.tmp_table_name = f"{self.table_name}_{self.ts_nodash}"
        self.now = FakedDatetime.now()

        sql = [
            "DROP SCHEMA if exists tmp CASCADE;",
            "DROP SCHEMA if exists data_quality CASCADE;",
            "CREATE SCHEMA IF NOT EXISTS tmp;",
            "CREATE SCHEMA IF NOT EXISTS data_quality;",
            "CREATE SCHEMA IF NOT EXISTS hello;",
            f"""
                    CREATE TABLE IF NOT EXISTS tmp.{self.table_name}(
                      id SERIAL PRIMARY KEY,
                      src text,
                      dst text,
                      price int,
                      turnover_after_refunds double precision,
                      initial_price double precision,
                      created_at timestamptz
                    )
                """,
            f"""
                    CREATE TABLE IF NOT EXISTS tmp.{self.tmp_table_name}(
                      id SERIAL PRIMARY KEY,
                      src text,
                      dst text,
                      price int,
                      turnover_after_refunds double precision,
                      initial_price double precision,
                      created_at timestamptz
                    )
                """,
            f"""
                    INSERT INTO tmp.{self.tmp_table_name}
                        (src, dst, price, turnover_after_refunds, initial_price, created_at)
                    VALUES
                        ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'),
                        -- this is older than 30 days.
                        -- not in stats when time_filter = `created_at`
                        (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'),
                        ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00'),
                        ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00')
                """,
            f"""
                INSERT INTO tmp.{self.table_name}
                    (src, dst, price, turnover_after_refunds, initial_price, created_at)
                VALUES
                    ('BTS', NULL, 1, 100, 11, '2018-09-12T13:00:00'),
                    -- this is older than 30 days.
                    -- not in stats when time_filter = `created_at`
                    (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T13:00:00'),
                    ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T13:00:00'),
                    ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T13:00:00')
            """,
        ]
        self.conn = Connector(TEST_DB_URI)
        for s in sql:
            self.conn.execute(s)

        self.contessa_runner = ContessaRunner(TEST_DB_URI)

    def tearDown(self):
        """
        Drop all created tables.
        """
        self.conn.execute(f"DROP schema tmp CASCADE;")
        self.conn.execute(f"DROP schema data_quality CASCADE;")
        self.conn.execute(f"DROP schema hello CASCADE;")
        DQBase.metadata.clear()

    @mock.patch("contessa.executor.datetime", FakedDatetime)
    def test_execute_tmp(self):
        sql = """
            SELECT
              CASE WHEN src = 'BTS' and dst is null THEN false ELSE true END as res
            from {{ table_fullname }}
        """
        rules = [
            {
                "name": "not_null_name",
                "type": "not_null",
                "column": "dst",
                "time_filter": "created_at",
            },
            {
                "name": "gt_name",
                "type": "gt",
                "column": "price",
                "value": 10,
                "time_filter": "created_at",
            },
            {
                "name": "sql_name",
                "type": "sql",
                "sql": sql,
                "column": "src_dst",
                "description": "test sql rule",
            },
            {
                "name": "not_name",
                "type": "not",
                "column": "src",
                "value": "dst"
            },
        ]
        self.contessa_runner.run(
            check_table={
                "schema_name": "tmp",
                "table_name": self.tmp_table_name
            },
            result_table={
                "schema_name": "data_quality",
                "table_name": self.table_name
            },
            raw_rules=rules,
            context={"task_ts": self.now},
        )

        rows = self.conn.get_records(f"""
            SELECT * from data_quality.quality_check_{self.table_name}
            order by created_at
        """).fetchall()
        self.assertEqual(len(rows), 4)

        notnull_rule = rows[0]
        self.assertEqual(notnull_rule["failed"], 1)
        self.assertEqual(notnull_rule["passed"], 2)
        self.assertEqual(notnull_rule["attribute"], "dst")
        self.assertEqual(notnull_rule["task_ts"].timestamp(),
                         self.now.timestamp())

        gt_rule = rows[1]
        self.assertEqual(gt_rule["failed"], 3)
        self.assertEqual(gt_rule["passed"], 0)
        self.assertEqual(gt_rule["attribute"], "price")

        sql_rule = rows[2]
        self.assertEqual(sql_rule["failed"], 1)
        self.assertEqual(sql_rule["passed"], 3)
        self.assertEqual(sql_rule["attribute"], "src_dst")

        not_column_rule = rows[3]
        self.assertEqual(not_column_rule["failed"], 1)
        self.assertEqual(not_column_rule["passed"], 3)
        self.assertEqual(not_column_rule["attribute"], "src")

    @mock.patch("contessa.executor.datetime", FakedDatetime)
    def test_execute_dst(self):
        sql = """
            SELECT
              CASE WHEN src = 'BTS' AND dst is null THEN false ELSE true END as res
            FROM {{ table_fullname }}
            WHERE created_at BETWEEN
                timestamptz '{{task_ts}}' - INTERVAL '1 hour'
                AND
                timestamptz '{{task_ts}}' + INTERVAL '1 hour'
        """
        rules = [
            {
                "name": "not_null_name",
                "type": "not_null",
                "column": "dst",
                "time_filter": "created_at",
            },
            {
                "name": "sql_name",
                "type": "sql",
                "sql": sql,
                "column": "src_dst",
                "description": "test sql rule",
            },
        ]
        self.contessa_runner.run(
            check_table={
                "schema_name": "tmp",
                "table_name": self.tmp_table_name
            },
            result_table={
                "schema_name": "data_quality",
                "table_name": self.table_name
            },
            raw_rules=rules,
            context={"task_ts": self.now},
        )

        rows = self.conn.get_records(f"""
            SELECT * from data_quality.quality_check_{self.table_name}
            order by created_at
        """).fetchall()
        self.assertEqual(len(rows), 2)

        notnull_rule = rows[0]
        self.assertEqual(notnull_rule["failed"], 1)
        self.assertEqual(notnull_rule["passed"], 2)
        self.assertEqual(notnull_rule["attribute"], "dst")
        self.assertEqual(notnull_rule["task_ts"].timestamp(),
                         self.now.timestamp())

        sql_rule = rows[1]
        self.assertEqual(sql_rule["failed"], 1)
        self.assertEqual(sql_rule["passed"], 0)
        self.assertEqual(sql_rule["attribute"], "src_dst")

    def test_different_schema(self):
        rules = [{
            "name": "not_nul_name",
            "type": "not_null",
            "column": "dst",
            "time_filter": "created_at",
        }]
        self.contessa_runner.run(
            check_table={
                "schema_name": "tmp",
                "table_name": self.tmp_table_name
            },
            result_table={
                "schema_name": "hello",
                "table_name": "abcde"
            },
            raw_rules=rules,
            context={"task_ts": self.now},
        )
        rows = self.conn.get_records(f"""
                SELECT 1 from hello.quality_check_abcde
            """).fetchall()
        self.assertEqual(len(rows), 1)
示例#19
0
class ConsistencyChecker:
    """
    Checks consistency of the sync between two tables.
    """

    model_cls = ConsistencyCheck

    COUNT = "count"
    DIFF = "difference"

    def __init__(self, left_conn_uri_or_engine, right_conn_uri_or_engine=None):
        self.left_conn_uri_or_engine = left_conn_uri_or_engine
        self.left_conn = Connector(left_conn_uri_or_engine)
        if right_conn_uri_or_engine is None:
            self.right_conn_uri_or_engine = self.left_conn_uri_or_engine
            self.right_conn = self.left_conn
        else:
            self.right_conn_uri_or_engine = right_conn_uri_or_engine
            self.right_conn = Connector(right_conn_uri_or_engine)

    def run(
        self,
        method: str,
        left_check_table: Dict,
        right_check_table: Dict,
        result_table: Optional[Dict] = None,
        columns: Optional[List[str]] = None,
        time_filter: str = TIME_FILTER_DEFAULT,
        left_custom_sql: str = None,
        right_custom_sql: str = None,
        context: Optional[Dict] = None,
    ) -> Union[CheckResult, ConsistencyCheck]:
        if left_custom_sql and right_custom_sql:
            if columns or time_filter:
                raise ValueError(
                    "When using custom sqls you cannot change 'columns' or 'time_filter' attribute"
                )

        left_check_table = Table(**left_check_table)
        right_check_table = Table(**right_check_table)
        context = self.get_context(left_check_table, right_check_table,
                                   context)

        result = self.do_consistency_check(
            method,
            columns,
            time_filter,
            left_check_table,
            right_check_table,
            left_custom_sql,
            right_custom_sql,
            context,
        )

        if result_table:
            result_table = ResultTable(**result_table,
                                       model_cls=self.model_cls)
            quality_check_class = create_default_check_class(result_table)
            self.right_conn.ensure_table(quality_check_class.__table__)
            self.upsert(quality_check_class, result)
            return result

        obj = CheckResult()
        obj.init_row_consistency(**result)
        return obj

    @staticmethod
    def get_context(
        left_check_table: Table,
        right_check_table: Table,
        context: Optional[Dict] = None,
    ) -> Dict:
        """
        Construct context to pass to executors. User context overrides defaults.
        """
        ctx_defaults = {
            "left_table_fullname": left_check_table.fullname,
            "right_table_fullname": right_check_table.fullname,
            "task_ts": datetime.now(),
        }
        if context:
            ctx_defaults.update(context)
        return ctx_defaults

    def do_consistency_check(
        self,
        method: str,
        columns: Optional[List[str]],
        time_filter: str,
        left_check_table: Table,
        right_check_table: Table,
        left_sql: str = None,
        right_sql: str = None,
        context: Dict = None,
    ):
        """
        Run quality check for all rules. Use `qc_cls` to construct objects that will be inserted
        afterwards.
        """
        if not left_sql or not right_sql:
            if method == self.COUNT:
                if columns:
                    column = f"count({', '.join(columns)})"
                else:
                    column = "count(*)"
            elif method == self.DIFF:
                if columns:
                    column = ", ".join(columns)
                else:
                    # List the columns explicitly in case column order of compared tables is not the same.
                    column = ", ".join(
                        sorted(
                            self.right_conn.get_column_names(
                                right_check_table.fullname)))
            else:
                raise NotImplementedError(f"Method {method} not implemented")

        if not left_sql:
            left_sql = self.construct_default_query(left_check_table.fullname,
                                                    column, time_filter,
                                                    context)
        left_result = self.run_query(self.left_conn, left_sql, context)
        if not right_sql:
            right_sql = self.construct_default_query(
                right_check_table.fullname, column, time_filter, context)
        right_result = self.run_query(self.right_conn, right_sql, context)

        valid, passed, failed = self.compare_results(left_result, right_result,
                                                     method)

        return {
            "check": {
                "type": method,
                "description": "",
                "name": "consistency",
                "passed": passed,
                "failed": failed,
            },
            "status": "valid" if valid else "invalid",
            "left_table_name": left_check_table.fullname,
            "right_table_name": right_check_table.fullname,
            "time_filter": time_filter,
            "context": context,
        }

    def compare_results(self, left_result, right_result, method):
        if method == self.COUNT:
            left_count = left_result[0][0]
            right_count = right_result[0][0]
            passed = min(left_count, right_count)
            failed = (left_count - passed) - (right_count - passed)
            return failed == 0, passed, failed
        elif method == self.DIFF:
            left_set = set(left_result)
            right_set = set(right_result)
            common = left_set.intersection(right_set)
            passed = len(common)
            failed = (len(left_set) - len(common)) + (len(right_set) -
                                                      len(common))
            return failed == 0, passed, failed
        else:
            raise NotImplementedError(f"Method {method} not implemented")

    def construct_default_query(self, table_name, column, time_filter,
                                context):
        time_filter = compose_where_time_filter(time_filter,
                                                context["task_ts"])
        query = f"""
            SELECT {column}
            FROM {table_name}
            {f'WHERE {time_filter}' if time_filter else ''}
        """
        return query

    def render_sql(self, sql, context):
        """
        Replace some parameters in query.
        :return str, formatted sql
        """
        t = jinja2.Template(sql)
        rendered = t.render(**context)
        return rendered

    def run_query(self, conn: Connector, query: str, context):
        query = self.render_sql(query, context)
        logging.info(query)
        result = [r._row for r in conn.get_records(query)]
        return result

    def upsert(self, dc_cls, result):
        obj = dc_cls()
        obj.init_row(**result)
        self.right_conn.upsert([
            obj,
        ])
示例#20
0
class MigrationsResolver:
    """
    Migrations helper class for the Contessa migrations.
    """
    def __init__(self, migrations_map, package_version, url, schema):
        """
        :param migrations_map: map of package versions and their migrations.
        In form of dictionary {'0.1.4':'A', '0.1.5':'B'}
        :param package_version: the version of the package planned to be migrated
        :param url: the database url where the Alembic migration table is present or planned to be created
        :param schema: the database schema where the Alembic migration table is present or planned to be created
        """
        self.versions_migrations = migrations_map
        self.package_version = package_version
        self.url = url
        self.schema = schema
        self.conn = Connector(self.url)

    def schema_exists(self):
        """
        Check if schema with the Alembic migration table exists.
        :return: Return true if schema with the Alembic migration exists.
        """
        result = self.conn.get_records(f"""
            SELECT EXISTS (
               SELECT 1
               FROM   information_schema.schemata
               WHERE  schema_name = '{self.schema}'
            );
            """)
        return result.first()[0]

    def migrations_table_exists(self):
        """
        Check if the Alembic versions table exists.
        """
        result = self.conn.get_records(f"""
            SELECT EXISTS (
               SELECT 1
               FROM   information_schema.tables
               WHERE  table_schema = '{self.schema}'
               AND    table_name = '{MIGRATION_TABLE}'
           );
            """)
        return result.first()[0]

    def get_applied_migration(self):
        """
        Get the current applied migration in the target schema.
        """
        if self.migrations_table_exists() is False:
            return None
        version = self.conn.get_records(
            f"select * from {self.schema}.{MIGRATION_TABLE}")
        return version.first()[0]

    def is_on_head(self):
        """
        Check if the current applied migration is valid for the Contessa version.
        """
        if self.migrations_table_exists() is False:
            return False
        current = self.get_applied_migration()

        fallback_package_version = self.get_fallback_version()
        return self.versions_migrations[fallback_package_version] == current

    def get_fallback_version(self):
        """
        Get fallback version in the case of non existing migration for the Contessa package version.
        Returns the last package version containing the migration.

        In the case we have this migrations versions map

        versions_migrations = {
            "0.1.4": "54f8985b0ee5",
            "0.1.5": "480e6618700d",
            "0.1.8": "3w4er8y50yyd",
            "0.1.9": "034hfa8943hr",
        }

        and we ask for the fallback version of our current version e.g. 0.1.7., we get 0.1.5. as a result because it is
        the last version with specified migration before ours.
        """
        keys = list(self.versions_migrations.keys())
        if self.package_version in self.versions_migrations.keys():
            return self.package_version
        if pv(self.package_version) < pv(keys[0]):
            return list(self.versions_migrations.keys())[0]
        if pv(self.package_version) > pv(keys[-1]):
            return list(self.versions_migrations.keys())[-1]

        result = keys[0]
        for k in keys[1:]:
            if pv(k) <= pv(self.package_version):
                result = k
            else:
                return result

    def get_migration_to_head(self):
        """
        Get the migration command for alembic. Migration command is a tupple of type of migration and migration hash.
        E.g. ('upgrade', 'dfgdfg5b0ee5') or ('downgrade', 'dfgdfg5b0ee5')
        """
        if self.is_on_head():
            return None

        fallback_version = self.get_fallback_version()

        if self.migrations_table_exists() is False:
            return "upgrade", self.versions_migrations[fallback_version]

        migrations_versions = dict(
            map(reversed, self.versions_migrations.items()))
        applied_migration = self.get_applied_migration()
        applied_package = migrations_versions[applied_migration]

        if pv(applied_package) < pv(fallback_version):
            return "upgrade", self.versions_migrations[fallback_version]
        if pv(applied_package) > pv(fallback_version):
            return "downgrade", self.versions_migrations[fallback_version]
示例#21
0
 def run_query(self, conn: Connector, query: str, context):
     query = self.render_sql(query, context)
     logging.info(query)
     result = [r._row for r in conn.get_records(query)]
     return result
示例#22
0
class ContessaRunner:
    """
    todo - rewrite comments
    """
    def __init__(self, conn_uri_or_engine, special_qc_map=None):
        self.conn_uri_or_engine = conn_uri_or_engine
        self.conn = Connector(conn_uri_or_engine)

        # todo - allow cfg
        self.special_qc_map = special_qc_map or {}

    def run(
        self,
        raw_rules: List[Dict[str, str]],
        check_table: Dict,
        result_table: Dict,  # todo - docs for quality name, maybe defaults..
        context: Optional[Dict] = None,
    ):
        check_table = Table(**check_table)
        result_table = ResultTable(**result_table)
        context = self.get_context(check_table, context)

        normalized_rules = self.normalize_rules(raw_rules)
        refresh_executors(check_table, self.conn, context)
        quality_check_class = self.get_quality_check_class(result_table)
        self.ensure_table(quality_check_class)

        rules = self.build_rules(normalized_rules)
        objs = self.do_quality_checks(quality_check_class, rules, context)

        self.insert(objs)

    @staticmethod
    def get_context(check_table: Table,
                    context: Optional[Dict] = None) -> Dict:
        """
        Construct context to pass to executors. User context overrides defaults.
        """
        ctx_defaults = {
            "table_fullname": check_table.fullname,
            "task_ts": datetime.now(),  # todo - is now() ok ?
        }
        ctx_defaults.update(context)
        return ctx_defaults

    def normalize_rules(self, raw_rules):
        return RuleNormalizer.normalize(raw_rules)

    def do_quality_checks(self,
                          dq_cls,
                          rules: List[Rule],
                          context: Dict = None):
        """
        Run quality check for all rules. Use `qc_cls` to construct objects that will be inserted
        afterwards.
        """
        ret = []
        for rule in rules:
            obj = self.apply_rule(context, dq_cls, rule)
            ret.append(obj)
        return ret

    def apply_rule(self, context, dq_cls, rule):
        e = get_executor(rule)
        logging.info(f"Executing rule `{rule}`.")
        results = e.execute(rule)
        obj = dq_cls()
        obj.init_row(rule, results, self.conn, context)
        return obj

    def insert(self, objs):
        """
        Insert QualityCheck objects using sqlalchemy. If there is integrity error, skip it.
        """
        logging.info(f"Inserting {len(objs)} results.")
        session = self.conn.make_session()
        try:
            session.add_all(objs)
            session.commit()
        except sqlalchemy.exc.IntegrityError:
            ts = objs[0].task_ts
            logging.info(
                f"This quality check ({ts}) was already done. Skipping it this time."
            )
            session.rollback()
        finally:
            session.close()

    def ensure_table(self, qc_cls):
        """
        Create table for QualityCheck class if it doesn't exists. E.g. quality_check_
        """
        try:
            qc_cls.__table__.create(bind=self.conn.engine)
            logging.info(f"Created table {qc_cls.__tablename__}.")
        except sqlalchemy.exc.ProgrammingError:
            logging.info(
                f"Table {qc_cls.__tablename__} already exists. Skipping creation."
            )

    def build_rules(self, normalized_rules):
        """
        Construct rules classes from user definition that are dicts.
        Raises if there are bad arguments for a certain rule.
        :return: list of Rule objects
        """
        ret = []
        for rule_def in normalized_rules:
            rule_cls = self.pick_rule_cls(rule_def)
            try:
                r = rule_cls(**rule_def)
            except Exception as e:
                logging.error(f"For rule `{rule_cls.__name__}`. {e.args[0]}")
                raise
            else:
                ret.append(r)
        return ret

    def pick_rule_cls(self, rule_def):
        """
        Get rule class based on its name that was input by user.
        :param rule_def: dict
        :return: Rule class
        """
        return get_rule_cls(rule_def["name"])

    def get_quality_check_class(self, result_table: ResultTable):
        """
        QualityCheck can be different, e.g. `special_table` has specific quality_check.
        Or kind of generic one that computes number of passed/failed objects etc.
        So determine if is special or not and return the class.
        :return: QualityCheck cls
        """
        special_checks = self.special_qc_map.keys()
        if result_table.fullname in special_checks:
            quality_check_class = self.special_qc_map[result_table.fullname]
            logging.info(
                f"Using {quality_check_class.__name__} as quality check class."
            )
        else:
            quality_check_class = create_default_quality_check_class(
                result_table)
            logging.info("Using default QualityCheck class.")
        return quality_check_class
示例#23
0
class TestConsistencyChecker(unittest.TestCase):
    def setUp(self):
        """
        Init a temporary table with some data.
        """
        self.left_table_name = "raw_booking"
        self.right_table_name = "booking"
        self.result_table_name = "booking"
        self.ts_nodash = (
            FakedDatetime.now().isoformat().replace("-", "").replace(":", "")
        )
        self.now = FakedDatetime.now()

        sql = [
            "DROP SCHEMA if exists tmp CASCADE;",
            "DROP SCHEMA if exists data_quality CASCADE;",
            "CREATE SCHEMA IF NOT EXISTS tmp;",
            "CREATE SCHEMA IF NOT EXISTS data_quality;",
            "CREATE SCHEMA IF NOT EXISTS hello;",
            f"""
                    CREATE TABLE IF NOT EXISTS tmp.{self.left_table_name}(
                      id SERIAL PRIMARY KEY,
                      src text,
                      dst text,
                      price int,
                      turnover_after_refunds double precision,
                      initial_price double precision,
                      created_at timestamptz
                    )
                """,
            f"""
                    CREATE TABLE IF NOT EXISTS hello.{self.right_table_name}(
                      id SERIAL PRIMARY KEY,
                      src text,
                      dst text,
                      price int,
                      turnover_after_refunds double precision,
                      initial_price double precision,
                      created_at timestamptz
                    )
                """,
            f"""
                INSERT INTO tmp.{self.left_table_name}
                    (src, dst, price, turnover_after_refunds, initial_price, created_at)
                VALUES
                    ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'),
                    (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'),
                    ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00'),
                    ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00')
                """,
            f"""
                INSERT INTO hello.{self.right_table_name}
                    (src, dst, price, turnover_after_refunds, initial_price, created_at)
                VALUES
                    ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'),
                    (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'),
                    ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00')
            """,
        ]
        self.conn = Connector(TEST_DB_URI)
        for s in sql:
            self.conn.execute(s)

        self.consistency_checker = ConsistencyChecker(TEST_DB_URI)

    def tearDown(self):
        """
        Drop all created tables.
        """
        self.conn.execute(f"DROP schema tmp CASCADE;")
        self.conn.execute(f"DROP schema data_quality CASCADE;")
        self.conn.execute(f"DROP schema hello CASCADE;")
        DQBase.metadata.clear()

    @mock.patch("contessa.executor.datetime", FakedDatetime)
    def test_execute_consistency_false(self):
        self.consistency_checker.run(
            self.consistency_checker.COUNT,
            left_check_table={"schema_name": "tmp", "table_name": self.left_table_name},
            right_check_table={
                "schema_name": "hello",
                "table_name": self.right_table_name,
            },
            result_table={
                "schema_name": "data_quality",
                "table_name": self.result_table_name,
            },
            context={"task_ts": self.now},
        )
        rows = self.conn.get_records(
            f"""
            SELECT * from data_quality.consistency_check_{self.result_table_name}
            order by created_at
        """
        )
        self.assertEqual(rows.fetchone()["status"], "invalid")

    @mock.patch("contessa.executor.datetime", FakedDatetime)
    def test_execute_consistency_true(self):
        # add missing record to the right table
        self.conn.execute(
            f"""
            INSERT INTO hello.{self.right_table_name}
                (src, dst, price, turnover_after_refunds, initial_price, created_at)
            VALUES
                ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00')
        """
        )

        self.consistency_checker.run(
            self.consistency_checker.COUNT,
            left_check_table={"schema_name": "tmp", "table_name": self.left_table_name},
            right_check_table={
                "schema_name": "hello",
                "table_name": self.right_table_name,
            },
            result_table={
                "schema_name": "data_quality",
                "table_name": self.result_table_name,
            },
            context={"task_ts": self.now},
        )

        rows = self.conn.get_records(
            f"""
            SELECT * from data_quality.consistency_check_{self.result_table_name}
            order by created_at
        """
        )
        self.assertEqual(rows.fetchone()["status"], "valid")

    @mock.patch("contessa.executor.datetime", FakedDatetime)
    def test_execute_consistency_diff(self):
        """ Test scenario where table diff is done on tables with different column order."""
        self.inconsistent_colums_left = "user"
        self.inconsistent_colums_right = "user_inconsistent"
        self.conn = Connector(TEST_DB_URI)
        self.consistency_checker = ConsistencyChecker(TEST_DB_URI)

        sql = [
            f"""
                CREATE TABLE IF NOT EXISTS tmp.{self.inconsistent_colums_left}(
                  id SERIAL PRIMARY KEY,
                  name text
                )
            """,
            f"""
                CREATE TABLE IF NOT EXISTS tmp.{self.inconsistent_colums_right}(
                  name text,
                  id SERIAL PRIMARY KEY
                )
            """,
            f"""
                INSERT INTO tmp.{self.inconsistent_colums_left}
                VALUES
                    (1, 'John Doe')
                """,
            f"""
                INSERT INTO tmp.{self.inconsistent_colums_right}
                VALUES
                    ('John Doe', 1)
            """,
        ]

        for s in sql:
            self.conn.execute(s)

        self.consistency_checker.run(
            self.consistency_checker.DIFF,
            left_check_table={
                "schema_name": "tmp",
                "table_name": self.inconsistent_colums_left,
            },
            right_check_table={
                "schema_name": "tmp",
                "table_name": self.inconsistent_colums_right,
            },
            result_table={
                "schema_name": "data_quality",
                "table_name": self.result_table_name,
            },
            context={"task_ts": self.now},
        )

        rows = self.conn.get_records(
            f"""
                SELECT * from data_quality.consistency_check_{self.result_table_name}
                order by created_at
            """
        )

        self.assertEqual(rows.fetchone()["status"], "valid")

    @mock.patch("contessa.executor.datetime", FakedDatetime)
    def test_execute_consistency_sqls(self):
        result = self.consistency_checker.run(
            self.consistency_checker.DIFF,
            left_check_table={"schema_name": "tmp", "table_name": self.left_table_name},
            right_check_table={
                "schema_name": "hello",
                "table_name": self.right_table_name,
            },
            left_custom_sql="SELECT 16349;",
            right_custom_sql="SELECT 16349;",
            context={"task_ts": self.now},
        )

        self.assertEqual("valid", result.status)

        result = self.consistency_checker.run(
            self.consistency_checker.DIFF,
            left_check_table={"schema_name": "tmp", "table_name": self.left_table_name},
            right_check_table={
                "schema_name": "hello",
                "table_name": self.right_table_name,
            },
            left_custom_sql="SELECT 42;",
            right_custom_sql="SELECT 16349;",
            context={"task_ts": self.now},
        )

        self.assertEqual("invalid", result.status)
示例#24
0
    def setUpClass(cls) -> None:
        cls.conn = Connector(TEST_DB_URI)
        cls.alembic_cfg = Config(ALEMBIC_INI_PATH)

        migration_table_name = cls.alembic_cfg.get_main_option("version_table")
        cls.migration_table = Table(DATA_QUALITY_SCHEMA, migration_table_name)
示例#25
0
 def run_query(self, conn: Connector, query: str, context):
     query = self.render_sql(query, context)
     logging.debug(query)
     result = [tuple(r.values()) for r in conn.get_records(query)]
     return result
示例#26
0
    def __init__(self, conn_uri_or_engine, special_qc_map=None):
        self.conn_uri_or_engine = conn_uri_or_engine
        self.conn = Connector(conn_uri_or_engine)

        # todo - allow cfg
        self.special_qc_map = special_qc_map or {}
示例#27
0
class ConsistencyChecker:
    """
    Checks consistency of the sync between two tables.
    """

    model_cls = ConsistencyCheck

    COUNT = "count"
    DIFF = "difference"

    def __init__(self, left_conn_uri_or_engine, right_conn_uri_or_engine=None):
        self.left_conn_uri_or_engine = left_conn_uri_or_engine
        self.left_conn = Connector(left_conn_uri_or_engine)
        if right_conn_uri_or_engine is None:
            self.right_conn_uri_or_engine = self.left_conn_uri_or_engine
            self.right_conn = self.left_conn
        else:
            self.right_conn_uri_or_engine = right_conn_uri_or_engine
            self.right_conn = Connector(right_conn_uri_or_engine)

    def run(
        self,
        method: str,
        left_check_table: Dict,
        right_check_table: Dict,
        result_table: Optional[Dict] = None,
        columns: Optional[List[str]] = None,
        time_filter: Optional[Union[str, List[Dict], TimeFilter]] = None,
        left_custom_sql: str = None,
        right_custom_sql: str = None,
        context: Optional[Dict] = None,
        example_selector: ExampleSelector = default_example_selector,
    ) -> Union[CheckResult, ConsistencyCheck]:
        if left_custom_sql and right_custom_sql:
            if columns or time_filter:
                raise ValueError(
                    "When using custom sqls you cannot change 'columns' or 'time_filter' attribute"
                )

        time_filter = parse_time_filter(time_filter)

        left_check_table = Table(**left_check_table)
        right_check_table = Table(**right_check_table)
        context = self.get_context(left_check_table, right_check_table,
                                   context)

        result = self.do_consistency_check(
            method,
            columns,
            time_filter,
            left_check_table,
            right_check_table,
            left_custom_sql,
            right_custom_sql,
            context,
            example_selector,
        )

        if result_table:
            result_table = ResultTable(**result_table,
                                       model_cls=self.model_cls)
            quality_check_class = create_default_check_class(result_table)
            self.right_conn.ensure_table(quality_check_class.__table__)
            self.upsert(quality_check_class, result)
            return result

        obj = CheckResult()
        obj.init_row_consistency(**result)
        return obj

    @staticmethod
    def get_context(
        left_check_table: Table,
        right_check_table: Table,
        context: Optional[Dict] = None,
    ) -> Dict:
        """
        Construct context to pass to executors. User context overrides defaults.
        """
        ctx_defaults = {
            "left_table_fullname": left_check_table.fullname,
            "right_table_fullname": right_check_table.fullname,
            "task_ts": datetime.now(),
        }
        if context:
            ctx_defaults.update(context)
        return ctx_defaults

    def do_consistency_check(
        self,
        method: str,
        columns: Optional[List[str]],
        time_filter: Optional[TimeFilter],
        left_check_table: Table,
        right_check_table: Table,
        left_sql: str = None,
        right_sql: str = None,
        context: Dict = None,
        example_selector: ExampleSelector = default_example_selector,
    ):
        """
        Run quality check for all rules. Use `qc_cls` to construct objects that will be inserted
        afterwards.
        """
        if not left_sql or not right_sql:
            if method == self.COUNT:
                if columns:
                    column = f"count({', '.join(columns)})"
                else:
                    column = "count(*)"
            elif method == self.DIFF:
                if columns:
                    column = ", ".join(columns)
                else:
                    # List the columns explicitly in case column order of compared tables is not the same.
                    column = ", ".join(
                        sorted(
                            self.right_conn.get_column_names(
                                right_check_table.fullname)))
            else:
                raise NotImplementedError(f"Method {method} not implemented")

        if not left_sql:
            left_sql = self.construct_default_query(left_check_table.fullname,
                                                    column, time_filter,
                                                    context)
        left_result = self.run_query(self.left_conn, left_sql, context)
        if not right_sql:
            right_sql = self.construct_default_query(
                right_check_table.fullname, column, time_filter, context)
        right_result = self.run_query(self.right_conn, right_sql, context)

        results = self.compare_results(left_result, right_result, method,
                                       example_selector)

        return {
            "check": {
                "type": method,
                "description": "",
                "name": "consistency",
            },
            "results": results,
            "left_table_name": left_check_table.fullname,
            "right_table_name": right_check_table.fullname,
            "time_filter": time_filter,
            "context": context,
        }

    def compare_results(self, left_result, right_result, method,
                        example_selector):
        if method == self.COUNT:
            left_count = left_result[0][0]
            right_count = right_result[0][0]
            passed = min(left_count, right_count)
            failed = (left_count - passed) - (right_count - passed)
            return AggregatedResult(
                total_records=max(left_count, right_count),
                failed=failed,
                passed=passed,
            )

        elif method == self.DIFF:
            left_set = set(left_result)
            right_set = set(right_result)
            common = left_set.intersection(right_set)
            passed = len(common)
            failed = (len(left_set) - len(common)) + (len(right_set) -
                                                      len(common))
            failed_examples = example_selector.select_examples(
                left_set.symmetric_difference(right_set))
            return AggregatedResult(
                total_records=failed + passed,
                failed=failed,
                passed=passed,
                failed_example=list(failed_examples),
            )

        else:
            raise NotImplementedError(f"Method {method} not implemented")

    def construct_default_query(
        self,
        table_name: str,
        column: str,
        time_filter: Optional[TimeFilter],
        context: Dict,
    ):
        if time_filter:
            if context.get("task_ts"):
                time_filter.now = context["task_ts"]
            time_filter = time_filter.sql
        query = f"""
            SELECT {column}
            FROM {table_name}
            {f'WHERE {time_filter}' if time_filter else ''}
        """
        return query

    def render_sql(self, sql, context):
        """
        Replace some parameters in query.
        :return str, formatted sql
        """
        rendered = render_jinja_sql(sql, context)
        return rendered

    def run_query(self, conn: Connector, query: str, context):
        query = self.render_sql(query, context)
        logging.debug(query)
        result = [tuple(r.values()) for r in conn.get_records(query)]
        return result

    def upsert(self, dc_cls, result):
        obj = dc_cls()
        obj.init_row(**result)
        self.right_conn.upsert([
            obj,
        ])

    def construct_automatic_time_filter(
        self,
        left_check_table: Dict,
        created_at_column=None,
        updated_at_column=None,
    ) -> TimeFilter:
        left_check_table = Table(**left_check_table)

        if created_at_column is None and updated_at_column is None:
            raise ValueError(
                "Automatic time filter need at least one time column")

        since_column = updated_at_column or created_at_column
        since_sql = f"SELECT min({since_column}) FROM {left_check_table.fullname}"
        logging.info(since_sql)
        since = self.left_conn.get_records(since_sql).scalar()

        return TimeFilter(
            columns=[
                TimeFilterColumn(since_column, since=since),
            ],
            conjunction=TimeFilterConjunction.AND,
        )