def test_create_operator_with_correct_parameters_postgres_proxy_tcp(
         self, get_connections):
     connection = Connection()
     connection.parse_from_uri(
         "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type=postgres&"
         "project_id=example-project&location=europe-west1&instance=testdb&"
         "use_proxy=True&sql_proxy_use_tcp=True")
     get_connections.return_value = [connection]
     operator = CloudSqlQueryOperator(sql=['SELECT * FROM TABLE'],
                                      task_id='task_id')
     operator.cloudsql_db_hook.create_connection()
     try:
         db_hook = operator.cloudsql_db_hook.get_database_hook()
         conn = db_hook._get_connections_from_db(
             db_hook.postgres_conn_id)[0]
     finally:
         operator.cloudsql_db_hook.delete_connection()
     self.assertEqual('postgres', conn.conn_type)
     self.assertEqual('127.0.0.1', conn.host)
     self.assertNotEqual(3200, conn.port)
     self.assertEqual('testdb', conn.schema)
Exemple #2
0
    def test_get_conn_basic_auth(self, mock_get_connection, mock_connect,
                                 mock_basic_auth):
        mock_get_connection.return_value = Connection(login='******',
                                                      password='******',
                                                      host='host',
                                                      schema='hive')

        conn = PrestoHook().get_conn()
        mock_connect.assert_called_once_with(
            catalog='hive',
            host='host',
            port=None,
            http_scheme='http',
            schema='hive',
            source='airflow',
            user='******',
            isolation_level=0,
            auth=mock_basic_auth.return_value,
        )
        mock_basic_auth.assert_called_once_with('login', 'password')
        assert mock_connect.return_value == conn
Exemple #3
0
def mock_livy_session_responses(
    mocker: Mock,
    mock_create: Iterable[MockedResponse] = None,
    mock_get_session: Iterable[MockedResponse] = None,
    mock_post_statement: Iterable[MockedResponse] = None,
    mock_get_statement: Iterable[MockedResponse] = None,
    log_lines=5,
    log_override_response=None,
    mock_delete: int = None,
):
    _mock_create_response(mock_create)
    _mock_get_session_response(mock_get_session)
    _mock_post_statement_response(mock_post_statement)
    _mock_get_statement_response(mock_get_statement)
    _mock_log_response(log_override_response, log_lines)
    _mock_delete_response(mock_delete)
    mocker.patch.object(
        HttpHook,
        "get_connection",
        return_value=Connection(host=HOST, port=PORT),
    )
Exemple #4
0
 def test_serialize(self, session):
     connection_model = Connection(
         conn_id='mysql_default',
         conn_type='mysql',
         host='mysql',
         login='******',
         schema='testschema',
         port=80,
     )
     session.add(connection_model)
     session.commit()
     connection_model = session.query(Connection).first()
     deserialized_connection = connection_collection_item_schema.dump(connection_model)
     assert deserialized_connection == {
         'connection_id': "mysql_default",
         'conn_type': 'mysql',
         'host': 'mysql',
         'login': '******',
         'schema': 'testschema',
         'port': 80,
     }
 def test_create_operator_with_wrong_parameters(self, project_id, location,
                                                instance_name,
                                                database_type, use_proxy,
                                                use_ssl, sql, message,
                                                get_connections):
     connection = Connection()
     connection.parse_from_uri(
         "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type={database_type}&"
         "project_id={project_id}&location={location}&instance={instance_name}&"
         "use_proxy={use_proxy}&use_ssl={use_ssl}".format(
             database_type=database_type,
             project_id=project_id,
             location=location,
             instance_name=instance_name,
             use_proxy=use_proxy,
             use_ssl=use_ssl))
     get_connections.return_value = [connection]
     with self.assertRaises(AirflowException) as cm:
         CloudSqlQueryOperator(sql=sql, task_id='task_id')
     err = cm.exception
     self.assertIn(message, str(err))
