示例#1
0
def test_source_wrong_credentials():
    source = SourceSendgrid()
    status, error = source.check_connection(
        logger=AirbyteLogger(), config={"apikey": "wrong.api.key123"})
    assert not status
示例#2
0
文件: source.py 项目: Mu-L/airbyte
 def _write_config(self, token):
     logger = AirbyteLogger()
     logger.info("Credentials Refreshed")
示例#3
0
class Error(Exception):
    """Base Error class for other exceptions"""

    # Define the instance of the Native Airbyte Logger
    logger = AirbyteLogger()
示例#4
0
import sys

import backoff
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.sources.streams.http.exceptions import DefaultBackoffException
from requests import codes, exceptions

TRANSIENT_EXCEPTIONS = (
    DefaultBackoffException,
    exceptions.ConnectTimeout,
    exceptions.ReadTimeout,
    exceptions.ConnectionError,
    exceptions.HTTPError,
)

logger = AirbyteLogger()


def default_backoff_handler(max_tries: int, factor: int, **kwargs):
    def log_retry_attempt(details):
        _, exc, _ = sys.exc_info()
        logger.info(str(exc))
        logger.info(
            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
        )

    def should_give_up(exc):
        give_up = exc.response is not None and exc.response.status_code != codes.too_many_requests and 400 <= exc.response.status_code < 500

        # Salesforce can return an error with a limit using a 403 code error.
        if exc.response is not None and exc.response.status_code == codes.forbidden:
示例#5
0
class AbstractFileParser(ABC):
    logger = AirbyteLogger()

    def __init__(self, format: dict, master_schema: dict = None):
        """
        :param format: file format specific mapping as described in spec.json
        :param master_schema: superset schema determined from all files, might be unused for some formats, defaults to None
        """
        self._format = format
        self._master_schema = (
            master_schema
            # this may need to be used differently by some formats, pyarrow allows extra columns in csv schema
        )

    @property
    @abstractmethod
    def is_binary(self) -> bool:
        """
        Override this per format so that file-like objects passed in are currently opened as binary or not
        """

    @abstractmethod
    def get_inferred_schema(self, file: Union[TextIO, BinaryIO]) -> dict:
        """
        Override this with format-specifc logic to infer the schema of file
        Note: needs to return inferred schema with JsonSchema datatypes

        :param file: file-like object (opened via StorageFile)
        :return: mapping of {columns:datatypes} where datatypes are JsonSchema types
        """

    @abstractmethod
    def stream_records(self, file: Union[TextIO, BinaryIO],
                       file_info: FileInfo) -> Iterator[Mapping[str, Any]]:
        """
        Override this with format-specifc logic to stream each data row from the file as a mapping of {columns:values}
        Note: avoid loading the whole file into memory to avoid OOM breakages

        :param file: file-like object (opened via StorageFile)
        :param file_info: file metadata
        :yield: data record as a mapping of {columns:values}
        """

    @staticmethod
    def json_type_to_pyarrow_type(
        typ: str,
        reverse: bool = False,
        logger: AirbyteLogger = AirbyteLogger()) -> str:
        """
        Converts Json Type to PyArrow types to (or the other way around if reverse=True)

        :param typ: Json type if reverse is False, else PyArrow type
        :param reverse: switch to True for PyArrow type -> Json type, defaults to False
        :param logger: defaults to AirbyteLogger()
        :return: PyArrow type if reverse is False, else Json type
        """
        str_typ = str(typ)
        # this is a map of airbyte types to pyarrow types. The first list element of the pyarrow types should be the one to use where required.
        map = {
            "boolean": ("bool_", "bool"),
            "integer": ("int64", "int8", "int16", "int32", "uint8", "uint16",
                        "uint32", "uint64"),
            "number": ("float64", "float16", "float32", "decimal128",
                       "decimal256", "halffloat", "float", "double"),
            "string": ("large_string", "string"),
            # TODO: support object type rather than coercing to string
            "object": ("large_string", ),
            # TODO: support array type rather than coercing to string
            "array": ("large_string", ),
            "null": ("large_string", ),
        }
        if not reverse:
            for json_type, pyarrow_types in map.items():
                if str_typ.lower() == json_type:
                    return str(
                        getattr(pa, pyarrow_types[0]).__call__()
                    )  # better way might be necessary when we decide to handle more type complexity
            logger.debug(
                f"JSON type '{str_typ}' is not mapped, falling back to default conversion to large_string"
            )
            return str(pa.large_string())
        else:
            for json_type, pyarrow_types in map.items():
                if any(
                        str_typ.startswith(pa_type)
                        for pa_type in pyarrow_types):
                    return json_type
            logger.debug(
                f"PyArrow type '{str_typ}' is not mapped, falling back to default conversion to string"
            )
            return "string"  # default type if unspecified in map

    @staticmethod
    def json_schema_to_pyarrow_schema(
            schema: Mapping[str, Any],
            reverse: bool = False) -> Mapping[str, Any]:
        """
        Converts a schema with JsonSchema datatypes to one with PyArrow types (or the other way if reverse=True)
        This utilises json_type_to_pyarrow_type() to convert each datatype

        :param schema: json/pyarrow schema to convert
        :param reverse: switch to True for PyArrow schema -> Json schema, defaults to False
        :return: converted schema dict
        """
        return {
            column:
            AbstractFileParser.json_type_to_pyarrow_type(json_type,
                                                         reverse=reverse)
            for column, json_type in schema.items()
        }
