def __init__(self, s3_staging_dir=None, region_name=None, schema_name='default', work_group=None, poll_interval=1, encryption_option=None, kms_key=None, profile_name=None, role_arn=None, role_session_name='PyAthena-session-{0}'.format(int(time.time())), duration_seconds=3600, converter=None, formatter=None, retry_config=None, cursor_class=Cursor, **kwargs): self._kwargs = kwargs if s3_staging_dir: self.s3_staging_dir = s3_staging_dir else: self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None) self.region_name = region_name self.schema_name = schema_name if work_group: self.work_group = work_group else: self.work_group = os.getenv(self._ENV_WORK_GROUP, None) self.poll_interval = poll_interval self.encryption_option = encryption_option self.kms_key = kms_key assert self.schema_name, 'Required argument `schema_name` not found.' assert self.s3_staging_dir or self.work_group,\ 'Required argument `s3_staging_dir` or `work_group` not found.' if role_arn: creds = self._assume_role(profile_name, region_name, role_arn, role_session_name, duration_seconds) profile_name = None self._kwargs.update({ 'aws_access_key_id': creds['AccessKeyId'], 'aws_secret_access_key': creds['SecretAccessKey'], 'aws_session_token': creds['SessionToken'], }) self._session = Session(profile_name=profile_name, **self._session_kwargs) self._client = self._session.client('athena', region_name=region_name, **self._client_kwargs) self._converter = converter self._formatter = formatter if formatter else DefaultParameterFormatter() self._retry_config = retry_config if retry_config else RetryConfig() self.cursor_class = cursor_class
def __init__(self, s3_staging_dir: Optional[str] = None, region_name: Optional[str] = None, schema_name: Optional[str] = "default", catalog_name: Optional[str] = "awsdatacatalog", work_group: Optional[str] = None, poll_interval: float = 1, encryption_option: Optional[str] = None, kms_key: Optional[str] = None, profile_name: Optional[str] = None, role_arn: Optional[str] = None, role_session_name: str = "PyAthena-session-{0}".format( int(time.time())), external_id: Optional[str] = None, serial_number: Optional[str] = None, duration_seconds: int = 3600, converter: Optional[Converter] = None, formatter: Optional[Formatter] = None, retry_config: Optional[RetryConfig] = None, cursor_class: Type[BaseCursor] = Cursor, cursor_kwargs: Optional[Dict[str, Any]] = None, kill_on_interrupt: bool = True, session: Optional[Session] = None, **kwargs) -> None: self._kwargs = { **kwargs, "role_arn": role_arn, "role_session_name": role_session_name, "external_id": external_id, "serial_number": serial_number, "duration_seconds": duration_seconds, } if s3_staging_dir: self.s3_staging_dir: Optional[str] = s3_staging_dir else: self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None) self.region_name = region_name self.schema_name = schema_name self.catalog_name = catalog_name if work_group: self.work_group: Optional[str] = work_group else: self.work_group = os.getenv(self._ENV_WORK_GROUP, None) self.poll_interval = poll_interval self.encryption_option = encryption_option self.kms_key = kms_key self.profile_name = profile_name assert ( self.s3_staging_dir or self.work_group ), "Required argument `s3_staging_dir` or `work_group` not found." if session: self._session = session else: if role_arn: creds = self._assume_role( profile_name=self.profile_name, region_name=self.region_name, role_arn=role_arn, role_session_name=role_session_name, external_id=external_id, serial_number=serial_number, duration_seconds=duration_seconds, ) self.profile_name = None self._kwargs.update({ "aws_access_key_id": creds["AccessKeyId"], "aws_secret_access_key": creds["SecretAccessKey"], "aws_session_token": creds["SessionToken"], }) elif serial_number: creds = self._get_session_token( profile_name=self.profile_name, region_name=self.region_name, serial_number=serial_number, duration_seconds=duration_seconds, ) self.profile_name = None self._kwargs.update({ "aws_access_key_id": creds["AccessKeyId"], "aws_secret_access_key": creds["SecretAccessKey"], "aws_session_token": creds["SessionToken"], }) self._session = Session(region_name=self.region_name, profile_name=self.profile_name, **self._session_kwargs) self._client = self._session.client("athena", region_name=self.region_name, **self._client_kwargs) self._converter = converter self._formatter = formatter if formatter else DefaultParameterFormatter( ) self._retry_config = retry_config if retry_config else RetryConfig() self.cursor_class = cursor_class self.cursor_kwargs = cursor_kwargs if cursor_kwargs else dict() self.kill_on_interrupt = kill_on_interrupt
def __init__( self, s3_staging_dir=None, region_name=None, schema_name="default", work_group=None, poll_interval=1, encryption_option=None, kms_key=None, profile_name=None, role_arn=None, role_session_name="PyAthena-session-{0}".format(int(time.time())), duration_seconds=3600, converter=None, formatter=None, retry_config=None, cursor_class=Cursor, kill_on_interrupt=True, **kwargs ): self._kwargs = kwargs if s3_staging_dir: self.s3_staging_dir = s3_staging_dir else: self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None) self.region_name = region_name self.schema_name = schema_name if work_group: self.work_group = work_group else: self.work_group = os.getenv(self._ENV_WORK_GROUP, None) self.poll_interval = poll_interval self.encryption_option = encryption_option self.kms_key = kms_key self.profile_name = profile_name assert self.schema_name, "Required argument `schema_name` not found." assert ( self.s3_staging_dir or self.work_group ), "Required argument `s3_staging_dir` or `work_group` not found." if role_arn: creds = self._assume_role( self.profile_name, self.region_name, role_arn, role_session_name, duration_seconds, ) self.profile_name = None self._kwargs.update( { "aws_access_key_id": creds["AccessKeyId"], "aws_secret_access_key": creds["SecretAccessKey"], "aws_session_token": creds["SessionToken"], } ) self._session = Session(profile_name=self.profile_name, **self._session_kwargs) self._client = self._session.client( "athena", region_name=self.region_name, **self._client_kwargs ) self._converter = converter self._formatter = formatter if formatter else DefaultParameterFormatter() self._retry_config = retry_config if retry_config else RetryConfig() self.cursor_class = cursor_class self.kill_on_interrupt = kill_on_interrupt
def setUp(self): self.formatter = DefaultParameterFormatter()
class TestDefaultParameterFormatter(unittest.TestCase): # TODO More DDL statement test case & Complex parameter format test case def setUp(self): self.formatter = DefaultParameterFormatter() def format(self, operation, parameters=None): return self.formatter.format(operation, parameters) def test_add_partition(self): expected = textwrap.dedent(""" ALTER TABLE test_table ADD PARTITION (dt=DATE '2017-01-01', hour=1) """).strip() actual = self.format( textwrap.dedent(""" ALTER TABLE test_table ADD PARTITION (dt=%(dt)s, hour=%(hour)d) """).strip(), { "dt": date(2017, 1, 1), "hour": 1 }, ) self.assertEqual(actual, expected) def test_drop_partition(self): expected = textwrap.dedent(""" ALTER TABLE test_table DROP PARTITION (dt=DATE '2017-01-01', hour=1) """).strip() actual = self.format( textwrap.dedent(""" ALTER TABLE test_table DROP PARTITION (dt=%(dt)s, hour=%(hour)d) """).strip(), { "dt": date(2017, 1, 1), "hour": 1 }, ) self.assertEqual(actual, expected) def test_format_none(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col is null """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col is %(param)s """).strip(), {"param": None}, ) self.assertEqual(actual, expected) def test_format_datetime(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_timestamp >= TIMESTAMP '2017-01-01 12:00:00.000' AND col_timestamp <= TIMESTAMP '2017-01-02 06:00:00.000' """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_timestamp >= %(start)s AND col_timestamp <= %(end)s """).strip(), { "start": datetime(2017, 1, 1, 12, 0, 0), "end": datetime(2017, 1, 2, 6, 0, 0), }, ) self.assertEqual(actual, expected) def test_format_date(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_date between DATE '2017-01-01' and DATE '2017-01-02' """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_date between %(start)s and %(end)s """).strip(), { "start": date(2017, 1, 1), "end": date(2017, 1, 2) }, ) self.assertEqual(actual, expected) def test_format_int(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_int = 1 """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_int = %(param)s """).strip(), {"param": 1}, ) self.assertEqual(actual, expected) def test_format_float(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_float >= 0.1 """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_float >= %(param).1f """).strip(), {"param": 0.1}, ) self.assertEqual(actual, expected) def test_format_decimal(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_decimal <= DECIMAL '0.0000000001' """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_decimal <= %(param)s """).strip(), {"param": Decimal("0.0000000001")}, ) self.assertEqual(actual, expected) def test_format_bool(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_boolean = True """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_boolean = %(param)s """).strip(), {"param": True}, ) self.assertEqual(actual, expected) def test_format_str(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_string = 'amazon athena' """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_string = %(param)s """).strip(), {"param": "amazon athena"}, ) self.assertEqual(actual, expected) def test_format_unicode(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_string = '密林 女神' """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_string = %(param)s """).strip(), {"param": "密林 女神"}, ) self.assertEqual(actual, expected) def test_format_none_list(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col IN (null, null) """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col IN %(param)s """).strip(), {"param": [None, None]}, ) self.assertEqual(actual, expected) def test_format_datetime_list(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_timestamp IN (TIMESTAMP '2017-01-01 12:00:00.000', TIMESTAMP '2017-01-02 06:00:00.000') """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_timestamp IN %(param)s """).strip(), { "param": [ datetime(2017, 1, 1, 12, 0, 0), datetime(2017, 1, 2, 6, 0, 0) ] }, ) self.assertEqual(actual, expected) def test_format_date_list(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_date IN (DATE '2017-01-01', DATE '2017-01-02') """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_date IN %(param)s """).strip(), {"param": [date(2017, 1, 1), date(2017, 1, 2)]}, ) self.assertEqual(actual, expected) def test_format_int_list(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_int IN (1, 2) """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_int IN %(param)s """).strip(), {"param": [1, 2]}, ) self.assertEqual(actual, expected) def test_format_float_list(self): # default precision is 6 expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_float IN (0.100000, 0.200000) """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_float IN %(param)s """).strip(), {"param": [0.1, 0.2]}, ) self.assertEqual(actual, expected) def test_format_decimal_list(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_decimal IN (DECIMAL '0.0000000001', DECIMAL '99.9999999999') """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_decimal IN %(param)s """).strip(), {"param": [Decimal("0.0000000001"), Decimal("99.9999999999")]}, ) self.assertEqual(actual, expected) def test_format_bool_list(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_boolean IN (True, False) """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_boolean IN %(param)s """).strip(), {"param": [True, False]}, ) self.assertEqual(actual, expected) def test_format_str_list(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_string IN ('amazon', 'athena') """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_string IN %(param)s """).strip(), {"param": ["amazon", "athena"]}, ) self.assertEqual(actual, expected) def test_format_unicode_list(self): expected = textwrap.dedent(""" SELECT * FROM test_table WHERE col_string IN ('密林', '女神') """).strip() actual = self.format( textwrap.dedent(""" SELECT * FROM test_table WHERE col_string IN %(param)s """).strip(), {"param": ["密林", "女神"]}, ) self.assertEqual(actual, expected) def test_format_bad_parameter(self): self.assertRaises( ProgrammingError, lambda: self.format( """ SELECT * FROM test_table where col_int = $(param)d """.strip(), 1, ), ) self.assertRaises( ProgrammingError, lambda: self.format( """ SELECT * FROM test_table where col_string = $(param)s """.strip(), "a string", ), ) self.assertRaises( ProgrammingError, lambda: self.format( """ SELECT * FROM test_table where col_string in $(param)s """.strip(), ["a string"], ), )
class TestDefaultParameterFormatter(unittest.TestCase): # TODO More DDL statement test case & Complex parameter format test case FORMATTER = DefaultParameterFormatter() def format(self, operation, parameters=None): return self.FORMATTER.format(operation, parameters) def test_add_partition(self): expected = """ ALTER TABLE test_table ADD PARTITION (dt='2017-01-01', hour=1) """.strip() actual = self.format( """ ALTER TABLE test_table ADD PARTITION (dt=%(dt)s, hour=%(hour)d) """, { 'dt': date(2017, 1, 1), 'hour': 1 }) self.assertEqual(actual, expected) def test_drop_partition(self): expected = """ ALTER TABLE test_table DROP PARTITION (dt='2017-01-01', hour=1) """.strip() actual = self.format( """ ALTER TABLE test_table DROP PARTITION (dt=%(dt)s, hour=%(hour)d) """, { 'dt': date(2017, 1, 1), 'hour': 1 }) self.assertEqual(actual, expected) def test_format_none(self): expected = """ SELECT * FROM test_table WHERE col is null """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col is %(param)s """, {'param': None}) self.assertEqual(actual, expected) def test_format_datetime(self): expected = """ SELECT * FROM test_table WHERE col_timestamp >= timestamp'2017-01-01 12:00:00.000' AND col_timestamp <= timestamp'2017-01-02 06:00:00.000' """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_timestamp >= %(start)s AND col_timestamp <= %(end)s """, { 'start': datetime(2017, 1, 1, 12, 0, 0), 'end': datetime(2017, 1, 2, 6, 0, 0) }) self.assertEqual(actual, expected) def test_format_date(self): expected = """ SELECT * FROM test_table WHERE col_date between date'2017-01-01' and date'2017-01-02' """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_date between %(start)s and %(end)s """, { 'start': date(2017, 1, 1), 'end': date(2017, 1, 2) }) self.assertEqual(actual, expected) def test_format_int(self): expected = """ SELECT * FROM test_table WHERE col_int = 1 """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_int = %(param)s """, {'param': 1}) self.assertEqual(actual, expected) def test_format_float(self): expected = """ SELECT * FROM test_table WHERE col_float >= 0.1 """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_float >= %(param).1f """, {'param': 0.1}) self.assertEqual(actual, expected) def test_format_decimal(self): expected = """ SELECT * FROM test_table WHERE col_decimal <= DECIMAL '0.0000000001' """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_decimal <= %(param)s """, {'param': Decimal('0.0000000001')}) self.assertEqual(actual, expected) def test_format_bool(self): expected = """ SELECT * FROM test_table WHERE col_boolean = True """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_boolean = %(param)s """, {'param': True}) self.assertEqual(actual, expected) def test_format_str(self): expected = """ SELECT * FROM test_table WHERE col_string = 'amazon athena' """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_string = %(param)s """, {'param': 'amazon athena'}) self.assertEqual(actual, expected) def test_format_unicode(self): expected = """ SELECT * FROM test_table WHERE col_string = '密林 女神' """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_string = %(param)s """, {'param': '密林 女神'}) self.assertEqual(actual, expected) def test_format_none_list(self): expected = """ SELECT * FROM test_table WHERE col IN (null, null) """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col IN %(param)s """, {'param': [None, None]}) self.assertEqual(actual, expected) def test_format_datetime_list(self): expected = """ SELECT * FROM test_table WHERE col_timestamp IN (timestamp'2017-01-01 12:00:00.000', timestamp'2017-01-02 06:00:00.000') """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_timestamp IN %(param)s """, { 'param': [ datetime(2017, 1, 1, 12, 0, 0), datetime(2017, 1, 2, 6, 0, 0) ] }) self.assertEqual(actual, expected) def test_format_date_list(self): expected = """ SELECT * FROM test_table WHERE col_date IN (date'2017-01-01', date'2017-01-02') """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_date IN %(param)s """, {'param': [date(2017, 1, 1), date(2017, 1, 2)]}) self.assertEqual(actual, expected) def test_format_int_list(self): expected = """ SELECT * FROM test_table WHERE col_int IN (1, 2) """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_int IN %(param)s """, {'param': [1, 2]}) self.assertEqual(actual, expected) def test_format_float_list(self): # default precision is 6 expected = """ SELECT * FROM test_table WHERE col_float IN (0.100000, 0.200000) """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_float IN %(param)s """, {'param': [0.1, 0.2]}) self.assertEqual(actual, expected) def test_format_decimal_list(self): expected = """ SELECT * FROM test_table WHERE col_decimal IN (DECIMAL '0.0000000001', DECIMAL '99.9999999999') """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_decimal IN %(param)s """, {'param': [Decimal('0.0000000001'), Decimal('99.9999999999')]}) self.assertEqual(actual, expected) def test_format_bool_list(self): expected = """ SELECT * FROM test_table WHERE col_boolean IN (True, False) """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_boolean IN %(param)s """, {'param': [True, False]}) self.assertEqual(actual, expected) def test_format_str_list(self): expected = """ SELECT * FROM test_table WHERE col_string IN ('amazon', 'athena') """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_string IN %(param)s """, {'param': ['amazon', 'athena']}) self.assertEqual(actual, expected) def test_format_unicode_list(self): expected = """ SELECT * FROM test_table WHERE col_string IN ('密林', '女神') """.strip() actual = self.format( """ SELECT * FROM test_table WHERE col_string IN %(param)s """, {'param': ['密林', '女神']}) self.assertEqual(actual, expected) def test_format_bad_parameter(self): self.assertRaises( ProgrammingError, lambda: self.format( """ SELECT * FROM test_table where col_int = $(param)d """.strip(), 1)) self.assertRaises( ProgrammingError, lambda: self.format( """ SELECT * FROM test_table where col_string = $(param)s """.strip(), 'a string')) self.assertRaises( ProgrammingError, lambda: self.format( """ SELECT * FROM test_table where col_string in $(param)s """.strip(), ['a string']))
def formatter(): from pyathena.formatter import DefaultParameterFormatter return DefaultParameterFormatter()