Exemple #6
0
def mock_livy_batch_responses(
    mocker: Mock,
    mock_create: Iterable[MockedResponse] = None,
    mock_get: Iterable[MockedResponse] = None,
    mock_spark: Iterable[MockedResponse] = None,
    mock_yarn: Iterable[MockedResponse] = None,
    log_lines=5,
    log_override_response=None,
    mock_delete: int = None,
):
    _mock_create_response(mock_create)
    _mock_get_response(mock_get)
    _mock_spark_response(mock_spark)
    _mock_yarn_response(mock_yarn)
    _mock_log_response(log_override_response, log_lines)
    _mock_delete_response(mock_delete)
    mocker.patch.object(
        BaseHook,
        "_get_connections_from_db",
        return_value=[Connection(host=HOST, port=PORT)],
    )
Exemple #7
0
    def registered(self, driver, frameworkId, masterInfo):
        self.log.info("AirflowScheduler registered to Mesos with framework ID %s", frameworkId.value)

        if configuration.conf.getboolean('mesos', 'CHECKPOINT') and \
                configuration.conf.get('mesos', 'FAILOVER_TIMEOUT'):
            # Import here to work around a circular import error
            from airflow.models import Connection

            # Update the Framework ID in the database.
            session = Session()
            conn_id = FRAMEWORK_CONNID_PREFIX + get_framework_name()
            connection = Session.query(Connection).filter_by(conn_id=conn_id).first()
            if connection is None:
                connection = Connection(conn_id=conn_id, conn_type='mesos_framework-id',
                                        extra=frameworkId.value)
            else:
                connection.extra = frameworkId.value

            session.add(connection)
            session.commit()
            Session.remove()
Exemple #8
0
 def get_connection(self, connection_id):
     try:
         conn = super().get_connection(connection_id)
         if not conn.schema:
             self.log.warn(
                 "Connection schema for {} was not set, setting https".
                 format(connection_id))
             conn.schema = "https"
         return conn
     except AirflowException:
         self.log.warn(
             "Didn't find connection with ID: {} - falling back to configuration"
             .format(connection_id))
         hopsworks_host = configuration.conf.get("webserver",
                                                 "hopsworks_host")
         hopsworks_port = configuration.conf.getint("webserver",
                                                    "hopsworks_port")
         return Connection(conn_id=connection_id,
                           schema="https",
                           host=self._parse_host(hopsworks_host),
                           port=hopsworks_port)
def provide_wasb_default_connection(key_file_path: str):
    """
    Context manager to provide a temporary value for wasb_default connection

    :param key_file_path: Path to file with wasb_default credentials .json file.
    :type key_file_path: str
    """
    if not key_file_path.endswith(".json"):
        raise AirflowException("Use a JSON key file.")
    with open(key_file_path) as credentials:
        creds = json.load(credentials)
    conn = Connection(
        conn_id=WASB_CONNECTION_ID,
        conn_type="wasb",
        host=creds.get("host", None),
        login=creds.get("login", None),
        password=creds.get("password", None),
        extra=json.dumps(creds.get('extra', None)),
    )
    with patch_environ({f"AIRFLOW_CONN_{conn.conn_id.upper()}": conn.get_uri()}):
        yield
Exemple #10
0
 def setUp(self):
     # set up some test variables
     self.test_end_point = 'https://test_endpoint:443'
     self.test_master_key = 'magic_test_key'
     self.test_database_name = 'test_database_name'
     self.test_collection_name = 'test_collection_name'
     self.test_database_default = 'test_database_default'
     self.test_collection_default = 'test_collection_default'
     db.merge_conn(
         Connection(
             conn_id='azure_cosmos_test_key_id',
             conn_type='azure_cosmos',
             login=self.test_end_point,
             password=self.test_master_key,
             extra=json.dumps({
                 'database_name':
                 self.test_database_default,
                 'collection_name':
                 self.test_collection_default,
             }),
         ))
