Ejemplo n.º 1
0
class GoogleSheetsClient:

    logger = AirbyteLogger()

    def __init__(self, config: Dict):
        self.config = config
        self.retries = 100  # max number of backoff retries

    def authorize(self) -> pygsheets_client:
        input_creds = self.config.get("credentials")
        auth_creds = client_account.Credentials.from_authorized_user_info(info=input_creds)
        client = pygsheets.authorize(custom_credentials=auth_creds)

        # Increase max number of retries if Rate Limit is reached. Error: <HttpError 429>
        client.drive.retries = self.retries  # for google drive api
        client.sheet.retries = self.retries  # for google sheets api

        # check if token is expired and refresh it
        if client.oauth.expired:
            self.logger.info("Auth session is expired. Refreshing...")
            client.oauth.refresh(Request())
            if not client.oauth.expired:
                self.logger.info("Successfully refreshed auth session")
            else:
                self.logger.fatal("The token is expired and could not be refreshed, please check the credentials are still valid!")

        return client
Ejemplo n.º 2
0
def run_load_dataframes(config, expected_columns=10, expected_rows=42):
    df_list = SourceFile.load_dataframes(config=config, logger=AirbyteLogger(), skip_data=False)
    assert len(df_list) == 1  # Properly load 1 DataFrame
    df = df_list[0]
    assert len(df.columns) == expected_columns  # DataFrame should have 10 columns
    assert len(df.index) == expected_rows  # DataFrame should have 42 rows of data
    return df
Ejemplo n.º 3
0
            def wrapper_balance_rate_limit(*args, **kwargs):
                sleep_time = 0
                free_load = float("inf")
                # find the Response inside args list
                for arg in args:
                    response = arg if type(
                        arg) is requests.models.Response else None

                # Get the rate_limits from response
                rate_limits = ([
                    (response.headers.get(rate_remaining_limit_header),
                     response.headers.get(rate_max_limit_header))
                    for rate_remaining_limit_header, rate_max_limit_header in
                    rate_limits_headers
                ] if response else None)

                # define current load from rate_limits
                if rate_limits:
                    for current_rate, max_rate_limit in rate_limits:
                        free_load = min(
                            free_load,
                            int(current_rate) / int(max_rate_limit))

                # define sleep time based on load conditions
                if free_load <= threshold:
                    sleep_time = sleep_on_high_load

                # sleep based on load conditions
                sleep(sleep_time)
                AirbyteLogger().info(
                    f"Sleep {sleep_time} seconds based on load conditions.")

                return func(*args, **kwargs)
Ejemplo n.º 4
0
    def streams(self, config: Mapping[str, Any]) -> List[Stream]:
        authenticator = TokenAuthenticator(config["api_token"])
        default_start_date = pendulum.parse(config["start_date"])
        threads_lookback_window = pendulum.Duration(
            days=config["lookback_window"])

        streams = [
            Channels(authenticator=authenticator),
            ChannelMembers(authenticator=authenticator),
            ChannelMessages(authenticator=authenticator,
                            default_start_date=default_start_date),
            Threads(authenticator=authenticator,
                    default_start_date=default_start_date,
                    lookback_window=threads_lookback_window),
            Users(authenticator=authenticator),
        ]

        # To sync data from channels, the bot backed by this token needs to join all those channels. This operation is idempotent.
        if config["join_channels"]:
            logger = AirbyteLogger()
            logger.info("joining Slack channels")
            join_channels_stream = JoinChannelsStream(
                authenticator=authenticator)
            for stream_slice in join_channels_stream.stream_slices():
                for message in join_channels_stream.read_records(
                        sync_mode=SyncMode.full_refresh,
                        stream_slice=stream_slice):
                    logger.info(message["message"])

        return streams
Ejemplo n.º 5
0
def test_local_storage_spec():
    """Checks spec properties"""
    source = SourceFileSecure()
    spec = source.spec(logger=AirbyteLogger())
    for provider in spec.connectionSpecification["properties"]["provider"][
            "oneOf"]:
        assert provider["properties"]["storage"][
            "const"] != LOCAL_STORAGE_NAME, "This connector shouldn't work with local files."
Ejemplo n.º 6
0
class NotImplementedAuth(Exception):
    """Not implemented Auth option error"""

    logger = AirbyteLogger()

    def __init__(self, auth_method: str = None):
        self.message = f"Not implemented Auth method = {auth_method}"
        super().__init__(self.logger.error(self.message))
Ejemplo n.º 7
0
 def set_sub_primary_key(self):
     if isinstance(self.primary_key, list):
         for index, value in enumerate(self.primary_key):
             setattr(self, f"sub_primary_key_{index + 1}", value)
     else:
         logger = AirbyteLogger()
         logger.error(
             "Failed during setting sub primary keys. Primary key should be list."
         )
