예제 #1
0
파일: main.py 프로젝트: eculicny/snowshu
def init(path: click.Path) -> None:
    """generates sample replica.yml and credentials.yml files in the current
    directory.

    Args:
        path: The full or relative path to where the files should be generated, defaults to current dir.
    """

    logger = Logger().logger
    templates = os.path.join(Path(__file__).parent.parent, 'templates')

    def destination(filename):
        return os.path.join(path, filename)

    def source(filename):
        return os.path.join(templates, filename)

    CREDENTIALS = 'credentials.yml'     # noqa pep8: disable=N806
    REPLICA = 'replica.yml'     # noqa pep8: disable=N806

    if os.path.isfile(destination(CREDENTIALS)) or os.path.isfile(
            destination(REPLICA)):
        message = "cannot generate sample files, already exist in current directory."
        logger.error(message)
        raise ValueError(message)
    try:
        copyfile(source(REPLICA), destination(REPLICA))
        copyfile(source(CREDENTIALS), destination(CREDENTIALS))
        logger.info(
            f"sample files created in directory {os.path.abspath(path)}")
    except Exception as exc:
        logger.error(f"failed to generate sample files: {exc}")
        raise exc
예제 #2
0
def test_sample_args_valid(run, replica):
    runner = CliRunner()
    with runner.isolated_filesystem():
        logger = Logger().logger
        tempfile = Path('./test-file.yml')
        tempfile.touch()
        EXPECTED_REPLICA_FILE = tempfile.absolute()
        EXPECTED_TAG = rand_string(10)
        EXPECTED_DEBUG = True
        result = runner.invoke(main.cli, (
            '--debug',
            'create',
            '--replica-file',
            EXPECTED_REPLICA_FILE,
        ))
        replica.assert_called_once_with(EXPECTED_REPLICA_FILE)
        assert logger.getEffectiveLevel() == DEBUG
예제 #3
0
파일: main.py 프로젝트: eculicny/snowshu
def cli(debug: bool):
    """SnowShu is a sampling engine designed to support testing in data development."""
    log_level = logging.DEBUG if debug else logging.INFO
    log_engine = Logger()
    log_engine.initialize_logger()
    log_engine.set_log_level(log_level)
    logger = log_engine.logger
    if not which('docker') and not IS_IN_DOCKER:
        logger.warning(NO_DOCKER)
예제 #4
0
def test_logger_always_logs_debug_to_file(temp_log):
    levels = ('WARNING', 'DEBUG', 'INFO', 'CRITICAL',)
    log_engine = Logger()
    log_engine.initialize_logger(log_file_location=temp_log.strpath)
    for level in LOG_LEVELS:
        log_engine.set_log_level(level)
        logger = log_engine.logger
        message = rand_string(10)
        logger.debug(message)
        with open(temp_log) as tmp:
            line = tmp.readlines()[-1]
            assert 'DEBUG' in line
            assert message in line
예제 #5
0
def test_logger_debug_log_level(temp_log):
    log_engine = Logger()
    log_engine.initialize_logger(log_file_location=temp_log.strpath)
    log_engine.set_log_level(logging.DEBUG)
    logger = log_engine.logger
    with LogCapture() as capture:
        ERROR = rand_string(10)
        INFO = rand_string(10)
        DEBUG = rand_string(10)
        WARNING = rand_string(10)
        logger.warning(WARNING)
        logger.error(ERROR)
        logger.info(INFO)
        logger.debug(DEBUG)
        capture.check(
            ('snowshu', 'WARNING', WARNING),
            ('snowshu', 'ERROR', ERROR),
            ('snowshu', 'INFO', INFO),
            ('snowshu', 'DEBUG', DEBUG),
        )
예제 #6
0
import time
from pathlib import Path
from typing import TextIO, Union

from snowshu.core.configuration_parser import (Configuration,
                                               ConfigurationParser)
from snowshu.core.graph import SnowShuGraph
from snowshu.core.graph_set_runner import GraphSetRunner
from snowshu.core.printable_result import (graph_to_result_list,
                                           printable_result)
from snowshu.logger import Logger, duration

logger = Logger().logger


class ReplicaFactory:

    def __init__(self):
        self._credentials = dict()
        self.config: Configuration = None
        self.run_analyze: bool = None

    def create(self,
               name: Union[str, None],
               barf: bool) -> None:
        self.run_analyze = False
        return self._execute(name=name, barf=barf)

    def analyze(self, barf: bool) -> None:
        self.run_analyze = True
        return self._execute(barf=barf)
