예제 #1
0
def get_resource(resource_arn: str,
                 secret_arn: str,
                 transaction_id: Optional[str] = None) -> Resource:
    if resource_arn not in RESOURCE_METAS:
        if transaction_id in CONNECTION_POOL:
            raise InternalServerErrorException
        raise BadRequestException(
            f'HttpEndPoint is not enabled for {resource_arn}')

    try:
        secret: Secret = get_secret(secret_arn)
    except BadRequestException:
        if transaction_id in CONNECTION_POOL:
            raise InternalServerErrorException
        raise

    meta: ResourceMeta = RESOURCE_METAS[resource_arn]

    # TODO: support multiple secret_arn for a resource
    if secret.user_name != meta.user_name or secret.password != meta.password:
        raise BadRequestException('Invalid secret_arn')

    if transaction_id is None:
        connection: Connection = create_connection(resource_arn)
    else:
        connection = get_connection(transaction_id)

    return meta.resource_type(connection, transaction_id)
예제 #2
0
 def create_query(cls, sql: str, params: Dict[str, Any]) -> str:
     text_sql: TextClause = text(sql)
     kwargs = {
         'dialect': cls.DIALECT,
         'compile_kwargs': {
             "literal_binds": True
         }
     }
     try:
         return str(
             text_sql.bindparams(
                 **
                 {k: null() if v is None else v
                  for k, v in params.items()}).compile(**kwargs))
     except CompileError as e:
         invalid_param_match = re.match(INVALID_PARAMETER_MESSAGE,
                                        e.args[0])
         if invalid_param_match:  # pragma: no cover
             raise BadRequestException(
                 message=
                 f'Cannot find parameter: {invalid_param_match.group(1)}')
         raise  # pragma: no cover
     except ArgumentError as e:
         undefined_param_match = re.match(UNDEFINED_PARAMETER_MESSAGE,
                                          e.args[0])
         if undefined_param_match:  # pragma: no cover
             undefined_param: str = undefined_param_match.group(1)
             return cls.create_query(
                 sql,
                 {k: v
                  for k, v in params.items() if k != undefined_param})
         raise  # pragma: no cover
예제 #3
0
    def execute(
        self,
        sql: str,
        params: Optional[Dict[str, Any]] = None,
        include_result_metadata: bool = False,
    ) -> ExecuteStatementResponse:
        try:
            cursor: Optional[jaydebeapi.Cursor] = None
            try:
                cursor = self.connection.cursor()

                self.reset_generated_id(cursor)
                if params:
                    cursor.execute(self.create_query(sql, params))
                else:
                    cursor.execute(str(text(sql)))
                if cursor.description:
                    column_metadata_set = self.create_column_metadata_set(cursor)
                    response = ExecuteStatementResponse(
                        numberOfRecordsUpdated=0,
                        records=[
                            [
                                self.get_filed_from_jdbc_type(
                                    column, column_metadata.type
                                )
                                for column, column_metadata in zip(
                                    row, column_metadata_set
                                )
                            ]
                            for row in cursor.fetchall()
                        ],
                    )
                    if include_result_metadata:
                        response.columnMetadata = column_metadata_set
                    return response
                else:
                    rowcount: int = cursor.rowcount
                    last_generated_id: int = self.last_generated_id(cursor)
                    generated_fields: List[Field] = []
                    if last_generated_id > 0:
                        generated_fields.append(
                            self.get_field_from_value(last_generated_id)
                        )
                    return ExecuteStatementResponse(
                        numberOfRecordsUpdated=rowcount,
                        generatedFields=generated_fields,
                    )
            finally:
                if cursor:  # pragma: no cover
                    cursor.close()

        except jaydebeapi.DatabaseError as e:
            message: str = 'Unknown'
            if len(getattr(e, 'args', [])):
                message = e.args[0]
                if len(getattr(e.args[0], 'args', [])):
                    message = e.args[0].args[0]
                    if getattr(e.args[0].args[0], 'cause', None):
                        message = e.args[0].args[0].cause.message
            raise BadRequestException(str(message))
예제 #4
0
def get_secret(secret_arn: str) -> Secret:
    if secret_arn in SECRETS:
        return SECRETS[secret_arn]
    raise BadRequestException(
        f'Error fetching secret {secret_arn} : Secrets Manager can’t find the specified '
        f'secret. (Service: AWSSecretsManager; Status Code: 400; Error Code: '
        f'ResourceNotFoundException; Request ID:  00000000-1111-2222-3333-44444444444)'
    )
