def get_count_sql(table_name): return ''' SELECT COUNT(*) FROM {}.{}.{} '''.format( sql.identifier(CONFIG['snowflake_database']), sql.identifier(CONFIG['snowflake_schema']), sql.identifier(table_name))
def add_table(self, cur, path, name, metadata): sql.valid_identifier(name) cur.execute(''' CREATE TABLE {}.{}.{} ({} {}) '''.format( sql.identifier(self.connection.configured_database), sql.identifier(self.connection.configured_schema), sql.identifier(name), # Snowflake does not allow for creation of tables with no columns sql.identifier(self.CREATE_TABLE_INITIAL_COLUMN), self.CREATE_TABLE_INITIAL_COLUMN_TYPE)) self._set_table_metadata( cur, name, { 'path': path, 'version': metadata.get('version', None), 'schema_version': metadata['schema_version'], 'mappings': {} }) self.add_column_mapping( cur, name, (self.CREATE_TABLE_INITIAL_COLUMN, ), self.CREATE_TABLE_INITIAL_COLUMN, json_schema.make_nullable({'type': json_schema.BOOLEAN}))
def write_table_batch(self, cur, table_batch, metadata): remote_schema = table_batch['remote_schema'] ## Create temp table to upload new data to target_table_name = self.canonicalize_identifier('tmp_' + str(uuid.uuid4())) cur.execute(''' CREATE TABLE {db}.{schema}.{temp_table} LIKE {db}.{schema}.{table} '''.format(db=sql.identifier(self.connection.configured_database), schema=sql.identifier(self.connection.configured_schema), temp_table=sql.identifier(target_table_name), table=sql.identifier(remote_schema['name']))) ## Make streamable CSV records csv_headers = list(remote_schema['schema']['properties'].keys()) rows_iter = iter(table_batch['records']) def transform(): try: row = next(rows_iter) with io.StringIO() as out: writer = csv.DictWriter(out, csv_headers) writer.writerow(row) return out.getvalue() except StopIteration: return '' csv_rows = TransformStream(transform) ## Persist csv rows self.persist_csv_rows(cur, remote_schema, target_table_name, csv_headers, csv_rows) return len(table_batch['records'])
def test_loading__single_char_columns(db_prep): stream_count = 50 main(CONFIG, input_stream=SingleCharStream(stream_count)) with connect(**TEST_DB) as conn: with conn.cursor() as cur: assert_columns_equal(cur, 'ROOT', { ('_SDC_PRIMARY_KEY', 'TEXT', 'NO'), ('_SDC_BATCHED_AT', 'TIMESTAMP_TZ', 'YES'), ('_SDC_RECEIVED_AT', 'TIMESTAMP_TZ', 'YES'), ('_SDC_SEQUENCE', 'NUMBER', 'YES'), ('_SDC_TABLE_VERSION', 'NUMBER', 'YES'), ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN', 'YES'), ('X', 'NUMBER', 'YES') }) cur.execute(''' SELECT {} FROM {}.{}.{} '''.format( sql.identifier('X'), sql.identifier(CONFIG['snowflake_database']), sql.identifier(CONFIG['snowflake_schema']), sql.identifier('ROOT') )) persisted_records = cur.fetchall() ## Assert that the column is has migrated data assert stream_count == len(persisted_records) assert stream_count == len([x for x in persisted_records if isinstance(x[0], float)])
def test_deduplication_older_rows(db_prep): stream = CatStream(100, nested_count=2, duplicates=2, duplicate_sequence_delta=-100) main(CONFIG, input_stream=stream) with connect(**TEST_DB) as conn: with conn.cursor() as cur: cur.execute(get_count_sql('CATS')) table_count = cur.fetchone()[0] cur.execute(get_count_sql('CATS__ADOPTION__IMMUNIZATIONS')) nested_table_count = cur.fetchone()[0] cur.execute(''' SELECT "_SDC_SEQUENCE" FROM {}.{}.{} WHERE "ID" in ({}) '''.format( sql.identifier(CONFIG['snowflake_database']), sql.identifier(CONFIG['snowflake_schema']), sql.identifier('CATS'), ','.join(["'{}'".format(x) for x in stream.duplicate_pks_used]))) dup_cat_records = cur.fetchall() assert stream.record_message_count == 102 assert table_count == 100 assert nested_table_count == 200 for record in dup_cat_records: assert record[0] == stream.sequence
def _get_table_metadata(self, cur, table_name): cur.execute(''' SHOW TABLES LIKE '{}' IN SCHEMA {}.{} '''.format( table_name, sql.identifier(self.connection.configured_database), sql.identifier(self.connection.configured_schema), )) tables = cur.fetchall() if not tables: return None if len(tables) != 1: raise SnowflakeError( '{} tables returned while searching for: {}.{}.{}'.format( len(tables), self.connection.configured_database, self.connection.configured_schema, table_name)) comment = tables[0][5] if comment: try: comment_meta = json.loads(comment) except: self.LOGGER.exception('Could not load table comment metadata') raise else: comment_meta = None return comment_meta
def test_deduplication_existing_new_rows(db_prep): stream = CatStream(100, nested_count=2) main(CONFIG, input_stream=stream) original_sequence = stream.sequence stream = CatStream(100, nested_count=2, sequence=original_sequence - 20) main(CONFIG, input_stream=stream) with connect(**TEST_DB) as conn: with conn.cursor() as cur: cur.execute(get_count_sql('CATS')) table_count = cur.fetchone()[0] cur.execute(get_count_sql('CATS__ADOPTION__IMMUNIZATIONS')) nested_table_count = cur.fetchone()[0] cur.execute(''' SELECT DISTINCT "_SDC_SEQUENCE" FROM {}.{}.{} '''.format(sql.identifier(CONFIG['snowflake_database']), sql.identifier(CONFIG['snowflake_schema']), sql.identifier('CATS'))) sequences = cur.fetchall() assert table_count == 100 assert nested_table_count == 200 assert len(sequences) == 1 assert sequences[0][0] == original_sequence
def is_table_empty(self, cur, table_name): cur.execute(''' SELECT COUNT(1) FROM {}.{}.{} '''.format(sql.identifier(self.connection.configured_database), sql.identifier(self.connection.configured_schema), sql.identifier(table_name))) return cur.fetchone()[0] == 0
def drop_column(self, cur, table_name, column_name): cur.execute(''' ALTER TABLE {database}.{table_schema}.{table_name} DROP COLUMN {column_name} '''.format( database=sql.identifier(self.connection.configured_database), table_schema=sql.identifier(self.connection.configured_schema), table_name=sql.identifier(table_name), column_name=sql.identifier(column_name)))
def make_column_nullable(self, cur, table_name, column_name): cur.execute(''' ALTER TABLE {database}.{table_schema}.{table_name} ALTER COLUMN {column_name} DROP NOT NULL '''.format( database=sql.identifier(self.connection.configured_database), table_schema=sql.identifier(self.connection.configured_schema), table_name=sql.Identifier(table_name), column_name=sql.Identifier(column_name)))
def migrate_column(self, cur, table_name, from_column, to_column): cur.execute(''' UPDATE {database}.{table_schema}.{table_name} SET {to_column} = {from_column} '''.format( database=sql.identifier(self.connection.configured_database), table_schema=sql.identifier(self.connection.configured_schema), table_name=sql.identifier(table_name), to_column=sql.identifier(to_column), from_column=sql.identifier(from_column)))
def test_loading__new_non_null_column(db_prep): cat_count = 50 main(CONFIG, input_stream=CatStream(cat_count)) class NonNullStream(CatStream): def generate_record(self): record = CatStream.generate_record(self) record['id'] = record['id'] + cat_count return record non_null_stream = NonNullStream(cat_count) non_null_stream.schema = deepcopy(non_null_stream.schema) non_null_stream.schema['schema']['properties']['paw_toe_count'] = {'type': 'integer', 'default': 5} main(CONFIG, input_stream=non_null_stream) with connect(**TEST_DB) as conn: with conn.cursor() as cur: assert_columns_equal(cur, 'CATS', { ('_SDC_BATCHED_AT', 'TIMESTAMP_TZ', 'YES'), ('_SDC_RECEIVED_AT', 'TIMESTAMP_TZ', 'YES'), ('_SDC_SEQUENCE', 'NUMBER', 'YES'), ('_SDC_TABLE_VERSION', 'NUMBER', 'YES'), ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN', 'YES'), ('ADOPTION__ADOPTED_ON', 'TIMESTAMP_TZ', 'YES'), ('ADOPTION__WAS_FOSTER', 'BOOLEAN', 'YES'), ('AGE', 'NUMBER', 'YES'), ('ID', 'NUMBER', 'NO'), ('NAME', 'TEXT', 'NO'), ('PAW_SIZE', 'NUMBER', 'NO'), ('PAW_COLOUR', 'TEXT', 'NO'), ('PAW_TOE_COUNT', 'NUMBER', 'YES'), ('FLEA_CHECK_COMPLETE', 'BOOLEAN', 'NO'), ('PATTERN', 'TEXT', 'YES') }) cur.execute(''' SELECT {}, {} FROM {}.{}.{} '''.format( sql.identifier('ID'), sql.identifier('PAW_TOE_COUNT'), sql.identifier(CONFIG['snowflake_database']), sql.identifier(CONFIG['snowflake_schema']), sql.identifier('CATS') )) persisted_records = cur.fetchall() ## Assert that the split columns before/after new non-null data assert 2 * cat_count == len(persisted_records) assert cat_count == len([x for x in persisted_records if x[1] is None]) assert cat_count == len([x for x in persisted_records if x[1] is not None])
def add_column(self, cur, table_name, column_name, column_schema): cur.execute(''' ALTER TABLE {database}.{table_schema}.{table_name} ADD COLUMN {column_name} {data_type} '''.format( database=sql.identifier(self.connection.configured_database), table_schema=sql.identifier(self.connection.configured_schema), table_name=sql.identifier(table_name), column_name=sql.identifier(column_name), data_type=self.json_schema_to_sql_type(column_schema)))
def _set_table_metadata(self, cur, table_name, metadata): """ Given a Metadata dict, set it as the comment on the given table. :param self: Snowflake :param cur: Cursor :param table_name: String :param metadata: Metadata Dict :return: None """ cur.execute(''' COMMENT ON TABLE {}.{}.{} IS '{}' '''.format(sql.identifier(self.connection.configured_database), sql.identifier(self.connection.configured_schema), sql.identifier(table_name), json.dumps(metadata)))
def test_loading__multi_types_columns(db_prep): stream_count = 50 main(CONFIG, input_stream=MultiTypeStream(stream_count)) with connect(**TEST_DB) as conn: with conn.cursor() as cur: assert_columns_equal(cur, 'ROOT', { ('_SDC_PRIMARY_KEY', 'TEXT', 'NO'), ('_SDC_BATCHED_AT', 'TIMESTAMP_TZ', 'YES'), ('_SDC_RECEIVED_AT', 'TIMESTAMP_TZ', 'YES'), ('_SDC_SEQUENCE', 'NUMBER', 'YES'), ('_SDC_TABLE_VERSION', 'NUMBER', 'YES'), ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN', 'YES'), ('EVERY_TYPE__I', 'NUMBER', 'YES'), ('EVERY_TYPE__F', 'FLOAT', 'YES'), ('EVERY_TYPE__B', 'BOOLEAN', 'YES'), ('EVERY_TYPE__T', 'TIMESTAMP_TZ', 'YES'), ('EVERY_TYPE__I__1', 'NUMBER', 'YES'), ('EVERY_TYPE__F__1', 'FLOAT', 'YES'), ('EVERY_TYPE__B__1', 'BOOLEAN', 'YES'), ('NUMBER_WHICH_ONLY_COMES_AS_INTEGER', 'FLOAT', 'NO') }) assert_columns_equal(cur, 'ROOT__EVERY_TYPE', { ('_SDC_SOURCE_KEY__SDC_PRIMARY_KEY', 'TEXT', 'NO'), ('_SDC_LEVEL_0_ID', 'NUMBER', 'NO'), ('_SDC_SEQUENCE', 'NUMBER', 'YES'), ('_SDC_VALUE', 'NUMBER', 'NO'), ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN', 'YES'), }) cur.execute(''' SELECT {} FROM {}.{}.{} '''.format( sql.identifier('NUMBER_WHICH_ONLY_COMES_AS_INTEGER'), sql.identifier(CONFIG['snowflake_database']), sql.identifier(CONFIG['snowflake_schema']), sql.identifier('ROOT') )) persisted_records = cur.fetchall() ## Assert that the column is has migrated data assert stream_count == len(persisted_records) assert stream_count == len([x for x in persisted_records if isinstance(x[0], float)])
def setup_table_mapping_cache(self, cur): self.table_mapping_cache = {} cur.execute(''' SHOW TABLES IN SCHEMA {}.{} '''.format(sql.identifier(self.connection.configured_database), sql.identifier(self.connection.configured_schema))) for row in cur.fetchall(): mapped_name = row[1] raw_json = row[5] table_path = None if raw_json: table_path = json.loads(raw_json).get('path', None) self.LOGGER.info("Mapping: {} to {}".format( mapped_name, table_path)) if table_path: self.table_mapping_cache[tuple(table_path)] = mapped_name
def assert_columns_equal(cursor, table_name, expected_column_tuples): cursor.execute(''' SELECT column_name, data_type, is_nullable FROM {}.information_schema.columns WHERE table_schema = '{}' AND table_name = '{}' '''.format(sql.identifier(CONFIG['snowflake_database']), CONFIG['snowflake_schema'], table_name)) columns = [] for column in cursor.fetchall(): columns.append((column[0], column[1], column[2])) assert set(columns) == expected_column_tuples
def get_table_schema(self, cur, name): metadata = self._get_table_metadata(cur, name) if not metadata: return None cur.execute(''' SELECT column_name, data_type, is_nullable FROM {}.information_schema.columns WHERE table_schema = '{}' AND table_name = '{}' '''.format(sql.identifier(self.connection.configured_database), self.connection.configured_schema, name)) properties = {} for column in cur.fetchall(): properties[column[0]] = self.sql_type_to_json_schema( column[1], column[2] == 'YES') metadata['name'] = name metadata['type'] = 'TABLE_SCHEMA' metadata['schema'] = {'properties': properties} return metadata
def activate_version(self, stream_buffer, version): with self.connection.cursor() as cur: try: self.setup_table_mapping_cache(cur) root_table_name = self.add_table_mapping( cur, (stream_buffer.stream, ), {}) current_table_schema = self.get_table_schema( cur, root_table_name) if not current_table_schema: self.LOGGER.error( '{} - Table for stream does not exist'.format( stream_buffer.stream)) elif current_table_schema.get( 'version') is not None and current_table_schema.get( 'version') >= version: self.LOGGER.warning( '{} - Table version {} already active'.format( stream_buffer.stream, version)) else: versioned_root_table = root_table_name + SEPARATOR + str( version) names_to_paths = dict([ (v, k) for k, v in self.table_mapping_cache.items() ]) cur.execute(''' SHOW TABLES LIKE '{}%' IN SCHEMA {}.{} '''.format( versioned_root_table, sql.identifier(self.connection.configured_database), sql.identifier(self.connection.configured_schema))) for versioned_table_name in [x[1] for x in cur.fetchall()]: table_name = root_table_name + versioned_table_name[ len(versioned_root_table):] table_path = names_to_paths[table_name] args = { 'db_schema': '{}.{}'.format( sql.identifier( self.connection.configured_database), sql.identifier( self.connection.configured_schema)), 'stream_table_old': sql.identifier(table_name + SEPARATOR + 'OLD'), 'stream_table': sql.identifier(table_name), 'version_table': sql.identifier(versioned_table_name) } cur.execute(''' ALTER TABLE {db_schema}.{stream_table} RENAME TO {db_schema}.{stream_table_old} '''.format(**args)) cur.execute(''' ALTER TABLE {db_schema}.{version_table} RENAME TO {db_schema}.{stream_table} '''.format(**args)) cur.execute(''' DROP TABLE {db_schema}.{stream_table_old} '''.format(**args)) self.connection.commit() metadata = self._get_table_metadata(cur, table_name) self.LOGGER.info( 'Activated {}, setting path to {}'.format( metadata, table_path)) metadata['path'] = table_path self._set_table_metadata(cur, table_name, metadata) except Exception as ex: self.connection.rollback() message = '{} - Exception activating table version {}'.format( stream_buffer.stream, version) self.LOGGER.exception(message) raise SnowflakeError(message, ex)
def perform_update(self, cur, target_table_name, temp_table_name, key_properties, columns, subkeys): full_table_name = '{}.{}.{}'.format( sql.identifier(self.connection.configured_database), sql.identifier(self.connection.configured_schema), sql.identifier(target_table_name)) full_temp_table_name = '{}.{}.{}'.format( sql.identifier(self.connection.configured_database), sql.identifier(self.connection.configured_schema), sql.identifier(temp_table_name)) pk_temp_select_list = [] pk_where_list = [] pk_null_list = [] cxt_where_list = [] for pk in key_properties: pk_identifier = sql.identifier(pk) pk_temp_select_list.append('{}.{}'.format(full_temp_table_name, pk_identifier)) pk_where_list.append('{table}.{pk} = "dedupped".{pk}'.format( table=full_table_name, temp_table=full_temp_table_name, pk=pk_identifier)) pk_null_list.append('{table}.{pk} IS NULL'.format( table=full_table_name, pk=pk_identifier)) cxt_where_list.append('{table}.{pk} = "pks".{pk}'.format( table=full_table_name, pk=pk_identifier)) pk_temp_select = ', '.join(pk_temp_select_list) pk_where = ' AND '.join(pk_where_list) pk_null = ' AND '.join(pk_null_list) cxt_where = ' AND '.join(cxt_where_list) sequence_identifier = sql.identifier( self.canonicalize_identifier(SINGER_SEQUENCE)) sequence_join = ' AND "dedupped".{} >= {}.{}'.format( sequence_identifier, full_table_name, sequence_identifier) distinct_order_by = ' ORDER BY {}, {}.{} DESC'.format( pk_temp_select, full_temp_table_name, sequence_identifier) if len(subkeys) > 0: pk_temp_subkey_select_list = [] for pk in (key_properties + subkeys): pk_temp_subkey_select_list.append('{}.{}'.format( full_temp_table_name, sql.identifier(pk))) insert_distinct_on = ', '.join(pk_temp_subkey_select_list) insert_distinct_order_by = ' ORDER BY {}, {}.{} DESC'.format( insert_distinct_on, full_temp_table_name, sequence_identifier) else: insert_distinct_on = pk_temp_select insert_distinct_order_by = distinct_order_by insert_columns_list = [] dedupped_columns_list = [] for column in columns: insert_columns_list.append(sql.identifier(column)) dedupped_columns_list.append('{}.{}'.format( sql.identifier('dedupped'), sql.identifier(column))) insert_columns = ', '.join(insert_columns_list) dedupped_columns = ', '.join(dedupped_columns_list) cur.execute(''' DELETE FROM {table} USING ( SELECT "dedupped".* FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY {pk_temp_select} {distinct_order_by}) AS "_sdc_pk_ranked" FROM {temp_table} {distinct_order_by}) AS "dedupped" JOIN {table} ON {pk_where}{sequence_join} WHERE "_sdc_pk_ranked" = 1 ) AS "pks" WHERE {cxt_where}; '''.format(table=full_table_name, temp_table=full_temp_table_name, pk_temp_select=pk_temp_select, pk_where=pk_where, cxt_where=cxt_where, sequence_join=sequence_join, distinct_order_by=distinct_order_by)) cur.execute(''' INSERT INTO {table}({insert_columns}) ( SELECT {dedupped_columns} FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY {insert_distinct_on} {insert_distinct_order_by}) AS "_sdc_pk_ranked" FROM {temp_table} {insert_distinct_order_by}) AS "dedupped" LEFT JOIN {table} ON {pk_where} WHERE "_sdc_pk_ranked" = 1 AND {pk_null} ); '''.format(table=full_table_name, temp_table=full_temp_table_name, pk_where=pk_where, pk_null=pk_null, insert_distinct_on=insert_distinct_on, insert_distinct_order_by=insert_distinct_order_by, insert_columns=insert_columns, dedupped_columns=dedupped_columns)) if not self.s3: # Clear out the associated stage for the table cur.execute(''' REMOVE @{db}.{schema}.%{temp_table} '''.format(db=sql.identifier(self.connection.configured_database), schema=sql.identifier( self.connection.configured_schema), temp_table=sql.identifier(temp_table_name))) # Drop the tmp table cur.execute(''' DROP TABLE {temp_table}; '''.format(temp_table=full_temp_table_name))
def test_loading__column_type_change(db_prep): cat_count = 20 main(CONFIG, input_stream=CatStream(cat_count)) with connect(**TEST_DB) as conn: with conn.cursor() as cur: assert_columns_equal( cur, 'CATS', {('_SDC_BATCHED_AT', 'TIMESTAMP_TZ', 'YES'), ('_SDC_RECEIVED_AT', 'TIMESTAMP_TZ', 'YES'), ('_SDC_SEQUENCE', 'NUMBER', 'YES'), ('_SDC_TABLE_VERSION', 'NUMBER', 'YES'), ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN', 'YES'), ('ADOPTION__ADOPTED_ON', 'TIMESTAMP_TZ', 'YES'), ('ADOPTION__WAS_FOSTER', 'BOOLEAN', 'YES'), ('AGE', 'NUMBER', 'YES'), ('ID', 'NUMBER', 'NO'), ('NAME', 'TEXT', 'NO'), ('PAW_SIZE', 'NUMBER', 'NO'), ('PAW_COLOUR', 'TEXT', 'NO'), ('FLEA_CHECK_COMPLETE', 'BOOLEAN', 'NO'), ('PATTERN', 'TEXT', 'YES')}) cur.execute(''' SELECT {} FROM {}.{}.{} '''.format(sql.identifier('NAME'), sql.identifier(CONFIG['snowflake_database']), sql.identifier(CONFIG['snowflake_schema']), sql.identifier('CATS'))) persisted_records = cur.fetchall() ## Assert that the original data is present assert cat_count == len(persisted_records) assert cat_count == len( [x for x in persisted_records if x[0] is not None]) class NameBooleanCatStream(CatStream): def generate_record(self): record = CatStream.generate_record(self) record['id'] = record['id'] + cat_count record['name'] = False return record stream = NameBooleanCatStream(cat_count) stream.schema = deepcopy(stream.schema) stream.schema['schema']['properties']['name'] = {'type': 'boolean'} main(CONFIG, input_stream=stream) with connect(**TEST_DB) as conn: with conn.cursor() as cur: assert_columns_equal( cur, 'CATS', {('_SDC_BATCHED_AT', 'TIMESTAMP_TZ', 'YES'), ('_SDC_RECEIVED_AT', 'TIMESTAMP_TZ', 'YES'), ('_SDC_SEQUENCE', 'NUMBER', 'YES'), ('_SDC_TABLE_VERSION', 'NUMBER', 'YES'), ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN', 'YES'), ('ADOPTION__ADOPTED_ON', 'TIMESTAMP_TZ', 'YES'), ('ADOPTION__WAS_FOSTER', 'BOOLEAN', 'YES'), ('AGE', 'NUMBER', 'YES'), ('ID', 'NUMBER', 'NO'), ('NAME__S', 'TEXT', 'YES'), ('NAME__B', 'BOOLEAN', 'YES'), ('PAW_SIZE', 'NUMBER', 'NO'), ('PAW_COLOUR', 'TEXT', 'NO'), ('FLEA_CHECK_COMPLETE', 'BOOLEAN', 'NO'), ('PATTERN', 'TEXT', 'YES')}) cur.execute(''' SELECT {}, {} FROM {}.{}.{} '''.format(sql.identifier('NAME__S'), sql.identifier('NAME__B'), sql.identifier(CONFIG['snowflake_database']), sql.identifier(CONFIG['snowflake_schema']), sql.identifier('CATS'))) persisted_records = cur.fetchall() ## Assert that the split columns migrated data/persisted new data assert 2 * cat_count == len(persisted_records) assert cat_count == len( [x for x in persisted_records if x[0] is not None]) assert cat_count == len( [x for x in persisted_records if x[1] is not None]) assert 0 == len([ x for x in persisted_records if x[0] is not None and x[1] is not None ]) class NameIntegerCatStream(CatStream): def generate_record(self): record = CatStream.generate_record(self) record['id'] = record['id'] + (2 * cat_count) record['name'] = 314 return record stream = NameIntegerCatStream(cat_count) stream.schema = deepcopy(stream.schema) stream.schema['schema']['properties']['name'] = {'type': 'integer'} main(CONFIG, input_stream=stream) with connect(**TEST_DB) as conn: with conn.cursor() as cur: assert_columns_equal( cur, 'CATS', {('_SDC_BATCHED_AT', 'TIMESTAMP_TZ', 'YES'), ('_SDC_RECEIVED_AT', 'TIMESTAMP_TZ', 'YES'), ('_SDC_SEQUENCE', 'NUMBER', 'YES'), ('_SDC_TABLE_VERSION', 'NUMBER', 'YES'), ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN', 'YES'), ('ADOPTION__ADOPTED_ON', 'TIMESTAMP_TZ', 'YES'), ('ADOPTION__WAS_FOSTER', 'BOOLEAN', 'YES'), ('AGE', 'NUMBER', 'YES'), ('ID', 'NUMBER', 'NO'), ('NAME__S', 'TEXT', 'YES'), ('NAME__B', 'BOOLEAN', 'YES'), ('NAME__I', 'NUMBER', 'YES'), ('PAW_SIZE', 'NUMBER', 'NO'), ('PAW_COLOUR', 'TEXT', 'NO'), ('FLEA_CHECK_COMPLETE', 'BOOLEAN', 'NO'), ('PATTERN', 'TEXT', 'YES')}) cur.execute(''' SELECT {}, {}, {} FROM {}.{}.{} '''.format(sql.identifier('NAME__S'), sql.identifier('NAME__B'), sql.identifier('NAME__I'), sql.identifier(CONFIG['snowflake_database']), sql.identifier(CONFIG['snowflake_schema']), sql.identifier('CATS'))) persisted_records = cur.fetchall() ## Assert that the split columns migrated data/persisted new data assert 3 * cat_count == len(persisted_records) assert cat_count == len( [x for x in persisted_records if x[0] is not None]) assert cat_count == len( [x for x in persisted_records if x[1] is not None]) assert cat_count == len( [x for x in persisted_records if x[2] is not None]) assert 0 == len([ x for x in persisted_records if x[0] is not None and x[1] is not None and x[2] is not None ]) assert 0 == len([ x for x in persisted_records if x[0] is None and x[1] is None and x[2] is None ])
def assert_records(conn, records, table_name, pks, match_pks=False): if not isinstance(pks, list): pks = [pks] with conn.cursor(True) as cur: cur.execute("set timezone='UTC';") cur.execute(''' SELECT * FROM {}.{}.{} '''.format(sql.identifier(CONFIG['snowflake_database']), sql.identifier(CONFIG['snowflake_schema']), sql.identifier(table_name))) persisted_records_raw = cur.fetchall() persisted_records = {} for persisted_record in persisted_records_raw: pk = get_pk_key(pks, persisted_record) persisted_records[pk] = persisted_record subtables = {} records_pks = [] pre_canonicalized_pks = [x.lower() for x in pks] for record in records: pk = get_pk_key(pre_canonicalized_pks, record) records_pks.append(pk) persisted_record = persisted_records[pk.upper()] subpks = {} for pk in pks: subpks[singer_stream.SINGER_SOURCE_PK_PREFIX + pk] = persisted_record[pk] assert_record(record, persisted_record, subtables, subpks) if match_pks: assert sorted(list( persisted_records.keys())) == sorted(records_pks) sub_pks = list( map(lambda pk: singer_stream.SINGER_SOURCE_PK_PREFIX.upper() + pk, pks)) for subtable_name, items in subtables.items(): cur.execute(''' SELECT * FROM {}.{}.{} '''.format( sql.identifier(CONFIG['snowflake_database']), sql.identifier(CONFIG['snowflake_schema']), sql.identifier(table_name + '__' + subtable_name.upper()))) persisted_records_raw = cur.fetchall() persisted_records = {} for persisted_record in persisted_records_raw: pk = get_pk_key(sub_pks, persisted_record, subrecord=True) persisted_records[pk] = persisted_record subtables = {} records_pks = [] pre_canonicalized_sub_pks = [x.lower() for x in sub_pks] for record in items: pk = get_pk_key(pre_canonicalized_sub_pks, record, subrecord=True) records_pks.append(pk) persisted_record = persisted_records[pk] assert_record(record, persisted_record, subtables, subpks) assert len(subtables.values()) == 0 if match_pks: assert sorted(list( persisted_records.keys())) == sorted(records_pks)
def persist_csv_rows(self, cur, remote_schema, temp_table_name, columns, csv_rows): params = [] if self.s3: bucket, key = self.s3.persist(csv_rows, key_prefix=temp_table_name + SEPARATOR) stage_location = "'s3://{bucket}/{key}' credentials=(AWS_KEY_ID=%s AWS_SECRET_KEY=%s)".format( bucket=bucket, key=key) params = [ self.s3.credentials()['aws_access_key_id'], self.s3.credentials()['aws_secret_access_key'] ] else: stage_location = '@{db}.{schema}.%{table}'.format( db=sql.identifier(self.connection.configured_database), schema=sql.identifier(self.connection.configured_schema), table=sql.identifier(temp_table_name)) rel_path = '/tmp/target-snowflake/' file_name = str(uuid.uuid4()).replace('-', '_') # Make tmp folder to hold data file os.makedirs(rel_path, exist_ok=True) # Write readable csv_rows to file with open(rel_path + file_name, 'wb') as file: line = csv_rows.read() while line: file.write(line.encode('utf-8')) line = csv_rows.read() # Upload to internal table stage cur.execute(''' PUT file://{rel_path}{file_name} {stage_location} '''.format(rel_path=rel_path, file_name=file_name, stage_location=stage_location)) # Tidy up and remove tmp staging file os.remove(rel_path + file_name) stage_location += '/{}'.format(file_name) cur.execute(''' COPY INTO {db}.{schema}.{table} ({cols}) FROM {stage_location} FILE_FORMAT = (TYPE = CSV EMPTY_FIELD_AS_NULL = FALSE) '''.format(db=sql.identifier(self.connection.configured_database), schema=sql.identifier(self.connection.configured_schema), table=sql.identifier(temp_table_name), cols=','.join([sql.identifier(x) for x in columns]), stage_location=stage_location), params=params) pattern = re.compile(SINGER_LEVEL.upper().format('[0-9]+')) subkeys = list( filter(lambda header: re.match(pattern, header) is not None, columns)) canonicalized_key_properties = [ self.fetch_column_from_path((key_property, ), remote_schema)[0] for key_property in remote_schema['key_properties'] ] self.perform_update(cur, remote_schema['name'], temp_table_name, canonicalized_key_properties, columns, subkeys)