def test_source_wrong_credentials(): source = SourceSendgrid() status, error = source.check_connection( logger=AirbyteLogger(), config={"apikey": "wrong.api.key123"}) assert not status
def _write_config(self, token): logger = AirbyteLogger() logger.info("Credentials Refreshed")
class Error(Exception): """Base Error class for other exceptions""" # Define the instance of the Native Airbyte Logger logger = AirbyteLogger()
import sys import backoff from airbyte_cdk.logger import AirbyteLogger from airbyte_cdk.sources.streams.http.exceptions import DefaultBackoffException from requests import codes, exceptions TRANSIENT_EXCEPTIONS = ( DefaultBackoffException, exceptions.ConnectTimeout, exceptions.ReadTimeout, exceptions.ConnectionError, exceptions.HTTPError, ) logger = AirbyteLogger() def default_backoff_handler(max_tries: int, factor: int, **kwargs): def log_retry_attempt(details): _, exc, _ = sys.exc_info() logger.info(str(exc)) logger.info( f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." ) def should_give_up(exc): give_up = exc.response is not None and exc.response.status_code != codes.too_many_requests and 400 <= exc.response.status_code < 500 # Salesforce can return an error with a limit using a 403 code error. if exc.response is not None and exc.response.status_code == codes.forbidden:
class AbstractFileParser(ABC): logger = AirbyteLogger() def __init__(self, format: dict, master_schema: dict = None): """ :param format: file format specific mapping as described in spec.json :param master_schema: superset schema determined from all files, might be unused for some formats, defaults to None """ self._format = format self._master_schema = ( master_schema # this may need to be used differently by some formats, pyarrow allows extra columns in csv schema ) @property @abstractmethod def is_binary(self) -> bool: """ Override this per format so that file-like objects passed in are currently opened as binary or not """ @abstractmethod def get_inferred_schema(self, file: Union[TextIO, BinaryIO]) -> dict: """ Override this with format-specifc logic to infer the schema of file Note: needs to return inferred schema with JsonSchema datatypes :param file: file-like object (opened via StorageFile) :return: mapping of {columns:datatypes} where datatypes are JsonSchema types """ @abstractmethod def stream_records(self, file: Union[TextIO, BinaryIO], file_info: FileInfo) -> Iterator[Mapping[str, Any]]: """ Override this with format-specifc logic to stream each data row from the file as a mapping of {columns:values} Note: avoid loading the whole file into memory to avoid OOM breakages :param file: file-like object (opened via StorageFile) :param file_info: file metadata :yield: data record as a mapping of {columns:values} """ @staticmethod def json_type_to_pyarrow_type( typ: str, reverse: bool = False, logger: AirbyteLogger = AirbyteLogger()) -> str: """ Converts Json Type to PyArrow types to (or the other way around if reverse=True) :param typ: Json type if reverse is False, else PyArrow type :param reverse: switch to True for PyArrow type -> Json type, defaults to False :param logger: defaults to AirbyteLogger() :return: PyArrow type if reverse is False, else Json type """ str_typ = str(typ) # this is a map of airbyte types to pyarrow types. The first list element of the pyarrow types should be the one to use where required. map = { "boolean": ("bool_", "bool"), "integer": ("int64", "int8", "int16", "int32", "uint8", "uint16", "uint32", "uint64"), "number": ("float64", "float16", "float32", "decimal128", "decimal256", "halffloat", "float", "double"), "string": ("large_string", "string"), # TODO: support object type rather than coercing to string "object": ("large_string", ), # TODO: support array type rather than coercing to string "array": ("large_string", ), "null": ("large_string", ), } if not reverse: for json_type, pyarrow_types in map.items(): if str_typ.lower() == json_type: return str( getattr(pa, pyarrow_types[0]).__call__() ) # better way might be necessary when we decide to handle more type complexity logger.debug( f"JSON type '{str_typ}' is not mapped, falling back to default conversion to large_string" ) return str(pa.large_string()) else: for json_type, pyarrow_types in map.items(): if any( str_typ.startswith(pa_type) for pa_type in pyarrow_types): return json_type logger.debug( f"PyArrow type '{str_typ}' is not mapped, falling back to default conversion to string" ) return "string" # default type if unspecified in map @staticmethod def json_schema_to_pyarrow_schema( schema: Mapping[str, Any], reverse: bool = False) -> Mapping[str, Any]: """ Converts a schema with JsonSchema datatypes to one with PyArrow types (or the other way if reverse=True) This utilises json_type_to_pyarrow_type() to convert each datatype :param schema: json/pyarrow schema to convert :param reverse: switch to True for PyArrow schema -> Json schema, defaults to False :return: converted schema dict """ return { column: AbstractFileParser.json_type_to_pyarrow_type(json_type, reverse=reverse) for column, json_type in schema.items() }
class Client: api_version: int = 13 refresh_token_safe_delta: int = 10 # in seconds logger: AirbyteLogger = AirbyteLogger() # retry on: rate limit errors, auth token expiration, internal errors # https://docs.microsoft.com/en-us/advertising/guides/services-protocol?view=bingads-13#throttling # https://docs.microsoft.com/en-us/advertising/guides/operation-error-codes?view=bingads-13 retry_on_codes: Iterator[str] = ["117", "207", "4204", "109", "0"] max_retries: int = 3 # A backoff factor to apply between attempts after the second try # {retry_factor} * (2 ** ({number of total retries} - 1)) retry_factor: int = 15 # environments supported by Microsoft Advertising: sandbox, production environment: str = "production" # The time interval in milliseconds between two status polling attempts. report_poll_interval: int = 15000 def __init__( self, tenant_id: str, reports_start_date: str, developer_token: str = None, client_id: str = None, client_secret: str = None, refresh_token: str = None, **kwargs: Mapping[str, Any], ) -> None: self.authorization_data: Mapping[str, AuthorizationData] = {} self.refresh_token = refresh_token self.developer_token = developer_token self.client_id = client_id self.client_secret = client_secret self.authentication = self._get_auth_client(client_id, tenant_id, client_secret) self.oauth: OAuthTokens = self._get_access_token() self.reports_start_date = pendulum.parse( reports_start_date).astimezone(tz=timezone.utc) def _get_auth_client(self, client_id: str, tenant_id: str, client_secret: str = None) -> OAuthWebAuthCodeGrant: # https://github.com/BingAds/BingAds-Python-SDK/blob/e7b5a618e87a43d0a5e2c79d9aa4626e208797bd/bingads/authorization.py#L390 auth_creds = { "client_id": client_id, "redirection_uri": "", # should be empty string "client_secret": None, "tenant": tenant_id, } # the `client_secret` should be provided for `non-public clients` only # https://docs.microsoft.com/en-us/advertising/guides/authentication-oauth-get-tokens?view=bingads-13#request-accesstoken if client_secret and client_secret != "": auth_creds["client_secret"] = client_secret return OAuthWebAuthCodeGrant(**auth_creds) @lru_cache(maxsize=None) def _get_auth_data(self, customer_id: str = None, account_id: Optional[str] = None) -> AuthorizationData: return AuthorizationData( account_id=account_id, customer_id=customer_id, developer_token=self.developer_token, authentication=self.authentication, ) def _get_access_token(self) -> OAuthTokens: self.logger.info("Fetching access token ...") # clear caches to be able to use new access token self.get_service.cache_clear() self._get_auth_data.cache_clear() return self.authentication.request_oauth_tokens_by_refresh_token( self.refresh_token) def is_token_expiring(self) -> bool: """ Performs check if access token expiring in less than refresh_token_safe_delta seconds """ token_total_lifetime: timedelta = datetime.utcnow( ) - self.oauth.access_token_received_datetime token_updated_expires_in: int = self.oauth.access_token_expires_in_seconds - token_total_lifetime.seconds return False if token_updated_expires_in > self.refresh_token_safe_delta else True def should_retry(self, error: WebFault) -> bool: error_code = str(errorcode_of_exception(error)) give_up = error_code not in self.retry_on_codes if give_up: self.logger.error( f"Giving up for returned error code: {error_code}. Error details: {self._get_error_message(error)}" ) return give_up def _get_error_message(self, error: WebFault) -> str: return str(self.asdict(error.fault)) if hasattr( error, "fault") else str(error) def log_retry_attempt(self, details: Mapping[str, Any]) -> None: _, exc, _ = sys.exc_info() self.logger.info( f"Caught retryable error: {self._get_error_message(exc)} after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." ) def request(self, **kwargs: Mapping[str, Any]) -> Mapping[str, Any]: return backoff.on_exception( backoff.expo, WebFault, max_tries=self.max_retries, factor=self.retry_factor, jitter=None, on_backoff=self.log_retry_attempt, giveup=self.should_retry, )(self._request)(**kwargs) def _request( self, service_name: Optional[str], operation_name: str, customer_id: Optional[str], account_id: Optional[str], params: Mapping[str, Any], is_report_service: bool = False, ) -> Mapping[str, Any]: """ Executes appropriate Service Operation on Bing Ads API """ if self.is_token_expiring(): self.oauth = self._get_access_token() if is_report_service: service = self._get_reporting_service(customer_id=customer_id, account_id=account_id) else: service = self.get_service(service_name=service_name, customer_id=customer_id, account_id=account_id) return getattr(service, operation_name)(**params) @lru_cache(maxsize=None) def get_service( self, service_name: str, customer_id: str = None, account_id: Optional[str] = None, ) -> ServiceClient: return ServiceClient( service=service_name, version=self.api_version, authorization_data=self._get_auth_data(customer_id, account_id), environment=self.environment, ) @lru_cache(maxsize=None) def _get_reporting_service( self, customer_id: Optional[str] = None, account_id: Optional[str] = None, ) -> ServiceClient: return ReportingServiceManager( authorization_data=self._get_auth_data(customer_id, account_id), poll_interval_in_milliseconds=self.report_poll_interval, environment=self.environment, ) @classmethod def asdict(cls, suds_object: sudsobject.Object) -> Mapping[str, Any]: """ Converts nested Suds Object into serializable format. Input sample: { obj[] = { value = 1 }, { value = "str" }, } Output sample: => {'obj': [{'value': 1}, {'value': 'str'}]} """ result: Mapping[str, Any] = {} for field, val in sudsobject.asdict(suds_object).items(): if hasattr(val, "__keylist__"): result[field] = cls.asdict(val) elif isinstance(val, list): result[field] = [] for item in val: if hasattr(item, "__keylist__"): result[field].append(cls.asdict(item)) else: result[field].append(item) elif isinstance(val, datetime): result[field] = val.isoformat() else: result[field] = val return result
class Stream(ABC): """ Base abstract class for an Airbyte Stream. Makes no assumption of the Stream's underlying transport protocol. """ # Use self.logger in subclasses to log any messages logger = AirbyteLogger( ) # TODO use native "logging" loggers with custom handlers @property def name(self) -> str: """ :return: Stream name. By default this is the implementing class name, but it can be overridden as needed. """ return casing.camel_to_snake(self.__class__.__name__) @abstractmethod def read_records( self, sync_mode: SyncMode, cursor_field: List[str] = None, stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, ) -> Iterable[Mapping[str, Any]]: """ This method should be overridden by subclasses to read records based on the inputs """ def get_json_schema(self) -> Mapping[str, Any]: """ :return: A dict of the JSON schema representing this stream. The default implementation of this method looks for a JSONSchema file with the same name as this stream's "name" property. Override as needed. """ # TODO show an example of using pydantic to define the JSON schema, or reading an OpenAPI spec return ResourceSchemaLoader(package_name_from_class( self.__class__)).get_schema(self.name) def as_airbyte_stream(self) -> AirbyteStream: stream = AirbyteStream(name=self.name, json_schema=dict(self.get_json_schema()), supported_sync_modes=[SyncMode.full_refresh]) if self.supports_incremental: stream.source_defined_cursor = self.source_defined_cursor stream.supported_sync_modes.append( SyncMode.incremental) # type: ignore stream.default_cursor_field = self._wrapped_cursor_field() keys = Stream._wrapped_primary_key(self.primary_key) if keys and len(keys) > 0: stream.source_defined_primary_key = keys return stream @property def supports_incremental(self) -> bool: """ :return: True if this stream supports incrementally reading data """ return len(self._wrapped_cursor_field()) > 0 def _wrapped_cursor_field(self) -> List[str]: return [self.cursor_field] if isinstance(self.cursor_field, str) else self.cursor_field @property def cursor_field(self) -> Union[str, List[str]]: """ Override to return the default cursor field used by this stream e.g: an API entity might always use created_at as the cursor field. :return: The name of the field used as a cursor. If the cursor is nested, return an array consisting of the path to the cursor. """ return [] @property def source_defined_cursor(self) -> bool: """ Return False if the cursor can be configured by the user. """ return True @property @abstractmethod def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: """ :return: string if single primary key, list of strings if composite primary key, list of list of strings if composite primary key consisting of nested fields. If the stream has no primary keys, return None. """ def stream_slices( self, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None ) -> Iterable[Optional[Mapping[str, Any]]]: """ Override to define the slices for this stream. See the stream slicing section of the docs for more information. :param stream_state: :return: """ return [None] @property def state_checkpoint_interval(self) -> Optional[int]: """ Decides how often to checkpoint state (i.e: emit a STATE message). E.g: if this returns a value of 100, then state is persisted after reading 100 records, then 200, 300, etc.. A good default value is 1000 although your mileage may vary depending on the underlying data source. Checkpointing a stream avoids re-reading records in the case a sync is failed or cancelled. return None if state should not be checkpointed e.g: because records returned from the underlying data source are not returned in ascending order with respect to the cursor field. This can happen if the source does not support reading records in ascending order of created_at date (or whatever the cursor is). In those cases, state must only be saved once the full stream has been read. """ return None def get_updated_state(self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]): """ Override to extract state from the latest record. Needed to implement incremental sync. Inspects the latest record extracted from the data source and the current state object and return an updated state object. For example: if the state object is based on created_at timestamp, and the current state is {'created_at': 10}, and the latest_record is {'name': 'octavia', 'created_at': 20 } then this method would return {'created_at': 20} to indicate state should be updated to this object. :param current_stream_state: The stream's current state object :param latest_record: The latest record extracted from the stream :return: An updated state object """ return {} @staticmethod def _wrapped_primary_key( keys: Optional[Union[str, List[str], List[List[str]]]] ) -> Optional[List[List[str]]]: """ :return: wrap the primary_key property in a list of list of strings required by the Airbyte Stream object. """ if not keys: return None if isinstance(keys, str): return [[keys]] elif isinstance(keys, list): wrapped_keys = [] for component in keys: if isinstance(component, str): wrapped_keys.append([component]) elif isinstance(component, list): wrapped_keys.append(component) else: raise ValueError("Element must be either list or str.") return wrapped_keys else: raise ValueError("Element must be either list or str.")
def check_config(self, logger: AirbyteLogger, config_path: str, config: json) -> AirbyteConnectionStatus: """ Tests if the input configuration can be used to successfully connect to the integration e.g: if a provided Stripe API token can be used to connect to the Stripe API. :param logger: Logging object to display debug/info/error to the logs (logs will not be accessible via airbyte UI if they are not passed to this logger) :param config_path: Path to the file containing the configuration json config :param config: Json object containing the configuration of this source, content of this json is as specified in the properties of the spec.json file :return: AirbyteConnectionStatus indicating a Success or Failure """ try: # If an app on the appstore does not support subscriptions or sales, it cannot pull the relevant reports. # However, the way the Appstore API expresses this is not via clear error messages. Instead it expresses it by throwing an unrelated # error, in this case "invalid vendor ID". There is no way to distinguish if this error is due to invalid credentials or due to # the account not supporting this kind of report. So to "check connection" we see if any of the reports can be pulled and if so # return success. If no reports can be pulled we display the exception messages generated for all reports and return failure. api_fields_to_test = { "subscription_event_report": { "reportType": "SUBSCRIPTION_EVENT", "frequency": "DAILY", "reportSubType": "SUMMARY", "version": "1_2", }, "subscriber_report": { "reportType": "SUBSCRIBER", "frequency": "DAILY", "reportSubType": "DETAILED", "version": "1_2" }, "subscription_report": { "reportType": "SUBSCRIPTION", "frequency": "DAILY", "reportSubType": "SUMMARY", "version": "1_2" }, "sales_report": { "reportType": "SALES", "frequency": "DAILY", "reportSubType": "SUMMARY", "version": "1_0" }, } api = Api(config["key_id"], config["key_file"], config["issuer_id"]) stream_to_error = {} for stream, params in api_fields_to_test.items(): test_date = date.today() - timedelta(days=2) report_filters = { "reportDate": test_date.strftime("%Y-%m-%d"), "vendorNumber": f"{config['vendor']}" } report_filters.update(api_fields_to_test[stream]) try: rep_tsv = api.download_sales_and_trends_reports( filters=report_filters) if isinstance(rep_tsv, dict): raise Exception( f"An exception occurred: Received a JSON response instead of" f" the report: {str(rep_tsv)}") except Exception as e: logger.warn(f"Unable to download {stream}: {e}") stream_to_error[stream] = e # All streams have failed if len(stream_to_error.keys()) == api_fields_to_test.keys(): message = "\n".join([ f"Unable to access {stream} due to error: {e}" for stream, e in stream_to_error ]) return AirbyteConnectionStatus(status=Status.FAILED, message=message) return AirbyteConnectionStatus(status=Status.SUCCEEDED) except Exception as e: logger.warn(e) return AirbyteConnectionStatus( status=Status.FAILED, message=f"An exception occurred: {str(e)}")
def logger() -> AirbyteLogger: return AirbyteLogger()
class SourceHubspot(AbstractSource): logger = AirbyteLogger() def check_connection( self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: """Check connection""" common_params = self.get_common_params(config=config) alive = True error_msg = None try: contacts = Contacts(**common_params) _ = contacts.properties except HTTPError as error: alive = False error_msg = repr(error) return alive, error_msg def get_granted_scopes(self, authenticator): try: access_token = authenticator.get_access_token() url = f"https://api.hubapi.com/oauth/v1/access-tokens/{access_token}" response = requests.get(url=url) response.raise_for_status() response_json = response.json() granted_scopes = response_json["scopes"] return granted_scopes except Exception as e: return False, repr(e) @staticmethod def get_api(config: Mapping[str, Any]) -> API: credentials = config.get("credentials", {}) return API(credentials=credentials) def get_common_params(self, config) -> Mapping[str, Any]: start_date = config.get("start_date") credentials = config["credentials"] api = self.get_api(config=config) common_params = dict(api=api, start_date=start_date, credentials=credentials) if credentials.get("credentials_title") == "OAuth Credentials": common_params["authenticator"] = api.get_authenticator() return common_params def streams(self, config: Mapping[str, Any]) -> List[Stream]: credentials = config.get("credentials", {}) common_params = self.get_common_params(config=config) streams = [ Campaigns(**common_params), Companies(**common_params), ContactLists(**common_params), Contacts(**common_params), ContactsListMemberships(**common_params), DealPipelines(**common_params), Deals(**common_params), EmailEvents(**common_params), Engagements(**common_params), EngagementsCalls(**common_params), EngagementsEmails(**common_params), EngagementsMeetings(**common_params), EngagementsNotes(**common_params), EngagementsTasks(**common_params), FeedbackSubmissions(**common_params), Forms(**common_params), FormSubmissions(**common_params), LineItems(**common_params), MarketingEmails(**common_params), Owners(**common_params), Products(**common_params), PropertyHistory(**common_params), SubscriptionChanges(**common_params), Tickets(**common_params), TicketPipelines(**common_params), Workflows(**common_params), ] credentials_title = credentials.get("credentials_title") if credentials_title == "API Key Credentials": streams.append(Quotes(**common_params)) api = API(credentials=credentials) if api.is_oauth2(): authenticator = API(credentials=credentials).get_authenticator() granted_scopes = self.get_granted_scopes(authenticator) self.logger.info( f"The following scopes were granted: {granted_scopes}") available_streams = [ stream for stream in streams if stream.scope_is_granted(granted_scopes) ] unavailable_streams = [ stream for stream in streams if not stream.scope_is_granted(granted_scopes) ] self.logger.info( f"The following streams are unavailable: {[s.name for s in unavailable_streams]}" ) else: self.logger.info( "No scopes to grant when authenticating with API key.") available_streams = streams return available_streams def read( self, logger: logging.Logger, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, state: MutableMapping[str, Any] = None) -> Iterator[AirbyteMessage]: """ This method is overridden to check whether the stream `quotes` exists in the source, if not skip reading that stream. """ connector_state = copy.deepcopy(state or {}) logger.info(f"Starting syncing {self.name}") config, internal_config = split_config(config) # TODO assert all streams exist in the connector # get the streams once in case the connector needs to make any queries to generate them stream_instances = {s.name: s for s in self.streams(config)} self._stream_to_instance_map = stream_instances with create_timer(self.name) as timer: for configured_stream in catalog.streams: stream_instance = stream_instances.get( configured_stream.stream.name) if not stream_instance and configured_stream.stream.name == "quotes": logger.warning( "Stream `quotes` does not exist in the source. Skip reading `quotes` stream." ) continue if not stream_instance: raise KeyError( f"The requested stream {configured_stream.stream.name} was not found in the source. Available streams: {stream_instances.keys()}" ) try: yield from self._read_stream( logger=logger, stream_instance=stream_instance, configured_stream=configured_stream, connector_state=connector_state, internal_config=internal_config, ) except Exception as e: logger.exception( f"Encountered an exception while reading stream {configured_stream.stream.name}" ) display_message = stream_instance.get_error_display_message( e) if display_message: raise AirbyteTracedException.from_exception( e, message=display_message) from e raise e finally: logger.info(f"Finished syncing {self.name}") logger.info(timer.report()) logger.info(f"Finished syncing {self.name}") def _read_incremental( self, logger: logging.Logger, stream_instance: Stream, configured_stream: ConfiguredAirbyteStream, connector_state: MutableMapping[str, Any], internal_config: InternalConfig, ) -> Iterator[AirbyteMessage]: """ This method is overridden to checkpoint the latest actual state, because stream state is refreshed after reading each batch of records (if need_chunk is True), or reading all records in the stream. """ yield from super()._read_incremental( logger=logger, stream_instance=stream_instance, configured_stream=configured_stream, connector_state=connector_state, internal_config=internal_config, ) stream_state = stream_instance.get_updated_state( current_stream_state={}, latest_record={}) yield self._checkpoint_state(stream_instance, stream_state, connector_state)
def test_source_wrong_credentials(): source = SourceSquare() status, error = source.check_connection(logger=AirbyteLogger(), config={"api_key": "wrong.api.key", "is_sandbox": True}) assert not status
def test_rate_limit_rest(stream_config, stream_api, configured_catalog, state): """ Connector should stop the sync if one stream reached rate limit stream_1, stream_2, stream_3, ... While reading `stream_1` if 403 (Rate Limit) is received, it should finish that stream with success and stop the sync process. Next streams should not be executed. """ stream_1: IncrementalSalesforceStream = generate_stream("Account", stream_config, stream_api, state=state) stream_2: IncrementalSalesforceStream = generate_stream("Asset", stream_config, stream_api, state=state) stream_1.state_checkpoint_interval = 3 configure_request_params_mock(stream_1, stream_2) source = SourceSalesforce() source.streams = Mock() source.streams.return_value = [stream_1, stream_2] logger = AirbyteLogger() next_page_url = "/services/data/v52.0/query/012345" response_1 = { "done": False, "totalSize": 10, "nextRecordsUrl": next_page_url, "records": [ { "ID": 1, "LastModifiedDate": "2021-11-15", }, { "ID": 2, "LastModifiedDate": "2021-11-16", }, { "ID": 3, "LastModifiedDate": "2021-11-17", # check point interval }, { "ID": 4, "LastModifiedDate": "2021-11-18", }, { "ID": 5, "LastModifiedDate": "2021-11-19", }, ], } response_2 = [{ "errorCode": "REQUEST_LIMIT_EXCEEDED", "message": "TotalRequests Limit exceeded." }] with requests_mock.Mocker() as m: m.register_uri("GET", stream_1.path(), json=response_1, status_code=200) m.register_uri("GET", next_page_url, json=response_2, status_code=403) result = [ i for i in source.read(logger=logger, config=stream_config, catalog=configured_catalog, state=state) ] assert stream_1.request_params.called assert ( not stream_2.request_params.called ), "The second stream should not be executed, because the first stream finished with Rate Limit." records = [item for item in result if item.type == Type.RECORD] assert len(records) == 5 state_record = [item for item in result if item.type == Type.STATE][0] assert state_record.state.data["Account"][ "LastModifiedDate"] == "2021-11-17"
def test_rate_limit_bulk(stream_config, stream_api, configured_catalog, state): """ Connector should stop the sync if one stream reached rate limit stream_1, stream_2, stream_3, ... While reading `stream_1` if 403 (Rate Limit) is received, it should finish that stream with success and stop the sync process. Next streams should not be executed. """ stream_1: BulkIncrementalSalesforceStream = generate_stream( "Account", stream_config, stream_api) stream_2: BulkIncrementalSalesforceStream = generate_stream( "Asset", stream_config, stream_api) streams = [stream_1, stream_2] configure_request_params_mock(stream_1, stream_2) stream_1.page_size = 6 stream_1.state_checkpoint_interval = 5 source = SourceSalesforce() source.streams = Mock() source.streams.return_value = streams logger = AirbyteLogger() json_response = [{ "errorCode": "REQUEST_LIMIT_EXCEEDED", "message": "TotalRequests Limit exceeded." }] with requests_mock.Mocker() as m: for stream in streams: creation_responses = [] for page in [1, 2]: job_id = f"fake_job_{page}_{stream.name}" creation_responses.append({"json": {"id": job_id}}) m.register_uri("GET", stream.path() + f"/{job_id}", json={"state": "JobComplete"}) resp = ["Field1,LastModifiedDate,ID" ] + [f"test,2021-11-0{i},{i}" for i in range(1, 7)] # 6 records per page if page == 1: # Read the first page successfully m.register_uri("GET", stream.path() + f"/{job_id}/results", text="\n".join(resp)) else: # Requesting for results when reading second page should fail with 403 (Rate Limit error) m.register_uri("GET", stream.path() + f"/{job_id}/results", status_code=403, json=json_response) m.register_uri("DELETE", stream.path() + f"/{job_id}") m.register_uri("POST", stream.path(), creation_responses) result = [ i for i in source.read(logger=logger, config=stream_config, catalog=configured_catalog, state=state) ] assert stream_1.request_params.called assert ( not stream_2.request_params.called ), "The second stream should not be executed, because the first stream finished with Rate Limit." records = [item for item in result if item.type == Type.RECORD] assert len(records) == 6 # stream page size: 6 state_record = [item for item in result if item.type == Type.STATE][0] assert state_record.state.data["Account"][ "LastModifiedDate"] == "2021-11-05" # state checkpoint interval is 5.
def test_client_wrong_credentials(): source = SourcePosthog() status, error = source.check_connection(logger=AirbyteLogger(), config={"api_key": "blahblah"}) assert not status
def check(self, logger: AirbyteLogger, config: json) -> AirbyteConnectionStatus: try: if "max_batch_size" in config: # Max batch size must be between 1 and 10 if config["max_batch_size"] > 10 or config["max_batch_size"] < 1: raise Exception("max_batch_size must be between 1 and 10") if "max_wait_time" in config: # Max wait time must be between 1 and 20 if config["max_wait_time"] > 20 or config["max_wait_time"] < 1: raise Exception("max_wait_time must be between 1 and 20") # Required propeties queue_url = config["queue_url"] logger.debug("Amazon SQS Source Config Check - queue_url: " + queue_url) queue_region = config["region"] logger.debug("Amazon SQS Source Config Check - region: " + queue_region) # Senstive Properties access_key = config["access_key"] logger.debug("Amazon SQS Source Config Check - access_key (ends with): " + access_key[-1]) secret_key = config["secret_key"] logger.debug("Amazon SQS Source Config Check - secret_key (ends with): " + secret_key[-1]) logger.debug("Amazon SQS Source Config Check - Starting connection test ---") session = boto3.Session(aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=queue_region) sqs = session.resource("sqs") queue = sqs.Queue(url=queue_url) if hasattr(queue, "attributes"): logger.debug("Amazon SQS Source Config Check - Connection test successful ---") return AirbyteConnectionStatus(status=Status.SUCCEEDED) else: return AirbyteConnectionStatus(status=Status.FAILED, message="Amazon SQS Source Config Check - Could not connect to queue") except ClientError as e: return AirbyteConnectionStatus(status=Status.FAILED, message=f"Amazon SQS Source Config Check - Error in AWS Client: {str(e)}") except Exception as e: return AirbyteConnectionStatus( status=Status.FAILED, message=f"Amazon SQS Source Config Check - An exception occurred: {str(e)}" )
from traceback import format_exc from typing import Any, Iterable, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union from airbyte_cdk.logger import AirbyteLogger from airbyte_cdk.models.airbyte_protocol import SyncMode from airbyte_cdk.sources.streams import Stream from wcmatch.glob import GLOBSTAR, SPLIT, globmatch from .formats.csv_parser import CsvParser from .formats.parquet_parser import ParquetParser JSON_TYPES = [ "string", "number", "integer", "object", "array", "boolean", "null" ] LOGGER = AirbyteLogger() class ConfigurationError(Exception): """Client mis-configured""" class FileStream(Stream, ABC): @property def fileformatparser_map(self): """Mapping where every key is equal 'filetype' and values are corresponding parser classes.""" return { "csv": CsvParser, "parquet": ParquetParser, }
def read( self, logger: AirbyteLogger, config: json, catalog: ConfiguredAirbyteCatalog, state: Dict[str, any] ) -> Generator[AirbyteMessage, None, None]: stream_name = self.parse_queue_name(config["queue_url"]) logger.debug("Amazon SQS Source Read - stream is: " + stream_name) # Required propeties queue_url = config["queue_url"] queue_region = config["region"] delete_messages = config["delete_messages"] # Optional Properties max_batch_size = config.get("max_batch_size", 10) max_wait_time = config.get("max_wait_time", 20) visibility_timeout = config.get("visibility_timeout") attributes_to_return = config.get("attributes_to_return") if attributes_to_return is None: attributes_to_return = ["All"] else: attributes_to_return = attributes_to_return.split(",") # Senstive Properties access_key = config["access_key"] secret_key = config["secret_key"] logger.debug("Amazon SQS Source Read - Creating SQS connection ---") session = boto3.Session(aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=queue_region) sqs = session.resource("sqs") queue = sqs.Queue(url=queue_url) logger.debug("Amazon SQS Source Read - Connected to SQS Queue ---") timed_out = False while not timed_out: try: logger.debug("Amazon SQS Source Read - Beginning message poll ---") messages = queue.receive_messages( MessageAttributeNames=attributes_to_return, MaxNumberOfMessages=max_batch_size, WaitTimeSeconds=max_wait_time ) if not messages: logger.debug("Amazon SQS Source Read - No messages recieved during poll, time out reached ---") timed_out = True break for msg in messages: logger.debug("Amazon SQS Source Read - Message recieved: " + msg.message_id) if visibility_timeout: logger.debug("Amazon SQS Source Read - Setting message visibility timeout: " + msg.message_id) self.change_message_visibility(msg, visibility_timeout) logger.debug("Amazon SQS Source Read - Message visibility timeout set: " + msg.message_id) data = { "id": msg.message_id, "body": msg.body, "attributes": msg.message_attributes, } # TODO: Support a 'BATCH OUTPUT' mode that outputs the full batch in a single AirbyteRecordMessage yield AirbyteMessage( type=Type.RECORD, record=AirbyteRecordMessage(stream=stream_name, data=data, emitted_at=int(datetime.now().timestamp()) * 1000), ) if delete_messages: logger.debug("Amazon SQS Source Read - Deleting message: " + msg.message_id) self.delete_message(msg) logger.debug("Amazon SQS Source Read - Message deleted: " + msg.message_id) # TODO: Delete messages in batches to reduce amount of requests? except ClientError as error: raise Exception("Error in AWS Client: " + str(error))
def check(self, logger: AirbyteLogger, config: json) -> AirbyteConnectionStatus: # Check involves verifying that the specified spreadsheet is reachable with our credentials. try: client = GoogleSheetsClient(self.get_credentials(config)) except Exception as e: return AirbyteConnectionStatus( status=Status.FAILED, message=f"Please use valid credentials json file. Error: {e}") spreadsheet_id = Helpers.get_spreadsheet_id(config["spreadsheet_id"]) try: # Attempt to get first row of sheet client.get(spreadsheetId=spreadsheet_id, includeGridData=False, ranges="1:1") except errors.HttpError as err: reason = str(err) # Give a clearer message if it's a common error like 404. if err.resp.status == status_codes.NOT_FOUND: reason = "Requested spreadsheet was not found." logger.error(f"Formatted error: {reason}") return AirbyteConnectionStatus( status=Status.FAILED, message= f"Unable to connect with the provided credentials to spreadsheet. Error: {reason}" ) # Check for duplicate headers spreadsheet_metadata = Spreadsheet.parse_obj( client.get(spreadsheetId=spreadsheet_id, includeGridData=False)) grid_sheets = Helpers.get_grid_sheets(spreadsheet_metadata) duplicate_headers_in_sheet = {} for sheet_name in grid_sheets: try: header_row_data = Helpers.get_first_row( client, spreadsheet_id, sheet_name) _, duplicate_headers = Helpers.get_valid_headers_and_duplicates( header_row_data) if duplicate_headers: duplicate_headers_in_sheet[sheet_name] = duplicate_headers except Exception as err: if str(err).startswith( "Expected data for exactly one row for sheet"): logger.warn(f"Skip empty sheet: {sheet_name}") else: logger.error(str(err)) return AirbyteConnectionStatus( status=Status.FAILED, message= f"Unable to read the schema of sheet {sheet_name}. Error: {str(err)}" ) if duplicate_headers_in_sheet: duplicate_headers_error_message = ", ".join([ f"[sheet:{sheet_name}, headers:{duplicate_sheet_headers}]" for sheet_name, duplicate_sheet_headers in duplicate_headers_in_sheet.items() ]) return AirbyteConnectionStatus( status=Status.FAILED, message= "The following duplicate headers were found in the following sheets. Please fix them to continue: " + duplicate_headers_error_message, ) return AirbyteConnectionStatus(status=Status.SUCCEEDED)
class ChargebeeStream(Stream): supports_incremental = True primary_key = "id" default_cursor_field = "updated_at" logger = AirbyteLogger() def __init__(self): self.next_offset = None # Request params below # according to Chargebee's guidance on pagination # https://apidocs.chargebee.com/docs/api/#pagination_and_filtering self.params = { "limit": 100, # Limit at 100 "sort_by[asc]": self.default_cursor_field, # Sort ascending by updated_at } super().__init__() @backoff.on_exception( backoff.expo, # Exponential back-off OperationFailedError, # Only on Chargebee's OperationFailedError max_tries=MAX_TRIES, max_time=MAX_TIME, ) def _send_request(self) -> ListResult: """ Just a wrapper to allow @backoff decorator Reference: https://apidocs.chargebee.com/docs/api/#error_codes_list """ # From Chargebee # Link: https://apidocs.chargebee.com/docs/api/#api_rate_limits list_result = self.api.list(self.params) return list_result def read_records( self, sync_mode: SyncMode, cursor_field: List[str] = None, stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, ) -> Iterable[Mapping[str, Any]]: """ Override airbyte_cdk Stream's read_records method """ # Add offset to params if found # Reference for Chargebee's pagination strategy below: # https://apidocs.chargebee.com/docs/api/#pagination_and_filtering pagination_completed = False if stream_state: self.params.update(stream_state) # Loop until pagination is completed while not pagination_completed: # Request the ListResult object from Chargebee # with back-off implemented through self._send_request() list_result = self._send_request() # Read message from results for message in list_result: yield message._response[self.name] # Get next page token self.next_offset = list_result.next_offset if self.next_offset: self.params.update({"offset": self.next_offset}) else: pagination_completed = True # Always return an empty generator just in case no records were ever yielded yield from [] def get_updated_state( self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any], ): """ Override airbyte_cdk Stream's get_updated_state method to get the latest Chargebee stream state """ # Init the current_stream_state current_stream_state = current_stream_state or {} # Get current timestamp # so Stream will sync all records # that have been updated before now now = pendulum.now().int_timestamp current_stream_state.update({ "update_at[before]": now, }) # Get the updated_at field from the latest record # using Chargebee's Model class # so Stream will sync all records # that have been updated since then print(latest_record) latest_updated_at = latest_record.get("updated_at") if latest_updated_at: current_stream_state.update({ "update_at[after]": latest_updated_at, }) return current_stream_state