Ejemplo n.º 8
0
class AbstractTestParser(ABC):
    """ Prefix this class with Abstract so the tests don't run here but only in the children """

    logger = AirbyteLogger()

    @property
    @abstractmethod
    def test_files(self) -> List[Mapping[str, Any]]:
        """return a list of test_file dicts in structure:
        [
            {"AbstractFileParser": CsvParser(format, master_schema), "filepath": "...", "num_records": 5, "inferred_schema": {...}, line_checks:{}, fails: []},
            {"AbstractFileParser": CsvParser(format, master_schema), "filepath": "...", "num_records": 16, "inferred_schema": {...}, line_checks:{}, fails: []}
        ]
        note: line_checks index is 1-based to align with row numbers
        """

    def _get_readmode(self, test_name, test_file):
        self.logger.info(
            f"testing {test_name}() with {test_file.get('test_alias', test_file['filepath'].split('/')[-1])} ..."
        )
        return "rb" if test_file["AbstractFileParser"].is_binary else "r"

    def test_get_inferred_schema(self):
        for test_file in self.test_files:
            with smart_open(
                    test_file["filepath"],
                    self._get_readmode("get_inferred_schema", test_file)) as f:
                if "test_get_inferred_schema" in test_file["fails"]:
                    with pytest.raises(Exception) as e_info:
                        test_file["AbstractFileParser"].get_inferred_schema(f)
                        self.logger.debug(str(e_info))
                else:
                    assert test_file["AbstractFileParser"].get_inferred_schema(
                        f) == test_file["inferred_schema"]

    def test_stream_records(self):
        for test_file in self.test_files:
            with smart_open(test_file["filepath"],
                            self._get_readmode("stream_records",
                                               test_file)) as f:
                if "test_stream_records" in test_file["fails"]:
                    with pytest.raises(Exception) as e_info:
                        [
                            print(r) for r in
                            test_file["AbstractFileParser"].stream_records(f)
                        ]
                        self.logger.debug(str(e_info))
                else:
                    records = [
                        r for r in
                        test_file["AbstractFileParser"].stream_records(f)
                    ]

                    assert len(records) == test_file["num_records"]
                    for index, expected_record in test_file[
                            "line_checks"].items():
                        assert records[index - 1] == expected_record
Ejemplo n.º 9
0
    def _read_records(conf, catalog, state=None) -> Tuple[List[AirbyteMessage], List[AirbyteMessage]]:
        records = []
        states = []
        for message in SourceFacebookMarketing().read(AirbyteLogger(), conf, catalog, state=state):
            if message.type == Type.RECORD:
                records.append(message)
            elif message.type == Type.STATE:
                states.append(message)

        return records, states
Ejemplo n.º 10
0
def test_check_connection_should_fail_when_api_call_fails(mocker):
    # We patch the object inside source.py because that's the calling context
    # https://docs.python.org/3/library/unittest.mock.html#where-to-patch
    mocker.patch("source_google_ads.source.GoogleAds",
                 MockErroringGoogleAdsClient)
    source = SourceGoogleAds()
    check_successful, message = source.check_connection(
        AirbyteLogger(),
        {
            "credentials": {
                "developer_token": "fake_developer_token",
                "client_id": "fake_client_id",
                "client_secret": "fake_client_secret",
                "refresh_token": "fake_refresh_token",
            },
            "customer_id":
            "fake_customer_id",
            "start_date":
            "2022-01-01",
            "conversion_window_days":
            14,
            "custom_queries": [
                {
                    "query":
                    "SELECT campaign.accessible_bidding_strategy, segments.ad_destination_type, campaign.start_date, campaign.end_date FROM campaign",
                    "primary_key": None,
                    "cursor_field": "campaign.start_date",
                    "table_name": "happytable",
                },
                {
                    "query":
                    "SELECT segments.ad_destination_type, segments.ad_network_type, segments.day_of_week, customer.auto_tagging_enabled, customer.id, metrics.conversions, campaign.start_date FROM campaign",
                    "primary_key": "customer.id",
                    "cursor_field": None,
                    "table_name": "unhappytable",
                },
                {
                    "query":
                    "SELECT ad_group.targeting_setting.target_restrictions FROM ad_group",
                    "primary_key": "customer.id",
                    "cursor_field": None,
                    "table_name": "ad_group_custom",
                },
            ],
        },
    )
    assert not check_successful
    assert message.startswith(
        "Unable to connect to Google Ads API with the provided configuration")
Ejemplo n.º 11
0
def test_invalid_custom_query_handled(mocked_gads_api, config):
    # limit to one custom query, otherwise need to mock more side effects
    config["custom_queries"] = [next(iter(config["custom_queries"]))]
    mocked_gads_api(
        response=[{
            "customer.id": "8765"
        }],
        failure_msg=
        "Unrecognized field in the query: 'ad_group_ad.ad.video_ad.media_file'",
        error_type="request_error",
    )
    source = SourceGoogleAds()
    status_ok, error = source.check_connection(AirbyteLogger(), config)
    assert not status_ok
    assert error == (
        "Unable to connect to Google Ads API with the provided configuration - Unrecognized field in the query: "
        "'ad_group_ad.ad.video_ad.media_file'")
