def test_create_operator_with_correct_parameters_postgres_proxy_tcp( self, get_connections): connection = Connection() connection.parse_from_uri( "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type=postgres&" "project_id=example-project&location=europe-west1&instance=testdb&" "use_proxy=True&sql_proxy_use_tcp=True") get_connections.return_value = [connection] operator = CloudSqlQueryOperator(sql=['SELECT * FROM TABLE'], task_id='task_id') operator.cloudsql_db_hook.create_connection() try: db_hook = operator.cloudsql_db_hook.get_database_hook() conn = db_hook._get_connections_from_db( db_hook.postgres_conn_id)[0] finally: operator.cloudsql_db_hook.delete_connection() self.assertEqual('postgres', conn.conn_type) self.assertEqual('127.0.0.1', conn.host) self.assertNotEqual(3200, conn.port) self.assertEqual('testdb', conn.schema)
def test_get_conn_basic_auth(self, mock_get_connection, mock_connect, mock_basic_auth): mock_get_connection.return_value = Connection(login='******', password='******', host='host', schema='hive') conn = PrestoHook().get_conn() mock_connect.assert_called_once_with( catalog='hive', host='host', port=None, http_scheme='http', schema='hive', source='airflow', user='******', isolation_level=0, auth=mock_basic_auth.return_value, ) mock_basic_auth.assert_called_once_with('login', 'password') assert mock_connect.return_value == conn
def mock_livy_session_responses( mocker: Mock, mock_create: Iterable[MockedResponse] = None, mock_get_session: Iterable[MockedResponse] = None, mock_post_statement: Iterable[MockedResponse] = None, mock_get_statement: Iterable[MockedResponse] = None, log_lines=5, log_override_response=None, mock_delete: int = None, ): _mock_create_response(mock_create) _mock_get_session_response(mock_get_session) _mock_post_statement_response(mock_post_statement) _mock_get_statement_response(mock_get_statement) _mock_log_response(log_override_response, log_lines) _mock_delete_response(mock_delete) mocker.patch.object( HttpHook, "get_connection", return_value=Connection(host=HOST, port=PORT), )
def test_serialize(self, session): connection_model = Connection( conn_id='mysql_default', conn_type='mysql', host='mysql', login='******', schema='testschema', port=80, ) session.add(connection_model) session.commit() connection_model = session.query(Connection).first() deserialized_connection = connection_collection_item_schema.dump(connection_model) assert deserialized_connection == { 'connection_id': "mysql_default", 'conn_type': 'mysql', 'host': 'mysql', 'login': '******', 'schema': 'testschema', 'port': 80, }
def test_create_operator_with_wrong_parameters(self, project_id, location, instance_name, database_type, use_proxy, use_ssl, sql, message, get_connections): connection = Connection() connection.parse_from_uri( "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type={database_type}&" "project_id={project_id}&location={location}&instance={instance_name}&" "use_proxy={use_proxy}&use_ssl={use_ssl}".format( database_type=database_type, project_id=project_id, location=location, instance_name=instance_name, use_proxy=use_proxy, use_ssl=use_ssl)) get_connections.return_value = [connection] with self.assertRaises(AirflowException) as cm: CloudSqlQueryOperator(sql=sql, task_id='task_id') err = cm.exception self.assertIn(message, str(err))
def mock_livy_batch_responses( mocker: Mock, mock_create: Iterable[MockedResponse] = None, mock_get: Iterable[MockedResponse] = None, mock_spark: Iterable[MockedResponse] = None, mock_yarn: Iterable[MockedResponse] = None, log_lines=5, log_override_response=None, mock_delete: int = None, ): _mock_create_response(mock_create) _mock_get_response(mock_get) _mock_spark_response(mock_spark) _mock_yarn_response(mock_yarn) _mock_log_response(log_override_response, log_lines) _mock_delete_response(mock_delete) mocker.patch.object( BaseHook, "_get_connections_from_db", return_value=[Connection(host=HOST, port=PORT)], )
def registered(self, driver, frameworkId, masterInfo): self.log.info("AirflowScheduler registered to Mesos with framework ID %s", frameworkId.value) if configuration.conf.getboolean('mesos', 'CHECKPOINT') and \ configuration.conf.get('mesos', 'FAILOVER_TIMEOUT'): # Import here to work around a circular import error from airflow.models import Connection # Update the Framework ID in the database. session = Session() conn_id = FRAMEWORK_CONNID_PREFIX + get_framework_name() connection = Session.query(Connection).filter_by(conn_id=conn_id).first() if connection is None: connection = Connection(conn_id=conn_id, conn_type='mesos_framework-id', extra=frameworkId.value) else: connection.extra = frameworkId.value session.add(connection) session.commit() Session.remove()
def get_connection(self, connection_id): try: conn = super().get_connection(connection_id) if not conn.schema: self.log.warn( "Connection schema for {} was not set, setting https". format(connection_id)) conn.schema = "https" return conn except AirflowException: self.log.warn( "Didn't find connection with ID: {} - falling back to configuration" .format(connection_id)) hopsworks_host = configuration.conf.get("webserver", "hopsworks_host") hopsworks_port = configuration.conf.getint("webserver", "hopsworks_port") return Connection(conn_id=connection_id, schema="https", host=self._parse_host(hopsworks_host), port=hopsworks_port)
def provide_wasb_default_connection(key_file_path: str): """ Context manager to provide a temporary value for wasb_default connection :param key_file_path: Path to file with wasb_default credentials .json file. :type key_file_path: str """ if not key_file_path.endswith(".json"): raise AirflowException("Use a JSON key file.") with open(key_file_path) as credentials: creds = json.load(credentials) conn = Connection( conn_id=WASB_CONNECTION_ID, conn_type="wasb", host=creds.get("host", None), login=creds.get("login", None), password=creds.get("password", None), extra=json.dumps(creds.get('extra', None)), ) with patch_environ({f"AIRFLOW_CONN_{conn.conn_id.upper()}": conn.get_uri()}): yield
def setUp(self): # set up some test variables self.test_end_point = 'https://test_endpoint:443' self.test_master_key = 'magic_test_key' self.test_database_name = 'test_database_name' self.test_collection_name = 'test_collection_name' self.test_database_default = 'test_database_default' self.test_collection_default = 'test_collection_default' db.merge_conn( Connection( conn_id='azure_cosmos_test_key_id', conn_type='azure_cosmos', login=self.test_end_point, password=self.test_master_key, extra=json.dumps({ 'database_name': self.test_database_default, 'collection_name': self.test_collection_default, }), ))
def create_aws_credentials(dwh): """ Creates and configures an AWS connection. Parameters ---------- dwh: dict A dictionary with the required parameters. Returns ------- conn: airflow.models.Connection An instance of airflow.models.Connection. """ conn = Connection( conn_id="aws_credentials", conn_type="aws", login=dwh['aws']['access_key_id'], password=dwh['aws']['secret_access_key'], ) return conn
def test_cli_delete_connections(self, session=None): merge_conn(Connection(conn_id="new1", conn_type="mysql", host="mysql", login="******", password="", schema="airflow"), session=session) # Delete connections with redirect_stdout(io.StringIO()) as stdout: connection_command.connections_delete( self.parser.parse_args(["connections", "delete", "new1"])) stdout = stdout.getvalue() # Check deletion stdout self.assertIn("\tSuccessfully deleted `conn_id`=new1", stdout) # Check deletions result = session.query(Connection).filter( Connection.conn_id == "new1").first() self.assertTrue(result is None)
def setUp(self): db.merge_conn( Connection(conn_id='azure_container_instance_test', conn_type='azure_container_instances', login='******', password='******', extra=json.dumps({ 'tenantId': 'tenant_id', 'subscriptionId': 'subscription_id' }))) self.resources = ResourceRequirements( requests=ResourceRequests(memory_in_gb='4', cpu='1')) with patch( 'azure.common.credentials.ServicePrincipalCredentials.__init__', autospec=True, return_value=None): with patch( 'azure.mgmt.containerinstance.ContainerInstanceManagementClient' ): self.hook = AzureContainerInstanceHook( conn_id='azure_container_instance_test')
def load_connections( config: dict, session: Session = None, ): connections = config.get("connections", None) if connections is None: log.info("No connections found, skipping") return log.info("Loading variabels from config...") for key in connections.keys(): val: dict = connections.get(key) if not isinstance(val, dict): log.warn( f"Connection {key} skipped. Value must be a dictionary.") connection = session.query(Connection).filter_by( conn_id=key).first() if connection is not None: log.info(f"Connection exists, skipping: {key}") continue log.info("Setting connection: " + key) extra = val.get("extra", None) if extra is not None and not isinstance(extra, (int, str)): extra = json.dumps(extra) connection = Connection( conn_id=key, conn_type=val.get("conn_type", None), host=val.get("host", None), login=val.get("login", None), password=val.get("password", None), schema=val.get("schema", None), port=val.get("port", None), extra=extra, ) session.add(connection) session.commit()
def setUpClass(cls): db.merge_conn( Connection(conn_id='livy_default', conn_type='http', host='host', schema='http', port='8998')) db.merge_conn( Connection(conn_id='default_port', conn_type='http', host='http://host')) db.merge_conn( Connection(conn_id='default_protocol', conn_type='http', host='host')) db.merge_conn( Connection(conn_id='port_set', host='host', conn_type='http', port=1234)) db.merge_conn( Connection(conn_id='schema_set', host='host', conn_type='http', schema='zzz')) db.merge_conn( Connection(conn_id='dont_override_schema', conn_type='http', host='http://host', schema='zzz')) db.merge_conn( Connection(conn_id='missing_host', conn_type='http', port=1234)) db.merge_conn( Connection(conn_id='invalid_uri', conn_type='http', uri='http://invalid_uri:4321'))
def test_connection_extra_with_encryption_rotate_fernet_key(self, mock_get): """ Tests rotating encrypted extras. """ key1 = Fernet.generate_key() key2 = Fernet.generate_key() mock_get.return_value = key1.decode() test_connection = Connection(extra='testextra') self.assertTrue(test_connection.is_extra_encrypted) self.assertEqual(test_connection.extra, 'testextra') self.assertEqual(Fernet(key1).decrypt(test_connection._extra.encode()), b'testextra') # Test decrypt of old value with new key mock_get.return_value = ','.join([key2.decode(), key1.decode()]) crypto._fernet = None self.assertEqual(test_connection.extra, 'testextra') # Test decrypt of new value with new key test_connection.rotate_fernet_key() self.assertTrue(test_connection.is_extra_encrypted) self.assertEqual(test_connection.extra, 'testextra') self.assertEqual(Fernet(key2).decrypt(test_connection._extra.encode()), b'testextra')
def test_serialize(self, session): connection_model = Connection(conn_id='mysql_default', conn_type='mysql', host='mysql', login='******', schema='testschema', port=80, password='******', extra="{'key':'string'}") session.add(connection_model) session.commit() connection_model = session.query(Connection).first() deserialized_connection = connection_schema.dump(connection_model) self.assertEqual( deserialized_connection[0], { 'connection_id': "mysql_default", 'conn_type': 'mysql', 'host': 'mysql', 'login': '******', 'schema': 'testschema', 'port': 80, 'extra': "{'key':'string'}" })
def create_redshift_connection(dwh): """ Creates and configures a Redshift connection. Parameters ---------- dwh: dict A dictionary with the required parameters. Returns ------- conn: airflow.models.Connection An instance of airflow.models.Connection. """ conn = Connection(conn_id="redshift", conn_type="postgresql", host=dwh['redshift']['host'], schema=dwh['redshift']['db_name'], login=dwh['redshift']['db_user'], password=dwh['redshift']['db_pass'], port=dwh['redshift']['db_port']) return conn
def test_connection_extra_with_encryption_rotate_fernet_key(self): """ Tests rotating encrypted extras. """ key1 = Fernet.generate_key() key2 = Fernet.generate_key() with conf_vars({('core', 'fernet_key'): key1.decode()}): test_connection = Connection(extra='testextra') assert test_connection.is_extra_encrypted assert test_connection.extra == 'testextra' assert Fernet(key1).decrypt(test_connection._extra.encode()) == b'testextra' # Test decrypt of old value with new key with conf_vars({('core', 'fernet_key'): ','.join([key2.decode(), key1.decode()])}): crypto._fernet = None assert test_connection.extra == 'testextra' # Test decrypt of new value with new key test_connection.rotate_fernet_key() assert test_connection.is_extra_encrypted assert test_connection.extra == 'testextra' assert Fernet(key2).decrypt(test_connection._extra.encode()) == b'testextra'
def test_run_example_gcp_vision_autogenerated_id_dag(self): mock_connection = Connection( conn_type="aws", extra=json.dumps({ "role_arn": ROLE_ANR, "assume_role_method": "assume_role_with_web_identity", "assume_role_with_web_identity_federation": 'google', "assume_role_with_web_identity_federation_audience": AUDIENCE, }), ) with mock.patch.dict( 'os.environ', AIRFLOW_CONN_AWS_DEFAULT=mock_connection.get_uri()): hook = AwsBaseHook(client_type='s3') client = hook.get_conn() response = client.list_buckets() assert 'Buckets' in response
def provide_facebook_connection(key_file_path: str): """ Context manager that provides a temporary value of AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT connection. It build a new connection that includes path to provided service json, required scopes and project id. :param key_file_path: Path to file with FACEBOOK credentials .json file. :type key_file_path: str """ if not key_file_path.endswith(".json"): raise AirflowException("Use a JSON key file.") with open(key_file_path) as credentials: creds = json.load(credentials) missing_keys = CONFIG_REQUIRED_FIELDS - creds.keys() if missing_keys: message = f"{missing_keys} fields are missing" raise AirflowException(message) conn = Connection(conn_id=FACEBOOK_CONNECTION_ID, conn_type=CONNECTION_TYPE, extra=json.dumps(creds)) with patch_environ( {f"AIRFLOW_CONN_{conn.conn_id.upper()}": conn.get_uri()}): yield
def test_connection_extra_with_encryption_rotate_fernet_key(self): """ Tests rotating encrypted extras. """ key1 = Fernet.generate_key() key2 = Fernet.generate_key() with conf_vars({('core', 'FERNET_KEY'): key1.decode()}): test_connection = Connection(extra='testextra') self.assertTrue(test_connection.is_extra_encrypted) self.assertEqual(test_connection.extra, 'testextra') self.assertEqual(Fernet(key1).decrypt(test_connection._extra.encode()), b'testextra') # Test decrypt of old value with new key with conf_vars({('core', 'FERNET_KEY'): ','.join([key2.decode(), key1.decode()])}): crypto._fernet = None self.assertEqual(test_connection.extra, 'testextra') # Test decrypt of new value with new key test_connection.rotate_fernet_key() self.assertTrue(test_connection.is_extra_encrypted) self.assertEqual(test_connection.extra, 'testextra') self.assertEqual(Fernet(key2).decrypt(test_connection._extra.encode()), b'testextra')
def test_create_operator_with_correct_parameters_mysql_proxy_socket( self, get_connections): connection = Connection() connection.parse_from_uri( "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type=mysql&" "project_id=example-project&location=europe-west1&instance=testdb&" "use_proxy=True&sql_proxy_use_tcp=False") get_connections.return_value = [connection] operator = CloudSqlQueryOperator(sql=['SELECT * FROM TABLE'], task_id='task_id') operator.cloudsql_db_hook.create_connection() try: db_hook = operator.cloudsql_db_hook.get_database_hook() conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] finally: operator.cloudsql_db_hook.delete_connection() self.assertEqual('mysql', conn.conn_type) self.assertEqual('localhost', conn.host) self.assertIn('/tmp', conn.extra_dejson['unix_socket']) self.assertIn('example-project:europe-west1:testdb', conn.extra_dejson['unix_socket']) self.assertIsNone(conn.port) self.assertEqual('testdb', conn.schema)
def test_should_response_200(self, session): connection_model = Connection(conn_id='test-connection-id', conn_type='mysql', host='mysql', login='******', schema='testschema', port=80) session.add(connection_model) session.commit() result = session.query(Connection).all() assert len(result) == 1 response = self.client.get("/api/v1/connections/test-connection-id") assert response.status_code == 200 self.assertEqual( response.json, { "connection_id": "test-connection-id", "conn_type": 'mysql', "host": 'mysql', "login": '******', 'schema': 'testschema', 'port': 80 }, )
def create_postgres_connection(conn_id, host, schema, login, password, port=5432, **kwargs): """ Creates a Postgres Connection (for Hooks to use) for the Airflow session :param conn_id: Str - name of Airflow Connection :param host: Str - db host :param schema: Str - db name :param login: Str - db login :param password: Str - db password :param port: Str - db port :param kwargs: Dict - keyword arguments :return: None """ logging.info('Creating Postgres Connection...') # Create connection to our Postgres instance pg_connection = Connection(conn_id=conn_id, conn_type='postgres', host=host, schema=schema, login=login, password=password, port=port) # Add the Connection to the Airflow session session = settings.Session() session.add(pg_connection) session.commit() logging.info('Successfully created Postgres Connection') session.close()
def test_connection() -> APIResponse: """ To test a connection, this method first creates an in-memory dummy conn_id & exports that to an env var, as some hook classes tries to find out the conn from their __init__ method & errors out if not found. It also deletes the conn id env variable after the test. """ body = request.json dummy_conn_id = get_random_string() conn_env_var = f'{CONN_ENV_PREFIX}{dummy_conn_id.upper()}' try: data = connection_schema.load(body) data['conn_id'] = dummy_conn_id conn = Connection(**data) os.environ[conn_env_var] = conn.get_uri() status, message = conn.test_connection() return connection_test_schema.dump({ "status": status, "message": message }) except ValidationError as err: raise BadRequest(detail=str(err.messages)) finally: if conn_env_var in os.environ: del os.environ[conn_env_var]
def add_docker_connection(ds, **kwargs): """"Add a airflow connection for google container registry""" new_conn = Connection( conn_id="gcr_docker_connection", # TODO: parameterize conn_type="docker", host="gcr.io/wam-bam-258119", # TODO: parameterize login="******", # TODO: parameterize ) # save contents of service account key into encrypted password field with open("service_account.json", "r") as file: data = file.read().replace("\n", "") # replace new lines new_conn.set_password(data) session = settings.Session() if not ( session.query(Connection).filter(Connection.conn_id == new_conn.conn_id).first() ): session.add(new_conn) session.commit() else: msg = "\n\tA connection with `conn_id`={conn_id} already exists\n" msg = msg.format(conn_id=new_conn.conn_id) print(msg)
def create_default_connections(session=None): """Create default Airflow connections.""" merge_conn( Connection( conn_id="airflow_db", conn_type="mysql", host="mysql", login="******", password="", schema="airflow", ), session, ) merge_conn( Connection( conn_id="aws_default", conn_type="aws", ), session, ) merge_conn( Connection( conn_id="azure_batch_default", conn_type="azure_batch", login="******", password="", extra='''{"account_url": "<ACCOUNT_URL>"}''', )) merge_conn( Connection( conn_id="azure_container_instances_default", conn_type="azure_container_instances", extra= '{"tenantId": "<TENANT>", "subscriptionId": "<SUBSCRIPTION ID>" }', ), session, ) merge_conn( Connection( conn_id="azure_cosmos_default", conn_type="azure_cosmos", extra= '{"database_name": "<DATABASE_NAME>", "collection_name": "<COLLECTION_NAME>" }', ), session, ) merge_conn( Connection( conn_id='azure_data_explorer_default', conn_type='azure_data_explorer', host='https://<CLUSTER>.kusto.windows.net', extra= '''{"auth_method": "<AAD_APP | AAD_APP_CERT | AAD_CREDS | AAD_DEVICE>", "tenant": "<TENANT ID>", "certificate": "<APPLICATION PEM CERTIFICATE>", "thumbprint": "<APPLICATION CERTIFICATE THUMBPRINT>"}''', ), session, ) merge_conn( Connection( conn_id="azure_data_lake_default", conn_type="azure_data_lake", extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }', ), session, ) merge_conn( Connection( conn_id="cassandra_default", conn_type="cassandra", host="cassandra", port=9042, ), session, ) merge_conn( Connection( conn_id="databricks_default", conn_type="databricks", host="localhost", ), session, ) merge_conn( Connection( conn_id="dingding_default", conn_type="http", host="", password="", ), session, ) merge_conn( Connection( conn_id="druid_broker_default", conn_type="druid", host="druid-broker", port=8082, extra='{"endpoint": "druid/v2/sql"}', ), session, ) merge_conn( Connection( conn_id="druid_ingest_default", conn_type="druid", host="druid-overlord", port=8081, extra='{"endpoint": "druid/indexer/v1/task"}', ), session, ) merge_conn( Connection( conn_id="elasticsearch_default", conn_type="elasticsearch", host="localhost", schema="http", port=9200, ), session, ) merge_conn( Connection( conn_id="emr_default", conn_type="emr", extra=""" { "Name": "default_job_flow_name", "LogUri": "s3://my-emr-log-bucket/default_job_flow_location", "ReleaseLabel": "emr-4.6.0", "Instances": { "Ec2KeyName": "mykey", "Ec2SubnetId": "somesubnet", "InstanceGroups": [ { "Name": "Master nodes", "Market": "ON_DEMAND", "InstanceRole": "MASTER", "InstanceType": "r3.2xlarge", "InstanceCount": 1 }, { "Name": "Slave nodes", "Market": "ON_DEMAND", "InstanceRole": "CORE", "InstanceType": "r3.2xlarge", "InstanceCount": 1 } ], "TerminationProtected": false, "KeepJobFlowAliveWhenNoSteps": false }, "Applications":[ { "Name": "Spark" } ], "VisibleToAllUsers": true, "JobFlowRole": "EMR_EC2_DefaultRole", "ServiceRole": "EMR_DefaultRole", "Tags": [ { "Key": "app", "Value": "analytics" }, { "Key": "environment", "Value": "development" } ] } """, ), session, ) merge_conn( Connection( conn_id="facebook_default", conn_type="facebook_social", extra=""" { "account_id": "<AD_ACCOUNT_ID>", "app_id": "<FACEBOOK_APP_ID>", "app_secret": "<FACEBOOK_APP_SECRET>", "access_token": "<FACEBOOK_AD_ACCESS_TOKEN>" } """, ), session, ) merge_conn( Connection( conn_id="fs_default", conn_type="fs", extra='{"path": "/"}', ), session, ) merge_conn( Connection( conn_id="google_cloud_default", conn_type="google_cloud_platform", schema="default", ), session, ) merge_conn( Connection( conn_id="hive_cli_default", conn_type="hive_cli", port=10000, host="localhost", extra='{"use_beeline": true, "auth": ""}', schema="default", ), session, ) merge_conn( Connection( conn_id="hiveserver2_default", conn_type="hiveserver2", host="localhost", schema="default", port=10000, ), session, ) merge_conn( Connection( conn_id="http_default", conn_type="http", host="https://www.httpbin.org/", ), session, ) merge_conn( Connection( conn_id='kubernetes_default', conn_type='kubernetes', ), session, ) merge_conn( Connection( conn_id='kylin_default', conn_type='kylin', host='localhost', port=7070, login="******", password="******", ), session, ) merge_conn( Connection( conn_id="leveldb_default", conn_type="leveldb", host="localhost", ), session, ) merge_conn( Connection(conn_id="livy_default", conn_type="livy", host="livy", port=8998), session) merge_conn( Connection( conn_id="local_mysql", conn_type="mysql", host="localhost", login="******", password="******", schema="airflow", ), session, ) merge_conn( Connection( conn_id="metastore_default", conn_type="hive_metastore", host="localhost", extra='{"authMechanism": "PLAIN"}', port=9083, ), session, ) merge_conn( Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017), session) merge_conn( Connection( conn_id="mssql_default", conn_type="mssql", host="localhost", port=1433, ), session, ) merge_conn( Connection( conn_id="mysql_default", conn_type="mysql", login="******", schema="airflow", host="mysql", ), session, ) merge_conn( Connection( conn_id="opsgenie_default", conn_type="http", host="", password="", ), session, ) merge_conn( Connection( conn_id="pig_cli_default", conn_type="pig_cli", schema="default", ), session, ) merge_conn( Connection( conn_id="pinot_admin_default", conn_type="pinot", host="localhost", port=9000, ), session, ) merge_conn( Connection( conn_id="pinot_broker_default", conn_type="pinot", host="localhost", port=9000, extra='{"endpoint": "/query", "schema": "http"}', ), session, ) merge_conn( Connection( conn_id="postgres_default", conn_type="postgres", login="******", password="******", schema="airflow", host="postgres", ), session, ) merge_conn( Connection( conn_id="presto_default", conn_type="presto", host="localhost", schema="hive", port=3400, ), session, ) merge_conn( Connection( conn_id="qubole_default", conn_type="qubole", host="localhost", ), session, ) merge_conn( Connection( conn_id="redis_default", conn_type="redis", host="redis", port=6379, extra='{"db": 0}', ), session, ) merge_conn( Connection( conn_id="segment_default", conn_type="segment", extra='{"write_key": "my-segment-write-key"}', ), session, ) merge_conn( Connection( conn_id="sftp_default", conn_type="sftp", host="localhost", port=22, login="******", extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}', ), session, ) merge_conn( Connection( conn_id="spark_default", conn_type="spark", host="yarn", extra='{"queue": "root.default"}', ), session, ) merge_conn( Connection( conn_id="sqlite_default", conn_type="sqlite", host="/tmp/sqlite_default.db", ), session, ) merge_conn( Connection( conn_id="sqoop_default", conn_type="sqoop", host="rdbms", ), session, ) merge_conn( Connection( conn_id="ssh_default", conn_type="ssh", host="localhost", ), session, ) merge_conn( Connection( conn_id="tableau_default", conn_type="tableau", host="https://tableau.server.url", login="******", password="******", extra='{"site_id": "my_site"}', ), session, ) merge_conn( Connection( conn_id="vertica_default", conn_type="vertica", host="localhost", port=5433, ), session, ) merge_conn( Connection( conn_id="wasb_default", conn_type="wasb", extra='{"sas_token": null}', ), session, ) merge_conn( Connection( conn_id="webhdfs_default", conn_type="hdfs", host="localhost", port=50070, ), session, ) merge_conn( Connection( conn_id='yandexcloud_default', conn_type='yandexcloud', schema='default', ), session, )
# software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import mock from parameterized import parameterized from airflow.models import Connection from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook from airflow.providers.microsoft.mssql.operators.mssql import MsSqlOperator from airflow.providers.odbc.hooks.odbc import OdbcHook ODBC_CONN = Connection( conn_id='test-odbc', conn_type='odbc', ) PYMSSQL_CONN = Connection( conn_id='test-pymssql', conn_type='anything', ) class TestMsSqlOperator: @parameterized.expand([(ODBC_CONN, OdbcHook), (PYMSSQL_CONN, MsSqlHook)]) @mock.patch('airflow.hooks.base_hook.BaseHook.get_connection') def test_get_hook(self, conn, hook_class, get_connection): """ Operator should use odbc hook if conn type is ``odbc`` and pymssql-based hook otherwise. """
def _get_connection_from_env(cls, conn_id): environment_uri = os.environ.get(CONN_ENV_PREFIX + conn_id.upper()) conn = None if environment_uri: conn = Connection(conn_id=conn_id, uri=environment_uri) return conn