def get_count_sql(table_name):
    return '''
        SELECT COUNT(*) FROM {}.{}.{}
    '''.format(
        sql.identifier(CONFIG['snowflake_database']),
        sql.identifier(CONFIG['snowflake_schema']),
        sql.identifier(table_name))
    def add_table(self, cur, path, name, metadata):
        sql.valid_identifier(name)

        cur.execute('''
            CREATE TABLE {}.{}.{} ({} {})
            '''.format(
            sql.identifier(self.connection.configured_database),
            sql.identifier(self.connection.configured_schema),
            sql.identifier(name),
            # Snowflake does not allow for creation of tables with no columns
            sql.identifier(self.CREATE_TABLE_INITIAL_COLUMN),
            self.CREATE_TABLE_INITIAL_COLUMN_TYPE))

        self._set_table_metadata(
            cur, name, {
                'path': path,
                'version': metadata.get('version', None),
                'schema_version': metadata['schema_version'],
                'mappings': {}
            })

        self.add_column_mapping(
            cur, name, (self.CREATE_TABLE_INITIAL_COLUMN, ),
            self.CREATE_TABLE_INITIAL_COLUMN,
            json_schema.make_nullable({'type': json_schema.BOOLEAN}))
    def write_table_batch(self, cur, table_batch, metadata):
        remote_schema = table_batch['remote_schema']

        ## Create temp table to upload new data to
        target_table_name = self.canonicalize_identifier('tmp_' +
                                                         str(uuid.uuid4()))
        cur.execute('''
            CREATE TABLE {db}.{schema}.{temp_table} LIKE {db}.{schema}.{table}
        '''.format(db=sql.identifier(self.connection.configured_database),
                   schema=sql.identifier(self.connection.configured_schema),
                   temp_table=sql.identifier(target_table_name),
                   table=sql.identifier(remote_schema['name'])))

        ## Make streamable CSV records
        csv_headers = list(remote_schema['schema']['properties'].keys())
        rows_iter = iter(table_batch['records'])

        def transform():
            try:
                row = next(rows_iter)

                with io.StringIO() as out:
                    writer = csv.DictWriter(out, csv_headers)
                    writer.writerow(row)
                    return out.getvalue()
            except StopIteration:
                return ''

        csv_rows = TransformStream(transform)

        ## Persist csv rows
        self.persist_csv_rows(cur, remote_schema, target_table_name,
                              csv_headers, csv_rows)

        return len(table_batch['records'])
def test_loading__single_char_columns(db_prep):
    stream_count = 50
    main(CONFIG, input_stream=SingleCharStream(stream_count))

    with connect(**TEST_DB) as conn:
        with conn.cursor() as cur:
            assert_columns_equal(cur,
                                 'ROOT',
                                 {
                                     ('_SDC_PRIMARY_KEY', 'TEXT', 'NO'),
                                     ('_SDC_BATCHED_AT', 'TIMESTAMP_TZ', 'YES'),
                                     ('_SDC_RECEIVED_AT', 'TIMESTAMP_TZ', 'YES'),
                                     ('_SDC_SEQUENCE', 'NUMBER', 'YES'),
                                     ('_SDC_TABLE_VERSION', 'NUMBER', 'YES'),
                                     ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN', 'YES'),
                                     ('X', 'NUMBER', 'YES')
                                 })

            cur.execute('''
                SELECT {} FROM {}.{}.{}
            '''.format(
                sql.identifier('X'),
                sql.identifier(CONFIG['snowflake_database']),
                sql.identifier(CONFIG['snowflake_schema']),
                sql.identifier('ROOT')
            ))
            persisted_records = cur.fetchall()

            ## Assert that the column is has migrated data
            assert stream_count == len(persisted_records)
            assert stream_count == len([x for x in persisted_records if isinstance(x[0], float)])