Exemple #11
0
def create_aws_credentials(dwh):
    """ Creates and configures an AWS connection.
    
    
    Parameters
    ----------
    dwh: dict
        A dictionary with the required parameters.
    
    Returns
    -------
    conn: airflow.models.Connection
        An instance of airflow.models.Connection.
    """
    conn = Connection(
        conn_id="aws_credentials",
        conn_type="aws",
        login=dwh['aws']['access_key_id'],
        password=dwh['aws']['secret_access_key'],
    )

    return conn
    def test_cli_delete_connections(self, session=None):
        merge_conn(Connection(conn_id="new1",
                              conn_type="mysql",
                              host="mysql",
                              login="******",
                              password="",
                              schema="airflow"),
                   session=session)
        # Delete connections
        with redirect_stdout(io.StringIO()) as stdout:
            connection_command.connections_delete(
                self.parser.parse_args(["connections", "delete", "new1"]))
            stdout = stdout.getvalue()

        # Check deletion stdout
        self.assertIn("\tSuccessfully deleted `conn_id`=new1", stdout)

        # Check deletions
        result = session.query(Connection).filter(
            Connection.conn_id == "new1").first()

        self.assertTrue(result is None)
    def setUp(self):
        db.merge_conn(
            Connection(conn_id='azure_container_instance_test',
                       conn_type='azure_container_instances',
                       login='******',
                       password='******',
                       extra=json.dumps({
                           'tenantId': 'tenant_id',
                           'subscriptionId': 'subscription_id'
                       })))

        self.resources = ResourceRequirements(
            requests=ResourceRequests(memory_in_gb='4', cpu='1'))
        with patch(
                'azure.common.credentials.ServicePrincipalCredentials.__init__',
                autospec=True,
                return_value=None):
            with patch(
                    'azure.mgmt.containerinstance.ContainerInstanceManagementClient'
            ):
                self.hook = AzureContainerInstanceHook(
                    conn_id='azure_container_instance_test')
Exemple #14
0
    def load_connections(
        config: dict,
        session: Session = None,
    ):
        connections = config.get("connections", None)
        if connections is None:
            log.info("No connections found, skipping")
            return

        log.info("Loading variabels from config...")
        for key in connections.keys():
            val: dict = connections.get(key)
            if not isinstance(val, dict):
                log.warn(
                    f"Connection {key} skipped. Value must be a dictionary.")

            connection = session.query(Connection).filter_by(
                conn_id=key).first()
            if connection is not None:
                log.info(f"Connection exists, skipping: {key}")
                continue

            log.info("Setting connection: " + key)
            extra = val.get("extra", None)
            if extra is not None and not isinstance(extra, (int, str)):
                extra = json.dumps(extra)

            connection = Connection(
                conn_id=key,
                conn_type=val.get("conn_type", None),
                host=val.get("host", None),
                login=val.get("login", None),
                password=val.get("password", None),
                schema=val.get("schema", None),
                port=val.get("port", None),
                extra=extra,
            )
            session.add(connection)
        session.commit()
Exemple #15
0
 def setUpClass(cls):
     db.merge_conn(
         Connection(conn_id='livy_default',
                    conn_type='http',
                    host='host',
                    schema='http',
                    port='8998'))
     db.merge_conn(
         Connection(conn_id='default_port',
                    conn_type='http',
                    host='http://host'))
     db.merge_conn(
         Connection(conn_id='default_protocol',
                    conn_type='http',
                    host='host'))
     db.merge_conn(
         Connection(conn_id='port_set',
                    host='host',
                    conn_type='http',
                    port=1234))
     db.merge_conn(
         Connection(conn_id='schema_set',
                    host='host',
                    conn_type='http',
                    schema='zzz'))
     db.merge_conn(
         Connection(conn_id='dont_override_schema',
                    conn_type='http',
                    host='http://host',
                    schema='zzz'))
     db.merge_conn(
         Connection(conn_id='missing_host', conn_type='http', port=1234))
     db.merge_conn(
         Connection(conn_id='invalid_uri',
                    conn_type='http',
                    uri='http://invalid_uri:4321'))