示例#6
0
class Client:
    api_version: int = 13
    refresh_token_safe_delta: int = 10  # in seconds
    logger: AirbyteLogger = AirbyteLogger()
    # retry on: rate limit errors, auth token expiration, internal errors
    # https://docs.microsoft.com/en-us/advertising/guides/services-protocol?view=bingads-13#throttling
    # https://docs.microsoft.com/en-us/advertising/guides/operation-error-codes?view=bingads-13
    retry_on_codes: Iterator[str] = ["117", "207", "4204", "109", "0"]
    max_retries: int = 3
    # A backoff factor to apply between attempts after the second try
    # {retry_factor} * (2 ** ({number of total retries} - 1))
    retry_factor: int = 15
    # environments supported by Microsoft Advertising: sandbox, production
    environment: str = "production"
    # The time interval in milliseconds between two status polling attempts.
    report_poll_interval: int = 15000

    def __init__(
        self,
        tenant_id: str,
        reports_start_date: str,
        developer_token: str = None,
        client_id: str = None,
        client_secret: str = None,
        refresh_token: str = None,
        **kwargs: Mapping[str, Any],
    ) -> None:
        self.authorization_data: Mapping[str, AuthorizationData] = {}
        self.refresh_token = refresh_token
        self.developer_token = developer_token

        self.client_id = client_id
        self.client_secret = client_secret

        self.authentication = self._get_auth_client(client_id, tenant_id,
                                                    client_secret)
        self.oauth: OAuthTokens = self._get_access_token()
        self.reports_start_date = pendulum.parse(
            reports_start_date).astimezone(tz=timezone.utc)

    def _get_auth_client(self,
                         client_id: str,
                         tenant_id: str,
                         client_secret: str = None) -> OAuthWebAuthCodeGrant:
        # https://github.com/BingAds/BingAds-Python-SDK/blob/e7b5a618e87a43d0a5e2c79d9aa4626e208797bd/bingads/authorization.py#L390
        auth_creds = {
            "client_id": client_id,
            "redirection_uri": "",  # should be empty string
            "client_secret": None,
            "tenant": tenant_id,
        }
        # the `client_secret` should be provided for `non-public clients` only
        # https://docs.microsoft.com/en-us/advertising/guides/authentication-oauth-get-tokens?view=bingads-13#request-accesstoken
        if client_secret and client_secret != "":
            auth_creds["client_secret"] = client_secret
        return OAuthWebAuthCodeGrant(**auth_creds)

    @lru_cache(maxsize=None)
    def _get_auth_data(self,
                       customer_id: str = None,
                       account_id: Optional[str] = None) -> AuthorizationData:
        return AuthorizationData(
            account_id=account_id,
            customer_id=customer_id,
            developer_token=self.developer_token,
            authentication=self.authentication,
        )

    def _get_access_token(self) -> OAuthTokens:
        self.logger.info("Fetching access token ...")
        # clear caches to be able to use new access token
        self.get_service.cache_clear()
        self._get_auth_data.cache_clear()
        return self.authentication.request_oauth_tokens_by_refresh_token(
            self.refresh_token)

    def is_token_expiring(self) -> bool:
        """
        Performs check if access token expiring in less than refresh_token_safe_delta seconds
        """
        token_total_lifetime: timedelta = datetime.utcnow(
        ) - self.oauth.access_token_received_datetime
        token_updated_expires_in: int = self.oauth.access_token_expires_in_seconds - token_total_lifetime.seconds
        return False if token_updated_expires_in > self.refresh_token_safe_delta else True

    def should_retry(self, error: WebFault) -> bool:
        error_code = str(errorcode_of_exception(error))
        give_up = error_code not in self.retry_on_codes
        if give_up:
            self.logger.error(
                f"Giving up for returned error code: {error_code}. Error details: {self._get_error_message(error)}"
            )
        return give_up

    def _get_error_message(self, error: WebFault) -> str:
        return str(self.asdict(error.fault)) if hasattr(
            error, "fault") else str(error)

    def log_retry_attempt(self, details: Mapping[str, Any]) -> None:
        _, exc, _ = sys.exc_info()
        self.logger.info(
            f"Caught retryable error: {self._get_error_message(exc)} after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
        )

    def request(self, **kwargs: Mapping[str, Any]) -> Mapping[str, Any]:
        return backoff.on_exception(
            backoff.expo,
            WebFault,
            max_tries=self.max_retries,
            factor=self.retry_factor,
            jitter=None,
            on_backoff=self.log_retry_attempt,
            giveup=self.should_retry,
        )(self._request)(**kwargs)

    def _request(
        self,
        service_name: Optional[str],
        operation_name: str,
        customer_id: Optional[str],
        account_id: Optional[str],
        params: Mapping[str, Any],
        is_report_service: bool = False,
    ) -> Mapping[str, Any]:
        """
        Executes appropriate Service Operation on Bing Ads API
        """
        if self.is_token_expiring():
            self.oauth = self._get_access_token()

        if is_report_service:
            service = self._get_reporting_service(customer_id=customer_id,
                                                  account_id=account_id)
        else:
            service = self.get_service(service_name=service_name,
                                       customer_id=customer_id,
                                       account_id=account_id)

        return getattr(service, operation_name)(**params)

    @lru_cache(maxsize=None)
    def get_service(
        self,
        service_name: str,
        customer_id: str = None,
        account_id: Optional[str] = None,
    ) -> ServiceClient:
        return ServiceClient(
            service=service_name,
            version=self.api_version,
            authorization_data=self._get_auth_data(customer_id, account_id),
            environment=self.environment,
        )

    @lru_cache(maxsize=None)
    def _get_reporting_service(
        self,
        customer_id: Optional[str] = None,
        account_id: Optional[str] = None,
    ) -> ServiceClient:
        return ReportingServiceManager(
            authorization_data=self._get_auth_data(customer_id, account_id),
            poll_interval_in_milliseconds=self.report_poll_interval,
            environment=self.environment,
        )

    @classmethod
    def asdict(cls, suds_object: sudsobject.Object) -> Mapping[str, Any]:
        """
        Converts nested Suds Object into serializable format.
        Input sample:
        {
            obj[] =
                {
                    value = 1
                },
                {
                    value = "str"
                },
        }
        Output sample: =>
        {'obj': [{'value': 1}, {'value': 'str'}]}
        """
        result: Mapping[str, Any] = {}

        for field, val in sudsobject.asdict(suds_object).items():
            if hasattr(val, "__keylist__"):
                result[field] = cls.asdict(val)
            elif isinstance(val, list):
                result[field] = []
                for item in val:
                    if hasattr(item, "__keylist__"):
                        result[field].append(cls.asdict(item))
                    else:
                        result[field].append(item)
            elif isinstance(val, datetime):
                result[field] = val.isoformat()
            else:
                result[field] = val
        return result