Ejemplo n.º 12
0
def test_check_connection_should_pass_when_config_valid(mocker):
    mocker.patch("source_google_ads.source.GoogleAds", MockGoogleAdsClient)
    source = SourceGoogleAds()
    check_successful, message = source.check_connection(
        AirbyteLogger(),
        {
            "credentials": {
                "developer_token": "fake_developer_token",
                "client_id": "fake_client_id",
                "client_secret": "fake_client_secret",
                "refresh_token": "fake_refresh_token",
            },
            "customer_id":
            "fake_customer_id",
            "start_date":
            "2022-01-01",
            "conversion_window_days":
            14,
            "custom_queries": [
                {
                    "query":
                    "SELECT campaign.accessible_bidding_strategy, segments.ad_destination_type, campaign.start_date, campaign.end_date FROM campaign",
                    "primary_key": None,
                    "cursor_field": "campaign.start_date",
                    "table_name": "happytable",
                },
                {
                    "query":
                    "SELECT segments.ad_destination_type, segments.ad_network_type, segments.day_of_week, customer.auto_tagging_enabled, customer.id, metrics.conversions, campaign.start_date FROM campaign",
                    "primary_key": "customer.id",
                    "cursor_field": None,
                    "table_name": "unhappytable",
                },
                {
                    "query":
                    "SELECT ad_group.targeting_setting.target_restrictions FROM ad_group",
                    "primary_key": "customer.id",
                    "cursor_field": None,
                    "table_name": "ad_group_custom",
                },
            ],
        },
    )
    assert check_successful
    assert message is None
Ejemplo n.º 13
0
class AbstractTestParser(ABC):
    """Prefix this class with Abstract so the tests don't run here but only in the children"""

    logger = AirbyteLogger()

    def _get_readmode(self, test_name, test_file):
        self.logger.info(
            f"testing {test_name}() with {test_file.get('test_alias', test_file['filepath'].split('/')[-1])} ..."
        )
        return "rb" if test_file["AbstractFileParser"].is_binary else "r"

    def test_get_inferred_schema(self, test_file):
        with smart_open(test_file["filepath"],
                        self._get_readmode("get_inferred_schema",
                                           test_file)) as f:
            if "test_get_inferred_schema" in test_file["fails"]:
                with pytest.raises(Exception) as e_info:
                    test_file["AbstractFileParser"].get_inferred_schema(f)
                    self.logger.debug(str(e_info))
            else:
                inferred_schema = test_file[
                    "AbstractFileParser"].get_inferred_schema(f)
                expected_schema = test_file["inferred_schema"]
                assert inferred_schema == expected_schema

    def test_stream_records(self, test_file):
        with smart_open(test_file["filepath"],
                        self._get_readmode("stream_records", test_file)) as f:
            if "test_stream_records" in test_file["fails"]:
                with pytest.raises(Exception) as e_info:
                    [
                        print(r) for r in
                        test_file["AbstractFileParser"].stream_records(f)
                    ]
                    self.logger.debug(str(e_info))
            else:
                records = [
                    r
                    for r in test_file["AbstractFileParser"].stream_records(f)
                ]

                assert len(records) == test_file["num_records"]
                for index, expected_record in test_file["line_checks"].items():
                    assert records[index - 1] == expected_record
Ejemplo n.º 14
0
    def request_params(self,
                       stream_state: Mapping[str, Any] = None,
                       next_page_token: Mapping[str, Any] = None,
                       **kwargs):
        params = super().request_params(stream_state=stream_state,
                                        next_page_token=next_page_token,
                                        **kwargs)
        AirbyteLogger().log("INFO", f"using params: {params}")
        # If there is a next page token then we should only send pagination-related parameters.
        if not next_page_token:
            params["orderby"] = self.order_field
            params["order"] = "asc"
            if stream_state:
                start_date = stream_state.get(self.cursor_field)
                start_date = pendulum.parse(start_date).replace(tzinfo=None)
                start_date = start_date.subtract(
                    days=self.conversion_window_days)

                params["after"] = start_date
        return params
Ejemplo n.º 15
0
    def streams(self, config: Mapping[str, Any]) -> List[Stream]:
        authenticator = TokenAuthenticator(config["api_token"])
        default_start_date = pendulum.now().subtract(days=14)  # TODO make this configurable
        threads_lookback_window = {"days": 7}  # TODO make this configurable

        streams = [
            Channels(authenticator=authenticator),
            ChannelMembers(authenticator=authenticator),
            ChannelMessages(authenticator=authenticator, default_start_date=default_start_date),
            Threads(authenticator=authenticator, default_start_date=default_start_date, lookback_window=threads_lookback_window),
            Users(authenticator=authenticator),
        ]

        # To sync data from channels, the bot backed by this token needs to join all those channels. This operation is idempotent.
        # TODO make joining configurable. Also make joining archived and private channels configurable
        logger = AirbyteLogger()
        logger.info("joining Slack channels")
        join_channels_stream = JoinChannelsStream(authenticator=authenticator)
        for stream_slice in join_channels_stream.stream_slices():
            for message in join_channels_stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice):
                logger.info(message["message"])

        return streams