Exemple #16
0
    def test_connection_extra_with_encryption_rotate_fernet_key(self, mock_get):
        """
        Tests rotating encrypted extras.
        """
        key1 = Fernet.generate_key()
        key2 = Fernet.generate_key()

        mock_get.return_value = key1.decode()
        test_connection = Connection(extra='testextra')
        self.assertTrue(test_connection.is_extra_encrypted)
        self.assertEqual(test_connection.extra, 'testextra')
        self.assertEqual(Fernet(key1).decrypt(test_connection._extra.encode()), b'testextra')

        # Test decrypt of old value with new key
        mock_get.return_value = ','.join([key2.decode(), key1.decode()])
        crypto._fernet = None
        self.assertEqual(test_connection.extra, 'testextra')

        # Test decrypt of new value with new key
        test_connection.rotate_fernet_key()
        self.assertTrue(test_connection.is_extra_encrypted)
        self.assertEqual(test_connection.extra, 'testextra')
        self.assertEqual(Fernet(key2).decrypt(test_connection._extra.encode()), b'testextra')
Exemple #17
0
 def test_serialize(self, session):
     connection_model = Connection(conn_id='mysql_default',
                                   conn_type='mysql',
                                   host='mysql',
                                   login='******',
                                   schema='testschema',
                                   port=80,
                                   password='******',
                                   extra="{'key':'string'}")
     session.add(connection_model)
     session.commit()
     connection_model = session.query(Connection).first()
     deserialized_connection = connection_schema.dump(connection_model)
     self.assertEqual(
         deserialized_connection[0], {
             'connection_id': "mysql_default",
             'conn_type': 'mysql',
             'host': 'mysql',
             'login': '******',
             'schema': 'testschema',
             'port': 80,
             'extra': "{'key':'string'}"
         })
Exemple #18
0
def create_redshift_connection(dwh):
    """ Creates and configures a Redshift connection.
    
    
    Parameters
    ----------
    dwh: dict
        A dictionary with the required parameters.
    
    Returns
    -------
    conn: airflow.models.Connection
        An instance of airflow.models.Connection.
    """
    conn = Connection(conn_id="redshift",
                      conn_type="postgresql",
                      host=dwh['redshift']['host'],
                      schema=dwh['redshift']['db_name'],
                      login=dwh['redshift']['db_user'],
                      password=dwh['redshift']['db_pass'],
                      port=dwh['redshift']['db_port'])

    return conn
    def test_connection_extra_with_encryption_rotate_fernet_key(self):
        """
        Tests rotating encrypted extras.
        """
        key1 = Fernet.generate_key()
        key2 = Fernet.generate_key()

        with conf_vars({('core', 'fernet_key'): key1.decode()}):
            test_connection = Connection(extra='testextra')
            assert test_connection.is_extra_encrypted
            assert test_connection.extra == 'testextra'
            assert Fernet(key1).decrypt(test_connection._extra.encode()) == b'testextra'

        # Test decrypt of old value with new key
        with conf_vars({('core', 'fernet_key'): ','.join([key2.decode(), key1.decode()])}):
            crypto._fernet = None
            assert test_connection.extra == 'testextra'

            # Test decrypt of new value with new key
            test_connection.rotate_fernet_key()
            assert test_connection.is_extra_encrypted
            assert test_connection.extra == 'testextra'
            assert Fernet(key2).decrypt(test_connection._extra.encode()) == b'testextra'