示例#7
0
文件: core.py 项目: zestyping/airbyte
class Stream(ABC):
    """
    Base abstract class for an Airbyte Stream. Makes no assumption of the Stream's underlying transport protocol.
    """

    # Use self.logger in subclasses to log any messages
    logger = AirbyteLogger(
    )  # TODO use native "logging" loggers with custom handlers

    @property
    def name(self) -> str:
        """
        :return: Stream name. By default this is the implementing class name, but it can be overridden as needed.
        """
        return casing.camel_to_snake(self.__class__.__name__)

    @abstractmethod
    def read_records(
        self,
        sync_mode: SyncMode,
        cursor_field: List[str] = None,
        stream_slice: Mapping[str, Any] = None,
        stream_state: Mapping[str, Any] = None,
    ) -> Iterable[Mapping[str, Any]]:
        """
        This method should be overridden by subclasses to read records based on the inputs
        """

    def get_json_schema(self) -> Mapping[str, Any]:
        """
        :return: A dict of the JSON schema representing this stream.

        The default implementation of this method looks for a JSONSchema file with the same name as this stream's "name" property.
        Override as needed.
        """
        # TODO show an example of using pydantic to define the JSON schema, or reading an OpenAPI spec
        return ResourceSchemaLoader(package_name_from_class(
            self.__class__)).get_schema(self.name)

    def as_airbyte_stream(self) -> AirbyteStream:
        stream = AirbyteStream(name=self.name,
                               json_schema=dict(self.get_json_schema()),
                               supported_sync_modes=[SyncMode.full_refresh])

        if self.supports_incremental:
            stream.source_defined_cursor = self.source_defined_cursor
            stream.supported_sync_modes.append(
                SyncMode.incremental)  # type: ignore
            stream.default_cursor_field = self._wrapped_cursor_field()

        keys = Stream._wrapped_primary_key(self.primary_key)
        if keys and len(keys) > 0:
            stream.source_defined_primary_key = keys

        return stream

    @property
    def supports_incremental(self) -> bool:
        """
        :return: True if this stream supports incrementally reading data
        """
        return len(self._wrapped_cursor_field()) > 0

    def _wrapped_cursor_field(self) -> List[str]:
        return [self.cursor_field] if isinstance(self.cursor_field,
                                                 str) else self.cursor_field

    @property
    def cursor_field(self) -> Union[str, List[str]]:
        """
        Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field.
        :return: The name of the field used as a cursor. If the cursor is nested, return an array consisting of the path to the cursor.
        """
        return []

    @property
    def source_defined_cursor(self) -> bool:
        """
        Return False if the cursor can be configured by the user.
        """
        return True

    @property
    @abstractmethod
    def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
        """
        :return: string if single primary key, list of strings if composite primary key, list of list of strings if composite primary key consisting of nested fields.
        If the stream has no primary keys, return None.
        """

    def stream_slices(
        self,
        sync_mode: SyncMode,
        cursor_field: List[str] = None,
        stream_state: Mapping[str, Any] = None
    ) -> Iterable[Optional[Mapping[str, Any]]]:
        """
        Override to define the slices for this stream. See the stream slicing section of the docs for more information.

        :param stream_state:
        :return:
        """
        return [None]

    @property
    def state_checkpoint_interval(self) -> Optional[int]:
        """
        Decides how often to checkpoint state (i.e: emit a STATE message). E.g: if this returns a value of 100, then state is persisted after reading
        100 records, then 200, 300, etc.. A good default value is 1000 although your mileage may vary depending on the underlying data source.

        Checkpointing a stream avoids re-reading records in the case a sync is failed or cancelled.

        return None if state should not be checkpointed e.g: because records returned from the underlying data source are not returned in
        ascending order with respect to the cursor field. This can happen if the source does not support reading records in ascending order of
        created_at date (or whatever the cursor is). In those cases, state must only be saved once the full stream has been read.
        """
        return None

    def get_updated_state(self, current_stream_state: MutableMapping[str, Any],
                          latest_record: Mapping[str, Any]):
        """
        Override to extract state from the latest record. Needed to implement incremental sync.

        Inspects the latest record extracted from the data source and the current state object and return an updated state object.

        For example: if the state object is based on created_at timestamp, and the current state is {'created_at': 10}, and the latest_record is
        {'name': 'octavia', 'created_at': 20 } then this method would return {'created_at': 20} to indicate state should be updated to this object.

        :param current_stream_state: The stream's current state object
        :param latest_record: The latest record extracted from the stream
        :return: An updated state object
        """
        return {}

    @staticmethod
    def _wrapped_primary_key(
        keys: Optional[Union[str, List[str], List[List[str]]]]
    ) -> Optional[List[List[str]]]:
        """
        :return: wrap the primary_key property in a list of list of strings required by the Airbyte Stream object.
        """
        if not keys:
            return None

        if isinstance(keys, str):
            return [[keys]]
        elif isinstance(keys, list):
            wrapped_keys = []
            for component in keys:
                if isinstance(component, str):
                    wrapped_keys.append([component])
                elif isinstance(component, list):
                    wrapped_keys.append(component)
                else:
                    raise ValueError("Element must be either list or str.")
            return wrapped_keys
        else:
            raise ValueError("Element must be either list or str.")