def test_deduplication_older_rows(db_prep):
    stream = CatStream(100,
                       nested_count=2,
                       duplicates=2,
                       duplicate_sequence_delta=-100)
    main(CONFIG, input_stream=stream)

    with connect(**TEST_DB) as conn:
        with conn.cursor() as cur:
            cur.execute(get_count_sql('CATS'))
            table_count = cur.fetchone()[0]
            cur.execute(get_count_sql('CATS__ADOPTION__IMMUNIZATIONS'))
            nested_table_count = cur.fetchone()[0]

            cur.execute('''
                SELECT "_SDC_SEQUENCE"
                FROM {}.{}.{}
                WHERE "ID" in ({})
            '''.format(
                sql.identifier(CONFIG['snowflake_database']),
                sql.identifier(CONFIG['snowflake_schema']),
                sql.identifier('CATS'),
                ','.join(["'{}'".format(x)
                          for x in stream.duplicate_pks_used])))
            dup_cat_records = cur.fetchall()

    assert stream.record_message_count == 102
    assert table_count == 100
    assert nested_table_count == 200

    for record in dup_cat_records:
        assert record[0] == stream.sequence
    def _get_table_metadata(self, cur, table_name):
        cur.execute('''
            SHOW TABLES LIKE '{}' IN SCHEMA {}.{}
            '''.format(
            table_name,
            sql.identifier(self.connection.configured_database),
            sql.identifier(self.connection.configured_schema),
        ))
        tables = cur.fetchall()

        if not tables:
            return None

        if len(tables) != 1:
            raise SnowflakeError(
                '{} tables returned while searching for: {}.{}.{}'.format(
                    len(tables), self.connection.configured_database,
                    self.connection.configured_schema, table_name))

        comment = tables[0][5]

        if comment:
            try:
                comment_meta = json.loads(comment)
            except:
                self.LOGGER.exception('Could not load table comment metadata')
                raise
        else:
            comment_meta = None

        return comment_meta
def test_deduplication_existing_new_rows(db_prep):
    stream = CatStream(100, nested_count=2)
    main(CONFIG, input_stream=stream)

    original_sequence = stream.sequence

    stream = CatStream(100, nested_count=2, sequence=original_sequence - 20)
    main(CONFIG, input_stream=stream)

    with connect(**TEST_DB) as conn:
        with conn.cursor() as cur:
            cur.execute(get_count_sql('CATS'))
            table_count = cur.fetchone()[0]
            cur.execute(get_count_sql('CATS__ADOPTION__IMMUNIZATIONS'))
            nested_table_count = cur.fetchone()[0]

            cur.execute('''
                SELECT DISTINCT "_SDC_SEQUENCE"
                FROM {}.{}.{}
            '''.format(sql.identifier(CONFIG['snowflake_database']),
                       sql.identifier(CONFIG['snowflake_schema']),
                       sql.identifier('CATS')))
            sequences = cur.fetchall()

    assert table_count == 100
    assert nested_table_count == 200

    assert len(sequences) == 1
    assert sequences[0][0] == original_sequence
    def is_table_empty(self, cur, table_name):
        cur.execute('''
            SELECT COUNT(1) FROM {}.{}.{}
            '''.format(sql.identifier(self.connection.configured_database),
                       sql.identifier(self.connection.configured_schema),
                       sql.identifier(table_name)))

        return cur.fetchone()[0] == 0
 def drop_column(self, cur, table_name, column_name):
     cur.execute('''
         ALTER TABLE {database}.{table_schema}.{table_name}
         DROP COLUMN {column_name}
         '''.format(
         database=sql.identifier(self.connection.configured_database),
         table_schema=sql.identifier(self.connection.configured_schema),
         table_name=sql.identifier(table_name),
         column_name=sql.identifier(column_name)))
 def make_column_nullable(self, cur, table_name, column_name):
     cur.execute('''
         ALTER TABLE {database}.{table_schema}.{table_name}
         ALTER COLUMN {column_name} DROP NOT NULL
         '''.format(
         database=sql.identifier(self.connection.configured_database),
         table_schema=sql.identifier(self.connection.configured_schema),
         table_name=sql.Identifier(table_name),
         column_name=sql.Identifier(column_name)))
 def migrate_column(self, cur, table_name, from_column, to_column):
     cur.execute('''
         UPDATE {database}.{table_schema}.{table_name}
         SET {to_column} = {from_column}
         '''.format(
         database=sql.identifier(self.connection.configured_database),
         table_schema=sql.identifier(self.connection.configured_schema),
         table_name=sql.identifier(table_name),
         to_column=sql.identifier(to_column),
         from_column=sql.identifier(from_column)))