예제 #5
0
def test_get_resource_exception(clear, secrets, mocker) -> None:
    resource_arn: str = 'dummy_resource_arn'

    connection_maker = SQLite.create_connection_maker()

    RESOURCE_METAS[resource_arn] = ResourceMeta(SQLite, connection_maker,
                                                'localhost', 3306, 'test',
                                                'pw')

    with pytest.raises(BadRequestException):
        get_resource('invalid', 'dummy')

    with pytest.raises(InternalServerErrorException):
        CONNECTION_POOL['dummy'] = connection_maker()
        get_resource('invalid', 'dummy', 'dummy')
    del CONNECTION_POOL['dummy']

    with pytest.raises(BadRequestException):
        secrets.side_effect = BadRequestException('error')
        get_resource(resource_arn, 'dummy')

    with pytest.raises(InternalServerErrorException):
        secrets.side_effect = BadRequestException('error')
        CONNECTION_POOL['dummy'] = connection_maker()
        get_resource(resource_arn, 'dummy', 'dummy')

    secrets.side_effect = None
    secret = mocker.Mock()
    secret.user_name = 'invalid'
    secret.password = '******'

    secrets.return_value = secret
    with pytest.raises(BadRequestException):
        get_resource(resource_arn, 'dummy')

    secret = mocker.Mock()
    secret.user_name = 'test'
    secret.password = '******'

    secrets.return_value = secret
    with pytest.raises(BadRequestException):
        get_resource(resource_arn, 'dummy')
예제 #6
0
def get_resource(
    resource_arn: str,
    secret_arn: str,
    transaction_id: Optional[str] = None,
    database: Optional[str] = None,
) -> Resource:
    if resource_arn not in RESOURCE_METAS:
        if transaction_id in CONNECTION_POOL:
            raise InternalServerErrorException
        raise BadRequestException(
            f'HttpEndPoint is not enabled for {resource_arn}')

    try:
        secret: Secret = get_secret(secret_arn)
    except BadRequestException:
        if transaction_id in CONNECTION_POOL:
            raise InternalServerErrorException
        raise

    meta: ResourceMeta = RESOURCE_METAS[resource_arn]

    # TODO: support multiple secret_arn for a resource
    if secret.user_name != meta.user_name or secret.password != meta.password:
        raise BadRequestException('Invalid secret_arn')

    if transaction_id is None:
        connection: Connection = create_connection(resource_arn, database)
    else:
        connection = get_connection(transaction_id)
        if database:
            try:
                connected_database: Optional[str] = connection.database
            except AttributeError:  # pragma: no cover
                connected_database = connection.get_dsn_parameters(
                )[  # for psycopg2
                    'dbname']
            if database != connected_database:  # pragma: no cover
                raise BadRequestException(
                    'Database name is not the same as when transaction was created'
                )

    return meta.resource_type(connection, transaction_id)
예제 #7
0
    def execute(
        self,
        sql: str,
        params: Optional[Dict[str, Any]] = None,
        database_name: Optional[str] = None,
        include_result_metadata: bool = False,
    ) -> ExecuteStatementResponse:

        try:
            if database_name:
                self.use_database(database_name)

            cursor: Optional[Cursor] = None
            try:
                cursor = self.connection.cursor()
                if params:
                    cursor.execute(self.create_query(sql, params))
                else:
                    cursor.execute(str(text(sql)))

                if cursor.description:
                    response: ExecuteStatementResponse = ExecuteStatementResponse(
                        numberOfRecordsUpdated=0,
                        records=[[Field.from_value(column) for column in row]
                                 for row in cursor.fetchall()],
                    )
                    if include_result_metadata:
                        response.columnMetadata = [
                            create_column_metadata(*d)
                            for d in cursor.description
                        ]
                    return response
                else:
                    rowcount: int = cursor.rowcount
                    last_generated_id: int = cursor.lastrowid
                    generated_fields: List[Field] = []
                    if last_generated_id > 0:
                        generated_fields.append(
                            Field.from_value(last_generated_id))
                    return ExecuteStatementResponse(
                        numberOfRecordsUpdated=rowcount,
                        generatedFields=generated_fields,
                    )
            finally:
                if cursor:  # pragma: no cover
                    cursor.close()

        except Exception as e:
            message: str = 'Unknown'
            if hasattr(e, 'orig') and hasattr(e.orig, 'args'):  # type: ignore
                message = str(e.orig.args[1])  # type: ignore
            elif len(getattr(e, 'args', [])) and e.args[0]:
                message = str(e.args[0])
            raise BadRequestException(message)
예제 #8
0
def get_connection(transaction_id: str) -> Connection:
    if transaction_id in CONNECTION_POOL:
        return CONNECTION_POOL[transaction_id]
    raise BadRequestException('Invalid transaction ID')