示例#8
0
    def check_config(self, logger: AirbyteLogger, config_path: str,
                     config: json) -> AirbyteConnectionStatus:
        """
        Tests if the input configuration can be used to successfully connect to the integration
            e.g: if a provided Stripe API token can be used to connect to the Stripe API.

        :param logger: Logging object to display debug/info/error to the logs
            (logs will not be accessible via airbyte UI if they are not passed to this logger)
        :param config_path: Path to the file containing the configuration json config
        :param config: Json object containing the configuration of this source, content of this json is as specified in
        the properties of the spec.json file

        :return: AirbyteConnectionStatus indicating a Success or Failure
        """
        try:
            # If an app on the appstore does not support subscriptions or sales, it cannot pull the relevant reports.
            # However, the way the Appstore API expresses this is not via clear error messages. Instead it expresses it by throwing an unrelated
            # error, in this case "invalid vendor ID". There is no way to distinguish if this error is due to invalid credentials or due to
            # the account not supporting this kind of report. So to "check connection" we see if any of the reports can be pulled and if so
            # return success. If no reports can be pulled we display the exception messages generated for all reports and return failure.
            api_fields_to_test = {
                "subscription_event_report": {
                    "reportType": "SUBSCRIPTION_EVENT",
                    "frequency": "DAILY",
                    "reportSubType": "SUMMARY",
                    "version": "1_2",
                },
                "subscriber_report": {
                    "reportType": "SUBSCRIBER",
                    "frequency": "DAILY",
                    "reportSubType": "DETAILED",
                    "version": "1_2"
                },
                "subscription_report": {
                    "reportType": "SUBSCRIPTION",
                    "frequency": "DAILY",
                    "reportSubType": "SUMMARY",
                    "version": "1_2"
                },
                "sales_report": {
                    "reportType": "SALES",
                    "frequency": "DAILY",
                    "reportSubType": "SUMMARY",
                    "version": "1_0"
                },
            }

            api = Api(config["key_id"], config["key_file"],
                      config["issuer_id"])
            stream_to_error = {}
            for stream, params in api_fields_to_test.items():
                test_date = date.today() - timedelta(days=2)
                report_filters = {
                    "reportDate": test_date.strftime("%Y-%m-%d"),
                    "vendorNumber": f"{config['vendor']}"
                }
                report_filters.update(api_fields_to_test[stream])
                try:
                    rep_tsv = api.download_sales_and_trends_reports(
                        filters=report_filters)
                    if isinstance(rep_tsv, dict):
                        raise Exception(
                            f"An exception occurred: Received a JSON response instead of"
                            f" the report: {str(rep_tsv)}")
                except Exception as e:
                    logger.warn(f"Unable to download {stream}: {e}")
                    stream_to_error[stream] = e

            # All streams have failed
            if len(stream_to_error.keys()) == api_fields_to_test.keys():
                message = "\n".join([
                    f"Unable to access {stream} due to error: {e}"
                    for stream, e in stream_to_error
                ])
                return AirbyteConnectionStatus(status=Status.FAILED,
                                               message=message)

            return AirbyteConnectionStatus(status=Status.SUCCEEDED)
        except Exception as e:
            logger.warn(e)
            return AirbyteConnectionStatus(
                status=Status.FAILED,
                message=f"An exception occurred: {str(e)}")
def logger() -> AirbyteLogger:
    return AirbyteLogger()