def test_loading__new_non_null_column(db_prep):
    cat_count = 50
    main(CONFIG, input_stream=CatStream(cat_count))

    class NonNullStream(CatStream):
        def generate_record(self):
            record = CatStream.generate_record(self)
            record['id'] = record['id'] + cat_count
            return record

    non_null_stream = NonNullStream(cat_count)
    non_null_stream.schema = deepcopy(non_null_stream.schema)
    non_null_stream.schema['schema']['properties']['paw_toe_count'] = {'type': 'integer',
                                                                       'default': 5}

    main(CONFIG, input_stream=non_null_stream)

    with connect(**TEST_DB) as conn:
        with conn.cursor() as cur:
            assert_columns_equal(cur,
                                 'CATS',
                                 {
                                     ('_SDC_BATCHED_AT', 'TIMESTAMP_TZ', 'YES'),
                                     ('_SDC_RECEIVED_AT', 'TIMESTAMP_TZ', 'YES'),
                                     ('_SDC_SEQUENCE', 'NUMBER', 'YES'),
                                     ('_SDC_TABLE_VERSION', 'NUMBER', 'YES'),
                                     ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN', 'YES'),
                                     ('ADOPTION__ADOPTED_ON', 'TIMESTAMP_TZ', 'YES'),
                                     ('ADOPTION__WAS_FOSTER', 'BOOLEAN', 'YES'),
                                     ('AGE', 'NUMBER', 'YES'),
                                     ('ID', 'NUMBER', 'NO'),
                                     ('NAME', 'TEXT', 'NO'),
                                     ('PAW_SIZE', 'NUMBER', 'NO'),
                                     ('PAW_COLOUR', 'TEXT', 'NO'),
                                     ('PAW_TOE_COUNT', 'NUMBER', 'YES'),
                                     ('FLEA_CHECK_COMPLETE', 'BOOLEAN', 'NO'),
                                     ('PATTERN', 'TEXT', 'YES')
                                 })

            cur.execute('''
                SELECT {}, {} FROM {}.{}.{}
            '''.format(
                sql.identifier('ID'),
                sql.identifier('PAW_TOE_COUNT'),
                sql.identifier(CONFIG['snowflake_database']),
                sql.identifier(CONFIG['snowflake_schema']),
                sql.identifier('CATS')
            ))

            persisted_records = cur.fetchall()

            ## Assert that the split columns before/after new non-null data
            assert 2 * cat_count == len(persisted_records)
            assert cat_count == len([x for x in persisted_records if x[1] is None])
            assert cat_count == len([x for x in persisted_records if x[1] is not None])
    def add_column(self, cur, table_name, column_name, column_schema):

        cur.execute('''
            ALTER TABLE {database}.{table_schema}.{table_name}
            ADD COLUMN {column_name} {data_type}
            '''.format(
            database=sql.identifier(self.connection.configured_database),
            table_schema=sql.identifier(self.connection.configured_schema),
            table_name=sql.identifier(table_name),
            column_name=sql.identifier(column_name),
            data_type=self.json_schema_to_sql_type(column_schema)))
 def _set_table_metadata(self, cur, table_name, metadata):
     """
     Given a Metadata dict, set it as the comment on the given table.
     :param self: Snowflake
     :param cur: Cursor
     :param table_name: String
     :param metadata: Metadata Dict
     :return: None
     """
     cur.execute('''
         COMMENT ON TABLE {}.{}.{} IS '{}'
         '''.format(sql.identifier(self.connection.configured_database),
                    sql.identifier(self.connection.configured_schema),
                    sql.identifier(table_name), json.dumps(metadata)))