예제 #7
0
from jsonschema.exceptions import ValidationError

from snowshu.configs import (DEFAULT_MAX_NUMBER_OF_OUTLIERS,
                             DEFAULT_PRESERVE_CASE, DEFAULT_THREAD_COUNT)
from snowshu.core.models import Credentials
from snowshu.core.samplings.utils import get_sampling_from_partial
from snowshu.core.utils import correct_case, fetch_adapter
from snowshu.logger import Logger

if TYPE_CHECKING:
    from io import StringIO
    from snowshu.core.samplings.bases.base_sampling import BaseSampling
    from snowshu.adapters.source_adapters.base_source_adapter import BaseSourceAdapter
    from snowshu.adapters.target_adapters.base_target_adapter import BaseTargetAdapter

logger = Logger().logger

TEMPLATES_PATH = Path(os.path.dirname(__file__)).parent / 'templates'
REPLICA_JSON_SCHEMA = TEMPLATES_PATH / 'replica_schema.json'
CREDENTIALS_JSON_SCHEMA = TEMPLATES_PATH / 'credentials_schema.json'


@dataclass
class MatchPattern:
    @dataclass
    class RelationPattern:
        relation_pattern: str

    @dataclass
    class SchemaPattern:
        schema_pattern: str
예제 #8
0
import time

import docker
import pytest
from sqlalchemy import create_engine

from snowshu.adapters.target_adapters import PostgresAdapter
from snowshu.core.docker import SnowShuDocker
from snowshu.logger import Logger
from tests.common import rand_string
from tests.integration.snowflake.test_end_to_end import DOCKER_SPIN_UP_TIMEOUT

Logger().set_log_level(0)

TEST_NAME, TEST_TABLE = [rand_string(10) for _ in range(2)]