示例#10
0
class SourceHubspot(AbstractSource):
    logger = AirbyteLogger()

    def check_connection(
            self, logger: logging.Logger,
            config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
        """Check connection"""
        common_params = self.get_common_params(config=config)
        alive = True
        error_msg = None
        try:
            contacts = Contacts(**common_params)
            _ = contacts.properties
        except HTTPError as error:
            alive = False
            error_msg = repr(error)
        return alive, error_msg

    def get_granted_scopes(self, authenticator):
        try:
            access_token = authenticator.get_access_token()
            url = f"https://api.hubapi.com/oauth/v1/access-tokens/{access_token}"
            response = requests.get(url=url)
            response.raise_for_status()
            response_json = response.json()
            granted_scopes = response_json["scopes"]
            return granted_scopes
        except Exception as e:
            return False, repr(e)

    @staticmethod
    def get_api(config: Mapping[str, Any]) -> API:
        credentials = config.get("credentials", {})
        return API(credentials=credentials)

    def get_common_params(self, config) -> Mapping[str, Any]:
        start_date = config.get("start_date")
        credentials = config["credentials"]
        api = self.get_api(config=config)
        common_params = dict(api=api,
                             start_date=start_date,
                             credentials=credentials)

        if credentials.get("credentials_title") == "OAuth Credentials":
            common_params["authenticator"] = api.get_authenticator()
        return common_params

    def streams(self, config: Mapping[str, Any]) -> List[Stream]:
        credentials = config.get("credentials", {})
        common_params = self.get_common_params(config=config)
        streams = [
            Campaigns(**common_params),
            Companies(**common_params),
            ContactLists(**common_params),
            Contacts(**common_params),
            ContactsListMemberships(**common_params),
            DealPipelines(**common_params),
            Deals(**common_params),
            EmailEvents(**common_params),
            Engagements(**common_params),
            EngagementsCalls(**common_params),
            EngagementsEmails(**common_params),
            EngagementsMeetings(**common_params),
            EngagementsNotes(**common_params),
            EngagementsTasks(**common_params),
            FeedbackSubmissions(**common_params),
            Forms(**common_params),
            FormSubmissions(**common_params),
            LineItems(**common_params),
            MarketingEmails(**common_params),
            Owners(**common_params),
            Products(**common_params),
            PropertyHistory(**common_params),
            SubscriptionChanges(**common_params),
            Tickets(**common_params),
            TicketPipelines(**common_params),
            Workflows(**common_params),
        ]

        credentials_title = credentials.get("credentials_title")
        if credentials_title == "API Key Credentials":
            streams.append(Quotes(**common_params))

        api = API(credentials=credentials)
        if api.is_oauth2():
            authenticator = API(credentials=credentials).get_authenticator()
            granted_scopes = self.get_granted_scopes(authenticator)
            self.logger.info(
                f"The following scopes were granted: {granted_scopes}")

            available_streams = [
                stream for stream in streams
                if stream.scope_is_granted(granted_scopes)
            ]
            unavailable_streams = [
                stream for stream in streams
                if not stream.scope_is_granted(granted_scopes)
            ]
            self.logger.info(
                f"The following streams are unavailable: {[s.name for s in unavailable_streams]}"
            )
        else:
            self.logger.info(
                "No scopes to grant when authenticating with API key.")
            available_streams = streams

        return available_streams

    def read(
            self,
            logger: logging.Logger,
            config: Mapping[str, Any],
            catalog: ConfiguredAirbyteCatalog,
            state: MutableMapping[str,
                                  Any] = None) -> Iterator[AirbyteMessage]:
        """
        This method is overridden to check whether the stream `quotes` exists in the source, if not skip reading that stream.
        """
        connector_state = copy.deepcopy(state or {})
        logger.info(f"Starting syncing {self.name}")
        config, internal_config = split_config(config)
        # TODO assert all streams exist in the connector
        # get the streams once in case the connector needs to make any queries to generate them
        stream_instances = {s.name: s for s in self.streams(config)}
        self._stream_to_instance_map = stream_instances
        with create_timer(self.name) as timer:
            for configured_stream in catalog.streams:
                stream_instance = stream_instances.get(
                    configured_stream.stream.name)
                if not stream_instance and configured_stream.stream.name == "quotes":
                    logger.warning(
                        "Stream `quotes` does not exist in the source. Skip reading `quotes` stream."
                    )
                    continue
                if not stream_instance:
                    raise KeyError(
                        f"The requested stream {configured_stream.stream.name} was not found in the source. Available streams: {stream_instances.keys()}"
                    )

                try:
                    yield from self._read_stream(
                        logger=logger,
                        stream_instance=stream_instance,
                        configured_stream=configured_stream,
                        connector_state=connector_state,
                        internal_config=internal_config,
                    )
                except Exception as e:
                    logger.exception(
                        f"Encountered an exception while reading stream {configured_stream.stream.name}"
                    )
                    display_message = stream_instance.get_error_display_message(
                        e)
                    if display_message:
                        raise AirbyteTracedException.from_exception(
                            e, message=display_message) from e
                    raise e
                finally:
                    logger.info(f"Finished syncing {self.name}")
                    logger.info(timer.report())

        logger.info(f"Finished syncing {self.name}")

    def _read_incremental(
        self,
        logger: logging.Logger,
        stream_instance: Stream,
        configured_stream: ConfiguredAirbyteStream,
        connector_state: MutableMapping[str, Any],
        internal_config: InternalConfig,
    ) -> Iterator[AirbyteMessage]:
        """
        This method is overridden to checkpoint the latest actual state,
        because stream state is refreshed after reading each batch of records (if need_chunk is True),
        or reading all records in the stream.
        """
        yield from super()._read_incremental(
            logger=logger,
            stream_instance=stream_instance,
            configured_stream=configured_stream,
            connector_state=connector_state,
            internal_config=internal_config,
        )
        stream_state = stream_instance.get_updated_state(
            current_stream_state={}, latest_record={})
        yield self._checkpoint_state(stream_instance, stream_state,
                                     connector_state)
示例#11
0
def test_source_wrong_credentials():
    source = SourceSquare()
    status, error = source.check_connection(logger=AirbyteLogger(), config={"api_key": "wrong.api.key", "is_sandbox": True})
    assert not status
示例#12
0
文件: api_test.py 项目: Mu-L/airbyte
def test_rate_limit_rest(stream_config, stream_api, configured_catalog, state):
    """
    Connector should stop the sync if one stream reached rate limit
    stream_1, stream_2, stream_3, ...
    While reading `stream_1` if 403 (Rate Limit) is received, it should finish that stream with success and stop the sync process.
    Next streams should not be executed.
    """

    stream_1: IncrementalSalesforceStream = generate_stream("Account",
                                                            stream_config,
                                                            stream_api,
                                                            state=state)
    stream_2: IncrementalSalesforceStream = generate_stream("Asset",
                                                            stream_config,
                                                            stream_api,
                                                            state=state)

    stream_1.state_checkpoint_interval = 3
    configure_request_params_mock(stream_1, stream_2)

    source = SourceSalesforce()
    source.streams = Mock()
    source.streams.return_value = [stream_1, stream_2]

    logger = AirbyteLogger()

    next_page_url = "/services/data/v52.0/query/012345"
    response_1 = {
        "done":
        False,
        "totalSize":
        10,
        "nextRecordsUrl":
        next_page_url,
        "records": [
            {
                "ID": 1,
                "LastModifiedDate": "2021-11-15",
            },
            {
                "ID": 2,
                "LastModifiedDate": "2021-11-16",
            },
            {
                "ID": 3,
                "LastModifiedDate": "2021-11-17",  # check point interval
            },
            {
                "ID": 4,
                "LastModifiedDate": "2021-11-18",
            },
            {
                "ID": 5,
                "LastModifiedDate": "2021-11-19",
            },
        ],
    }
    response_2 = [{
        "errorCode": "REQUEST_LIMIT_EXCEEDED",
        "message": "TotalRequests Limit exceeded."
    }]

    with requests_mock.Mocker() as m:
        m.register_uri("GET",
                       stream_1.path(),
                       json=response_1,
                       status_code=200)
        m.register_uri("GET", next_page_url, json=response_2, status_code=403)

        result = [
            i for i in source.read(logger=logger,
                                   config=stream_config,
                                   catalog=configured_catalog,
                                   state=state)
        ]

        assert stream_1.request_params.called
        assert (
            not stream_2.request_params.called
        ), "The second stream should not be executed, because the first stream finished with Rate Limit."

        records = [item for item in result if item.type == Type.RECORD]
        assert len(records) == 5

        state_record = [item for item in result if item.type == Type.STATE][0]
        assert state_record.state.data["Account"][
            "LastModifiedDate"] == "2021-11-17"
示例#13
0
文件: api_test.py 项目: Mu-L/airbyte
def test_rate_limit_bulk(stream_config, stream_api, configured_catalog, state):
    """
    Connector should stop the sync if one stream reached rate limit
    stream_1, stream_2, stream_3, ...
    While reading `stream_1` if 403 (Rate Limit) is received, it should finish that stream with success and stop the sync process.
    Next streams should not be executed.
    """
    stream_1: BulkIncrementalSalesforceStream = generate_stream(
        "Account", stream_config, stream_api)
    stream_2: BulkIncrementalSalesforceStream = generate_stream(
        "Asset", stream_config, stream_api)
    streams = [stream_1, stream_2]
    configure_request_params_mock(stream_1, stream_2)

    stream_1.page_size = 6
    stream_1.state_checkpoint_interval = 5

    source = SourceSalesforce()
    source.streams = Mock()
    source.streams.return_value = streams
    logger = AirbyteLogger()

    json_response = [{
        "errorCode": "REQUEST_LIMIT_EXCEEDED",
        "message": "TotalRequests Limit exceeded."
    }]
    with requests_mock.Mocker() as m:
        for stream in streams:
            creation_responses = []
            for page in [1, 2]:
                job_id = f"fake_job_{page}_{stream.name}"
                creation_responses.append({"json": {"id": job_id}})

                m.register_uri("GET",
                               stream.path() + f"/{job_id}",
                               json={"state": "JobComplete"})

                resp = ["Field1,LastModifiedDate,ID"
                        ] + [f"test,2021-11-0{i},{i}"
                             for i in range(1, 7)]  # 6 records per page

                if page == 1:
                    # Read the first page successfully
                    m.register_uri("GET",
                                   stream.path() + f"/{job_id}/results",
                                   text="\n".join(resp))
                else:
                    # Requesting for results when reading second page should fail with 403 (Rate Limit error)
                    m.register_uri("GET",
                                   stream.path() + f"/{job_id}/results",
                                   status_code=403,
                                   json=json_response)

                m.register_uri("DELETE", stream.path() + f"/{job_id}")

            m.register_uri("POST", stream.path(), creation_responses)

        result = [
            i for i in source.read(logger=logger,
                                   config=stream_config,
                                   catalog=configured_catalog,
                                   state=state)
        ]
        assert stream_1.request_params.called
        assert (
            not stream_2.request_params.called
        ), "The second stream should not be executed, because the first stream finished with Rate Limit."

        records = [item for item in result if item.type == Type.RECORD]
        assert len(records) == 6  # stream page size: 6

        state_record = [item for item in result if item.type == Type.STATE][0]
        assert state_record.state.data["Account"][
            "LastModifiedDate"] == "2021-11-05"  # state checkpoint interval is 5.
示例#14
0
def test_client_wrong_credentials():
    source = SourcePosthog()
    status, error = source.check_connection(logger=AirbyteLogger(),
                                            config={"api_key": "blahblah"})
    assert not status
示例#15
0
    def check(self, logger: AirbyteLogger, config: json) -> AirbyteConnectionStatus:
        try:
            if "max_batch_size" in config:
                # Max batch size must be between 1 and 10
                if config["max_batch_size"] > 10 or config["max_batch_size"] < 1:
                    raise Exception("max_batch_size must be between 1 and 10")
            if "max_wait_time" in config:
                # Max wait time must be between 1 and 20
                if config["max_wait_time"] > 20 or config["max_wait_time"] < 1:
                    raise Exception("max_wait_time must be between 1 and 20")

            # Required propeties
            queue_url = config["queue_url"]
            logger.debug("Amazon SQS Source Config Check - queue_url: " + queue_url)
            queue_region = config["region"]
            logger.debug("Amazon SQS Source Config Check - region: " + queue_region)
            # Senstive Properties
            access_key = config["access_key"]
            logger.debug("Amazon SQS Source Config Check - access_key (ends with): " + access_key[-1])
            secret_key = config["secret_key"]
            logger.debug("Amazon SQS Source Config Check - secret_key (ends with): " + secret_key[-1])

            logger.debug("Amazon SQS Source Config Check - Starting connection test ---")
            session = boto3.Session(aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=queue_region)
            sqs = session.resource("sqs")
            queue = sqs.Queue(url=queue_url)
            if hasattr(queue, "attributes"):
                logger.debug("Amazon SQS Source Config Check - Connection test successful ---")
                return AirbyteConnectionStatus(status=Status.SUCCEEDED)
            else:
                return AirbyteConnectionStatus(status=Status.FAILED, message="Amazon SQS Source Config Check - Could not connect to queue")
        except ClientError as e:
            return AirbyteConnectionStatus(status=Status.FAILED, message=f"Amazon SQS Source Config Check - Error in AWS Client: {str(e)}")
        except Exception as e:
            return AirbyteConnectionStatus(
                status=Status.FAILED, message=f"Amazon SQS Source Config Check - An exception occurred: {str(e)}"
            )
示例#16
0
from traceback import format_exc
from typing import Any, Iterable, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union

from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models.airbyte_protocol import SyncMode
from airbyte_cdk.sources.streams import Stream
from wcmatch.glob import GLOBSTAR, SPLIT, globmatch

from .formats.csv_parser import CsvParser
from .formats.parquet_parser import ParquetParser

JSON_TYPES = [
    "string", "number", "integer", "object", "array", "boolean", "null"
]

LOGGER = AirbyteLogger()


class ConfigurationError(Exception):
    """Client mis-configured"""


class FileStream(Stream, ABC):
    @property
    def fileformatparser_map(self):
        """Mapping where every key is equal  'filetype' and values are  corresponding  parser classes."""
        return {
            "csv": CsvParser,
            "parquet": ParquetParser,
        }
示例#17
0
    def read(
        self, logger: AirbyteLogger, config: json, catalog: ConfiguredAirbyteCatalog, state: Dict[str, any]
    ) -> Generator[AirbyteMessage, None, None]:
        stream_name = self.parse_queue_name(config["queue_url"])
        logger.debug("Amazon SQS Source Read - stream is: " + stream_name)

        # Required propeties
        queue_url = config["queue_url"]
        queue_region = config["region"]
        delete_messages = config["delete_messages"]

        # Optional Properties
        max_batch_size = config.get("max_batch_size", 10)
        max_wait_time = config.get("max_wait_time", 20)
        visibility_timeout = config.get("visibility_timeout")
        attributes_to_return = config.get("attributes_to_return")
        if attributes_to_return is None:
            attributes_to_return = ["All"]
        else:
            attributes_to_return = attributes_to_return.split(",")

        # Senstive Properties
        access_key = config["access_key"]
        secret_key = config["secret_key"]

        logger.debug("Amazon SQS Source Read - Creating SQS connection ---")
        session = boto3.Session(aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=queue_region)
        sqs = session.resource("sqs")
        queue = sqs.Queue(url=queue_url)
        logger.debug("Amazon SQS Source Read - Connected to SQS Queue ---")
        timed_out = False
        while not timed_out:
            try:
                logger.debug("Amazon SQS Source Read - Beginning message poll ---")
                messages = queue.receive_messages(
                    MessageAttributeNames=attributes_to_return, MaxNumberOfMessages=max_batch_size, WaitTimeSeconds=max_wait_time
                )

                if not messages:
                    logger.debug("Amazon SQS Source Read - No messages recieved during poll, time out reached ---")
                    timed_out = True
                    break

                for msg in messages:
                    logger.debug("Amazon SQS Source Read - Message recieved: " + msg.message_id)
                    if visibility_timeout:
                        logger.debug("Amazon SQS Source Read - Setting message visibility timeout: " + msg.message_id)
                        self.change_message_visibility(msg, visibility_timeout)
                        logger.debug("Amazon SQS Source Read - Message visibility timeout set: " + msg.message_id)

                    data = {
                        "id": msg.message_id,
                        "body": msg.body,
                        "attributes": msg.message_attributes,
                    }

                    # TODO: Support a 'BATCH OUTPUT' mode that outputs the full batch in a single AirbyteRecordMessage
                    yield AirbyteMessage(
                        type=Type.RECORD,
                        record=AirbyteRecordMessage(stream=stream_name, data=data, emitted_at=int(datetime.now().timestamp()) * 1000),
                    )
                    if delete_messages:
                        logger.debug("Amazon SQS Source Read - Deleting message: " + msg.message_id)
                        self.delete_message(msg)
                        logger.debug("Amazon SQS Source Read - Message deleted: " + msg.message_id)
                        # TODO: Delete messages in batches to reduce amount of requests?

            except ClientError as error:
                raise Exception("Error in AWS Client: " + str(error))
示例#18
0
    def check(self, logger: AirbyteLogger,
              config: json) -> AirbyteConnectionStatus:
        # Check involves verifying that the specified spreadsheet is reachable with our credentials.
        try:
            client = GoogleSheetsClient(self.get_credentials(config))
        except Exception as e:
            return AirbyteConnectionStatus(
                status=Status.FAILED,
                message=f"Please use valid credentials json file. Error: {e}")

        spreadsheet_id = Helpers.get_spreadsheet_id(config["spreadsheet_id"])

        try:
            # Attempt to get first row of sheet
            client.get(spreadsheetId=spreadsheet_id,
                       includeGridData=False,
                       ranges="1:1")
        except errors.HttpError as err:
            reason = str(err)
            # Give a clearer message if it's a common error like 404.
            if err.resp.status == status_codes.NOT_FOUND:
                reason = "Requested spreadsheet was not found."
            logger.error(f"Formatted error: {reason}")
            return AirbyteConnectionStatus(
                status=Status.FAILED,
                message=
                f"Unable to connect with the provided credentials to spreadsheet. Error: {reason}"
            )

        # Check for duplicate headers
        spreadsheet_metadata = Spreadsheet.parse_obj(
            client.get(spreadsheetId=spreadsheet_id, includeGridData=False))

        grid_sheets = Helpers.get_grid_sheets(spreadsheet_metadata)

        duplicate_headers_in_sheet = {}
        for sheet_name in grid_sheets:
            try:
                header_row_data = Helpers.get_first_row(
                    client, spreadsheet_id, sheet_name)
                _, duplicate_headers = Helpers.get_valid_headers_and_duplicates(
                    header_row_data)
                if duplicate_headers:
                    duplicate_headers_in_sheet[sheet_name] = duplicate_headers
            except Exception as err:
                if str(err).startswith(
                        "Expected data for exactly one row for sheet"):
                    logger.warn(f"Skip empty sheet: {sheet_name}")
                else:
                    logger.error(str(err))
                    return AirbyteConnectionStatus(
                        status=Status.FAILED,
                        message=
                        f"Unable to read the schema of sheet {sheet_name}. Error: {str(err)}"
                    )
        if duplicate_headers_in_sheet:
            duplicate_headers_error_message = ", ".join([
                f"[sheet:{sheet_name}, headers:{duplicate_sheet_headers}]"
                for sheet_name, duplicate_sheet_headers in
                duplicate_headers_in_sheet.items()
            ])
            return AirbyteConnectionStatus(
                status=Status.FAILED,
                message=
                "The following duplicate headers were found in the following sheets. Please fix them to continue: "
                + duplicate_headers_error_message,
            )

        return AirbyteConnectionStatus(status=Status.SUCCEEDED)
示例#19
0
class ChargebeeStream(Stream):
    supports_incremental = True
    primary_key = "id"
    default_cursor_field = "updated_at"
    logger = AirbyteLogger()

    def __init__(self):
        self.next_offset = None
        # Request params below
        # according to Chargebee's guidance on pagination
        # https://apidocs.chargebee.com/docs/api/#pagination_and_filtering
        self.params = {
            "limit": 100,  # Limit at 100
            "sort_by[asc]":
            self.default_cursor_field,  # Sort ascending by updated_at
        }
        super().__init__()

    @backoff.on_exception(
        backoff.expo,  # Exponential back-off
        OperationFailedError,  # Only on Chargebee's OperationFailedError
        max_tries=MAX_TRIES,
        max_time=MAX_TIME,
    )
    def _send_request(self) -> ListResult:
        """
        Just a wrapper to allow @backoff decorator
        Reference: https://apidocs.chargebee.com/docs/api/#error_codes_list
        """
        # From Chargebee
        # Link: https://apidocs.chargebee.com/docs/api/#api_rate_limits
        list_result = self.api.list(self.params)
        return list_result

    def read_records(
        self,
        sync_mode: SyncMode,
        cursor_field: List[str] = None,
        stream_slice: Mapping[str, Any] = None,
        stream_state: Mapping[str, Any] = None,
    ) -> Iterable[Mapping[str, Any]]:
        """
        Override airbyte_cdk Stream's read_records method
        """
        # Add offset to params if found
        # Reference for Chargebee's pagination strategy below:
        # https://apidocs.chargebee.com/docs/api/#pagination_and_filtering
        pagination_completed = False
        if stream_state:
            self.params.update(stream_state)
        # Loop until pagination is completed
        while not pagination_completed:
            # Request the ListResult object from Chargebee
            # with back-off implemented through self._send_request()
            list_result = self._send_request()
            # Read message from results
            for message in list_result:
                yield message._response[self.name]
            # Get next page token
            self.next_offset = list_result.next_offset
            if self.next_offset:
                self.params.update({"offset": self.next_offset})
            else:
                pagination_completed = True

        # Always return an empty generator just in case no records were ever yielded
        yield from []

    def get_updated_state(
        self,
        current_stream_state: MutableMapping[str, Any],
        latest_record: Mapping[str, Any],
    ):
        """
        Override airbyte_cdk Stream's get_updated_state method
        to get the latest Chargebee stream state
        """
        # Init the current_stream_state
        current_stream_state = current_stream_state or {}

        # Get current timestamp
        # so Stream will sync all records
        # that have been updated before now
        now = pendulum.now().int_timestamp
        current_stream_state.update({
            "update_at[before]": now,
        })

        # Get the updated_at field from the latest record
        # using Chargebee's Model class
        # so Stream will sync all records
        # that have been updated since then
        print(latest_record)
        latest_updated_at = latest_record.get("updated_at")
        if latest_updated_at:
            current_stream_state.update({
                "update_at[after]": latest_updated_at,
            })

        return current_stream_state