def test_loading__multi_types_columns(db_prep):
    stream_count = 50
    main(CONFIG, input_stream=MultiTypeStream(stream_count))

    with connect(**TEST_DB) as conn:
        with conn.cursor() as cur:
            assert_columns_equal(cur,
                                 'ROOT',
                                 {
                                     ('_SDC_PRIMARY_KEY', 'TEXT', 'NO'),
                                     ('_SDC_BATCHED_AT', 'TIMESTAMP_TZ', 'YES'),
                                     ('_SDC_RECEIVED_AT', 'TIMESTAMP_TZ', 'YES'),
                                     ('_SDC_SEQUENCE', 'NUMBER', 'YES'),
                                     ('_SDC_TABLE_VERSION', 'NUMBER', 'YES'),
                                     ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN', 'YES'),
                                     ('EVERY_TYPE__I', 'NUMBER', 'YES'),
                                     ('EVERY_TYPE__F', 'FLOAT', 'YES'),
                                     ('EVERY_TYPE__B', 'BOOLEAN', 'YES'),
                                     ('EVERY_TYPE__T', 'TIMESTAMP_TZ', 'YES'),
                                     ('EVERY_TYPE__I__1', 'NUMBER', 'YES'),
                                     ('EVERY_TYPE__F__1', 'FLOAT', 'YES'),
                                     ('EVERY_TYPE__B__1', 'BOOLEAN', 'YES'),
                                     ('NUMBER_WHICH_ONLY_COMES_AS_INTEGER', 'FLOAT', 'NO')
                                 })

            assert_columns_equal(cur,
                                 'ROOT__EVERY_TYPE',
                                 {
                                     ('_SDC_SOURCE_KEY__SDC_PRIMARY_KEY', 'TEXT', 'NO'),
                                     ('_SDC_LEVEL_0_ID', 'NUMBER', 'NO'),
                                     ('_SDC_SEQUENCE', 'NUMBER', 'YES'),
                                     ('_SDC_VALUE', 'NUMBER', 'NO'),
                                     ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN', 'YES'),
                                 })

            cur.execute('''
                SELECT {} FROM {}.{}.{}
            '''.format(
                sql.identifier('NUMBER_WHICH_ONLY_COMES_AS_INTEGER'),
                sql.identifier(CONFIG['snowflake_database']),
                sql.identifier(CONFIG['snowflake_schema']),
                sql.identifier('ROOT')
            ))
            persisted_records = cur.fetchall()

            ## Assert that the column is has migrated data
            assert stream_count == len(persisted_records)
            assert stream_count == len([x for x in persisted_records if isinstance(x[0], float)])
    def setup_table_mapping_cache(self, cur):
        self.table_mapping_cache = {}

        cur.execute('''
            SHOW TABLES IN SCHEMA {}.{}
            '''.format(sql.identifier(self.connection.configured_database),
                       sql.identifier(self.connection.configured_schema)))

        for row in cur.fetchall():
            mapped_name = row[1]
            raw_json = row[5]

            table_path = None
            if raw_json:
                table_path = json.loads(raw_json).get('path', None)
            self.LOGGER.info("Mapping: {} to {}".format(
                mapped_name, table_path))
            if table_path:
                self.table_mapping_cache[tuple(table_path)] = mapped_name