def test_creates_replica(docker_flush):
    # build image
    # load it up with some data
    # convert it to a replica
    # spin it all down
    # start the replica
    # query it and confirm that the data is in there

    shdocker = SnowShuDocker()
    target_adapter = PostgresAdapter()
    target_container = shdocker.startup(target_adapter.DOCKER_IMAGE,
                                        target_adapter.DOCKER_START_COMMAND,
                                        9999, target_adapter,
                                        'SnowflakeAdapter', [
예제 #9
0
class SnowflakeAdapter(BaseSourceAdapter):
    """The Snowflake Data Warehouse source adapter.

    Args:
        preserve_case: By default the adapter folds case-insensitive strings to lowercase.
                       If preserve_case is True,SnowShu will __not__ alter cases (dangerous!).
    """

    name = 'snowflake'
    SUPPORTS_CROSS_DATABASE = True
    SUPPORTED_FUNCTIONS = set(['ANY_VALUE', 'RLIKE', 'UUID_STRING'])
    SUPPORTED_SAMPLE_METHODS = (BernoulliSampleMethod, )
    REQUIRED_CREDENTIALS = (
        USER,
        PASSWORD,
        ACCOUNT,
        DATABASE,
    )
    ALLOWED_CREDENTIALS = (
        SCHEMA,
        WAREHOUSE,
        ROLE,
    )
    # snowflake in-db is UPPER, but connector is actually lower :(
    DEFAULT_CASE = 'lower'

    DATA_TYPE_MAPPINGS = {
        "array": dtypes.JSON,
        "bigint": dtypes.BIGINT,
        "binary": dtypes.BINARY,
        "boolean": dtypes.BOOLEAN,
        "char": dtypes.CHAR,
        "character": dtypes.CHAR,
        "date": dtypes.DATE,
        "datetime": dtypes.DATETIME,
        "decimal": dtypes.DECIMAL,
        "double": dtypes.FLOAT,
        "double precision": dtypes.FLOAT,
        "float": dtypes.FLOAT,
        "float4": dtypes.FLOAT,
        "float8": dtypes.FLOAT,
        "int": dtypes.BIGINT,
        "integer": dtypes.BIGINT,
        "number": dtypes.BIGINT,
        "numeric": dtypes.NUMERIC,
        "object": dtypes.JSON,
        "real": dtypes.FLOAT,
        "smallint": dtypes.BIGINT,
        "string": dtypes.VARCHAR,
        "text": dtypes.VARCHAR,
        "time": dtypes.TIME,
        "timestamp": dtypes.TIMESTAMP_NTZ,
        "timestamp_ntz": dtypes.TIMESTAMP_NTZ,
        "timestamp_ltz": dtypes.TIMESTAMP_TZ,
        "timestamp_tz": dtypes.TIMESTAMP_TZ,
        "varbinary": dtypes.BINARY,
        "varchar": dtypes.VARCHAR,
        "variant": dtypes.JSON
    }

    MATERIALIZATION_MAPPINGS = {"BASE TABLE": mz.TABLE, "VIEW": mz.VIEW}

    @overrides
    def _get_all_databases(self) -> List[str]:
        """ Use the SHOW api to get all the available db structures."""
        logger.debug('Collecting databases from snowflake...')
        show_result = tuple(
            self._safe_query("SHOW TERSE DATABASES")['name'].tolist())
        databases = list(set(show_result))
        logger.debug(f'Done. Found {len(databases)} databases.')
        return databases

    @overrides
    def _get_all_schemas(
            self,
            database: str,
            exclude_defaults: Optional[bool] = False) -> List[str]:
        logger.debug(f'Collecting schemas from {database} in snowflake...')
        show_result = self._safe_query(
            f'SHOW TERSE SCHEMAS IN DATABASE {database}')['name'].tolist()
        schemas = set(show_result)
        logger.debug(
            f'Done. Found {len(schemas)} schemas in {database} database.')
        return schemas

    @staticmethod
    def population_count_statement(relation: Relation) -> str:
        """creates the count * statement for a relation

        Args:
            relation: the :class:`Relation <snowshu.core.models.relation.Relation>` to create the statement for.
        Returns:
            a query that results in a single row, single column, integer value of the unsampled relation population size
        """
        return f"SELECT COUNT(*) FROM {relation.quoted_dot_notation}"

    @staticmethod
    def view_creation_statement(relation: Relation) -> str:
        return f"""
SELECT
SUBSTRING(GET_DDL('view','{relation.quoted_dot_notation}'),
POSITION(' AS ' IN UPPER(GET_DDL('view','{relation.quoted_dot_notation}')))+3)
"""

    @staticmethod
    def unsampled_statement(relation: Relation) -> str:
        return f"""
SELECT
    *
FROM
    {relation.quoted_dot_notation}
"""

    def directionally_wrap_statement(
            self, sql: str, relation: Relation,
            sample_type: Optional['BaseSampleMethod']) -> str:
        if sample_type is None:
            return sql

        return f"""
WITH
{relation.scoped_cte('SNOWSHU_FINAL_SAMPLE')} AS (
{sql}
)
,{relation.scoped_cte('SNOWSHU_DIRECTIONAL_SAMPLE')} AS (
SELECT
    *
FROM
{relation.scoped_cte('SNOWSHU_FINAL_SAMPLE')}
{self._sample_type_to_query_sql(sample_type)}
)
SELECT
    *
FROM
{relation.scoped_cte('SNOWSHU_DIRECTIONAL_SAMPLE')}
"""

    @staticmethod
    def analyze_wrap_statement(sql: str, relation: Relation) -> str:
        return f"""
WITH
    {relation.scoped_cte('SNOWSHU_COUNT_POPULATION')} AS (
SELECT
    COUNT(*) AS population_size
FROM
    {relation.quoted_dot_notation}
)
,{relation.scoped_cte('SNOWSHU_CORE_SAMPLE')} AS (
{sql}
)
,{relation.scoped_cte('SNOWSHU_CORE_SAMPLE_COUNT')} AS (
SELECT
    COUNT(*) AS sample_size
FROM
    {relation.scoped_cte('SNOWSHU_CORE_SAMPLE')}
)
SELECT
    s.sample_size AS sample_size
    ,p.population_size AS population_size
FROM
    {relation.scoped_cte('SNOWSHU_CORE_SAMPLE_COUNT')} s
INNER JOIN
    {relation.scoped_cte('SNOWSHU_COUNT_POPULATION')} p
ON
    1=1
LIMIT 1
"""

    def sample_statement_from_relation(
            self, relation: Relation, sample_type: Union['BaseSampleMethod',
                                                         None]) -> str:
        """builds the base sample statment for a given relation."""
        query = f"""
SELECT
    *
FROM
    {relation.quoted_dot_notation}
"""
        if sample_type is not None:
            query += f"{self._sample_type_to_query_sql(sample_type)}"
        return query

    @staticmethod
    def union_constraint_statement(subject: Relation, constraint: Relation,
                                   subject_key: str, constraint_key: str,
                                   max_number_of_outliers: int) -> str:
        """ Union statements to select outliers. This does not pull in NULL values. """
        return f"""
(SELECT
    *
FROM
{subject.quoted_dot_notation}
WHERE
    {subject_key}
NOT IN
(SELECT
    {constraint_key}
FROM
{constraint.quoted_dot_notation})
LIMIT {max_number_of_outliers})
"""

    @staticmethod
    def upstream_constraint_statement(relation: Relation, local_key: str,
                                      remote_key: str) -> str:
        """ builds upstream where constraints against downstream full population"""
        return f" {local_key} in (SELECT {remote_key} FROM {relation.quoted_dot_notation})"

    @staticmethod
    def predicate_constraint_statement(relation: Relation, analyze: bool,
                                       local_key: str, remote_key: str) -> str:
        """builds 'where' strings"""
        constraint_sql = str()
        if analyze:
            constraint_sql = f" SELECT {remote_key} AS {local_key} FROM ({relation.core_query})"
        else:

            def quoted(val: Any) -> str:
                return f"'{val}'" if relation.lookup_attribute(
                    remote_key).data_type.requires_quotes else str(val)

            try:
                constraint_set = [
                    quoted(val) for val in relation.data[remote_key].unique()
                ]
                constraint_sql = ','.join(constraint_set)
            except KeyError as err:
                logger.critical(
                    f'failed to build predicates for {relation.dot_notation}: '
                    f'remote key {remote_key} not in dataframe columns ({relation.data.columns})'
                )
                raise err

        return f"{local_key} IN ({constraint_sql}) "

    @staticmethod
    def polymorphic_constraint_statement(
            relation: Relation,  # noqa pylint: disable=too-many-arguments
            analyze: bool,
            local_key: str,
            remote_key: str,
            local_type: str,
            local_type_match_val: str = None) -> str:
        predicate = SnowflakeAdapter.predicate_constraint_statement(
            relation, analyze, local_key, remote_key)
        if local_type_match_val:
            type_match_val = local_type_match_val
        else:
            type_match_val = relation.name[:-1] if relation.name[-1].lower(
            ) == 's' else relation.name
        return f" ({predicate} AND LOWER({local_type}) = LOWER('{type_match_val}') ) "

    @staticmethod
    def _sample_type_to_query_sql(sample_type: 'BaseSampleMethod') -> str:
        if sample_type.name == 'BERNOULLI':
            qualifier = sample_type.probability if sample_type.probability\
                else str(sample_type.rows) + ' ROWS'
            return f"SAMPLE BERNOULLI ({qualifier})"
        if sample_type.name == 'SYSTEM':
            return f"SAMPLE SYSTEM ({sample_type.probability})"

        message = f"{sample_type.name} is not supported for SnowflakeAdapter"
        logger.error(message)
        raise NotImplementedError(message)

    # TODO: change arg name in parent to the fix issue here
    @overrides
    def _build_conn_string(self, overrides: Optional[dict] = None) -> str:  # noqa pylint: disable=redefined-outer-name
        """overrides the base conn string."""
        conn_parts = [
            f"snowflake://{self.credentials.user}:{self.credentials.password}"
            f"@{self.credentials.account}/{self.credentials.database}/"
        ]
        conn_parts.append(self.credentials.schema if self.credentials.
                          schema is not None else '')
        get_args = list()
        for arg in (
                'warehouse',
                'role',
        ):
            if self.credentials.__dict__[arg] is not None:
                get_args.append(f"{arg}={self.credentials.__dict__[arg]}")

        get_string = "?" + "&".join(get_args)
        return (''.join(conn_parts)) + get_string

    @overrides
    def _get_relations_from_database(
            self,
            schema_obj: BaseSourceAdapter._DatabaseObject) -> List[Relation]:
        quoted_database = schema_obj.full_relation.quoted(
            schema_obj.full_relation.database)  # quoted db name
        relation_database = schema_obj.full_relation.database  # case corrected db name
        case_sensitive_schema = schema_obj.case_sensitive_name  # case sensitive schame name
        relations_sql = f"""
                                 SELECT
                                    m.table_schema AS schema,
                                    m.table_name AS relation,
                                    m.table_type AS materialization,
                                    c.column_name AS attribute,
                                    c.ordinal_position AS ordinal,
                                    c.data_type AS data_type
                                 FROM
                                    {quoted_database}.INFORMATION_SCHEMA.TABLES m
                                 INNER JOIN
                                    {quoted_database}.INFORMATION_SCHEMA.COLUMNS c
                                 ON
                                    c.table_schema = m.table_schema
                                 AND
                                    c.table_name = m.table_name
                                 WHERE
                                    m.table_schema = '{case_sensitive_schema}'
                                    AND m.table_schema <> 'INFORMATION_SCHEMA'
                              """

        logger.debug(
            f'Collecting detailed relations from database {quoted_database}...'
        )
        relations_frame = self._safe_query(relations_sql)
        unique_relations = (relations_frame['schema'] + '.' +
                            relations_frame['relation']).unique().tolist()
        logger.debug(
            f'Done collecting relations. Found a total of {len(unique_relations)} '
            f'unique relations in database {quoted_database}')
        relations = list()
        for relation in unique_relations:
            logger.debug(
                f'Building relation { quoted_database + "." + relation }...')
            attributes = list()

            for attribute in relations_frame.loc[(
                    relations_frame['schema'] + '.' +
                    relations_frame['relation']) == relation].itertuples():
                logger.debug(
                    f'adding attribute {attribute.attribute} to relation..')
                attributes.append(
                    Attribute(self._correct_case(attribute.attribute),
                              self._get_data_type(attribute.data_type)))

            relation = Relation(
                relation_database,
                self._correct_case(attribute.schema),  # noqa pylint: disable=undefined-loop-variable
                self._correct_case(attribute.relation),  # noqa pylint: disable=undefined-loop-variable
                self.MATERIALIZATION_MAPPINGS[attribute.materialization],  # noqa pylint: disable=undefined-loop-variable
                attributes)
            logger.debug(f'Added relation {relation.dot_notation} to pool.')
            relations.append(relation)

        logger.debug(
            f'Acquired {len(relations)} total relations from database {quoted_database}.'
        )
        return relations

    @overrides
    def _count_query(self, query: str) -> int:
        count_sql = f"WITH __SNOWSHU__COUNTABLE__QUERY as ({query}) \
                    SELECT COUNT(*) AS count FROM __SNOWSHU__COUNTABLE__QUERY"

        count = int(self._safe_query(count_sql).iloc[0]['count'])
        return count

    @tenacity.retry(wait=wait_exponential(),
                    stop=stop_after_attempt(4),
                    before_sleep=Logger().log_retries,
                    reraise=True)
    @overrides
    def check_count_and_query(self, query: str, max_count: int,
                              unsampled: bool) -> pd.DataFrame:
        """checks the count, if count passes returns results as a dataframe."""
        try:
            logger.debug('Checking count for query...')
            start_time = time.time()
            count = self._count_query(query)
            if unsampled and count > max_count:
                warn_msg = (
                    f'Unsampled relation has {count} rows which is over '
                    f'the max allowed rows for this type of query ({max_count}). '
                    f'All records will be loaded into replica.')
                logger.warning(warn_msg)
            else:
                assert count <= max_count
            logger.debug(
                f'Query count safe at {count} rows in {time.time()-start_time} seconds.'
            )
        except AssertionError:
            message = (
                f'failed to execute query, result would have returned {count} rows '
                f'but the max allowed rows for this type of query is {max_count}.'
            )
            logger.error(message)
            logger.debug(f'failed sql: {query}')
            raise TooManyRecords(message)
        response = self._safe_query(query)
        return response

    @overrides
    def get_connection(
        self,
        database_override: Optional[str] = None,
        schema_override: Optional[str] = None
    ) -> sqlalchemy.engine.base.Engine:
        """Creates a connection engine without transactions.

        By default uses the instance credentials unless database or
        schema override are provided.
        """
        if not self._credentials:
            raise KeyError(
                'Adapter.get_connection called before setting Adapter.credentials'
            )

        logger.debug(f'Aquiring {self.CLASSNAME} connection...')
        overrides = dict(  # noqa pylint: disable=redefined-outer-name 
            (k, v) for (k, v) in dict(database=database_override,
                                      schema=schema_override).items()
            if v is not None)

        engine = sqlalchemy.create_engine(self._build_conn_string(overrides),
                                          poolclass=NullPool)
        logger.debug(f'engine aquired. Conn string: {repr(engine.url)}')
        return engine