Exemple #20
0
    def test_run_example_gcp_vision_autogenerated_id_dag(self):
        mock_connection = Connection(
            conn_type="aws",
            extra=json.dumps({
                "role_arn":
                ROLE_ANR,
                "assume_role_method":
                "assume_role_with_web_identity",
                "assume_role_with_web_identity_federation":
                'google',
                "assume_role_with_web_identity_federation_audience":
                AUDIENCE,
            }),
        )

        with mock.patch.dict(
                'os.environ',
                AIRFLOW_CONN_AWS_DEFAULT=mock_connection.get_uri()):
            hook = AwsBaseHook(client_type='s3')

            client = hook.get_conn()
            response = client.list_buckets()
            assert 'Buckets' in response
def provide_facebook_connection(key_file_path: str):
    """
    Context manager that provides a temporary value of AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT
    connection. It build a new connection that includes path to provided service json,
    required scopes and project id.

    :param key_file_path: Path to file with FACEBOOK credentials .json file.
    :type key_file_path: str
    """
    if not key_file_path.endswith(".json"):
        raise AirflowException("Use a JSON key file.")
    with open(key_file_path) as credentials:
        creds = json.load(credentials)
    missing_keys = CONFIG_REQUIRED_FIELDS - creds.keys()
    if missing_keys:
        message = f"{missing_keys} fields are missing"
        raise AirflowException(message)
    conn = Connection(conn_id=FACEBOOK_CONNECTION_ID,
                      conn_type=CONNECTION_TYPE,
                      extra=json.dumps(creds))
    with patch_environ(
        {f"AIRFLOW_CONN_{conn.conn_id.upper()}": conn.get_uri()}):
        yield
Exemple #22
0
    def test_connection_extra_with_encryption_rotate_fernet_key(self):
        """
        Tests rotating encrypted extras.
        """
        key1 = Fernet.generate_key()
        key2 = Fernet.generate_key()

        with conf_vars({('core', 'FERNET_KEY'): key1.decode()}):
            test_connection = Connection(extra='testextra')
            self.assertTrue(test_connection.is_extra_encrypted)
            self.assertEqual(test_connection.extra, 'testextra')
            self.assertEqual(Fernet(key1).decrypt(test_connection._extra.encode()), b'testextra')

        # Test decrypt of old value with new key
        with conf_vars({('core', 'FERNET_KEY'): ','.join([key2.decode(), key1.decode()])}):
            crypto._fernet = None
            self.assertEqual(test_connection.extra, 'testextra')

            # Test decrypt of new value with new key
            test_connection.rotate_fernet_key()
            self.assertTrue(test_connection.is_extra_encrypted)
            self.assertEqual(test_connection.extra, 'testextra')
            self.assertEqual(Fernet(key2).decrypt(test_connection._extra.encode()), b'testextra')
 def test_create_operator_with_correct_parameters_mysql_proxy_socket(
         self, get_connections):
     connection = Connection()
     connection.parse_from_uri(
         "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type=mysql&"
         "project_id=example-project&location=europe-west1&instance=testdb&"
         "use_proxy=True&sql_proxy_use_tcp=False")
     get_connections.return_value = [connection]
     operator = CloudSqlQueryOperator(sql=['SELECT * FROM TABLE'],
                                      task_id='task_id')
     operator.cloudsql_db_hook.create_connection()
     try:
         db_hook = operator.cloudsql_db_hook.get_database_hook()
         conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0]
     finally:
         operator.cloudsql_db_hook.delete_connection()
     self.assertEqual('mysql', conn.conn_type)
     self.assertEqual('localhost', conn.host)
     self.assertIn('/tmp', conn.extra_dejson['unix_socket'])
     self.assertIn('example-project:europe-west1:testdb',
                   conn.extra_dejson['unix_socket'])
     self.assertIsNone(conn.port)
     self.assertEqual('testdb', conn.schema)
 def test_should_response_200(self, session):
     connection_model = Connection(conn_id='test-connection-id',
                                   conn_type='mysql',
                                   host='mysql',
                                   login='******',
                                   schema='testschema',
                                   port=80)
     session.add(connection_model)
     session.commit()
     result = session.query(Connection).all()
     assert len(result) == 1
     response = self.client.get("/api/v1/connections/test-connection-id")
     assert response.status_code == 200
     self.assertEqual(
         response.json,
         {
             "connection_id": "test-connection-id",
             "conn_type": 'mysql',
             "host": 'mysql',
             "login": '******',
             'schema': 'testschema',
             'port': 80
         },
     )