def assert_columns_equal(cursor, table_name, expected_column_tuples):
    cursor.execute('''
        SELECT column_name, data_type, is_nullable
        FROM {}.information_schema.columns
        WHERE table_schema = '{}' AND table_name = '{}'
    '''.format(sql.identifier(CONFIG['snowflake_database']),
               CONFIG['snowflake_schema'], table_name))

    columns = []
    for column in cursor.fetchall():
        columns.append((column[0], column[1], column[2]))

    assert set(columns) == expected_column_tuples
    def get_table_schema(self, cur, name):
        metadata = self._get_table_metadata(cur, name)

        if not metadata:
            return None

        cur.execute('''
            SELECT column_name, data_type, is_nullable
            FROM {}.information_schema.columns
            WHERE table_schema = '{}' AND table_name = '{}'
            '''.format(sql.identifier(self.connection.configured_database),
                       self.connection.configured_schema, name))

        properties = {}
        for column in cur.fetchall():
            properties[column[0]] = self.sql_type_to_json_schema(
                column[1], column[2] == 'YES')

        metadata['name'] = name
        metadata['type'] = 'TABLE_SCHEMA'
        metadata['schema'] = {'properties': properties}

        return metadata
    def activate_version(self, stream_buffer, version):
        with self.connection.cursor() as cur:
            try:
                self.setup_table_mapping_cache(cur)
                root_table_name = self.add_table_mapping(
                    cur, (stream_buffer.stream, ), {})
                current_table_schema = self.get_table_schema(
                    cur, root_table_name)

                if not current_table_schema:
                    self.LOGGER.error(
                        '{} - Table for stream does not exist'.format(
                            stream_buffer.stream))
                elif current_table_schema.get(
                        'version') is not None and current_table_schema.get(
                            'version') >= version:
                    self.LOGGER.warning(
                        '{} - Table version {} already active'.format(
                            stream_buffer.stream, version))
                else:
                    versioned_root_table = root_table_name + SEPARATOR + str(
                        version)

                    names_to_paths = dict([
                        (v, k) for k, v in self.table_mapping_cache.items()
                    ])

                    cur.execute('''
                        SHOW TABLES LIKE '{}%' IN SCHEMA {}.{}
                        '''.format(
                        versioned_root_table,
                        sql.identifier(self.connection.configured_database),
                        sql.identifier(self.connection.configured_schema)))

                    for versioned_table_name in [x[1] for x in cur.fetchall()]:
                        table_name = root_table_name + versioned_table_name[
                            len(versioned_root_table):]
                        table_path = names_to_paths[table_name]

                        args = {
                            'db_schema':
                            '{}.{}'.format(
                                sql.identifier(
                                    self.connection.configured_database),
                                sql.identifier(
                                    self.connection.configured_schema)),
                            'stream_table_old':
                            sql.identifier(table_name + SEPARATOR + 'OLD'),
                            'stream_table':
                            sql.identifier(table_name),
                            'version_table':
                            sql.identifier(versioned_table_name)
                        }

                        cur.execute('''
                            ALTER TABLE {db_schema}.{stream_table} RENAME TO {db_schema}.{stream_table_old}
                            '''.format(**args))

                        cur.execute('''
                            ALTER TABLE {db_schema}.{version_table} RENAME TO {db_schema}.{stream_table}
                            '''.format(**args))

                        cur.execute('''
                            DROP TABLE {db_schema}.{stream_table_old}
                            '''.format(**args))

                        self.connection.commit()

                        metadata = self._get_table_metadata(cur, table_name)

                        self.LOGGER.info(
                            'Activated {}, setting path to {}'.format(
                                metadata, table_path))

                        metadata['path'] = table_path
                        self._set_table_metadata(cur, table_name, metadata)
            except Exception as ex:
                self.connection.rollback()
                message = '{} - Exception activating table version {}'.format(
                    stream_buffer.stream, version)
                self.LOGGER.exception(message)
                raise SnowflakeError(message, ex)
    def perform_update(self, cur, target_table_name, temp_table_name,
                       key_properties, columns, subkeys):
        full_table_name = '{}.{}.{}'.format(
            sql.identifier(self.connection.configured_database),
            sql.identifier(self.connection.configured_schema),
            sql.identifier(target_table_name))

        full_temp_table_name = '{}.{}.{}'.format(
            sql.identifier(self.connection.configured_database),
            sql.identifier(self.connection.configured_schema),
            sql.identifier(temp_table_name))

        pk_temp_select_list = []
        pk_where_list = []
        pk_null_list = []
        cxt_where_list = []
        for pk in key_properties:
            pk_identifier = sql.identifier(pk)
            pk_temp_select_list.append('{}.{}'.format(full_temp_table_name,
                                                      pk_identifier))

            pk_where_list.append('{table}.{pk} = "dedupped".{pk}'.format(
                table=full_table_name,
                temp_table=full_temp_table_name,
                pk=pk_identifier))

            pk_null_list.append('{table}.{pk} IS NULL'.format(
                table=full_table_name, pk=pk_identifier))

            cxt_where_list.append('{table}.{pk} = "pks".{pk}'.format(
                table=full_table_name, pk=pk_identifier))
        pk_temp_select = ', '.join(pk_temp_select_list)
        pk_where = ' AND '.join(pk_where_list)
        pk_null = ' AND '.join(pk_null_list)
        cxt_where = ' AND '.join(cxt_where_list)

        sequence_identifier = sql.identifier(
            self.canonicalize_identifier(SINGER_SEQUENCE))

        sequence_join = ' AND "dedupped".{} >= {}.{}'.format(
            sequence_identifier, full_table_name, sequence_identifier)

        distinct_order_by = ' ORDER BY {}, {}.{} DESC'.format(
            pk_temp_select, full_temp_table_name, sequence_identifier)

        if len(subkeys) > 0:
            pk_temp_subkey_select_list = []
            for pk in (key_properties + subkeys):
                pk_temp_subkey_select_list.append('{}.{}'.format(
                    full_temp_table_name, sql.identifier(pk)))
            insert_distinct_on = ', '.join(pk_temp_subkey_select_list)

            insert_distinct_order_by = ' ORDER BY {}, {}.{} DESC'.format(
                insert_distinct_on, full_temp_table_name, sequence_identifier)
        else:
            insert_distinct_on = pk_temp_select
            insert_distinct_order_by = distinct_order_by

        insert_columns_list = []
        dedupped_columns_list = []
        for column in columns:
            insert_columns_list.append(sql.identifier(column))
            dedupped_columns_list.append('{}.{}'.format(
                sql.identifier('dedupped'), sql.identifier(column)))
        insert_columns = ', '.join(insert_columns_list)
        dedupped_columns = ', '.join(dedupped_columns_list)

        cur.execute('''
            DELETE FROM {table} USING (
                    SELECT "dedupped".*
                    FROM (
                        SELECT *,
                               ROW_NUMBER() OVER (PARTITION BY {pk_temp_select}
                                                  {distinct_order_by}) AS "_sdc_pk_ranked"
                        FROM {temp_table}
                        {distinct_order_by}) AS "dedupped"
                    JOIN {table} ON {pk_where}{sequence_join}
                    WHERE "_sdc_pk_ranked" = 1
                ) AS "pks" WHERE {cxt_where};
            '''.format(table=full_table_name,
                       temp_table=full_temp_table_name,
                       pk_temp_select=pk_temp_select,
                       pk_where=pk_where,
                       cxt_where=cxt_where,
                       sequence_join=sequence_join,
                       distinct_order_by=distinct_order_by))

        cur.execute('''
            INSERT INTO {table}({insert_columns}) (
                SELECT {dedupped_columns}
                FROM (
                    SELECT *,
                           ROW_NUMBER() OVER (PARTITION BY {insert_distinct_on}
                                              {insert_distinct_order_by}) AS "_sdc_pk_ranked"
                    FROM {temp_table}
                    {insert_distinct_order_by}) AS "dedupped"
                LEFT JOIN {table} ON {pk_where}
                WHERE "_sdc_pk_ranked" = 1 AND {pk_null}
            );
            '''.format(table=full_table_name,
                       temp_table=full_temp_table_name,
                       pk_where=pk_where,
                       pk_null=pk_null,
                       insert_distinct_on=insert_distinct_on,
                       insert_distinct_order_by=insert_distinct_order_by,
                       insert_columns=insert_columns,
                       dedupped_columns=dedupped_columns))

        if not self.s3:
            # Clear out the associated stage for the table
            cur.execute('''
                REMOVE @{db}.{schema}.%{temp_table}
            '''.format(db=sql.identifier(self.connection.configured_database),
                       schema=sql.identifier(
                           self.connection.configured_schema),
                       temp_table=sql.identifier(temp_table_name)))

        # Drop the tmp table
        cur.execute('''
            DROP TABLE {temp_table};
            '''.format(temp_table=full_temp_table_name))
