Ejemplo n.º 1
0
    def __init__(self, s3_staging_dir=None, region_name=None, schema_name='default',
                 work_group=None, poll_interval=1, encryption_option=None, kms_key=None,
                 profile_name=None, role_arn=None,
                 role_session_name='PyAthena-session-{0}'.format(int(time.time())),
                 duration_seconds=3600, converter=None, formatter=None,
                 retry_config=None, cursor_class=Cursor, **kwargs):
        self._kwargs = kwargs
        if s3_staging_dir:
            self.s3_staging_dir = s3_staging_dir
        else:
            self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None)
        self.region_name = region_name
        self.schema_name = schema_name
        if work_group:
            self.work_group = work_group
        else:
            self.work_group = os.getenv(self._ENV_WORK_GROUP, None)
        self.poll_interval = poll_interval
        self.encryption_option = encryption_option
        self.kms_key = kms_key

        assert self.schema_name, 'Required argument `schema_name` not found.'
        assert self.s3_staging_dir or self.work_group,\
            'Required argument `s3_staging_dir` or `work_group` not found.'

        if role_arn:
            creds = self._assume_role(profile_name, region_name, role_arn,
                                      role_session_name, duration_seconds)
            profile_name = None
            self._kwargs.update({
                'aws_access_key_id': creds['AccessKeyId'],
                'aws_secret_access_key': creds['SecretAccessKey'],
                'aws_session_token': creds['SessionToken'],
            })
        self._session = Session(profile_name=profile_name,
                                **self._session_kwargs)
        self._client = self._session.client('athena', region_name=region_name,
                                            **self._client_kwargs)
        self._converter = converter
        self._formatter = formatter if formatter else DefaultParameterFormatter()
        self._retry_config = retry_config if retry_config else RetryConfig()
        self.cursor_class = cursor_class
Ejemplo n.º 2
0
    def __init__(self,
                 s3_staging_dir: Optional[str] = None,
                 region_name: Optional[str] = None,
                 schema_name: Optional[str] = "default",
                 catalog_name: Optional[str] = "awsdatacatalog",
                 work_group: Optional[str] = None,
                 poll_interval: float = 1,
                 encryption_option: Optional[str] = None,
                 kms_key: Optional[str] = None,
                 profile_name: Optional[str] = None,
                 role_arn: Optional[str] = None,
                 role_session_name: str = "PyAthena-session-{0}".format(
                     int(time.time())),
                 external_id: Optional[str] = None,
                 serial_number: Optional[str] = None,
                 duration_seconds: int = 3600,
                 converter: Optional[Converter] = None,
                 formatter: Optional[Formatter] = None,
                 retry_config: Optional[RetryConfig] = None,
                 cursor_class: Type[BaseCursor] = Cursor,
                 cursor_kwargs: Optional[Dict[str, Any]] = None,
                 kill_on_interrupt: bool = True,
                 session: Optional[Session] = None,
                 **kwargs) -> None:
        self._kwargs = {
            **kwargs,
            "role_arn": role_arn,
            "role_session_name": role_session_name,
            "external_id": external_id,
            "serial_number": serial_number,
            "duration_seconds": duration_seconds,
        }
        if s3_staging_dir:
            self.s3_staging_dir: Optional[str] = s3_staging_dir
        else:
            self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None)
        self.region_name = region_name
        self.schema_name = schema_name
        self.catalog_name = catalog_name
        if work_group:
            self.work_group: Optional[str] = work_group
        else:
            self.work_group = os.getenv(self._ENV_WORK_GROUP, None)
        self.poll_interval = poll_interval
        self.encryption_option = encryption_option
        self.kms_key = kms_key
        self.profile_name = profile_name

        assert (
            self.s3_staging_dir or self.work_group
        ), "Required argument `s3_staging_dir` or `work_group` not found."

        if session:
            self._session = session
        else:
            if role_arn:
                creds = self._assume_role(
                    profile_name=self.profile_name,
                    region_name=self.region_name,
                    role_arn=role_arn,
                    role_session_name=role_session_name,
                    external_id=external_id,
                    serial_number=serial_number,
                    duration_seconds=duration_seconds,
                )
                self.profile_name = None
                self._kwargs.update({
                    "aws_access_key_id":
                    creds["AccessKeyId"],
                    "aws_secret_access_key":
                    creds["SecretAccessKey"],
                    "aws_session_token":
                    creds["SessionToken"],
                })
            elif serial_number:
                creds = self._get_session_token(
                    profile_name=self.profile_name,
                    region_name=self.region_name,
                    serial_number=serial_number,
                    duration_seconds=duration_seconds,
                )
                self.profile_name = None
                self._kwargs.update({
                    "aws_access_key_id":
                    creds["AccessKeyId"],
                    "aws_secret_access_key":
                    creds["SecretAccessKey"],
                    "aws_session_token":
                    creds["SessionToken"],
                })
            self._session = Session(region_name=self.region_name,
                                    profile_name=self.profile_name,
                                    **self._session_kwargs)
        self._client = self._session.client("athena",
                                            region_name=self.region_name,
                                            **self._client_kwargs)
        self._converter = converter
        self._formatter = formatter if formatter else DefaultParameterFormatter(
        )
        self._retry_config = retry_config if retry_config else RetryConfig()
        self.cursor_class = cursor_class
        self.cursor_kwargs = cursor_kwargs if cursor_kwargs else dict()
        self.kill_on_interrupt = kill_on_interrupt
