def apply(self, conn: Connector): """ Execute a formatted sql. Check if it returns list of booleans that is needed to do a quality check. If yes, return pd.Series. :return: pd.Series """ sql = self.sql_with_where results = [r for r in conn.get_records(sql) ] # returns generator, so get it to memory is_list_of_bool = all((len(r) == 1 and isinstance(r[0], (bool, type(None))) for r in results)) if not is_list_of_bool: raise ValueError( f"Your query for rule `{self.name}` does not return list of booleans or Nones." ) return pd.Series([bool(r[0]) for r in results])
class TestDataQualityOperator(unittest.TestCase): def setUp(self): """ Init a temporary table with some data. """ self.table_name = "booking_all_v2" self.ts_nodash = (FakedDatetime.now().isoformat().replace("-", "").replace( ":", "")) self.tmp_table_name = f"{self.table_name}_{self.ts_nodash}" self.now = FakedDatetime.now() sql = [ "DROP SCHEMA if exists tmp CASCADE;", "DROP SCHEMA if exists data_quality CASCADE;", "CREATE SCHEMA IF NOT EXISTS tmp;", "CREATE SCHEMA IF NOT EXISTS data_quality;", "CREATE SCHEMA IF NOT EXISTS hello;", f""" CREATE TABLE IF NOT EXISTS tmp.{self.table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" CREATE TABLE IF NOT EXISTS tmp.{self.tmp_table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" INSERT INTO tmp.{self.tmp_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'), -- this is older than 30 days. -- not in stats when time_filter = `created_at` (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00'), ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00') """, f""" INSERT INTO tmp.{self.table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T13:00:00'), -- this is older than 30 days. -- not in stats when time_filter = `created_at` (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T13:00:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T13:00:00'), ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T13:00:00') """, ] self.conn = Connector(TEST_DB_URI) for s in sql: self.conn.execute(s) self.contessa_runner = ContessaRunner(TEST_DB_URI) def tearDown(self): """ Drop all created tables. """ self.conn.execute(f"DROP schema tmp CASCADE;") self.conn.execute(f"DROP schema data_quality CASCADE;") self.conn.execute(f"DROP schema hello CASCADE;") DQBase.metadata.clear() @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_tmp(self): sql = """ SELECT CASE WHEN src = 'BTS' and dst is null THEN false ELSE true END as res from {{ table_fullname }} """ rules = [ { "name": "not_null_name", "type": "not_null", "column": "dst", "time_filter": "created_at", }, { "name": "gt_name", "type": "gt", "column": "price", "value": 10, "time_filter": "created_at", }, { "name": "sql_name", "type": "sql", "sql": sql, "column": "src_dst", "description": "test sql rule", }, { "name": "not_name", "type": "not", "column": "src", "value": "dst" }, ] self.contessa_runner.run( check_table={ "schema_name": "tmp", "table_name": self.tmp_table_name }, result_table={ "schema_name": "data_quality", "table_name": self.table_name }, raw_rules=rules, context={"task_ts": self.now}, ) rows = self.conn.get_records(f""" SELECT * from data_quality.quality_check_{self.table_name} order by created_at """).fetchall() self.assertEqual(len(rows), 4) notnull_rule = rows[0] self.assertEqual(notnull_rule["failed"], 1) self.assertEqual(notnull_rule["passed"], 2) self.assertEqual(notnull_rule["attribute"], "dst") self.assertEqual(notnull_rule["task_ts"].timestamp(), self.now.timestamp()) gt_rule = rows[1] self.assertEqual(gt_rule["failed"], 3) self.assertEqual(gt_rule["passed"], 0) self.assertEqual(gt_rule["attribute"], "price") sql_rule = rows[2] self.assertEqual(sql_rule["failed"], 1) self.assertEqual(sql_rule["passed"], 3) self.assertEqual(sql_rule["attribute"], "src_dst") not_column_rule = rows[3] self.assertEqual(not_column_rule["failed"], 1) self.assertEqual(not_column_rule["passed"], 3) self.assertEqual(not_column_rule["attribute"], "src") @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_dst(self): sql = """ SELECT CASE WHEN src = 'BTS' AND dst is null THEN false ELSE true END as res FROM {{ table_fullname }} WHERE created_at BETWEEN timestamptz '{{task_ts}}' - INTERVAL '1 hour' AND timestamptz '{{task_ts}}' + INTERVAL '1 hour' """ rules = [ { "name": "not_null_name", "type": "not_null", "column": "dst", "time_filter": "created_at", }, { "name": "sql_name", "type": "sql", "sql": sql, "column": "src_dst", "description": "test sql rule", }, ] self.contessa_runner.run( check_table={ "schema_name": "tmp", "table_name": self.tmp_table_name }, result_table={ "schema_name": "data_quality", "table_name": self.table_name }, raw_rules=rules, context={"task_ts": self.now}, ) rows = self.conn.get_records(f""" SELECT * from data_quality.quality_check_{self.table_name} order by created_at """).fetchall() self.assertEqual(len(rows), 2) notnull_rule = rows[0] self.assertEqual(notnull_rule["failed"], 1) self.assertEqual(notnull_rule["passed"], 2) self.assertEqual(notnull_rule["attribute"], "dst") self.assertEqual(notnull_rule["task_ts"].timestamp(), self.now.timestamp()) sql_rule = rows[1] self.assertEqual(sql_rule["failed"], 1) self.assertEqual(sql_rule["passed"], 0) self.assertEqual(sql_rule["attribute"], "src_dst") def test_different_schema(self): rules = [{ "name": "not_nul_name", "type": "not_null", "column": "dst", "time_filter": "created_at", }] self.contessa_runner.run( check_table={ "schema_name": "tmp", "table_name": self.tmp_table_name }, result_table={ "schema_name": "hello", "table_name": "abcde" }, raw_rules=rules, context={"task_ts": self.now}, ) rows = self.conn.get_records(f""" SELECT 1 from hello.quality_check_abcde """).fetchall() self.assertEqual(len(rows), 1)
class ConsistencyChecker: """ Checks consistency of the sync between two tables. """ model_cls = ConsistencyCheck COUNT = "count" DIFF = "difference" def __init__(self, left_conn_uri_or_engine, right_conn_uri_or_engine=None): self.left_conn_uri_or_engine = left_conn_uri_or_engine self.left_conn = Connector(left_conn_uri_or_engine) if right_conn_uri_or_engine is None: self.right_conn_uri_or_engine = self.left_conn_uri_or_engine self.right_conn = self.left_conn else: self.right_conn_uri_or_engine = right_conn_uri_or_engine self.right_conn = Connector(right_conn_uri_or_engine) def run( self, method: str, left_check_table: Dict, right_check_table: Dict, result_table: Optional[Dict] = None, columns: Optional[List[str]] = None, time_filter: Optional[Union[str, List[Dict], TimeFilter]] = None, left_custom_sql: str = None, right_custom_sql: str = None, context: Optional[Dict] = None, example_selector: ExampleSelector = default_example_selector, ) -> Union[CheckResult, ConsistencyCheck]: if left_custom_sql and right_custom_sql: if columns or time_filter: raise ValueError( "When using custom sqls you cannot change 'columns' or 'time_filter' attribute" ) time_filter = parse_time_filter(time_filter) left_check_table = Table(**left_check_table) right_check_table = Table(**right_check_table) context = self.get_context(left_check_table, right_check_table, context) result = self.do_consistency_check( method, columns, time_filter, left_check_table, right_check_table, left_custom_sql, right_custom_sql, context, example_selector, ) if result_table: result_table = ResultTable(**result_table, model_cls=self.model_cls) quality_check_class = create_default_check_class(result_table) self.right_conn.ensure_table(quality_check_class.__table__) self.upsert(quality_check_class, result) return result obj = CheckResult() obj.init_row_consistency(**result) return obj @staticmethod def get_context( left_check_table: Table, right_check_table: Table, context: Optional[Dict] = None, ) -> Dict: """ Construct context to pass to executors. User context overrides defaults. """ ctx_defaults = { "left_table_fullname": left_check_table.fullname, "right_table_fullname": right_check_table.fullname, "task_ts": datetime.now(), } if context: ctx_defaults.update(context) return ctx_defaults def do_consistency_check( self, method: str, columns: Optional[List[str]], time_filter: Optional[TimeFilter], left_check_table: Table, right_check_table: Table, left_sql: str = None, right_sql: str = None, context: Dict = None, example_selector: ExampleSelector = default_example_selector, ): """ Run quality check for all rules. Use `qc_cls` to construct objects that will be inserted afterwards. """ if not left_sql or not right_sql: if method == self.COUNT: if columns: column = f"count({', '.join(columns)})" else: column = "count(*)" elif method == self.DIFF: if columns: column = ", ".join(columns) else: # List the columns explicitly in case column order of compared tables is not the same. column = ", ".join( sorted( self.right_conn.get_column_names( right_check_table.fullname))) else: raise NotImplementedError(f"Method {method} not implemented") if not left_sql: left_sql = self.construct_default_query(left_check_table.fullname, column, time_filter, context) left_result = self.run_query(self.left_conn, left_sql, context) if not right_sql: right_sql = self.construct_default_query( right_check_table.fullname, column, time_filter, context) right_result = self.run_query(self.right_conn, right_sql, context) results = self.compare_results(left_result, right_result, method, example_selector) return { "check": { "type": method, "description": "", "name": "consistency", }, "results": results, "left_table_name": left_check_table.fullname, "right_table_name": right_check_table.fullname, "time_filter": time_filter, "context": context, } def compare_results(self, left_result, right_result, method, example_selector): if method == self.COUNT: left_count = left_result[0][0] right_count = right_result[0][0] passed = min(left_count, right_count) failed = (left_count - passed) - (right_count - passed) return AggregatedResult( total_records=max(left_count, right_count), failed=failed, passed=passed, ) elif method == self.DIFF: left_set = set(left_result) right_set = set(right_result) common = left_set.intersection(right_set) passed = len(common) failed = (len(left_set) - len(common)) + (len(right_set) - len(common)) failed_examples = example_selector.select_examples( left_set.symmetric_difference(right_set)) return AggregatedResult( total_records=failed + passed, failed=failed, passed=passed, failed_example=list(failed_examples), ) else: raise NotImplementedError(f"Method {method} not implemented") def construct_default_query( self, table_name: str, column: str, time_filter: Optional[TimeFilter], context: Dict, ): if time_filter: if context.get("task_ts"): time_filter.now = context["task_ts"] time_filter = time_filter.sql query = f""" SELECT {column} FROM {table_name} {f'WHERE {time_filter}' if time_filter else ''} """ return query def render_sql(self, sql, context): """ Replace some parameters in query. :return str, formatted sql """ rendered = render_jinja_sql(sql, context) return rendered def run_query(self, conn: Connector, query: str, context): query = self.render_sql(query, context) logging.debug(query) result = [tuple(r.values()) for r in conn.get_records(query)] return result def upsert(self, dc_cls, result): obj = dc_cls() obj.init_row(**result) self.right_conn.upsert([ obj, ]) def construct_automatic_time_filter( self, left_check_table: Dict, created_at_column=None, updated_at_column=None, ) -> TimeFilter: left_check_table = Table(**left_check_table) if created_at_column is None and updated_at_column is None: raise ValueError( "Automatic time filter need at least one time column") since_column = updated_at_column or created_at_column since_sql = f"SELECT min({since_column}) FROM {left_check_table.fullname}" logging.info(since_sql) since = self.left_conn.get_records(since_sql).scalar() return TimeFilter( columns=[ TimeFilterColumn(since_column, since=since), ], conjunction=TimeFilterConjunction.AND, )
def run_query(self, conn: Connector, query: str, context): query = self.render_sql(query, context) logging.debug(query) result = [tuple(r.values()) for r in conn.get_records(query)] return result
class TestConsistencyChecker(unittest.TestCase): def setUp(self): """ Init a temporary table with some data. """ self.left_table_name = "raw_booking" self.right_table_name = "booking" self.result_table_name = "booking" self.ts_nodash = ( FakedDatetime.now().isoformat().replace("-", "").replace(":", "") ) self.now = FakedDatetime.now() sql = [ "DROP SCHEMA if exists tmp CASCADE;", "DROP SCHEMA if exists data_quality CASCADE;", "CREATE SCHEMA IF NOT EXISTS tmp;", "CREATE SCHEMA IF NOT EXISTS data_quality;", "CREATE SCHEMA IF NOT EXISTS hello;", f""" CREATE TABLE IF NOT EXISTS tmp.{self.left_table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" CREATE TABLE IF NOT EXISTS hello.{self.right_table_name}( id SERIAL PRIMARY KEY, src text, dst text, price int, turnover_after_refunds double precision, initial_price double precision, created_at timestamptz ) """, f""" INSERT INTO tmp.{self.left_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'), (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00'), ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00') """, f""" INSERT INTO hello.{self.right_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('BTS', NULL, 1, 100, 11, '2018-09-12T11:50:00'), (NULL, 'PEK', 33, 1.1, 13, '2018-01-12T15:50:00'), ('VIE', 'JFK', 4, 5.5, 23.4, '2018-09-11T11:50:00') """, ] self.conn = Connector(TEST_DB_URI) for s in sql: self.conn.execute(s) self.consistency_checker = ConsistencyChecker(TEST_DB_URI) def tearDown(self): """ Drop all created tables. """ self.conn.execute(f"DROP schema tmp CASCADE;") self.conn.execute(f"DROP schema data_quality CASCADE;") self.conn.execute(f"DROP schema hello CASCADE;") DQBase.metadata.clear() @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_consistency_false(self): self.consistency_checker.run( self.consistency_checker.COUNT, left_check_table={"schema_name": "tmp", "table_name": self.left_table_name}, right_check_table={ "schema_name": "hello", "table_name": self.right_table_name, }, result_table={ "schema_name": "data_quality", "table_name": self.result_table_name, }, context={"task_ts": self.now}, ) rows = self.conn.get_records( f""" SELECT * from data_quality.consistency_check_{self.result_table_name} order by created_at """ ) self.assertEqual(rows.fetchone()["status"], "invalid") @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_consistency_true(self): # add missing record to the right table self.conn.execute( f""" INSERT INTO hello.{self.right_table_name} (src, dst, price, turnover_after_refunds, initial_price, created_at) VALUES ('VIE', 'VIE', 4, 0.0, 0.0, '2018-09-11T11:50:00') """ ) self.consistency_checker.run( self.consistency_checker.COUNT, left_check_table={"schema_name": "tmp", "table_name": self.left_table_name}, right_check_table={ "schema_name": "hello", "table_name": self.right_table_name, }, result_table={ "schema_name": "data_quality", "table_name": self.result_table_name, }, context={"task_ts": self.now}, ) rows = self.conn.get_records( f""" SELECT * from data_quality.consistency_check_{self.result_table_name} order by created_at """ ) self.assertEqual(rows.fetchone()["status"], "valid") @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_consistency_diff(self): """ Test scenario where table diff is done on tables with different column order.""" self.inconsistent_colums_left = "user" self.inconsistent_colums_right = "user_inconsistent" self.conn = Connector(TEST_DB_URI) self.consistency_checker = ConsistencyChecker(TEST_DB_URI) sql = [ f""" CREATE TABLE IF NOT EXISTS tmp.{self.inconsistent_colums_left}( id SERIAL PRIMARY KEY, name text ) """, f""" CREATE TABLE IF NOT EXISTS tmp.{self.inconsistent_colums_right}( name text, id SERIAL PRIMARY KEY ) """, f""" INSERT INTO tmp.{self.inconsistent_colums_left} VALUES (1, 'John Doe') """, f""" INSERT INTO tmp.{self.inconsistent_colums_right} VALUES ('John Doe', 1) """, ] for s in sql: self.conn.execute(s) self.consistency_checker.run( self.consistency_checker.DIFF, left_check_table={ "schema_name": "tmp", "table_name": self.inconsistent_colums_left, }, right_check_table={ "schema_name": "tmp", "table_name": self.inconsistent_colums_right, }, result_table={ "schema_name": "data_quality", "table_name": self.result_table_name, }, context={"task_ts": self.now}, ) rows = self.conn.get_records( f""" SELECT * from data_quality.consistency_check_{self.result_table_name} order by created_at """ ) self.assertEqual(rows.fetchone()["status"], "valid") @mock.patch("contessa.executor.datetime", FakedDatetime) def test_execute_consistency_sqls(self): result = self.consistency_checker.run( self.consistency_checker.DIFF, left_check_table={"schema_name": "tmp", "table_name": self.left_table_name}, right_check_table={ "schema_name": "hello", "table_name": self.right_table_name, }, left_custom_sql="SELECT 16349;", right_custom_sql="SELECT 16349;", context={"task_ts": self.now}, ) self.assertEqual("valid", result.status) result = self.consistency_checker.run( self.consistency_checker.DIFF, left_check_table={"schema_name": "tmp", "table_name": self.left_table_name}, right_check_table={ "schema_name": "hello", "table_name": self.right_table_name, }, left_custom_sql="SELECT 42;", right_custom_sql="SELECT 16349;", context={"task_ts": self.now}, ) self.assertEqual("invalid", result.status)
def run_query(self, conn: Connector, query: str, context): query = self.render_sql(query, context) logging.info(query) result = [r._row for r in conn.get_records(query)] return result
class MigrationsResolver: """ Migrations helper class for the Contessa migrations. """ def __init__(self, migrations_map, package_version, url, schema): """ :param migrations_map: map of package versions and their migrations. In form of dictionary {'0.1.4':'A', '0.1.5':'B'} :param package_version: the version of the package planned to be migrated :param url: the database url where the Alembic migration table is present or planned to be created :param schema: the database schema where the Alembic migration table is present or planned to be created """ self.versions_migrations = migrations_map self.package_version = package_version self.url = url self.schema = schema self.conn = Connector(self.url) def schema_exists(self): """ Check if schema with the Alembic migration table exists. :return: Return true if schema with the Alembic migration exists. """ result = self.conn.get_records(f""" SELECT EXISTS ( SELECT 1 FROM information_schema.schemata WHERE schema_name = '{self.schema}' ); """) return result.first()[0] def migrations_table_exists(self): """ Check if the Alembic versions table exists. """ result = self.conn.get_records(f""" SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = '{self.schema}' AND table_name = '{MIGRATION_TABLE}' ); """) return result.first()[0] def get_applied_migration(self): """ Get the current applied migration in the target schema. """ if self.migrations_table_exists() is False: return None version = self.conn.get_records( f"select * from {self.schema}.{MIGRATION_TABLE}") return version.first()[0] def is_on_head(self): """ Check if the current applied migration is valid for the Contessa version. """ if self.migrations_table_exists() is False: return False current = self.get_applied_migration() fallback_package_version = self.get_fallback_version() return self.versions_migrations[fallback_package_version] == current def get_fallback_version(self): """ Get fallback version in the case of non existing migration for the Contessa package version. Returns the last package version containing the migration. In the case we have this migrations versions map versions_migrations = { "0.1.4": "54f8985b0ee5", "0.1.5": "480e6618700d", "0.1.8": "3w4er8y50yyd", "0.1.9": "034hfa8943hr", } and we ask for the fallback version of our current version e.g. 0.1.7., we get 0.1.5. as a result because it is the last version with specified migration before ours. """ keys = list(self.versions_migrations.keys()) if self.package_version in self.versions_migrations.keys(): return self.package_version if pv(self.package_version) < pv(keys[0]): return list(self.versions_migrations.keys())[0] if pv(self.package_version) > pv(keys[-1]): return list(self.versions_migrations.keys())[-1] result = keys[0] for k in keys[1:]: if pv(k) <= pv(self.package_version): result = k else: return result def get_migration_to_head(self): """ Get the migration command for alembic. Migration command is a tupple of type of migration and migration hash. E.g. ('upgrade', 'dfgdfg5b0ee5') or ('downgrade', 'dfgdfg5b0ee5') """ if self.is_on_head(): return None fallback_version = self.get_fallback_version() if self.migrations_table_exists() is False: return "upgrade", self.versions_migrations[fallback_version] migrations_versions = dict( map(reversed, self.versions_migrations.items())) applied_migration = self.get_applied_migration() applied_package = migrations_versions[applied_migration] if pv(applied_package) < pv(fallback_version): return "upgrade", self.versions_migrations[fallback_version] if pv(applied_package) > pv(fallback_version): return "downgrade", self.versions_migrations[fallback_version]