Exemple #25
0
def create_postgres_connection(conn_id,
                               host,
                               schema,
                               login,
                               password,
                               port=5432,
                               **kwargs):
    """
    Creates a Postgres Connection (for Hooks to use) for the Airflow session

    :param conn_id: Str - name of Airflow Connection
    :param host: Str - db host
    :param schema: Str - db name
    :param login: Str - db login
    :param password: Str - db password
    :param port: Str - db port
    :param kwargs: Dict - keyword arguments
    :return: None
    """
    logging.info('Creating Postgres Connection...')

    # Create connection to our Postgres instance
    pg_connection = Connection(conn_id=conn_id,
                               conn_type='postgres',
                               host=host,
                               schema=schema,
                               login=login,
                               password=password,
                               port=port)

    # Add the Connection to the Airflow session
    session = settings.Session()
    session.add(pg_connection)
    session.commit()
    logging.info('Successfully created Postgres Connection')
    session.close()
def test_connection() -> APIResponse:
    """
    To test a connection, this method first creates an in-memory dummy conn_id & exports that to an
    env var, as some hook classes tries to find out the conn from their __init__ method & errors out
    if not found. It also deletes the conn id env variable after the test.
    """
    body = request.json
    dummy_conn_id = get_random_string()
    conn_env_var = f'{CONN_ENV_PREFIX}{dummy_conn_id.upper()}'
    try:
        data = connection_schema.load(body)
        data['conn_id'] = dummy_conn_id
        conn = Connection(**data)
        os.environ[conn_env_var] = conn.get_uri()
        status, message = conn.test_connection()
        return connection_test_schema.dump({
            "status": status,
            "message": message
        })
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))
    finally:
        if conn_env_var in os.environ:
            del os.environ[conn_env_var]
Exemple #27
0
def add_docker_connection(ds, **kwargs):
    """"Add a airflow connection for google container registry"""
    new_conn = Connection(
        conn_id="gcr_docker_connection",  # TODO: parameterize
        conn_type="docker",
        host="gcr.io/wam-bam-258119",  # TODO: parameterize
        login="******",  # TODO: parameterize
    )

    # save contents of service account key into encrypted password field
    with open("service_account.json", "r") as file:
        data = file.read().replace("\n", "")  # replace new lines
        new_conn.set_password(data)

    session = settings.Session()
    if not (
        session.query(Connection).filter(Connection.conn_id == new_conn.conn_id).first()
    ):
        session.add(new_conn)
        session.commit()
    else:
        msg = "\n\tA connection with `conn_id`={conn_id} already exists\n"
        msg = msg.format(conn_id=new_conn.conn_id)
        print(msg)