Ejemplo n.º 3
0
    def __init__(
        self,
        s3_staging_dir=None,
        region_name=None,
        schema_name="default",
        work_group=None,
        poll_interval=1,
        encryption_option=None,
        kms_key=None,
        profile_name=None,
        role_arn=None,
        role_session_name="PyAthena-session-{0}".format(int(time.time())),
        duration_seconds=3600,
        converter=None,
        formatter=None,
        retry_config=None,
        cursor_class=Cursor,
        kill_on_interrupt=True,
        **kwargs
    ):
        self._kwargs = kwargs
        if s3_staging_dir:
            self.s3_staging_dir = s3_staging_dir
        else:
            self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None)
        self.region_name = region_name
        self.schema_name = schema_name
        if work_group:
            self.work_group = work_group
        else:
            self.work_group = os.getenv(self._ENV_WORK_GROUP, None)
        self.poll_interval = poll_interval
        self.encryption_option = encryption_option
        self.kms_key = kms_key
        self.profile_name = profile_name

        assert self.schema_name, "Required argument `schema_name` not found."
        assert (
            self.s3_staging_dir or self.work_group
        ), "Required argument `s3_staging_dir` or `work_group` not found."

        if role_arn:
            creds = self._assume_role(
                self.profile_name,
                self.region_name,
                role_arn,
                role_session_name,
                duration_seconds,
            )
            self.profile_name = None
            self._kwargs.update(
                {
                    "aws_access_key_id": creds["AccessKeyId"],
                    "aws_secret_access_key": creds["SecretAccessKey"],
                    "aws_session_token": creds["SessionToken"],
                }
            )
        self._session = Session(profile_name=self.profile_name, **self._session_kwargs)
        self._client = self._session.client(
            "athena", region_name=self.region_name, **self._client_kwargs
        )
        self._converter = converter
        self._formatter = formatter if formatter else DefaultParameterFormatter()
        self._retry_config = retry_config if retry_config else RetryConfig()
        self.cursor_class = cursor_class
        self.kill_on_interrupt = kill_on_interrupt
Ejemplo n.º 4
0
 def setUp(self):
     self.formatter = DefaultParameterFormatter()