Ejemplo n.º 16
0
class TestIntegrationCsvFiles:
    logger = AirbyteLogger()

    @memory_limit(150)  # max used memory should be less than 150Mb
    def read_source(self, credentials: Dict[str, Any],
                    catalog: Dict[str, Any]) -> int:
        read_count = 0
        for msg in SourceS3().read(logger=self.logger,
                                   config=credentials,
                                   catalog=catalog):
            if msg.record:
                read_count += 1
        return read_count

    @pytest.mark.order(1)
    def test_big_file(self, minio_credentials: Dict[str, Any]) -> None:
        """tests a big csv file (>= 1.0G records)"""
        # generates a big CSV files separately
        big_file_folder = os.path.join(TMP_FOLDER, "minio_data", "test-bucket",
                                       "big_files")
        shutil.rmtree(big_file_folder, ignore_errors=True)
        os.makedirs(big_file_folder)
        filepath = os.path.join(big_file_folder, "file.csv")

        # please change this value if you need to test another file size
        future_file_size = 0.5  # in gigabytes
        _, file_size = generate_big_file(filepath, future_file_size, 500)
        expected_count = sum(1 for _ in open(filepath)) - 1
        self.logger.info(
            f"generated file {filepath} with size {file_size}Gb, lines: {expected_count}"
        )

        minio_credentials["path_pattern"] = "big_files/file.csv"
        minio_credentials["format"]["block_size"] = 5 * 1024**2
        source = SourceS3()
        catalog = source.read_catalog(HERE / "configured_catalog.json")
        assert self.read_source(minio_credentials, catalog) == expected_count
Ejemplo n.º 17
0
def test_check_invalid_config():
    outcome = DestinationAmazonSqs().check(AirbyteLogger(), {"secret_key": "not_a_real_secret"})
    assert outcome.status == Status.FAILED
Ejemplo n.º 18
0
def test_check_valid_config(config: Mapping):
    outcome = DestinationAmazonSqs().check(AirbyteLogger(), config)
    assert outcome.status == Status.SUCCEEDED
Ejemplo n.º 19
0
from datetime import datetime
from typing import Any, Dict, List, Mapping
from unittest.mock import MagicMock, patch

import pytest
from airbyte_cdk import AirbyteLogger
from airbyte_cdk.models import SyncMode
from source_s3.source_files_abstract.file_info import FileInfo
from source_s3.source_files_abstract.storagefile import StorageFile
from source_s3.source_files_abstract.stream import IncrementalFileStream
from source_s3.stream import IncrementalFileStreamS3

from .abstract_test_parser import create_by_local_file, memory_limit

LOGGER = AirbyteLogger()


def mock_big_size_object():
    mock = MagicMock()
    mock.__sizeof__.return_value = 1000000001
    return mock


