def __init__(self, s3_staging_dir=None, region_name=None, schema_name='default', poll_interval=1, encryption_option=None, kms_key=None, profile_name=None, role_arn=None, role_session_name='PyAthena-session-{0}'.format( int(time.time())), duration_seconds=3600, converter=None, formatter=None, retry_exceptions=('ThrottlingException', 'TooManyRequestsException'), retry_attempt=5, retry_multiplier=1, retry_max_delay=1800, retry_exponential_base=2, cursor_class=Cursor, **kwargs): if s3_staging_dir: self.s3_staging_dir = s3_staging_dir else: self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None) assert self.s3_staging_dir, 'Required argument `s3_staging_dir` not found.' assert schema_name, 'Required argument `schema_name` not found.' self.region_name = region_name self.schema_name = schema_name self.poll_interval = poll_interval self.encryption_option = encryption_option self.kms_key = kms_key if role_arn: creds = self._assume_role(profile_name, region_name, role_arn, role_session_name, duration_seconds, **kwargs) profile_name = None kwargs.update({ 'aws_access_key_id': creds['AccessKeyId'], 'aws_secret_access_key': creds['SecretAccessKey'], 'aws_session_token': creds['SessionToken'], }) self._session = Session(profile_name=profile_name, **kwargs) self._client = self._session.client('athena', region_name=region_name, **kwargs) self._converter = converter if converter else TypeConverter() self._formatter = formatter if formatter else ParameterFormatter() self.retry_exceptions = retry_exceptions self.retry_attempt = retry_attempt self.retry_multiplier = retry_multiplier self.retry_max_delay = retry_max_delay self.retry_exponential_base = retry_exponential_base self.cursor_class = cursor_class self._kwargs = kwargs
def __init__(self, s3_staging_dir=None, region_name=None, schema_name='default', poll_interval=1, encryption_option=None, kms_key=None, profile_name=None, converter=None, formatter=None, retry_exceptions=('ThrottlingException', 'TooManyRequestsException'), retry_attempt=5, retry_multiplier=1, retry_max_delay=1800, retry_exponential_base=2, cursor_class=Cursor, **kwargs): if s3_staging_dir: self.s3_staging_dir = s3_staging_dir else: self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None) assert self.s3_staging_dir, 'Required argument `s3_staging_dir` not found.' assert schema_name, 'Required argument `schema_name` not found.' self.region_name = region_name self.schema_name = schema_name self.poll_interval = poll_interval self.encryption_option = encryption_option self.kms_key = kms_key if profile_name: session = Session(profile_name=profile_name, **kwargs) self._client = session.client('athena', region_name=region_name, **kwargs) else: self._client = boto3.client('athena', region_name=region_name, **kwargs) self._converter = converter if converter else TypeConverter() self._formatter = formatter if formatter else ParameterFormatter() self.retry_exceptions = retry_exceptions self.retry_attempt = retry_attempt self.retry_multiplier = retry_multiplier self.retry_max_delay = retry_max_delay self.retry_exponential_base = retry_exponential_base self.cursor_class = cursor_class
def __init__(self, s3_staging_dir=None, region_name=None, schema_name='default', work_group=None, poll_interval=1, encryption_option=None, kms_key=None, profile_name=None, role_arn=None, role_session_name='PyAthena-session-{0}'.format( int(time.time())), duration_seconds=3600, converter=None, formatter=None, retry_config=None, cursor_class=Cursor, **kwargs): self._kwargs = kwargs if s3_staging_dir: self.s3_staging_dir = s3_staging_dir else: self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None) assert self.s3_staging_dir, 'Required argument `s3_staging_dir` not found.' assert schema_name, 'Required argument `schema_name` not found.' self.region_name = region_name self.schema_name = schema_name self.work_group = work_group self.poll_interval = poll_interval self.encryption_option = encryption_option self.kms_key = kms_key if role_arn: creds = self._assume_role(profile_name, region_name, role_arn, role_session_name, duration_seconds) profile_name = None self._kwargs.update({ 'aws_access_key_id': creds['AccessKeyId'], 'aws_secret_access_key': creds['SecretAccessKey'], 'aws_session_token': creds['SessionToken'], }) self._session = Session(profile_name=profile_name, **self._session_kwargs) self._client = self._session.client('athena', region_name=region_name, **self._client_kwargs) self._converter = converter if converter else TypeConverter() self._formatter = formatter if formatter else ParameterFormatter() self._retry_config = retry_config if retry_config else RetryConfig() self.cursor_class = cursor_class
class TestParameterFormatter(unittest.TestCase): # TODO More DDL statement test case & Complex parameter format test case FORMATTER = ParameterFormatter() def format(self, operation, parameters=None): return self.FORMATTER.format(operation, parameters) def test_add_partition(self): expected = """ ALTER TABLE test_table ADD PARTITION (dt='2017-01-01', hour=1) """.strip() actual = self.format( """ ALTER TABLE test_table ADD PARTITION (dt=%(dt)s, hour=%(hour)d) """, { 'dt': date(2017, 1, 1), 'hour': 1 }) self.assertEqual(actual, expected) def test_drop_partition(self): expected = """ ALTER TABLE test_table DROP PARTITION (dt='2017-01-01', hour=1) """.strip() actual = self.format( """ ALTER TABLE test_table DROP PARTITION (dt=%(dt)s, hour=%(hour)d) """, { 'dt': date(2017, 1, 1), 'hour': 1 }) self.assertEqual(actual, expected) def test_format_none(self): expected = """ SELECT * FROM test_table WHERE col is null """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col is %(param)s """, {'param': None}) self.assertEqual(actual, expected) def test_format_datetime(self): expected = """ SELECT * FROM test_table WHERE col_timestamp >= timestamp'2017-01-01 12:00:00.000' AND col_timestamp <= timestamp'2017-01-02 06:00:00.000' """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_timestamp >= %(start)s AND col_timestamp <= %(end)s """, { 'start': datetime(2017, 1, 1, 12, 0, 0), 'end': datetime(2017, 1, 2, 6, 0, 0) }) self.assertEqual(actual, expected) def test_format_date(self): expected = """ SELECT * FROM test_table WHERE col_date between date'2017-01-01' and date'2017-01-02' """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_date between %(start)s and %(end)s """, { 'start': date(2017, 1, 1), 'end': date(2017, 1, 2) }) self.assertEqual(actual, expected) def test_format_int(self): expected = """ SELECT * FROM test_table WHERE col_int = 1 """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_int = %(param)s """, {'param': 1}) self.assertEqual(actual, expected) def test_format_float(self): expected = """ SELECT * FROM test_table WHERE col_float >= 0.1 """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_float >= %(param).1f """, {'param': 0.1}) self.assertEqual(actual, expected) def test_format_decimal(self): expected = """ SELECT * FROM test_table WHERE col_decimal <= 0.0000000001 """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_decimal <= %(param).10f """, {'param': Decimal('0.0000000001')}) self.assertEqual(actual, expected) def test_format_bool(self): expected = """ SELECT * FROM test_table WHERE col_boolean = True """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_boolean = %(param)s """, {'param': True}) self.assertEqual(actual, expected) def test_format_str(self): expected = """ SELECT * FROM test_table WHERE col_string = 'amazon athena' """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_string = %(param)s """, {'param': 'amazon athena'}) self.assertEqual(actual, expected) def test_format_unicode(self): expected = """ SELECT * FROM test_table WHERE col_string = '密林 女神' """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_string = %(param)s """, {'param': '密林 女神'}) self.assertEqual(actual, expected) def test_format_none_list(self): expected = """ SELECT * FROM test_table WHERE col IN (null,null) """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col IN %(param)s """, {'param': [None, None]}) self.assertEqual(actual, expected) def test_format_datetime_list(self): expected = """ SELECT * FROM test_table WHERE col_timestamp IN (timestamp'2017-01-01 12:00:00.000',timestamp'2017-01-02 06:00:00.000') """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_timestamp IN %(param)s """, { 'param': [ datetime(2017, 1, 1, 12, 0, 0), datetime(2017, 1, 2, 6, 0, 0) ] }) self.assertEqual(actual, expected) def test_format_date_list(self): expected = """ SELECT * FROM test_table WHERE col_date IN (date'2017-01-01',date'2017-01-02') """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_date IN %(param)s """, {'param': [date(2017, 1, 1), date(2017, 1, 2)]}) self.assertEqual(actual, expected) def test_format_int_list(self): expected = """ SELECT * FROM test_table WHERE col_int IN (1,2) """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_int IN %(param)s """, {'param': [1, 2]}) self.assertEqual(actual, expected) def test_format_float_list(self): # default precision is 6 expected = """ SELECT * FROM test_table WHERE col_float IN (0.100000,0.200000) """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_float IN %(param)s """, {'param': [0.1, 0.2]}) self.assertEqual(actual, expected) def test_format_decimal_list(self): expected = """ SELECT * FROM test_table WHERE col_decimal IN (0.0000000001,99.9999999999) """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_decimal IN %(param)s """, {'param': [Decimal('0.0000000001'), Decimal('99.9999999999')]}) self.assertEqual(actual, expected) def test_format_bool_list(self): expected = """ SELECT * FROM test_table WHERE col_boolean IN (True,False) """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_boolean IN %(param)s """, {'param': [True, False]}) self.assertEqual(actual, expected) def test_format_str_list(self): expected = """ SELECT * FROM test_table WHERE col_string IN ('amazon','athena') """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_string IN %(param)s """, {'param': ['amazon', 'athena']}) self.assertEqual(actual, expected) def test_format_unicode_list(self): expected = """ SELECT * FROM test_table WHERE col_string IN ('密林','女神') """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_string IN %(param)s """, {'param': ['密林', '女神']}) self.assertEqual(actual, expected) def test_format_bad_parameter(self): self.assertRaises( ProgrammingError, lambda: self.format( """ SELECT * FROM test_table where col_int = $(param)d """.strip(), 1)) self.assertRaises( ProgrammingError, lambda: self.format( """ SELECT * FROM test_table where col_string = $(param)s """.strip(), 'a string')) self.assertRaises( ProgrammingError, lambda: self.format( """ SELECT * FROM test_table where col_string in $(param)s """.strip(), ['a string']))