def __init__(self, parser: Parser): super().__init__(ATHENA) if parser: self.aws_credentials = parser.get_aws_credentials_optional() self.athena_staging_dir = parser.get_str_required_env( 'staging_dir') self.database = parser.get_str_required_env('database')
def __init__(self, parser: Parser = None, type: str = MYSQL): super().__init__(type) if parser: self.host = parser.get_str_optional_env('host', 'localhost') self.port = parser.get_str_optional_env('port', '3306') self.username = parser.get_str_required_env('username') self.password = parser.get_credential('password') self.database = parser.get_str_required_env('database')
def __init__(self, parser: Parser): super().__init__(HIVE) if parser: self.host = parser.get_str_required('host') self.port = parser.get_int_optional('port', '10000') self.username = parser.get_str_required_env('username') self.password = parser.get_str_required_env('password') self.database = parser.get_str_optional('database', 'default') self.configuration = parser.get_dict_optional('configuration')
def __init__(self, parser: Parser): super().__init__(ATHENA) if parser: self.aws_credentials = AthenaDialect.get_aws_credentials_optional( parser) self.athena_staging_dir = parser.get_str_required_env( 'staging_dir') self.database = parser.get_str_required_env('database') self.catalog = parser.get_str_optional_env('catalog') self.work_group = parser.get_str_optional_env('work_group')
def __init__(self, parser: Parser = None, type: str = SQLSERVER): super().__init__(type) if parser: self.host = parser.get_str_optional_env('host', 'localhost') self.port = parser.get_str_optional_env('port', '1433') self.driver = parser.get_str_optional_env( 'driver', 'ODBC Driver 17 for SQL Server') self.username = parser.get_str_required_env('username') self.password = parser.get_credential('password') self.database = parser.get_str_required_env('database') self.schema = parser.get_str_required_env('schema') self.trusted_connection = parser.get_bool_optional( 'trusted_connection', False) self.encrypt = parser.get_bool_optional('encrypt', False) self.trust_server_certificate = parser.get_bool_optional( 'trust_server_certificate', False)
def __init__(self, parser: Parser): super().__init__(BIGQUERY) if parser: self.account_info_dict = self.__parse_json_credential( 'account_info_json', parser) self.dataset_name = parser.get_str_required('dataset') self.client = None
def __parse_json_credential(credential_name, parser: Parser): account_info_path = parser.get_str_optional('account_info_json_path') try: if account_info_path: account_info = parser._read_file_as_string(account_info_path) if account_info is not None: return json.loads(account_info) else: cred = parser.get_credential(credential_name) # Prevent json load when the Dialect is init from create command if cred is not None: return json.loads(cred) else: logger.warning("Dialect initiated from the create command, cred is None.") except JSONDecodeError as e: parser.error(f'Error parsing credential {credential_name}: {e}', credential_name)
def __init__(self, parser: Parser): super().__init__(BIGQUERY) if parser: self.dataset_name = parser.get_str_required('dataset') default_auth_scopes = ['https://www.googleapis.com/auth/bigquery', 'https://www.googleapis.com/auth/cloud-platform', 'https://www.googleapis.com/auth/drive'] self.auth_scopes = parser.get_list_optional('auth_scopes', default_auth_scopes) self.__context_auth = parser.get_bool_optional('use_context_auth', None) if self.__context_auth: self.account_info_dict = None self.project_id = parser.get_str_required('project_id') logger.info("Using context auth, account_info_json will be ignored.") else: self.account_info_dict = self.__parse_json_credential('account_info_json', parser) if self.account_info_dict: self.project_id = self.account_info_dict.get('project_id') self.client = None
def __init__(self, parser: Parser): super().__init__(SNOWFLAKE) if parser: self.account = parser.get_str_required_env('account') self.warehouse = parser.get_str_required_env('warehouse') self.username = parser.get_str_required_env('username') self.password = parser.get_credential('password') self.database = parser.get_str_optional_env('database') self.schema = parser.get_str_required_env('schema') self.connection_timeout = parser.get_int_optional(KEY_CONNECTION_TIMEOUT, DEFAULT_SOCKET_CONNECT_TIMEOUT)
def __init__(self, parser: Parser = None, type: str = SQLSERVER): super().__init__(type) if parser: self.host = parser.get_str_optional_env('host', 'localhost') self.port = parser.get_str_optional_env('port', '1433') self.driver = parser.get_str_optional_env('driver', 'ODBC Driver 17 for SQL Server') self.username = parser.get_str_required_env('username') self.password = parser.get_credential('password') self.database = parser.get_str_required_env('database') self.schema = parser.get_str_required_env('schema')
def __init__(self, parser: Parser = None, type: str = POSTGRES): super().__init__(type) if parser: self.host = parser.get_str_optional_env('host', 'localhost') self.port = parser.get_str_optional_env('port', '5432') self.username = parser.get_str_required_env('username') self.password = parser.get_credential('password') self.database = parser.get_str_required_env('database') self.schema = parser.get_str_required_env('schema') self.connection_timeout = parser.get_int_optional(KEY_CONNECTION_TIMEOUT)
def get_aws_credentials_optional(parser: Parser): access_key_id = parser.get_str_optional_env('access_key_id') role_arn = parser.get_str_optional_env('role_arn') profile_name = parser.get_str_optional_env('profile_name') if access_key_id or role_arn or profile_name: return AwsCredentials( access_key_id=access_key_id, secret_access_key=parser.get_credential('secret_access_key'), role_arn=parser.get_str_optional_env('role_arn'), session_token=parser.get_credential('session_token'), region_name=parser.get_str_optional_env('region', 'eu-west-1'))
def __init__(self, parser: Parser): super().__init__(SNOWFLAKE) self.account = parser.get_str_required_env('account') self.warehouse = parser.get_str_required_env('warehouse') self.username = parser.get_str_required_env('username') self.password = parser.get_credential('password') self.database = parser.get_str_optional_env('database') self.schema = parser.get_str_required_env('schema')
def __init__(self, parser: Parser, type: str = POSTGRES): super().__init__(type) self.host = parser.get_str_optional_env('host', 'localhost') self.port = parser.get_str_optional_env('port', '5432') self.username = parser.get_str_required_env('username') self.password = parser.get_credential('password') self.database = parser.get_str_required_env('database') self.schema = parser.get_str_required_env('schema')
def create(cls, parser: Parser): warehouse_type = parser.get_str_optional(KEY_WAREHOUSE_TYPE) if warehouse_type == POSTGRES: from sodasql.dialects.postgres_dialect import PostgresDialect return PostgresDialect(parser) if warehouse_type == SNOWFLAKE: from sodasql.dialects.snowflake_dialect import SnowflakeDialect return SnowflakeDialect(parser) if warehouse_type == REDSHIFT: from sodasql.dialects.redshift_dialect import RedshiftDialect return RedshiftDialect(parser) if warehouse_type == BIGQUERY: from sodasql.dialects.bigquery_dialect import BigQueryDialect return BigQueryDialect(parser) if warehouse_type == ATHENA: from sodasql.dialects.athena_dialect import AthenaDialect return AthenaDialect(parser)
def create(cls, parser: Parser) -> Optional[Dialect]: _warehouse_class = None warehouse_type = parser.get_str_optional(KEY_WAREHOUSE_TYPE) if warehouse_type not in ALL_WAREHOUSE_TYPES: logger.error( f'Invalid warehouse type: {warehouse_type}, it must be one of {", ".join(ALL_WAREHOUSE_TYPES)}' ) else: if warehouse_type == ATHENA: _warehouse_class = Dialect._import_class( 'sodasql.dialects.athena_dialect', 'AthenaDialect') elif warehouse_type == BIGQUERY: _warehouse_class = Dialect._import_class( 'sodasql.dialects.bigquery_dialect', 'BigQueryDialect') elif warehouse_type == HIVE: _warehouse_class = Dialect._import_class( 'sodasql.dialects.hive_dialect', 'HiveDialect') elif warehouse_type == POSTGRES: _warehouse_class = Dialect._import_class( 'sodasql.dialects.postgres_dialect', 'PostgresDialect') elif warehouse_type == MYSQL: _warehouse_class = Dialect._import_class( 'sodasql.dialects.mysql_dialect', 'MySQLDialect') elif warehouse_type == REDSHIFT: _warehouse_class = Dialect._import_class( 'sodasql.dialects.redshift_dialect', 'RedshiftDialect') elif warehouse_type == SNOWFLAKE: _warehouse_class = Dialect._import_class( 'sodasql.dialects.snowflake_dialect', 'SnowflakeDialect') elif warehouse_type == SQLSERVER: _warehouse_class = Dialect._import_class( 'sodasql.dialects.sqlserver_dialect', 'SQLServerDialect') elif warehouse_type == SPARK: _warehouse_class = Dialect._import_class( 'sodasql.dialects.spark_dialect', 'SparkDialect') elif warehouse_type == TRINO: _warehouse_class = Dialect._import_class( 'sodasql.dialects.trino_dialect', 'TrinoDialect') return _warehouse_class(parser)
def __init__(self, parser: Parser): super().__init__(parser, REDSHIFT) self.port = parser.get_str_optional('port', '5439') self.aws_credentials = parser.get_aws_credentials_optional()
def __init__(self, parser: Parser): super().__init__(HIVE) if parser: self.host = parser.get_str_required('host') self.port = parser.get_int_optional('port', '10000') self.scheme = parser.get_str_optional('scheme', None) self.username = parser.get_str_required_env('username') self.database = parser.get_str_optional('database', 'default') self.auth_method = parser.get_str_optional('authentication', None) self.configuration = parser.get_dict_optional('configuration', {}) self.kerberos_service_name = parser.get_str_optional( 'kerberos_service_name', None) self.password = parser.get_str_optional_env('password') self.check_hostname = parser.get_bool_optional( 'check_hostname', None) self.ssl_cert = parser.get_str_optional('ssl_cert', None) self.thrift_transport = parser.get_str_optional( 'thrift_transport', None)
def __init__(self, parser: Parser): super().__init__(SNOWFLAKE) if parser: self.account = parser.get_str_required_env('account') self.warehouse = parser.get_str_required_env('warehouse') self.username = parser.get_str_required_env('username') self.password = parser.get_credential('password') self.database = parser.get_str_optional_env('database') self.schema = parser.get_str_required_env('schema') self.role = parser.get_str_optional('role') self.passcode_in_password = parser.get_bool_optional( 'passcode_in_password', False) self.private_key_passphrase = parser.get_str_optional( 'private_key_passphrase') self.private_key = parser.get_str_optional('private_key') self.private_key_path = parser.get_str_optional('private_key_path') self.client_prefetch_threads = parser.get_int_optional( 'client_prefetch_threads', 4) self.client_session_keep_alive = parser.get_bool_optional( 'client_session_keep_alive', False) self.authenticator = parser.get_str_optional( 'authenticator', 'snowflake') self.session_params = parser.get_dict_optional( 'session_parameters', None) self.connection_timeout = parser.get_int_optional( KEY_CONNECTION_TIMEOUT, DEFAULT_SOCKET_CONNECT_TIMEOUT)
def __init__(self, parser: Parser): super().__init__(SPARK) if parser: self.method = parser.get_str_optional('method', 'hive') self.host = parser.get_str_required('host') self.port = parser.get_int_optional('port', '10000') self.username = parser.get_credential('username') self.password = parser.get_credential('password') self.database = parser.get_str_optional('database') self.auth_method = parser.get_str_optional('authentication', None) self.configuration = parser.get_dict_optional('configuration', {}) self.driver = parser.get_str_optional('driver', None) self.token = parser.get_credential('token') self.organization = parser.get_str_optional('organization', None) self.cluster = parser.get_str_optional('cluster', None) self.server_side_parameters = { f"SSP_{k}": f"{{{v}}}" for k, v in parser.get_dict_optional("server_side_parameters", {}) }