Ejemplo n.º 1
0
    def __init__(self,
                 s3_staging_dir=None,
                 region_name=None,
                 schema_name='default',
                 work_group=None,
                 poll_interval=1,
                 encryption_option=None,
                 kms_key=None,
                 profile_name=None,
                 role_arn=None,
                 role_session_name='PyAthena-session-{0}'.format(
                     int(time.time())),
                 duration_seconds=3600,
                 converter=None,
                 formatter=None,
                 retry_config=None,
                 cursor_class=Cursor,
                 **kwargs):
        self._kwargs = kwargs
        if s3_staging_dir:
            self.s3_staging_dir = s3_staging_dir
        else:
            self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None)
        assert self.s3_staging_dir, 'Required argument `s3_staging_dir` not found.'
        assert schema_name, 'Required argument `schema_name` not found.'
        self.region_name = region_name
        self.schema_name = schema_name
        self.work_group = work_group
        self.poll_interval = poll_interval
        self.encryption_option = encryption_option
        self.kms_key = kms_key

        if role_arn:
            creds = self._assume_role(profile_name, region_name, role_arn,
                                      role_session_name, duration_seconds)
            profile_name = None
            self._kwargs.update({
                'aws_access_key_id':
                creds['AccessKeyId'],
                'aws_secret_access_key':
                creds['SecretAccessKey'],
                'aws_session_token':
                creds['SessionToken'],
            })
        self._session = Session(profile_name=profile_name,
                                **self._session_kwargs)
        self._client = self._session.client('athena',
                                            region_name=region_name,
                                            **self._client_kwargs)
        self._converter = converter if converter else TypeConverter()
        self._formatter = formatter if formatter else ParameterFormatter()
        self._retry_config = retry_config if retry_config else RetryConfig()
        self.cursor_class = cursor_class
Ejemplo n.º 2
0
    def open(cls, connection):
        if connection.state == 'open':
            logger.debug('Connection is already open, skipping open.')
            return connection

        credentials = connection.credentials

        conn = connect(s3_staging_dir=credentials.s3_staging_dir,
                       region_name=credentials.region_name,
                       schema_name=credentials.database,
                       cursor_class=AsyncCursor,
                       retry_config=RetryConfig(
                           attempt=credentials.max_retry_number,
                           max_delay=credentials.max_retry_delay))
        connection.state = 'open'
        connection.handle = ConnectionWrapper(conn, credentials.threads)
        return connection
Ejemplo n.º 3
0
    def open(cls, connection):
        if connection.state == "open":
            logger.debug("Connection is already open, skipping open.")
            return connection

        credentials = connection.credentials

        conn = connect(
            s3_staging_dir=credentials.s3_staging_dir,
            region_name=credentials.region_name,
            role_arn=credentials.role_arn,
            work_group=credentials.work_group,
            schema_name=credentials.schema,
            cursor_class=AsyncCursor,
            retry_config=RetryConfig(
                attempt=credentials.max_retry_number,
                max_delay=credentials.max_retry_delay,
            ),
        )
        connection.state = "open"
        connection.handle = ConnectionWrapper(conn, credentials.threads)
        return connection