def test_loading__column_type_change(db_prep):
    cat_count = 20
    main(CONFIG, input_stream=CatStream(cat_count))

    with connect(**TEST_DB) as conn:
        with conn.cursor() as cur:
            assert_columns_equal(
                cur, 'CATS',
                {('_SDC_BATCHED_AT', 'TIMESTAMP_TZ', 'YES'),
                 ('_SDC_RECEIVED_AT', 'TIMESTAMP_TZ', 'YES'),
                 ('_SDC_SEQUENCE', 'NUMBER', 'YES'),
                 ('_SDC_TABLE_VERSION', 'NUMBER', 'YES'),
                 ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN',
                  'YES'), ('ADOPTION__ADOPTED_ON', 'TIMESTAMP_TZ', 'YES'),
                 ('ADOPTION__WAS_FOSTER', 'BOOLEAN', 'YES'),
                 ('AGE', 'NUMBER', 'YES'), ('ID', 'NUMBER', 'NO'),
                 ('NAME', 'TEXT', 'NO'), ('PAW_SIZE', 'NUMBER', 'NO'),
                 ('PAW_COLOUR', 'TEXT', 'NO'),
                 ('FLEA_CHECK_COMPLETE', 'BOOLEAN', 'NO'),
                 ('PATTERN', 'TEXT', 'YES')})

            cur.execute('''
                SELECT {} FROM {}.{}.{}
            '''.format(sql.identifier('NAME'),
                       sql.identifier(CONFIG['snowflake_database']),
                       sql.identifier(CONFIG['snowflake_schema']),
                       sql.identifier('CATS')))
            persisted_records = cur.fetchall()

            ## Assert that the original data is present
            assert cat_count == len(persisted_records)
            assert cat_count == len(
                [x for x in persisted_records if x[0] is not None])

    class NameBooleanCatStream(CatStream):
        def generate_record(self):
            record = CatStream.generate_record(self)
            record['id'] = record['id'] + cat_count
            record['name'] = False
            return record

    stream = NameBooleanCatStream(cat_count)
    stream.schema = deepcopy(stream.schema)
    stream.schema['schema']['properties']['name'] = {'type': 'boolean'}

    main(CONFIG, input_stream=stream)

    with connect(**TEST_DB) as conn:
        with conn.cursor() as cur:
            assert_columns_equal(
                cur, 'CATS',
                {('_SDC_BATCHED_AT', 'TIMESTAMP_TZ', 'YES'),
                 ('_SDC_RECEIVED_AT', 'TIMESTAMP_TZ', 'YES'),
                 ('_SDC_SEQUENCE', 'NUMBER', 'YES'),
                 ('_SDC_TABLE_VERSION', 'NUMBER', 'YES'),
                 ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN',
                  'YES'), ('ADOPTION__ADOPTED_ON', 'TIMESTAMP_TZ', 'YES'),
                 ('ADOPTION__WAS_FOSTER', 'BOOLEAN', 'YES'),
                 ('AGE', 'NUMBER', 'YES'), ('ID', 'NUMBER', 'NO'),
                 ('NAME__S', 'TEXT', 'YES'), ('NAME__B', 'BOOLEAN', 'YES'),
                 ('PAW_SIZE', 'NUMBER', 'NO'), ('PAW_COLOUR', 'TEXT', 'NO'),
                 ('FLEA_CHECK_COMPLETE', 'BOOLEAN', 'NO'),
                 ('PATTERN', 'TEXT', 'YES')})

            cur.execute('''
                SELECT {}, {} FROM {}.{}.{}
            '''.format(sql.identifier('NAME__S'), sql.identifier('NAME__B'),
                       sql.identifier(CONFIG['snowflake_database']),
                       sql.identifier(CONFIG['snowflake_schema']),
                       sql.identifier('CATS')))
            persisted_records = cur.fetchall()

            ## Assert that the split columns migrated data/persisted new data
            assert 2 * cat_count == len(persisted_records)
            assert cat_count == len(
                [x for x in persisted_records if x[0] is not None])
            assert cat_count == len(
                [x for x in persisted_records if x[1] is not None])
            assert 0 == len([
                x for x in persisted_records
                if x[0] is not None and x[1] is not None
            ])

    class NameIntegerCatStream(CatStream):
        def generate_record(self):
            record = CatStream.generate_record(self)
            record['id'] = record['id'] + (2 * cat_count)
            record['name'] = 314
            return record

    stream = NameIntegerCatStream(cat_count)
    stream.schema = deepcopy(stream.schema)
    stream.schema['schema']['properties']['name'] = {'type': 'integer'}

    main(CONFIG, input_stream=stream)

    with connect(**TEST_DB) as conn:
        with conn.cursor() as cur:
            assert_columns_equal(
                cur, 'CATS',
                {('_SDC_BATCHED_AT', 'TIMESTAMP_TZ', 'YES'),
                 ('_SDC_RECEIVED_AT', 'TIMESTAMP_TZ', 'YES'),
                 ('_SDC_SEQUENCE', 'NUMBER', 'YES'),
                 ('_SDC_TABLE_VERSION', 'NUMBER', 'YES'),
                 ('_SDC_TARGET_SNOWFLAKE_CREATE_TABLE_PLACEHOLDER', 'BOOLEAN',
                  'YES'), ('ADOPTION__ADOPTED_ON', 'TIMESTAMP_TZ', 'YES'),
                 ('ADOPTION__WAS_FOSTER', 'BOOLEAN', 'YES'),
                 ('AGE', 'NUMBER', 'YES'), ('ID', 'NUMBER', 'NO'),
                 ('NAME__S', 'TEXT', 'YES'), ('NAME__B', 'BOOLEAN', 'YES'),
                 ('NAME__I', 'NUMBER', 'YES'), ('PAW_SIZE', 'NUMBER', 'NO'),
                 ('PAW_COLOUR', 'TEXT', 'NO'),
                 ('FLEA_CHECK_COMPLETE', 'BOOLEAN', 'NO'),
                 ('PATTERN', 'TEXT', 'YES')})

            cur.execute('''
                SELECT {}, {}, {} FROM {}.{}.{}
            '''.format(sql.identifier('NAME__S'), sql.identifier('NAME__B'),
                       sql.identifier('NAME__I'),
                       sql.identifier(CONFIG['snowflake_database']),
                       sql.identifier(CONFIG['snowflake_schema']),
                       sql.identifier('CATS')))
            persisted_records = cur.fetchall()

            ## Assert that the split columns migrated data/persisted new data
            assert 3 * cat_count == len(persisted_records)
            assert cat_count == len(
                [x for x in persisted_records if x[0] is not None])
            assert cat_count == len(
                [x for x in persisted_records if x[1] is not None])
            assert cat_count == len(
                [x for x in persisted_records if x[2] is not None])
            assert 0 == len([
                x for x in persisted_records
                if x[0] is not None and x[1] is not None and x[2] is not None
            ])
            assert 0 == len([
                x for x in persisted_records
                if x[0] is None and x[1] is None and x[2] is None
            ])