Ejemplo n.º 5
0
class TestDefaultParameterFormatter(unittest.TestCase):

    # TODO More DDL statement test case & Complex parameter format test case

    def setUp(self):
        self.formatter = DefaultParameterFormatter()

    def format(self, operation, parameters=None):
        return self.formatter.format(operation, parameters)

    def test_add_partition(self):
        expected = textwrap.dedent("""
            ALTER TABLE test_table
            ADD PARTITION (dt=DATE '2017-01-01', hour=1)
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                ALTER TABLE test_table
                ADD PARTITION (dt=%(dt)s, hour=%(hour)d)
                """).strip(),
            {
                "dt": date(2017, 1, 1),
                "hour": 1
            },
        )
        self.assertEqual(actual, expected)

    def test_drop_partition(self):
        expected = textwrap.dedent("""
            ALTER TABLE test_table
            DROP PARTITION (dt=DATE '2017-01-01', hour=1)
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                ALTER TABLE test_table
                DROP PARTITION (dt=%(dt)s, hour=%(hour)d)
                """).strip(),
            {
                "dt": date(2017, 1, 1),
                "hour": 1
            },
        )
        self.assertEqual(actual, expected)

    def test_format_none(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col is null
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col is %(param)s
                """).strip(),
            {"param": None},
        )
        self.assertEqual(actual, expected)

    def test_format_datetime(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_timestamp >= TIMESTAMP '2017-01-01 12:00:00.000'
            AND col_timestamp <= TIMESTAMP '2017-01-02 06:00:00.000'
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_timestamp >= %(start)s
                AND col_timestamp <= %(end)s
                """).strip(),
            {
                "start": datetime(2017, 1, 1, 12, 0, 0),
                "end": datetime(2017, 1, 2, 6, 0, 0),
            },
        )
        self.assertEqual(actual, expected)

    def test_format_date(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_date between DATE '2017-01-01' and DATE '2017-01-02'
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_date between %(start)s and %(end)s
                """).strip(),
            {
                "start": date(2017, 1, 1),
                "end": date(2017, 1, 2)
            },
        )
        self.assertEqual(actual, expected)

    def test_format_int(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_int = 1
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_int = %(param)s
                """).strip(),
            {"param": 1},
        )
        self.assertEqual(actual, expected)

    def test_format_float(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_float >= 0.1
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_float >= %(param).1f
                """).strip(),
            {"param": 0.1},
        )
        self.assertEqual(actual, expected)

    def test_format_decimal(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_decimal <= DECIMAL '0.0000000001'
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_decimal <= %(param)s
                """).strip(),
            {"param": Decimal("0.0000000001")},
        )
        self.assertEqual(actual, expected)

    def test_format_bool(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_boolean = True
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_boolean = %(param)s
                """).strip(),
            {"param": True},
        )
        self.assertEqual(actual, expected)

    def test_format_str(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_string = 'amazon athena'
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_string = %(param)s
                """).strip(),
            {"param": "amazon athena"},
        )
        self.assertEqual(actual, expected)

    def test_format_unicode(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_string = '密林 女神'
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_string = %(param)s
                """).strip(),
            {"param": "密林 女神"},
        )
        self.assertEqual(actual, expected)

    def test_format_none_list(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col IN (null, null)
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col IN %(param)s
                """).strip(),
            {"param": [None, None]},
        )
        self.assertEqual(actual, expected)

    def test_format_datetime_list(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_timestamp IN
            (TIMESTAMP '2017-01-01 12:00:00.000', TIMESTAMP '2017-01-02 06:00:00.000')
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_timestamp IN
                %(param)s
                """).strip(),
            {
                "param": [
                    datetime(2017, 1, 1, 12, 0, 0),
                    datetime(2017, 1, 2, 6, 0, 0)
                ]
            },
        )
        self.assertEqual(actual, expected)

    def test_format_date_list(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_date IN (DATE '2017-01-01', DATE '2017-01-02')
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_date IN %(param)s
                """).strip(),
            {"param": [date(2017, 1, 1), date(2017, 1, 2)]},
        )
        self.assertEqual(actual, expected)

    def test_format_int_list(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_int IN (1, 2)
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_int IN %(param)s
                """).strip(),
            {"param": [1, 2]},
        )
        self.assertEqual(actual, expected)

    def test_format_float_list(self):
        # default precision is 6
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_float IN (0.100000, 0.200000)
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_float IN %(param)s
                """).strip(),
            {"param": [0.1, 0.2]},
        )
        self.assertEqual(actual, expected)

    def test_format_decimal_list(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_decimal IN (DECIMAL '0.0000000001', DECIMAL '99.9999999999')
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_decimal IN %(param)s
                """).strip(),
            {"param": [Decimal("0.0000000001"),
                       Decimal("99.9999999999")]},
        )
        self.assertEqual(actual, expected)

    def test_format_bool_list(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_boolean IN (True, False)
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_boolean IN %(param)s
                """).strip(),
            {"param": [True, False]},
        )
        self.assertEqual(actual, expected)

    def test_format_str_list(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_string IN ('amazon', 'athena')
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_string IN %(param)s
                """).strip(),
            {"param": ["amazon", "athena"]},
        )
        self.assertEqual(actual, expected)

    def test_format_unicode_list(self):
        expected = textwrap.dedent("""
            SELECT *
            FROM test_table
            WHERE col_string IN ('密林', '女神')
            """).strip()

        actual = self.format(
            textwrap.dedent("""
                SELECT *
                FROM test_table
                WHERE col_string IN %(param)s
                """).strip(),
            {"param": ["密林", "女神"]},
        )
        self.assertEqual(actual, expected)

    def test_format_bad_parameter(self):
        self.assertRaises(
            ProgrammingError,
            lambda: self.format(
                """
                SELECT *
                FROM test_table
                where col_int = $(param)d
                """.strip(),
                1,
            ),
        )

        self.assertRaises(
            ProgrammingError,
            lambda: self.format(
                """
                SELECT *
                FROM test_table
                where col_string = $(param)s
                """.strip(),
                "a string",
            ),
        )

        self.assertRaises(
            ProgrammingError,
            lambda: self.format(
                """
                SELECT *
                FROM test_table
                where col_string in $(param)s
                """.strip(),
                ["a string"],
            ),
        )
Ejemplo n.º 6
0
class TestDefaultParameterFormatter(unittest.TestCase):

    # TODO More DDL statement test case & Complex parameter format test case

    FORMATTER = DefaultParameterFormatter()

    def format(self, operation, parameters=None):
        return self.FORMATTER.format(operation, parameters)

    def test_add_partition(self):
        expected = """
        ALTER TABLE test_table
        ADD PARTITION (dt='2017-01-01', hour=1)
        """.strip()

        actual = self.format(
            """
        ALTER TABLE test_table
        ADD PARTITION (dt=%(dt)s, hour=%(hour)d)
        """, {
                'dt': date(2017, 1, 1),
                'hour': 1
            })
        self.assertEqual(actual, expected)

    def test_drop_partition(self):
        expected = """
        ALTER TABLE test_table
        DROP PARTITION (dt='2017-01-01', hour=1)
        """.strip()

        actual = self.format(
            """
        ALTER TABLE test_table
        DROP PARTITION (dt=%(dt)s, hour=%(hour)d)
        """, {
                'dt': date(2017, 1, 1),
                'hour': 1
            })
        self.assertEqual(actual, expected)

    def test_format_none(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col is null
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col is %(param)s
        """, {'param': None})
        self.assertEqual(actual, expected)

    def test_format_datetime(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_timestamp >= timestamp'2017-01-01 12:00:00.000'
          AND col_timestamp <= timestamp'2017-01-02 06:00:00.000'
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_timestamp >= %(start)s
          AND col_timestamp <= %(end)s
        """, {
                'start': datetime(2017, 1, 1, 12, 0, 0),
                'end': datetime(2017, 1, 2, 6, 0, 0)
            })
        self.assertEqual(actual, expected)

    def test_format_date(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_date between date'2017-01-01' and date'2017-01-02'
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_date between %(start)s and %(end)s
        """, {
                'start': date(2017, 1, 1),
                'end': date(2017, 1, 2)
            })
        self.assertEqual(actual, expected)

    def test_format_int(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_int = 1
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_int = %(param)s
        """, {'param': 1})
        self.assertEqual(actual, expected)

    def test_format_float(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_float >= 0.1
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_float >= %(param).1f
        """, {'param': 0.1})
        self.assertEqual(actual, expected)

    def test_format_decimal(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_decimal <= DECIMAL '0.0000000001'
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_decimal <= %(param)s
        """, {'param': Decimal('0.0000000001')})
        self.assertEqual(actual, expected)

    def test_format_bool(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_boolean = True
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_boolean = %(param)s
        """, {'param': True})
        self.assertEqual(actual, expected)

    def test_format_str(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_string = 'amazon athena'
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_string = %(param)s
        """, {'param': 'amazon athena'})
        self.assertEqual(actual, expected)

    def test_format_unicode(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_string = '密林 女神'
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_string = %(param)s
        """, {'param': '密林 女神'})
        self.assertEqual(actual, expected)

    def test_format_none_list(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col IN (null, null)
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col IN %(param)s
        """, {'param': [None, None]})
        self.assertEqual(actual, expected)

    def test_format_datetime_list(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_timestamp IN
        (timestamp'2017-01-01 12:00:00.000', timestamp'2017-01-02 06:00:00.000')
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_timestamp IN
        %(param)s
        """, {
                'param': [
                    datetime(2017, 1, 1, 12, 0, 0),
                    datetime(2017, 1, 2, 6, 0, 0)
                ]
            })
        self.assertEqual(actual, expected)

    def test_format_date_list(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_date IN (date'2017-01-01', date'2017-01-02')
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_date IN %(param)s
        """, {'param': [date(2017, 1, 1), date(2017, 1, 2)]})
        self.assertEqual(actual, expected)

    def test_format_int_list(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_int IN (1, 2)
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_int IN %(param)s
        """, {'param': [1, 2]})
        self.assertEqual(actual, expected)

    def test_format_float_list(self):
        # default precision is 6
        expected = """
        SELECT *
        FROM test_table
        WHERE col_float IN (0.100000, 0.200000)
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_float IN %(param)s
        """, {'param': [0.1, 0.2]})
        self.assertEqual(actual, expected)

    def test_format_decimal_list(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_decimal IN (DECIMAL '0.0000000001', DECIMAL '99.9999999999')
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_decimal IN %(param)s
        """, {'param': [Decimal('0.0000000001'),
                        Decimal('99.9999999999')]})
        self.assertEqual(actual, expected)

    def test_format_bool_list(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_boolean IN (True, False)
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_boolean IN %(param)s
        """, {'param': [True, False]})
        self.assertEqual(actual, expected)

    def test_format_str_list(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_string IN ('amazon', 'athena')
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_string IN %(param)s
        """, {'param': ['amazon', 'athena']})
        self.assertEqual(actual, expected)

    def test_format_unicode_list(self):
        expected = """
        SELECT *
        FROM test_table
        WHERE col_string IN ('密林', '女神')
        """.strip()

        actual = self.format(
            """
        SELECT *
        FROM test_table
        WHERE col_string IN %(param)s
        """, {'param': ['密林', '女神']})
        self.assertEqual(actual, expected)

    def test_format_bad_parameter(self):
        self.assertRaises(
            ProgrammingError, lambda: self.format(
                """
        SELECT *
        FROM test_table
        where col_int = $(param)d
        """.strip(), 1))

        self.assertRaises(
            ProgrammingError, lambda: self.format(
                """
        SELECT *
        FROM test_table
        where col_string = $(param)s
        """.strip(), 'a string'))

        self.assertRaises(
            ProgrammingError, lambda: self.format(
                """
        SELECT *
        FROM test_table
        where col_string in $(param)s
        """.strip(), ['a string']))
Ejemplo n.º 7
0
def formatter():
    from pyathena.formatter import DefaultParameterFormatter

    return DefaultParameterFormatter()