Ejemplo n.º 4
0
    def __init__(self,
                 s3_staging_dir: Optional[str] = None,
                 region_name: Optional[str] = None,
                 schema_name: Optional[str] = "default",
                 catalog_name: Optional[str] = "awsdatacatalog",
                 work_group: Optional[str] = None,
                 poll_interval: float = 1,
                 encryption_option: Optional[str] = None,
                 kms_key: Optional[str] = None,
                 profile_name: Optional[str] = None,
                 role_arn: Optional[str] = None,
                 role_session_name: str = "PyAthena-session-{0}".format(
                     int(time.time())),
                 external_id: Optional[str] = None,
                 serial_number: Optional[str] = None,
                 duration_seconds: int = 3600,
                 converter: Optional[Converter] = None,
                 formatter: Optional[Formatter] = None,
                 retry_config: Optional[RetryConfig] = None,
                 cursor_class: Type[BaseCursor] = Cursor,
                 cursor_kwargs: Optional[Dict[str, Any]] = None,
                 kill_on_interrupt: bool = True,
                 session: Optional[Session] = None,
                 **kwargs) -> None:
        self._kwargs = {
            **kwargs,
            "role_arn": role_arn,
            "role_session_name": role_session_name,
            "external_id": external_id,
            "serial_number": serial_number,
            "duration_seconds": duration_seconds,
        }
        if s3_staging_dir:
            self.s3_staging_dir: Optional[str] = s3_staging_dir
        else:
            self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None)
        self.region_name = region_name
        self.schema_name = schema_name
        self.catalog_name = catalog_name
        if work_group:
            self.work_group: Optional[str] = work_group
        else:
            self.work_group = os.getenv(self._ENV_WORK_GROUP, None)
        self.poll_interval = poll_interval
        self.encryption_option = encryption_option
        self.kms_key = kms_key
        self.profile_name = profile_name

        assert (
            self.s3_staging_dir or self.work_group
        ), "Required argument `s3_staging_dir` or `work_group` not found."

        if session:
            self._session = session
        else:
            if role_arn:
                creds = self._assume_role(
                    profile_name=self.profile_name,
                    region_name=self.region_name,
                    role_arn=role_arn,
                    role_session_name=role_session_name,
                    external_id=external_id,
                    serial_number=serial_number,
                    duration_seconds=duration_seconds,
                )
                self.profile_name = None
                self._kwargs.update({
                    "aws_access_key_id":
                    creds["AccessKeyId"],
                    "aws_secret_access_key":
                    creds["SecretAccessKey"],
                    "aws_session_token":
                    creds["SessionToken"],
                })
            elif serial_number:
                creds = self._get_session_token(
                    profile_name=self.profile_name,
                    region_name=self.region_name,
                    serial_number=serial_number,
                    duration_seconds=duration_seconds,
                )
                self.profile_name = None
                self._kwargs.update({
                    "aws_access_key_id":
                    creds["AccessKeyId"],
                    "aws_secret_access_key":
                    creds["SecretAccessKey"],
                    "aws_session_token":
                    creds["SessionToken"],
                })
            self._session = Session(region_name=self.region_name,
                                    profile_name=self.profile_name,
                                    **self._session_kwargs)
        self._client = self._session.client("athena",
                                            region_name=self.region_name,
                                            **self._client_kwargs)
        self._converter = converter
        self._formatter = formatter if formatter else DefaultParameterFormatter(
        )
        self._retry_config = retry_config if retry_config else RetryConfig()
        self.cursor_class = cursor_class
        self.cursor_kwargs = cursor_kwargs if cursor_kwargs else dict()
        self.kill_on_interrupt = kill_on_interrupt
Ejemplo n.º 5
0
    def __init__(
        self,
        s3_staging_dir=None,
        region_name=None,
        schema_name="default",
        work_group=None,
        poll_interval=1,
        encryption_option=None,
        kms_key=None,
        profile_name=None,
        role_arn=None,
        role_session_name="PyAthena-session-{0}".format(int(time.time())),
        duration_seconds=3600,
        converter=None,
        formatter=None,
        retry_config=None,
        cursor_class=Cursor,
        kill_on_interrupt=True,
        **kwargs
    ):
        self._kwargs = kwargs
        if s3_staging_dir:
            self.s3_staging_dir = s3_staging_dir
        else:
            self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR, None)
        self.region_name = region_name
        self.schema_name = schema_name
        if work_group:
            self.work_group = work_group
        else:
            self.work_group = os.getenv(self._ENV_WORK_GROUP, None)
        self.poll_interval = poll_interval
        self.encryption_option = encryption_option
        self.kms_key = kms_key
        self.profile_name = profile_name

        assert self.schema_name, "Required argument `schema_name` not found."
        assert (
            self.s3_staging_dir or self.work_group
        ), "Required argument `s3_staging_dir` or `work_group` not found."

        if role_arn:
            creds = self._assume_role(
                self.profile_name,
                self.region_name,
                role_arn,
                role_session_name,
                duration_seconds,
            )
            self.profile_name = None
            self._kwargs.update(
                {
                    "aws_access_key_id": creds["AccessKeyId"],
                    "aws_secret_access_key": creds["SecretAccessKey"],
                    "aws_session_token": creds["SessionToken"],
                }
            )
        self._session = Session(profile_name=self.profile_name, **self._session_kwargs)
        self._client = self._session.client(
            "athena", region_name=self.region_name, **self._client_kwargs
        )
        self._converter = converter
        self._formatter = formatter if formatter else DefaultParameterFormatter()
        self._retry_config = retry_config if retry_config else RetryConfig()
        self.cursor_class = cursor_class
        self.kill_on_interrupt = kill_on_interrupt