def assert_records(conn, records, table_name, pks, match_pks=False):
    if not isinstance(pks, list):
        pks = [pks]

    with conn.cursor(True) as cur:
        cur.execute("set timezone='UTC';")

        cur.execute('''
            SELECT * FROM {}.{}.{}
        '''.format(sql.identifier(CONFIG['snowflake_database']),
                   sql.identifier(CONFIG['snowflake_schema']),
                   sql.identifier(table_name)))
        persisted_records_raw = cur.fetchall()

        persisted_records = {}
        for persisted_record in persisted_records_raw:
            pk = get_pk_key(pks, persisted_record)
            persisted_records[pk] = persisted_record

        subtables = {}
        records_pks = []
        pre_canonicalized_pks = [x.lower() for x in pks]
        for record in records:
            pk = get_pk_key(pre_canonicalized_pks, record)
            records_pks.append(pk)
            persisted_record = persisted_records[pk.upper()]
            subpks = {}
            for pk in pks:
                subpks[singer_stream.SINGER_SOURCE_PK_PREFIX +
                       pk] = persisted_record[pk]
            assert_record(record, persisted_record, subtables, subpks)

        if match_pks:
            assert sorted(list(
                persisted_records.keys())) == sorted(records_pks)

        sub_pks = list(
            map(lambda pk: singer_stream.SINGER_SOURCE_PK_PREFIX.upper() + pk,
                pks))
        for subtable_name, items in subtables.items():
            cur.execute('''
                SELECT * FROM {}.{}.{}
            '''.format(
                sql.identifier(CONFIG['snowflake_database']),
                sql.identifier(CONFIG['snowflake_schema']),
                sql.identifier(table_name + '__' + subtable_name.upper())))
            persisted_records_raw = cur.fetchall()

            persisted_records = {}
            for persisted_record in persisted_records_raw:
                pk = get_pk_key(sub_pks, persisted_record, subrecord=True)
                persisted_records[pk] = persisted_record

            subtables = {}
            records_pks = []
            pre_canonicalized_sub_pks = [x.lower() for x in sub_pks]
            for record in items:
                pk = get_pk_key(pre_canonicalized_sub_pks,
                                record,
                                subrecord=True)
                records_pks.append(pk)
                persisted_record = persisted_records[pk]
                assert_record(record, persisted_record, subtables, subpks)
            assert len(subtables.values()) == 0

            if match_pks:
                assert sorted(list(
                    persisted_records.keys())) == sorted(records_pks)
    def persist_csv_rows(self, cur, remote_schema, temp_table_name, columns,
                         csv_rows):
        params = []

        if self.s3:
            bucket, key = self.s3.persist(csv_rows,
                                          key_prefix=temp_table_name +
                                          SEPARATOR)
            stage_location = "'s3://{bucket}/{key}' credentials=(AWS_KEY_ID=%s AWS_SECRET_KEY=%s)".format(
                bucket=bucket, key=key)
            params = [
                self.s3.credentials()['aws_access_key_id'],
                self.s3.credentials()['aws_secret_access_key']
            ]
        else:
            stage_location = '@{db}.{schema}.%{table}'.format(
                db=sql.identifier(self.connection.configured_database),
                schema=sql.identifier(self.connection.configured_schema),
                table=sql.identifier(temp_table_name))

            rel_path = '/tmp/target-snowflake/'
            file_name = str(uuid.uuid4()).replace('-', '_')

            # Make tmp folder to hold data file
            os.makedirs(rel_path, exist_ok=True)

            # Write readable csv_rows to file
            with open(rel_path + file_name, 'wb') as file:
                line = csv_rows.read()
                while line:
                    file.write(line.encode('utf-8'))
                    line = csv_rows.read()

            # Upload to internal table stage
            cur.execute('''
                PUT file://{rel_path}{file_name} {stage_location}
            '''.format(rel_path=rel_path,
                       file_name=file_name,
                       stage_location=stage_location))

            # Tidy up and remove tmp staging file
            os.remove(rel_path + file_name)

            stage_location += '/{}'.format(file_name)

        cur.execute('''
            COPY INTO {db}.{schema}.{table} ({cols})
            FROM {stage_location}
            FILE_FORMAT = (TYPE = CSV EMPTY_FIELD_AS_NULL = FALSE)
        '''.format(db=sql.identifier(self.connection.configured_database),
                   schema=sql.identifier(self.connection.configured_schema),
                   table=sql.identifier(temp_table_name),
                   cols=','.join([sql.identifier(x) for x in columns]),
                   stage_location=stage_location),
                    params=params)

        pattern = re.compile(SINGER_LEVEL.upper().format('[0-9]+'))
        subkeys = list(
            filter(lambda header: re.match(pattern, header) is not None,
                   columns))

        canonicalized_key_properties = [
            self.fetch_column_from_path((key_property, ), remote_schema)[0]
            for key_property in remote_schema['key_properties']
        ]

        self.perform_update(cur, remote_schema['name'], temp_table_name,
                            canonicalized_key_properties, columns, subkeys)