class TestIncrementalFileStream:
    @pytest.mark.parametrize(  # set return_schema to None for an expected fail
        "schema_string, return_schema",
        [
            (
                '{"id": "integer", "name": "string", "valid": "boolean", "code": "integer", "degrees": "number", "birthday": '
                '"string", "last_seen": "string"}',
Ejemplo n.º 20
0
class Salesforce:
    logger = AirbyteLogger()
    version = "v52.0"

    def __init__(
        self,
        refresh_token: str = None,
        token: str = None,
        client_id: str = None,
        client_secret: str = None,
        is_sandbox: bool = None,
        start_date: str = None,
        api_type: str = None,
        **kwargs,
    ):
        self.api_type = api_type.upper() if api_type else None
        self.refresh_token = refresh_token
        self.token = token
        self.client_id = client_id
        self.client_secret = client_secret
        self.access_token = None
        self.instance_url = None
        self.session = requests.Session()
        self.is_sandbox = is_sandbox is True or (isinstance(
            is_sandbox, str) and is_sandbox.lower() == "true")
        self.start_date = start_date

    def _get_standard_headers(self):
        return {"Authorization": "Bearer {}".format(self.access_token)}

    def get_streams_black_list(self) -> List[str]:
        black_list = QUERY_RESTRICTED_SALESFORCE_OBJECTS + QUERY_INCOMPATIBLE_SALESFORCE_OBJECTS
        if self.api_type == "REST":
            return black_list
        else:
            return black_list + UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS

    def filter_streams(self, stream_name: str) -> bool:
        # REST and BULK API do not support all entities that end with `ChangeEvent`.
        if stream_name.endswith(
                "ChangeEvent") or stream_name in self.get_streams_black_list():
            return False
        return True

    def get_validated_streams(self,
                              config: Mapping[str, Any],
                              catalog: ConfiguredAirbyteCatalog = None):
        salesforce_objects = self.describe()["sobjects"]
        stream_names = [
            stream_object["name"] for stream_object in salesforce_objects
        ]
        if catalog:
            return [
                configured_stream.stream.name
                for configured_stream in catalog.streams
            ]

        if config.get("streams_criteria"):
            filtered_stream_list = []
            for stream_criteria in config["streams_criteria"]:
                filtered_stream_list += filter_streams(
                    streams_list=stream_names,
                    search_word=stream_criteria["value"],
                    search_criteria=stream_criteria["criteria"])
            stream_names = list(set(filtered_stream_list))

        validated_streams = [
            stream_name for stream_name in stream_names
            if self.filter_streams(stream_name)
        ]
        return validated_streams

    @default_backoff_handler(max_tries=5, factor=15)
    def _make_request(self,
                      http_method: str,
                      url: str,
                      headers: dict = None,
                      body: dict = None,
                      stream: bool = False,
                      params: dict = None) -> requests.models.Response:
        try:
            if http_method == "GET":
                resp = self.session.get(url,
                                        headers=headers,
                                        stream=stream,
                                        params=params)
            elif http_method == "POST":
                resp = self.session.post(url, headers=headers, data=body)
            resp.raise_for_status()
        except HTTPError as err:
            self.logger.warn(f"http error body: {err.response.text}")
            raise
        return resp

    def login(self):
        login_url = f"https://{'test' if self.is_sandbox else 'login'}.salesforce.com/services/oauth2/token"
        login_body = {
            "grant_type": "refresh_token",
            "client_id": self.client_id,
            "client_secret": self.client_secret,
            "refresh_token": self.refresh_token,
        }

        resp = self._make_request(
            "POST",
            login_url,
            body=login_body,
            headers={"Content-Type": "application/x-www-form-urlencoded"})

        auth = resp.json()
        self.access_token = auth["access_token"]
        self.instance_url = auth["instance_url"]

    def describe(self, sobject: str = None) -> Mapping[str, Any]:
        """Describes all objects or a specific object"""
        headers = self._get_standard_headers()

        endpoint = "sobjects" if not sobject else f"sobjects/{sobject}/describe"

        url = f"{self.instance_url}/services/data/{self.version}/{endpoint}"
        resp = self._make_request("GET", url, headers=headers)
        return resp.json()

    def generate_schema(self, stream_name: str = None) -> Mapping[str, Any]:
        response = self.describe(stream_name)
        schema = {
            "$schema": "http://json-schema.org/draft-07/schema#",
            "type": "object",
            "additionalProperties": True,
            "properties": {}
        }
        for field in response["fields"]:
            schema["properties"][
                field["name"]] = self.field_to_property_schema(field)
        return schema

    @staticmethod
    def get_pk_and_replication_key(
            json_schema: Mapping[str,
                                 Any]) -> Tuple[Optional[str], Optional[str]]:
        fields_list = json_schema.get("properties", {}).keys()

        pk = "Id" if "Id" in fields_list else None
        replication_key = None
        if "SystemModstamp" in fields_list:
            replication_key = "SystemModstamp"
        elif "LastModifiedDate" in fields_list:
            replication_key = "LastModifiedDate"
        elif "CreatedDate" in fields_list:
            replication_key = "CreatedDate"
        elif "LoginTime" in fields_list:
            replication_key = "LoginTime"

        return pk, replication_key

    @staticmethod
    def field_to_property_schema(
            field_params: Mapping[str, Any]) -> Mapping[str, Any]:
        sf_type = field_params["type"]
        property_schema = {}

        if sf_type in STRING_TYPES:
            property_schema["type"] = ["string", "null"]
        elif sf_type in DATE_TYPES:
            property_schema = {
                "type": ["string", "null"],
                "format": "date-time" if sf_type == "datetime" else "date"
            }
        elif sf_type in NUMBER_TYPES:
            property_schema["type"] = ["number", "null"]
        elif sf_type == "address":
            property_schema = {
                "type": ["object", "null"],
                "properties": {
                    "street": {
                        "type": ["null", "string"]
                    },
                    "state": {
                        "type": ["null", "string"]
                    },
                    "postalCode": {
                        "type": ["null", "string"]
                    },
                    "city": {
                        "type": ["null", "string"]
                    },
                    "country": {
                        "type": ["null", "string"]
                    },
                    "longitude": {
                        "type": ["null", "number"]
                    },
                    "latitude": {
                        "type": ["null", "number"]
                    },
                    "geocodeAccuracy": {
                        "type": ["null", "string"]
                    },
                },
            }
        elif sf_type == "base64":
            property_schema = {"type": ["string", "null"], "format": "base64"}
        elif sf_type == "int":
            property_schema["type"] = ["integer", "null"]
        elif sf_type == "boolean":
            property_schema["type"] = ["boolean", "null"]
        elif sf_type in LOOSE_TYPES:
            """
            LOOSE_TYPES can return data of completely different types (more than 99% of them are `strings`),
            and in order to avoid conflicts in schemas and destinations, we cast this data to the `string` type.
            """
            property_schema["type"] = ["string", "null"]
        elif sf_type == "location":
            property_schema = {
                "type": ["object", "null"],
                "properties": {
                    "longitude": {
                        "type": ["null", "number"]
                    },
                    "latitude": {
                        "type": ["null", "number"]
                    }
                },
            }
        else:
            raise TypeSalesforceException(
                "Found unsupported type: {}".format(sf_type))

        return property_schema
Ejemplo n.º 21
0
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, List, Mapping

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from airbyte_cdk import AirbyteLogger
from source_s3.source_files_abstract.formats.parquet_parser import PARQUET_TYPES, ParquetParser

from .abstract_test_parser import AbstractTestParser

SAMPLE_DIRECTORY = Path(__file__).resolve().parent.joinpath("sample_files/")

logger = AirbyteLogger()
filetype = "parquet"


def tmp_folder():
    return os.path.join(tempfile.mkdtemp())


def _save_parquet_file(filename: str, columns: List[str],
                       rows: List[List[Any]]) -> str:
    data = {}
    for col_values in zip(columns, *rows):
        data[col_values[0]] = list(col_values[1:])

    if rows:
        df = pd.DataFrame(data)
Ejemplo n.º 22
0
def test_check_invalid_config():
    outcome = DestinationKvdb().check(AirbyteLogger(),
                                      {"bucket_id": "not_a_real_id"})
    assert outcome.status == Status.FAILED
Ejemplo n.º 23
0
def logger():
    return AirbyteLogger()
Ejemplo n.º 24
0
def test_check_invalid_config(config):
    outcome = DestinationSftpJson().check(
        AirbyteLogger(), {
            **config, "destination_path": "/doesnotexist"
        })
    assert outcome.status == Status.FAILED
Ejemplo n.º 25
0
class AbstractTestParser(ABC):
    """Prefix this class with Abstract so the tests don't run here but only in the children"""

    logger = AirbyteLogger()
    record_types: Mapping[str, Any] = {}

    @classmethod
    def _generate_row(cls, types: List[str]) -> List[Any]:
        """Generates random values with request types"""
        row = []
        for needed_type in types:
            for json_type in cls.record_types:
                if json_type == needed_type:
                    row.append(cls._generate_value(needed_type))
                    break
        return row

    @classmethod
    def _generate_value(cls, typ: str) -> Any:
        if typ not in ["boolean", "integer"
                       ] and cls._generate_value("boolean"):
            # return 'None' for +- 33% of all requests
            return None

        if typ == "number":
            while True:
                int_value = cls._generate_value("integer")
                if int_value:
                    break
            return float(int_value) + random.random()
        elif typ == "integer":
            return random.randint(-sys.maxsize - 1, sys.maxsize)
            # return random.randint(0, 1000)
        elif typ == "boolean":
            return random.choice([True, False, None])
        elif typ == "string":
            random_length = random.randint(0, 10 *
                                           1024)  # max size of bytes is 10k
            return os.urandom(random_length)
        elif typ == "timestamp":
            return datetime.now() + timedelta(seconds=random.randint(0, 7200))
        elif typ == "date":
            dt = cls._generate_value("timestamp")
            return dt.date() if dt else None
        elif typ == "time":
            dt = cls._generate_value("timestamp")
            return dt.time() if dt else None
        raise Exception(f"not supported type: {typ}")

    @classmethod
    @lru_cache(maxsize=None)
    def cached_cases(cls) -> Mapping[str, Any]:
        return cls.cases()

    @classmethod
    @abstractmethod
    def cases(cls) -> Mapping[str, Any]:
        """return a map of test_file dicts in structure:
        {
           "small_file": {"AbstractFileParser": CsvParser(format, master_schema), "filepath": "...", "num_records": 5, "inferred_schema": {...}, line_checks:{}, fails: []},
           "big_file": {"AbstractFileParser": CsvParser(format, master_schema), "filepath": "...", "num_records": 16, "inferred_schema": {...}, line_checks:{}, fails: []}
        ]
        note: line_checks index is 1-based to align with row numbers
        """

    def _get_readmode(self, file_info: Mapping[str, Any]) -> str:
        return "rb" if file_info["AbstractFileParser"].is_binary else "r"

    @memory_limit(1024)
    def test_suite_inferred_schema(self, file_info: Mapping[str, Any]) -> None:
        with smart_open(file_info["filepath"],
                        self._get_readmode(file_info)) as f:
            if "test_get_inferred_schema" in file_info["fails"]:
                with pytest.raises(Exception) as e_info:
                    file_info["AbstractFileParser"].get_inferred_schema(f)
                    self.logger.debug(str(e_info))
            else:
                assert file_info["AbstractFileParser"].get_inferred_schema(
                    f) == file_info["inferred_schema"]

    @memory_limit(1024)
    def test_stream_suite_records(self, file_info: Mapping[str, Any]) -> None:
        filepath = file_info["filepath"]
        self.logger.info(
            f"read the file: {filepath}, size: {os.stat(filepath).st_size / (1024 ** 2)}Mb"
        )
        with smart_open(filepath, self._get_readmode(file_info)) as f:
            if "test_stream_records" in file_info["fails"]:
                with pytest.raises(Exception) as e_info:
                    [
                        print(r) for r in
                        file_info["AbstractFileParser"].stream_records(f)
                    ]
                    self.logger.debug(str(e_info))
            else:
                records = [
                    r
                    for r in file_info["AbstractFileParser"].stream_records(f)
                ]

                assert len(records) == file_info["num_records"]
                for index, expected_record in file_info["line_checks"].items():
                    assert records[index - 1] == expected_record
Ejemplo n.º 26
0
def test_check_invalid_config():
    outcome = DestinationFirestore().check(AirbyteLogger(),
                                           {"project_id": "not_a_real_id"})
    assert outcome.status == Status.FAILED
Ejemplo n.º 27
0
def run_load_nested_json_schema(config, expected_columns=10, expected_rows=42):
    data_list = SourceFile.load_nested_json(config, logger=AirbyteLogger())
    assert len(data_list) == 1  # Properly load data
    df = data_list[0]
    assert len(df) == expected_rows  # DataFrame should have 42 items
    return df
Ejemplo n.º 28
0
class Destination(Connector, ABC):
    logger = AirbyteLogger()
    VALID_CMDS = {"spec", "check", "write"}

    @abstractmethod
    def write(
            self, config: Mapping[str, Any],
            configured_catalog: ConfiguredAirbyteCatalog,
            input_messages: Iterable[AirbyteMessage]
    ) -> Iterable[AirbyteMessage]:
        """Implement to define how the connector writes data to the destination"""

    def _run_check(self, config: Mapping[str, Any]) -> AirbyteMessage:
        check_result = self.check(self.logger, config)
        return AirbyteMessage(type=Type.CONNECTION_STATUS,
                              connectionStatus=check_result)

    def _parse_input_stream(
            self, input_stream: io.TextIOWrapper) -> Iterable[AirbyteMessage]:
        """Reads from stdin, converting to Airbyte messages"""
        for line in input_stream:
            try:
                yield AirbyteMessage.parse_raw(line)
            except ValidationError:
                self.logger.info(
                    f"ignoring input which can't be deserialized as Airbyte Message: {line}"
                )

    def _run_write(self, config: Mapping[str,
                                         Any], configured_catalog_path: str,
                   input_stream: io.TextIOWrapper) -> Iterable[AirbyteMessage]:
        catalog = ConfiguredAirbyteCatalog.parse_file(configured_catalog_path)
        input_messages = self._parse_input_stream(input_stream)
        self.logger.info("Begin writing to the destination...")
        yield from self.write(config=config,
                              configured_catalog=catalog,
                              input_messages=input_messages)
        self.logger.info("Writing complete.")

    def parse_args(self, args: List[str]) -> argparse.Namespace:
        """
        :param args: commandline arguments
        :return:
        """

        parent_parser = argparse.ArgumentParser(add_help=False)
        main_parser = argparse.ArgumentParser()
        subparsers = main_parser.add_subparsers(title="commands",
                                                dest="command")

        # spec
        subparsers.add_parser(
            "spec",
            help="outputs the json configuration specification",
            parents=[parent_parser])

        # check
        check_parser = subparsers.add_parser(
            "check",
            help="checks the config can be used to connect",
            parents=[parent_parser])
        required_check_parser = check_parser.add_argument_group(
            "required named arguments")
        required_check_parser.add_argument(
            "--config",
            type=str,
            required=True,
            help="path to the json configuration file")

        # write
        write_parser = subparsers.add_parser(
            "write",
            help="Writes data to the destination",
            parents=[parent_parser])
        write_required = write_parser.add_argument_group(
            "required named arguments")
        write_required.add_argument("--config",
                                    type=str,
                                    required=True,
                                    help="path to the JSON configuration file")
        write_required.add_argument(
            "--catalog",
            type=str,
            required=True,
            help="path to the configured catalog JSON file")

        parsed_args = main_parser.parse_args(args)
        cmd = parsed_args.command
        if not cmd:
            raise Exception("No command entered. ")
        elif cmd not in ["spec", "check", "write"]:
            # This is technically dead code since parse_args() would fail if this was the case
            # But it's non-obvious enough to warrant placing it here anyways
            raise Exception(f"Unknown command entered: {cmd}")

        return parsed_args

    def run_cmd(self,
                parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]:
        cmd = parsed_args.command
        if cmd not in self.VALID_CMDS:
            raise Exception(f"Unrecognized command: {cmd}")

        spec = self.spec(self.logger)
        if cmd == "spec":
            yield AirbyteMessage(type=Type.SPEC, spec=spec)
            return
        config = self.read_config(config_path=parsed_args.config)
        if self.check_config_against_spec or cmd == "check":
            check_config_against_spec_or_exit(config, spec, self.logger)

        if cmd == "check":
            yield self._run_check(config=config)
        elif cmd == "write":
            # Wrap in UTF-8 to override any other input encodings
            wrapped_stdin = io.TextIOWrapper(sys.stdin.buffer,
                                             encoding="utf-8")
            yield from self._run_write(
                config=config,
                configured_catalog_path=parsed_args.catalog,
                input_stream=wrapped_stdin)

    def run(self, args: List[str]):
        parsed_args = self.parse_args(args)
        output_messages = self.run_cmd(parsed_args)
        for message in output_messages:
            print(message.json(exclude_unset=True))
Ejemplo n.º 29
0
def test_local_storage_check():
    """Checks working with a local options"""
    source = SourceFileSecure()
    with pytest.raises(RuntimeError) as exc:
        source.check(logger=AirbyteLogger(), config=local_storage_config)
    assert "not supported" in str(exc.value)
Ejemplo n.º 30
0
class WriteBufferMixin:

    # Default instance of AirbyteLogger
    logger = AirbyteLogger()
    # interval after which the records_buffer should be cleaned up for selected stream
    flush_interval = 1000

    def __init__(self):
        # Buffer for input records
        self.records_buffer = {}
        # Placeholder for streams metadata
        self.stream_info = {}

    @property
    def default_missing(self) -> str:
        """
        Default value for missing keys in record stream, compared to configured_stream catalog.
        Overwrite if needed.
        """
        return ""

    def init_buffer_stream(self, configured_stream: AirbyteStream):
        """
        Saves important stream's information for later use.

        Particulary, creates the data structure for `records_stream`.
        Populates `stream_info` placeholder with stream metadata information.
        """
        stream = configured_stream.stream
        self.records_buffer[stream.name] = []
        self.stream_info[stream.name] = {
            "headers": sorted(list(stream.json_schema.get("properties").keys())),
            "is_set": False,
        }

    def add_to_buffer(self, stream_name: str, record: Mapping):
        """
        Populates input records to `records_buffer`.

        1) normalizes input record
        2) coerces normalized record to str
        3) gets values as list of record values from record mapping.
        """

        norm_record = self._normalize_record(stream_name, record)
        norm_values = list(map(str, norm_record.values()))
        self.records_buffer[stream_name].append(norm_values)

    def clear_buffer(self, stream_name: str):
        """
        Cleans up the `records_buffer` values, belonging to input stream.
        """
        self.records_buffer[stream_name].clear()

    def _normalize_record(self, stream_name: str, record: Mapping) -> Mapping[str, Any]:
        """
        Updates the record keys up to the input configured_stream catalog keys.

        Handles two scenarios:
        1) when record has less keys than catalog declares (undersetting)
        2) when record has more keys than catalog declares (oversetting)

        Returns: alphabetically sorted, catalog-normalized Mapping[str, Any].

        EXAMPLE:
        - UnderSetting:
            * Catalog:
                - has 3 entities:
                    [ 'id', 'key1', 'key2' ]
                              ^
            * Input record:
                - missing 1 entity, compare to catalog
                    { 'id': 123,    'key2': 'value' }
                                  ^
            * Result:
                - 'key1' has been added to the record, because it was declared in catalog, to keep the data structure.
                    {'id': 123, 'key1': '', {'key2': 'value'} }
                                  ^
        - OverSetting:
            * Catalog:
                - has 3 entities:
                    [ 'id', 'key1', 'key2',   ]
                                            ^
            * Input record:
                - doesn't have entity 'key1'
                - has 1 more enitity, compare to catalog 'key3'
                    { 'id': 123,     ,'key2': 'value', 'key3': 'value' }
                                  ^                      ^
            * Result:
                - 'key1' was added, because it expected be the part of the record, to keep the data structure
                - 'key3' was dropped, because it was not declared in catalog, to keep the data structure
                    { 'id': 123, 'key1': '', 'key2': 'value',   }
                                   ^                          ^

        """
        headers = self.stream_info[stream_name]["headers"]
        # undersetting scenario
        [record.update({key: self.default_missing}) for key in headers if key not in record.keys()]
        # oversetting scenario
        [record.pop(key) for key in record.copy().keys() if key not in headers]

        return dict(sorted(record.items(), key=lambda x: x[0]))