Exemple #28
0
def create_default_connections(session=None):
    """Create default Airflow connections."""
    merge_conn(
        Connection(
            conn_id="airflow_db",
            conn_type="mysql",
            host="mysql",
            login="******",
            password="",
            schema="airflow",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="aws_default",
            conn_type="aws",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="azure_batch_default",
            conn_type="azure_batch",
            login="******",
            password="",
            extra='''{"account_url": "<ACCOUNT_URL>"}''',
        ))
    merge_conn(
        Connection(
            conn_id="azure_container_instances_default",
            conn_type="azure_container_instances",
            extra=
            '{"tenantId": "<TENANT>", "subscriptionId": "<SUBSCRIPTION ID>" }',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="azure_cosmos_default",
            conn_type="azure_cosmos",
            extra=
            '{"database_name": "<DATABASE_NAME>", "collection_name": "<COLLECTION_NAME>" }',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id='azure_data_explorer_default',
            conn_type='azure_data_explorer',
            host='https://<CLUSTER>.kusto.windows.net',
            extra=
            '''{"auth_method": "<AAD_APP | AAD_APP_CERT | AAD_CREDS | AAD_DEVICE>",
                    "tenant": "<TENANT ID>", "certificate": "<APPLICATION PEM CERTIFICATE>",
                    "thumbprint": "<APPLICATION CERTIFICATE THUMBPRINT>"}''',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="azure_data_lake_default",
            conn_type="azure_data_lake",
            extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="cassandra_default",
            conn_type="cassandra",
            host="cassandra",
            port=9042,
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="databricks_default",
            conn_type="databricks",
            host="localhost",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="dingding_default",
            conn_type="http",
            host="",
            password="",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="druid_broker_default",
            conn_type="druid",
            host="druid-broker",
            port=8082,
            extra='{"endpoint": "druid/v2/sql"}',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="druid_ingest_default",
            conn_type="druid",
            host="druid-overlord",
            port=8081,
            extra='{"endpoint": "druid/indexer/v1/task"}',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="elasticsearch_default",
            conn_type="elasticsearch",
            host="localhost",
            schema="http",
            port=9200,
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="emr_default",
            conn_type="emr",
            extra="""
                {   "Name": "default_job_flow_name",
                    "LogUri": "s3://my-emr-log-bucket/default_job_flow_location",
                    "ReleaseLabel": "emr-4.6.0",
                    "Instances": {
                        "Ec2KeyName": "mykey",
                        "Ec2SubnetId": "somesubnet",
                        "InstanceGroups": [
                            {
                                "Name": "Master nodes",
                                "Market": "ON_DEMAND",
                                "InstanceRole": "MASTER",
                                "InstanceType": "r3.2xlarge",
                                "InstanceCount": 1
                            },
                            {
                                "Name": "Slave nodes",
                                "Market": "ON_DEMAND",
                                "InstanceRole": "CORE",
                                "InstanceType": "r3.2xlarge",
                                "InstanceCount": 1
                            }
                        ],
                        "TerminationProtected": false,
                        "KeepJobFlowAliveWhenNoSteps": false
                    },
                    "Applications":[
                        { "Name": "Spark" }
                    ],
                    "VisibleToAllUsers": true,
                    "JobFlowRole": "EMR_EC2_DefaultRole",
                    "ServiceRole": "EMR_DefaultRole",
                    "Tags": [
                        {
                            "Key": "app",
                            "Value": "analytics"
                        },
                        {
                            "Key": "environment",
                            "Value": "development"
                        }
                    ]
                }
            """,
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="facebook_default",
            conn_type="facebook_social",
            extra="""
                {   "account_id": "<AD_ACCOUNT_ID>",
                    "app_id": "<FACEBOOK_APP_ID>",
                    "app_secret": "<FACEBOOK_APP_SECRET>",
                    "access_token": "<FACEBOOK_AD_ACCESS_TOKEN>"
                }
            """,
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="fs_default",
            conn_type="fs",
            extra='{"path": "/"}',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="google_cloud_default",
            conn_type="google_cloud_platform",
            schema="default",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="hive_cli_default",
            conn_type="hive_cli",
            port=10000,
            host="localhost",
            extra='{"use_beeline": true, "auth": ""}',
            schema="default",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="hiveserver2_default",
            conn_type="hiveserver2",
            host="localhost",
            schema="default",
            port=10000,
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="http_default",
            conn_type="http",
            host="https://www.httpbin.org/",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id='kubernetes_default',
            conn_type='kubernetes',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id='kylin_default',
            conn_type='kylin',
            host='localhost',
            port=7070,
            login="******",
            password="******",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="leveldb_default",
            conn_type="leveldb",
            host="localhost",
        ),
        session,
    )
    merge_conn(
        Connection(conn_id="livy_default",
                   conn_type="livy",
                   host="livy",
                   port=8998), session)
    merge_conn(
        Connection(
            conn_id="local_mysql",
            conn_type="mysql",
            host="localhost",
            login="******",
            password="******",
            schema="airflow",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="metastore_default",
            conn_type="hive_metastore",
            host="localhost",
            extra='{"authMechanism": "PLAIN"}',
            port=9083,
        ),
        session,
    )
    merge_conn(
        Connection(conn_id="mongo_default",
                   conn_type="mongo",
                   host="mongo",
                   port=27017), session)
    merge_conn(
        Connection(
            conn_id="mssql_default",
            conn_type="mssql",
            host="localhost",
            port=1433,
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="mysql_default",
            conn_type="mysql",
            login="******",
            schema="airflow",
            host="mysql",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="opsgenie_default",
            conn_type="http",
            host="",
            password="",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="pig_cli_default",
            conn_type="pig_cli",
            schema="default",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="pinot_admin_default",
            conn_type="pinot",
            host="localhost",
            port=9000,
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="pinot_broker_default",
            conn_type="pinot",
            host="localhost",
            port=9000,
            extra='{"endpoint": "/query", "schema": "http"}',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="postgres_default",
            conn_type="postgres",
            login="******",
            password="******",
            schema="airflow",
            host="postgres",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="presto_default",
            conn_type="presto",
            host="localhost",
            schema="hive",
            port=3400,
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="qubole_default",
            conn_type="qubole",
            host="localhost",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="redis_default",
            conn_type="redis",
            host="redis",
            port=6379,
            extra='{"db": 0}',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="segment_default",
            conn_type="segment",
            extra='{"write_key": "my-segment-write-key"}',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="sftp_default",
            conn_type="sftp",
            host="localhost",
            port=22,
            login="******",
            extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="spark_default",
            conn_type="spark",
            host="yarn",
            extra='{"queue": "root.default"}',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="sqlite_default",
            conn_type="sqlite",
            host="/tmp/sqlite_default.db",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="sqoop_default",
            conn_type="sqoop",
            host="rdbms",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="ssh_default",
            conn_type="ssh",
            host="localhost",
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="tableau_default",
            conn_type="tableau",
            host="https://tableau.server.url",
            login="******",
            password="******",
            extra='{"site_id": "my_site"}',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="vertica_default",
            conn_type="vertica",
            host="localhost",
            port=5433,
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="wasb_default",
            conn_type="wasb",
            extra='{"sas_token": null}',
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id="webhdfs_default",
            conn_type="hdfs",
            host="localhost",
            port=50070,
        ),
        session,
    )
    merge_conn(
        Connection(
            conn_id='yandexcloud_default',
            conn_type='yandexcloud',
            schema='default',
        ),
        session,
    )
Exemple #29
0
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import mock
from parameterized import parameterized

from airflow.models import Connection
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
from airflow.providers.microsoft.mssql.operators.mssql import MsSqlOperator
from airflow.providers.odbc.hooks.odbc import OdbcHook

ODBC_CONN = Connection(
    conn_id='test-odbc',
    conn_type='odbc',
)
PYMSSQL_CONN = Connection(
    conn_id='test-pymssql',
    conn_type='anything',
)


class TestMsSqlOperator:
    @parameterized.expand([(ODBC_CONN, OdbcHook), (PYMSSQL_CONN, MsSqlHook)])
    @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
    def test_get_hook(self, conn, hook_class, get_connection):
        """
        Operator should use odbc hook if conn type is ``odbc`` and pymssql-based hook otherwise.
        """
Exemple #30
0
 def _get_connection_from_env(cls, conn_id):
     environment_uri = os.environ.get(CONN_ENV_PREFIX + conn_id.upper())
     conn = None
     if environment_uri:
         conn = Connection(conn_id=conn_id, uri=environment_uri)
     return conn