def test_create_operator_with_wrong_parameters(self, project_id, location, instance_name, database_type, use_proxy, use_ssl, sql, message, get_connections): connection = Connection() connection.parse_from_uri( "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type={database_type}&" "project_id={project_id}&location={location}&instance={instance_name}&" "use_proxy={use_proxy}&use_ssl={use_ssl}". format(database_type=database_type, project_id=project_id, location=location, instance_name=instance_name, use_proxy=use_proxy, use_ssl=use_ssl)) get_connections.return_value = [connection] with self.assertRaises(AirflowException) as cm: CloudSqlQueryOperator( sql=sql, task_id='task_id' ) err = cm.exception self.assertIn(message, str(err))
def test_create_operator_with_correct_parameters_mysql_ssl(self, get_connections): connection = Connection() connection.parse_from_uri( "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type=mysql&" "project_id=example-project&location=europe-west1&instance=testdb&" "use_proxy=False&use_ssl=True&sslcert=/bin/bash&" "sslkey=/bin/bash&sslrootcert=/bin/bash") get_connections.return_value = [connection] operator = CloudSqlQueryOperator( sql=['SELECT * FROM TABLE'], task_id='task_id' ) operator.cloudsql_db_hook.create_connection() try: db_hook = operator.cloudsql_db_hook.get_database_hook() conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] finally: operator.cloudsql_db_hook.delete_connection() self.assertEqual('mysql', conn.conn_type) self.assertEqual('8.8.8.8', conn.host) self.assertEqual(3200, conn.port) self.assertEqual('testdb', conn.schema) self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['cert']) self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['key']) self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['ca'])
def test_create_operator_with_correct_parameters_mysql_proxy_socket(self, get_connections): connection = Connection() connection.parse_from_uri( "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type=mysql&" "project_id=example-project&location=europe-west1&instance=testdb&" "use_proxy=True&sql_proxy_use_tcp=False") get_connections.return_value = [connection] operator = CloudSqlQueryOperator( sql=['SELECT * FROM TABLE'], task_id='task_id' ) operator.cloudsql_db_hook.create_connection() try: db_hook = operator.cloudsql_db_hook.get_database_hook() conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] finally: operator.cloudsql_db_hook.delete_connection() self.assertEqual('mysql', conn.conn_type) self.assertEqual('localhost', conn.host) self.assertIn('/tmp', conn.extra_dejson['unix_socket']) self.assertIn('example-project:europe-west1:testdb', conn.extra_dejson['unix_socket']) self.assertIsNone(conn.port) self.assertEqual('testdb', conn.schema)
def create_connection(self, session=None): """ Create connection in the Connection table, according to whether it uses proxy, TCP, UNIX sockets, SSL. Connection ID will be randomly generated. :param session: Session of the SQL Alchemy ORM (automatically generated with decorator). """ connection = Connection(conn_id=self.db_conn_id) uri = self._generate_connection_uri() self.log.info("Creating connection {}".format(self.db_conn_id)) connection.parse_from_uri(uri) session.add(connection) session.commit()
def test_cloudsql_hook_delete_connection_on_exception( self, get_connections, run, get_connection, delete_connection): connection = Connection() connection.parse_from_uri( "gcpcloudsql://*****:*****@127.0.0.1:3200/testdb?database_type=mysql&" "project_id=example-project&location=europe-west1&instance=testdb&" "use_proxy=False") get_connection.return_value = connection db_connection = Connection() db_connection.host = "127.0.0.1" db_connection.set_extra(json.dumps({"project_id": "example-project", "location": "europe-west1", "instance": "testdb", "database_type": "mysql"})) get_connections.return_value = [db_connection] run.side_effect = Exception("Exception when running a query") operator = CloudSqlQueryOperator( sql=['SELECT * FROM TABLE'], task_id='task_id' ) with self.assertRaises(Exception) as cm: operator.execute(None) err = cm.exception self.assertEqual("Exception when running a query", str(err)) delete_connection.assert_called_once_with()
def registered(self, driver, frameworkId, masterInfo): logging.info("AirflowScheduler registered to mesos with framework ID %s", frameworkId.value) if configuration.getboolean('mesos', 'CHECKPOINT') and configuration.get('mesos', 'FAILOVER_TIMEOUT'): # Import here to work around a circular import error from airflow.models import Connection # Update the Framework ID in the database. session = Session() conn_id = FRAMEWORK_CONNID_PREFIX + get_framework_name() connection = Session.query(Connection).filter_by(conn_id=conn_id).first() if connection is None: connection = Connection(conn_id=conn_id, conn_type='mesos_framework-id', extra=frameworkId.value) else: connection.extra = frameworkId.value session.add(connection) session.commit() Session.remove()
def test_create_operator_with_correct_parameters_postgres(self, get_connections): connection = Connection() connection.parse_from_uri( "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type=postgres&" "project_id=example-project&location=europe-west1&instance=testdb&" "use_proxy=False&use_ssl=False") get_connections.return_value = [connection] operator = CloudSqlQueryOperator( sql=['SELECT * FROM TABLE'], task_id='task_id' ) operator.cloudsql_db_hook.create_connection() try: db_hook = operator.cloudsql_db_hook.get_database_hook() conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] finally: operator.cloudsql_db_hook.delete_connection() self.assertEqual('postgres', conn.conn_type) self.assertEqual('8.8.8.8', conn.host) self.assertEqual(3200, conn.port) self.assertEqual('testdb', conn.schema)
def test_connection_extra_with_encryption_rotate_fernet_key(self, mock_get): """ Tests rotating encrypted extras. """ key1 = Fernet.generate_key() key2 = Fernet.generate_key() mock_get.return_value = key1.decode() test_connection = Connection(extra='testextra') self.assertTrue(test_connection.is_extra_encrypted) self.assertEqual(test_connection.extra, 'testextra') self.assertEqual(Fernet(key1).decrypt(test_connection._extra.encode()), b'testextra') # Test decrypt of old value with new key mock_get.return_value = ','.join([key2.decode(), key1.decode()]) crypto._fernet = None self.assertEqual(test_connection.extra, 'testextra') # Test decrypt of new value with new key test_connection.rotate_fernet_key() self.assertTrue(test_connection.is_extra_encrypted) self.assertEqual(test_connection.extra, 'testextra') self.assertEqual(Fernet(key2).decrypt(test_connection._extra.encode()), b'testextra')
def _setup_connections(get_connections, uri): gcp_connection = mock.MagicMock() gcp_connection.extra_dejson = mock.MagicMock() gcp_connection.extra_dejson.get.return_value = 'empty_project' cloudsql_connection = Connection() cloudsql_connection.parse_from_uri(uri) cloudsql_connection2 = Connection() cloudsql_connection2.parse_from_uri(uri) get_connections.side_effect = [[gcp_connection], [cloudsql_connection], [cloudsql_connection2]]
def setUp(self): db.merge_conn(Connection(conn_id=DEFAULT_CONN, conn_type='HTTP')) db.merge_conn( Connection(conn_id=TEST_CONN, conn_type='HTTP', host='http://localhost/api'))
class TestAwsS3Hook: @mock_s3 def test_get_conn(self): hook = S3Hook() assert hook.get_conn() is not None @mock_s3 def test_use_threads_default_value(self): hook = S3Hook() assert hook.transfer_config.use_threads is True @mock_s3 def test_use_threads_set_value(self): hook = S3Hook(transfer_config_args={"use_threads": False}) assert hook.transfer_config.use_threads is False def test_parse_s3_url(self): parsed = S3Hook.parse_s3_url("s3://test/this/is/not/a-real-key.txt") assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url" def test_check_for_bucket(self, s3_bucket): hook = S3Hook() assert hook.check_for_bucket(s3_bucket) is True assert hook.check_for_bucket('not-a-bucket') is False def test_check_for_bucket_raises_error_with_invalid_conn_id(self, s3_bucket, monkeypatch): monkeypatch.delenv('AWS_PROFILE', raising=False) monkeypatch.delenv('AWS_ACCESS_KEY_ID', raising=False) monkeypatch.delenv('AWS_SECRET_ACCESS_KEY', raising=False) hook = S3Hook(aws_conn_id="does_not_exist") with pytest.raises(NoCredentialsError): hook.check_for_bucket(s3_bucket) @mock_s3 def test_get_bucket(self): hook = S3Hook() assert hook.get_bucket('bucket') is not None @mock_s3 def test_create_bucket_default_region(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket') assert hook.get_bucket('new_bucket') is not None @mock_s3 def test_create_bucket_us_standard_region(self, monkeypatch): monkeypatch.delenv('AWS_DEFAULT_REGION', raising=False) hook = S3Hook() hook.create_bucket(bucket_name='new_bucket', region_name='us-east-1') bucket = hook.get_bucket('new_bucket') assert bucket is not None region = bucket.meta.client.get_bucket_location(Bucket=bucket.name).get('LocationConstraint') # https://github.com/spulec/moto/pull/1961 # If location is "us-east-1", LocationConstraint should be None assert region is None @mock_s3 def test_create_bucket_other_region(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket', region_name='us-east-2') bucket = hook.get_bucket('new_bucket') assert bucket is not None region = bucket.meta.client.get_bucket_location(Bucket=bucket.name).get('LocationConstraint') assert region == 'us-east-2' def test_check_for_prefix(self, s3_bucket): hook = S3Hook() bucket = hook.get_bucket(s3_bucket) bucket.put_object(Key='a', Body=b'a') bucket.put_object(Key='dir/b', Body=b'b') assert hook.check_for_prefix(bucket_name=s3_bucket, prefix='dir/', delimiter='/') is True assert hook.check_for_prefix(bucket_name=s3_bucket, prefix='a', delimiter='/') is False def test_list_prefixes(self, s3_bucket): hook = S3Hook() bucket = hook.get_bucket(s3_bucket) bucket.put_object(Key='a', Body=b'a') bucket.put_object(Key='dir/b', Body=b'b') assert [] == hook.list_prefixes(s3_bucket, prefix='non-existent/') assert ['dir/'] == hook.list_prefixes(s3_bucket, delimiter='/') assert ['a'] == hook.list_keys(s3_bucket, delimiter='/') assert ['dir/b'] == hook.list_keys(s3_bucket, prefix='dir/') def test_list_prefixes_paged(self, s3_bucket): hook = S3Hook() bucket = hook.get_bucket(s3_bucket) # we don't need to test the paginator that's covered by boto tests keys = [f"{i}/b" for i in range(2)] dirs = [f"{i}/" for i in range(2)] for key in keys: bucket.put_object(Key=key, Body=b'a') assert sorted(dirs) == sorted(hook.list_prefixes(s3_bucket, delimiter='/', page_size=1)) def test_list_keys(self, s3_bucket): hook = S3Hook() bucket = hook.get_bucket(s3_bucket) bucket.put_object(Key='a', Body=b'a') bucket.put_object(Key='dir/b', Body=b'b') assert [] == hook.list_keys(s3_bucket, prefix='non-existent/') assert ['a', 'dir/b'] == hook.list_keys(s3_bucket) assert ['a'] == hook.list_keys(s3_bucket, delimiter='/') assert ['dir/b'] == hook.list_keys(s3_bucket, prefix='dir/') def test_list_keys_paged(self, s3_bucket): hook = S3Hook() bucket = hook.get_bucket(s3_bucket) keys = [str(i) for i in range(2)] for key in keys: bucket.put_object(Key=key, Body=b'a') assert sorted(keys) == sorted(hook.list_keys(s3_bucket, delimiter='/', page_size=1)) def test_check_for_key(self, s3_bucket): hook = S3Hook() bucket = hook.get_bucket(s3_bucket) bucket.put_object(Key='a', Body=b'a') assert hook.check_for_key('a', s3_bucket) is True assert hook.check_for_key(f's3://{s3_bucket}//a') is True assert hook.check_for_key('b', s3_bucket) is False assert hook.check_for_key(f's3://{s3_bucket}//b') is False def test_check_for_key_raises_error_with_invalid_conn_id(self, monkeypatch, s3_bucket): monkeypatch.delenv('AWS_PROFILE', raising=False) monkeypatch.delenv('AWS_ACCESS_KEY_ID', raising=False) monkeypatch.delenv('AWS_SECRET_ACCESS_KEY', raising=False) hook = S3Hook(aws_conn_id="does_not_exist") with pytest.raises(NoCredentialsError): hook.check_for_key('a', s3_bucket) def test_get_key(self, s3_bucket): hook = S3Hook() bucket = hook.get_bucket(s3_bucket) bucket.put_object(Key='a', Body=b'a') assert hook.get_key('a', s3_bucket).key == 'a' assert hook.get_key(f's3://{s3_bucket}/a').key == 'a' def test_read_key(self, s3_bucket): hook = S3Hook() bucket = hook.get_bucket(s3_bucket) bucket.put_object(Key='my_key', Body=b'Cont\xC3\xA9nt') assert hook.read_key('my_key', s3_bucket) == 'Contént' # As of 1.3.2, Moto doesn't support select_object_content yet. @mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_client_type') def test_select_key(self, mock_get_client_type, s3_bucket): mock_get_client_type.return_value.select_object_content.return_value = { 'Payload': [{'Records': {'Payload': b'Cont\xC3\xA9nt'}}] } hook = S3Hook() assert hook.select_key('my_key', s3_bucket) == 'Contént' def test_check_for_wildcard_key(self, s3_bucket): hook = S3Hook() bucket = hook.get_bucket(s3_bucket) bucket.put_object(Key='abc', Body=b'a') bucket.put_object(Key='a/b', Body=b'a') assert hook.check_for_wildcard_key('a*', s3_bucket) is True assert hook.check_for_wildcard_key('abc', s3_bucket) is True assert hook.check_for_wildcard_key(f's3://{s3_bucket}//a*') is True assert hook.check_for_wildcard_key(f's3://{s3_bucket}//abc') is True assert hook.check_for_wildcard_key('a', s3_bucket) is False assert hook.check_for_wildcard_key('b', s3_bucket) is False assert hook.check_for_wildcard_key(f's3://{s3_bucket}//a') is False assert hook.check_for_wildcard_key(f's3://{s3_bucket}//b') is False def test_get_wildcard_key(self, s3_bucket): hook = S3Hook() bucket = hook.get_bucket(s3_bucket) bucket.put_object(Key='abc', Body=b'a') bucket.put_object(Key='a/b', Body=b'a') # The boto3 Class API is _odd_, and we can't do an isinstance check as # each instance is a different class, so lets just check one property # on S3.Object. Not great but... assert hook.get_wildcard_key('a*', s3_bucket).key == 'a/b' assert hook.get_wildcard_key('a*', s3_bucket, delimiter='/').key == 'abc' assert hook.get_wildcard_key('abc', s3_bucket, delimiter='/').key == 'abc' assert hook.get_wildcard_key(f's3://{s3_bucket}/a*').key == 'a/b' assert hook.get_wildcard_key(f's3://{s3_bucket}/a*', delimiter='/').key == 'abc' assert hook.get_wildcard_key(f's3://{s3_bucket}/abc', delimiter='/').key == 'abc' assert hook.get_wildcard_key('a', s3_bucket) is None assert hook.get_wildcard_key('b', s3_bucket) is None assert hook.get_wildcard_key(f's3://{s3_bucket}/a') is None assert hook.get_wildcard_key(f's3://{s3_bucket}/b') is None def test_load_string(self, s3_bucket): hook = S3Hook() hook.load_string("Contént", "my_key", s3_bucket) resource = boto3.resource('s3').Object(s3_bucket, 'my_key') # pylint: disable=no-member assert resource.get()['Body'].read() == b'Cont\xC3\xA9nt' def test_load_string_compress(self, s3_bucket): hook = S3Hook() hook.load_string("Contént", "my_key", s3_bucket, compression='gzip') resource = boto3.resource('s3').Object(s3_bucket, 'my_key') # pylint: disable=no-member data = gz.decompress(resource.get()['Body'].read()) assert data == b'Cont\xC3\xA9nt' def test_load_string_compress_exception(self, s3_bucket): hook = S3Hook() with pytest.raises(NotImplementedError): hook.load_string("Contént", "my_key", s3_bucket, compression='bad-compression') def test_load_string_acl(self, s3_bucket): hook = S3Hook() hook.load_string("Contént", "my_key", s3_bucket, acl_policy='public-read') response = boto3.client('s3').get_object_acl(Bucket=s3_bucket, Key="my_key", RequestPayer='requester') assert (response['Grants'][1]['Permission'] == 'READ') and ( response['Grants'][0]['Permission'] == 'FULL_CONTROL' ) def test_load_bytes(self, s3_bucket): hook = S3Hook() hook.load_bytes(b"Content", "my_key", s3_bucket) resource = boto3.resource('s3').Object(s3_bucket, 'my_key') # pylint: disable=no-member assert resource.get()['Body'].read() == b'Content' def test_load_bytes_acl(self, s3_bucket): hook = S3Hook() hook.load_bytes(b"Content", "my_key", s3_bucket, acl_policy='public-read') response = boto3.client('s3').get_object_acl(Bucket=s3_bucket, Key="my_key", RequestPayer='requester') assert (response['Grants'][1]['Permission'] == 'READ') and ( response['Grants'][0]['Permission'] == 'FULL_CONTROL' ) def test_load_fileobj(self, s3_bucket): hook = S3Hook() with tempfile.TemporaryFile() as temp_file: temp_file.write(b"Content") temp_file.seek(0) hook.load_file_obj(temp_file, "my_key", s3_bucket) resource = boto3.resource('s3').Object(s3_bucket, 'my_key') # pylint: disable=no-member assert resource.get()['Body'].read() == b'Content' def test_load_fileobj_acl(self, s3_bucket): hook = S3Hook() with tempfile.TemporaryFile() as temp_file: temp_file.write(b"Content") temp_file.seek(0) hook.load_file_obj(temp_file, "my_key", s3_bucket, acl_policy='public-read') response = boto3.client('s3').get_object_acl( Bucket=s3_bucket, Key="my_key", RequestPayer='requester' ) # pylint: disable=no-member # noqa: E501 # pylint: disable=C0301 assert (response['Grants'][1]['Permission'] == 'READ') and ( response['Grants'][0]['Permission'] == 'FULL_CONTROL' ) def test_load_file_gzip(self, s3_bucket): hook = S3Hook() with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(b"Content") temp_file.seek(0) hook.load_file(temp_file.name, "my_key", s3_bucket, gzip=True) resource = boto3.resource('s3').Object(s3_bucket, 'my_key') # pylint: disable=no-member assert gz.decompress(resource.get()['Body'].read()) == b'Content' os.unlink(temp_file.name) def test_load_file_acl(self, s3_bucket): hook = S3Hook() with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(b"Content") temp_file.seek(0) hook.load_file(temp_file.name, "my_key", s3_bucket, gzip=True, acl_policy='public-read') response = boto3.client('s3').get_object_acl( Bucket=s3_bucket, Key="my_key", RequestPayer='requester' ) # pylint: disable=no-member # noqa: E501 # pylint: disable=C0301 assert (response['Grants'][1]['Permission'] == 'READ') and ( response['Grants'][0]['Permission'] == 'FULL_CONTROL' ) os.unlink(temp_file.name) def test_copy_object_acl(self, s3_bucket): hook = S3Hook() with tempfile.NamedTemporaryFile() as temp_file: temp_file.write(b"Content") temp_file.seek(0) hook.load_file_obj(temp_file, "my_key", s3_bucket) hook.copy_object("my_key", "my_key", s3_bucket, s3_bucket) response = boto3.client('s3').get_object_acl( Bucket=s3_bucket, Key="my_key", RequestPayer='requester' ) # pylint: disable=no-member # noqa: E501 # pylint: disable=C0301 assert (response['Grants'][0]['Permission'] == 'FULL_CONTROL') and (len(response['Grants']) == 1) @mock_s3 def test_delete_bucket_if_bucket_exist(self, s3_bucket): # assert if the bucket is created mock_hook = S3Hook() mock_hook.create_bucket(bucket_name=s3_bucket) assert mock_hook.check_for_bucket(bucket_name=s3_bucket) mock_hook.delete_bucket(bucket_name=s3_bucket, force_delete=True) assert not mock_hook.check_for_bucket(s3_bucket) @mock_s3 def test_delete_bucket_if_not_bucket_exist(self, s3_bucket): # assert if exception is raised if bucket not present mock_hook = S3Hook() with pytest.raises(ClientError) as ctx: assert mock_hook.delete_bucket(bucket_name=s3_bucket, force_delete=True) assert ctx.value.response['Error']['Code'] == 'NoSuchBucket' @mock.patch.object(S3Hook, 'get_connection', return_value=Connection(schema='test_bucket')) def test_provide_bucket_name(self, mock_get_connection): class FakeS3Hook(S3Hook): @provide_bucket_name def test_function(self, bucket_name=None): return bucket_name fake_s3_hook = FakeS3Hook() test_bucket_name = fake_s3_hook.test_function() assert test_bucket_name == mock_get_connection.return_value.schema test_bucket_name = fake_s3_hook.test_function(bucket_name='bucket') assert test_bucket_name == 'bucket' def test_delete_objects_key_does_not_exist(self, s3_bucket): hook = S3Hook() with pytest.raises(AirflowException) as ctx: hook.delete_objects(bucket=s3_bucket, keys=['key-1']) assert isinstance(ctx.value, AirflowException) assert str(ctx.value) == "Errors when deleting: ['key-1']" def test_delete_objects_one_key(self, mocked_s3_res, s3_bucket): key = 'key-1' mocked_s3_res.Object(s3_bucket, key).put(Body=b'Data') hook = S3Hook() hook.delete_objects(bucket=s3_bucket, keys=[key]) assert [o.key for o in mocked_s3_res.Bucket(s3_bucket).objects.all()] == [] def test_delete_objects_many_keys(self, mocked_s3_res, s3_bucket): num_keys_to_remove = 1001 keys = [] for index in range(num_keys_to_remove): key = f'key-{index}' mocked_s3_res.Object(s3_bucket, key).put(Body=b'Data') keys.append(key) assert sum(1 for _ in mocked_s3_res.Bucket(s3_bucket).objects.all()) == num_keys_to_remove hook = S3Hook() hook.delete_objects(bucket=s3_bucket, keys=keys) assert [o.key for o in mocked_s3_res.Bucket(s3_bucket).objects.all()] == [] def test_unify_bucket_name_and_key(self): class FakeS3Hook(S3Hook): @unify_bucket_name_and_key def test_function_with_wildcard_key(self, wildcard_key, bucket_name=None): return bucket_name, wildcard_key @unify_bucket_name_and_key def test_function_with_key(self, key, bucket_name=None): return bucket_name, key @unify_bucket_name_and_key def test_function_with_test_key(self, test_key, bucket_name=None): return bucket_name, test_key fake_s3_hook = FakeS3Hook() test_bucket_name_with_wildcard_key = fake_s3_hook.test_function_with_wildcard_key('s3://foo/bar*.csv') assert ('foo', 'bar*.csv') == test_bucket_name_with_wildcard_key test_bucket_name_with_key = fake_s3_hook.test_function_with_key('s3://foo/bar.csv') assert ('foo', 'bar.csv') == test_bucket_name_with_key with pytest.raises(ValueError) as ctx: fake_s3_hook.test_function_with_test_key('s3://foo/bar.csv') assert isinstance(ctx.value, ValueError) @mock.patch('airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile') def test_download_file(self, mock_temp_file): mock_temp_file.return_value.__enter__ = Mock(return_value=mock_temp_file) s3_hook = S3Hook(aws_conn_id='s3_test') s3_hook.check_for_key = Mock(return_value=True) s3_obj = Mock() s3_obj.download_fileobj = Mock(return_value=None) s3_hook.get_key = Mock(return_value=s3_obj) key = 'test_key' bucket = 'test_bucket' s3_hook.download_file(key=key, bucket_name=bucket) s3_hook.check_for_key.assert_called_once_with(key, bucket) s3_hook.get_key.assert_called_once_with(key, bucket) s3_obj.download_fileobj.assert_called_once_with(mock_temp_file) def test_generate_presigned_url(self, s3_bucket): hook = S3Hook() presigned_url = hook.generate_presigned_url( client_method="get_object", params={'Bucket': s3_bucket, 'Key': "my_key"} ) url = presigned_url.split("?")[1] params = {x[0]: x[1] for x in [x.split("=") for x in url[0:].split("&")]} assert {"AWSAccessKeyId", "Signature", "Expires"}.issubset(set(params.keys())) def test_should_throw_error_if_extra_args_is_not_dict(self): with pytest.raises(ValueError): S3Hook(extra_args=1) def test_should_throw_error_if_extra_args_contains_unknown_arg(self, s3_bucket): hook = S3Hook(extra_args={"unknown_s3_args": "value"}) with tempfile.TemporaryFile() as temp_file: temp_file.write(b"Content") temp_file.seek(0) with pytest.raises(ValueError): hook.load_file_obj(temp_file, "my_key", s3_bucket, acl_policy='public-read') def test_should_pass_extra_args(self, s3_bucket): hook = S3Hook(extra_args={"ContentLanguage": "value"}) with tempfile.TemporaryFile() as temp_file: temp_file.write(b"Content") temp_file.seek(0) hook.load_file_obj(temp_file, "my_key", s3_bucket, acl_policy='public-read') resource = boto3.resource('s3').Object(s3_bucket, 'my_key') # pylint: disable=no-member assert resource.get()['ContentLanguage'] == "value" @mock_s3 def test_get_bucket_tagging_no_tags_raises_error(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket') with pytest.raises(ClientError, match=r".*NoSuchTagSet.*"): hook.get_bucket_tagging(bucket_name='new_bucket') @mock_s3 def test_get_bucket_tagging_no_bucket_raises_error(self): hook = S3Hook() with pytest.raises(ClientError, match=r".*NoSuchBucket.*"): hook.get_bucket_tagging(bucket_name='new_bucket') @mock_s3 def test_put_bucket_tagging_with_valid_set(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket') tag_set = [{'Key': 'Color', 'Value': 'Green'}] hook.put_bucket_tagging(bucket_name='new_bucket', tag_set=tag_set) assert hook.get_bucket_tagging(bucket_name='new_bucket') == tag_set @mock_s3 def test_put_bucket_tagging_with_pair(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket') tag_set = [{'Key': 'Color', 'Value': 'Green'}] key = 'Color' value = 'Green' hook.put_bucket_tagging(bucket_name='new_bucket', key=key, value=value) assert hook.get_bucket_tagging(bucket_name='new_bucket') == tag_set @mock_s3 def test_put_bucket_tagging_with_pair_and_set(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket') expected = [{'Key': 'Color', 'Value': 'Green'}, {'Key': 'Fruit', 'Value': 'Apple'}] tag_set = [{'Key': 'Color', 'Value': 'Green'}] key = 'Fruit' value = 'Apple' hook.put_bucket_tagging(bucket_name='new_bucket', tag_set=tag_set, key=key, value=value) result = hook.get_bucket_tagging(bucket_name='new_bucket') assert len(result) == 2 assert result == expected @mock_s3 def test_put_bucket_tagging_with_key_but_no_value_raises_error(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket') key = 'Color' with pytest.raises(ValueError): hook.put_bucket_tagging(bucket_name='new_bucket', key=key) @mock_s3 def test_put_bucket_tagging_with_value_but_no_key_raises_error(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket') value = 'Color' with pytest.raises(ValueError): hook.put_bucket_tagging(bucket_name='new_bucket', value=value) @mock_s3 def test_put_bucket_tagging_with_key_and_set_raises_error(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket') tag_set = [{'Key': 'Color', 'Value': 'Green'}] key = 'Color' with pytest.raises(ValueError): hook.put_bucket_tagging(bucket_name='new_bucket', key=key, tag_set=tag_set) @mock_s3 def test_put_bucket_tagging_with_value_and_set_raises_error(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket') tag_set = [{'Key': 'Color', 'Value': 'Green'}] value = 'Green' with pytest.raises(ValueError): hook.put_bucket_tagging(bucket_name='new_bucket', value=value, tag_set=tag_set) @mock_s3 def test_put_bucket_tagging_when_tags_exist_overwrites(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket') initial_tag_set = [{'Key': 'Color', 'Value': 'Green'}] hook.put_bucket_tagging(bucket_name='new_bucket', tag_set=initial_tag_set) assert len(hook.get_bucket_tagging(bucket_name='new_bucket')) == 1 assert hook.get_bucket_tagging(bucket_name='new_bucket') == initial_tag_set new_tag_set = [{'Key': 'Fruit', 'Value': 'Apple'}] hook.put_bucket_tagging(bucket_name='new_bucket', tag_set=new_tag_set) result = hook.get_bucket_tagging(bucket_name='new_bucket') assert len(result) == 1 assert result == new_tag_set @mock_s3 def test_delete_bucket_tagging(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket') tag_set = [{'Key': 'Color', 'Value': 'Green'}] hook.put_bucket_tagging(bucket_name='new_bucket', tag_set=tag_set) hook.get_bucket_tagging(bucket_name='new_bucket') hook.delete_bucket_tagging(bucket_name='new_bucket') with pytest.raises(ClientError, match=r".*NoSuchTagSet.*"): hook.get_bucket_tagging(bucket_name='new_bucket') @mock_s3 def test_delete_bucket_tagging_with_no_tags(self): hook = S3Hook() hook.create_bucket(bucket_name='new_bucket') hook.delete_bucket_tagging(bucket_name='new_bucket') with pytest.raises(ClientError, match=r".*NoSuchTagSet.*"): hook.get_bucket_tagging(bucket_name='new_bucket')
def test_connection_test_success(self): conn = Connection(conn_id='test_uri', conn_type='sqlite') res = conn.test_connection() assert res[0] is True assert res[1] == 'Connection successfully tested'
def get_airflow_connection_with_port(conn_id=None): return Connection(conn_id='http_default', conn_type='http', host='test.com', port=1234)
def create_default_connections(session=None): merge_conn( Connection(conn_id="airflow_db", conn_type="mysql", host="mysql", login="******", password="", schema="airflow"), session, ) merge_conn( Connection( conn_id="local_mysql", conn_type="mysql", host="localhost", login="******", password="******", schema="airflow", ), session, ) merge_conn( Connection(conn_id="presto_default", conn_type="presto", host="localhost", schema="hive", port=3400), session, ) merge_conn( Connection(conn_id="google_cloud_default", conn_type="google_cloud_platform", schema="default"), session, ) merge_conn( Connection( conn_id="hive_cli_default", conn_type="hive_cli", port=10000, host="localhost", extra='{"use_beeline": true, "auth": ""}', schema="default", ), session, ) merge_conn( Connection(conn_id="pig_cli_default", conn_type="pig_cli", schema="default"), session) merge_conn( Connection( conn_id="hiveserver2_default", conn_type="hiveserver2", host="localhost", schema="default", port=10000, ), session, ) merge_conn( Connection( conn_id="metastore_default", conn_type="hive_metastore", host="localhost", extra='{"authMechanism": "PLAIN"}', port=9083, ), session, ) merge_conn( Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017), session) merge_conn( Connection(conn_id="mysql_default", conn_type="mysql", login="******", schema="airflow", host="mysql"), session, ) merge_conn( Connection( conn_id="postgres_default", conn_type="postgres", login="******", password="******", schema="airflow", host="postgres", ), session, ) merge_conn( Connection(conn_id="sqlite_default", conn_type="sqlite", host="/tmp/sqlite_default.db"), session) merge_conn( Connection(conn_id="http_default", conn_type="http", host="https://www.httpbin.org/"), session) merge_conn( Connection(conn_id="mssql_default", conn_type="mssql", host="localhost", port=1433), session) merge_conn( Connection(conn_id="vertica_default", conn_type="vertica", host="localhost", port=5433), session) merge_conn( Connection(conn_id="wasb_default", conn_type="wasb", extra='{"sas_token": null}'), session) merge_conn( Connection(conn_id="webhdfs_default", conn_type="hdfs", host="localhost", port=50070), session) merge_conn( Connection(conn_id="ssh_default", conn_type="ssh", host="localhost"), session) merge_conn( Connection( conn_id="sftp_default", conn_type="sftp", host="localhost", port=22, login="******", extra=""" {"key_file": "~/.ssh/id_rsa", "no_host_key_check": true} """, ), session, ) merge_conn( Connection(conn_id="fs_default", conn_type="fs", extra='{"path": "/"}'), session) merge_conn(Connection(conn_id="aws_default", conn_type="aws"), session) merge_conn( Connection(conn_id="spark_default", conn_type="spark", host="yarn", extra='{"queue": "root.default"}'), session, ) merge_conn( Connection( conn_id="druid_broker_default", conn_type="druid", host="druid-broker", port=8082, extra='{"endpoint": "druid/v2/sql"}', ), session, ) merge_conn( Connection( conn_id="druid_ingest_default", conn_type="druid", host="druid-overlord", port=8081, extra='{"endpoint": "druid/indexer/v1/task"}', ), session, ) merge_conn( Connection(conn_id="redis_default", conn_type="redis", host="redis", port=6379, extra='{"db": 0}'), session, ) merge_conn( Connection(conn_id="sqoop_default", conn_type="sqoop", host="rmdbs", extra=""), session) merge_conn( Connection( conn_id="emr_default", conn_type="emr", extra=""" { "Name": "default_job_flow_name", "LogUri": "s3://my-emr-log-bucket/default_job_flow_location", "ReleaseLabel": "emr-4.6.0", "Instances": { "Ec2KeyName": "mykey", "Ec2SubnetId": "somesubnet", "InstanceGroups": [ { "Name": "Master nodes", "Market": "ON_DEMAND", "InstanceRole": "MASTER", "InstanceType": "r3.2xlarge", "InstanceCount": 1 }, { "Name": "Slave nodes", "Market": "ON_DEMAND", "InstanceRole": "CORE", "InstanceType": "r3.2xlarge", "InstanceCount": 1 } ], "TerminationProtected": false, "KeepJobFlowAliveWhenNoSteps": false }, "Applications":[ { "Name": "Spark" } ], "VisibleToAllUsers": true, "JobFlowRole": "EMR_EC2_DefaultRole", "ServiceRole": "EMR_DefaultRole", "Tags": [ { "Key": "app", "Value": "analytics" }, { "Key": "environment", "Value": "development" } ] } """, ), session, ) merge_conn( Connection(conn_id="databricks_default", conn_type="databricks", host="localhost"), session) merge_conn( Connection(conn_id="qubole_default", conn_type="qubole", host="localhost"), session) merge_conn( Connection(conn_id="segment_default", conn_type="segment", extra='{"write_key": "my-segment-write-key"}'), session, ), merge_conn( Connection( conn_id="azure_data_lake_default", conn_type="azure_data_lake", extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }', ), session, ) merge_conn( Connection( conn_id="azure_cosmos_default", conn_type="azure_cosmos", extra= '{"database_name": "<DATABASE_NAME>", "collection_name": "<COLLECTION_NAME>" }', ), session, ) merge_conn( Connection( conn_id="azure_container_instances_default", conn_type="azure_container_instances", extra= '{"tenantId": "<TENANT>", "subscriptionId": "<SUBSCRIPTION ID>" }', ), session, ) merge_conn( Connection(conn_id="cassandra_default", conn_type="cassandra", host="cassandra", port=9042), session) merge_conn( Connection(conn_id="dingding_default", conn_type="http", host="", password=""), session) merge_conn( Connection(conn_id="opsgenie_default", conn_type="http", host="", password=""), session) merge_conn( Connection(conn_id="pinot_admin_default", conn_type="pinot", host="localhost", port=9000), session)
def _get_connection_from_env(cls, conn_id): environment_uri = os.environ.get(CONN_ENV_PREFIX + conn_id.upper()) conn = None if environment_uri: conn = Connection(conn_id=conn_id, uri=environment_uri) return conn
def setUp(self): db.merge_conn(Connection(conn_id=DEFAULT_CONN, conn_type='HTTP'))
def setUp(self, mock_batch, mock_hook): # set up the test variable self.test_vm_conn_id = "test_azure_batch_vm2" self.test_cloud_conn_id = "test_azure_batch_cloud2" self.test_account_name = "test_account_name" self.test_account_key = 'test_account_key' self.test_account_url = "http://test-endpoint:29000" self.test_vm_size = "test-vm-size" self.test_vm_publisher = "test.vm.publisher" self.test_vm_offer = "test.vm.offer" self.test_vm_sku = "test-sku" self.test_cloud_os_family = "test-family" self.test_cloud_os_version = "test-version" self.test_node_agent_sku = "test-node-agent-sku" # connect with vm configuration db.merge_conn( Connection( conn_id=self.test_vm_conn_id, conn_type="azure_batch", extra=json.dumps({ "account_name": self.test_account_name, "account_key": self.test_account_key, "account_url": self.test_account_url, "vm_publisher": self.test_vm_publisher, "vm_offer": self.test_vm_offer, "vm_sku": self.test_vm_sku, "node_agent_sku_id": self.test_node_agent_sku, }), )) # connect with cloud service db.merge_conn( Connection( conn_id=self.test_cloud_conn_id, conn_type="azure_batch", extra=json.dumps({ "account_name": self.test_account_name, "account_key": self.test_account_key, "account_url": self.test_account_url, "os_family": self.test_cloud_os_family, "os_version": self.test_cloud_os_version, "node_agent_sku_id": self.test_node_agent_sku, }), )) self.operator = AzureBatchOperator( task_id=TASK_ID, batch_pool_id=BATCH_POOL_ID, batch_pool_vm_size=BATCH_VM_SIZE, batch_job_id=BATCH_JOB_ID, batch_task_id=BATCH_TASK_ID, vm_publisher=self.test_vm_publisher, vm_offer=self.test_vm_offer, vm_sku=self.test_vm_sku, vm_node_agent_sku_id=self.test_node_agent_sku, sku_starts_with=self.test_vm_sku, batch_task_command_line="echo hello", azure_batch_conn_id=self.test_vm_conn_id, target_dedicated_nodes=1, timeout=2, ) self.operator2_pass = AzureBatchOperator( task_id=TASK_ID, batch_pool_id=BATCH_POOL_ID, batch_pool_vm_size=BATCH_VM_SIZE, batch_job_id=BATCH_JOB_ID, batch_task_id=BATCH_TASK_ID, os_family="4", batch_task_command_line="echo hello", azure_batch_conn_id=self.test_vm_conn_id, enable_auto_scale=True, auto_scale_formula=FORMULA, timeout=2, ) self.operator2_no_formula = AzureBatchOperator( task_id=TASK_ID, batch_pool_id=BATCH_POOL_ID, batch_pool_vm_size=BATCH_VM_SIZE, batch_job_id=BATCH_JOB_ID, batch_task_id=BATCH_TASK_ID, os_family='4', batch_task_command_line="echo hello", azure_batch_conn_id=self.test_vm_conn_id, enable_auto_scale=True, timeout=2, ) self.operator_fail = AzureBatchOperator( task_id=TASK_ID, batch_pool_id=BATCH_POOL_ID, batch_pool_vm_size=BATCH_VM_SIZE, batch_job_id=BATCH_JOB_ID, batch_task_id=BATCH_TASK_ID, os_family='4', batch_task_command_line="echo hello", azure_batch_conn_id=self.test_vm_conn_id, timeout=2, ) self.operator_mutual_exclusive = AzureBatchOperator( task_id=TASK_ID, batch_pool_id=BATCH_POOL_ID, batch_pool_vm_size=BATCH_VM_SIZE, batch_job_id=BATCH_JOB_ID, batch_task_id=BATCH_TASK_ID, vm_publisher=self.test_vm_publisher, vm_offer=self.test_vm_offer, vm_sku=self.test_vm_sku, vm_node_agent_sku_id=self.test_node_agent_sku, os_family="5", sku_starts_with=self.test_vm_sku, batch_task_command_line="echo hello", azure_batch_conn_id=self.test_vm_conn_id, target_dedicated_nodes=1, timeout=2, ) self.operator_invalid = AzureBatchOperator( task_id=TASK_ID, batch_pool_id=BATCH_POOL_ID, batch_pool_vm_size=BATCH_VM_SIZE, batch_job_id=BATCH_JOB_ID, batch_task_id=BATCH_TASK_ID, batch_task_command_line="echo hello", azure_batch_conn_id=self.test_vm_conn_id, target_dedicated_nodes=1, timeout=2, ) self.batch_client = mock_batch.return_value self.mock_instance = mock_hook.return_value self.assertEqual(self.batch_client, self.operator.hook.connection)
class TestGKEPodOperator(unittest.TestCase): def setUp(self): self.gke_op = GKEStartPodOperator(project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, cluster_name=CLUSTER_NAME, task_id=PROJECT_TASK_ID, name=TASK_NAME, namespace=NAMESPACE, image=IMAGE) def test_template_fields(self): self.assertTrue( set(KubernetesPodOperator.template_fields).issubset( GKEStartPodOperator.template_fields)) # pylint: disable=unused-argument @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections", return_value=[Connection(extra=json.dumps({}))]) @mock.patch( 'airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.execute' ) @mock.patch('tempfile.NamedTemporaryFile') @mock.patch("subprocess.check_call") @mock.patch.dict(os.environ, {CREDENTIALS: '/tmp/local-creds'}) def test_execute_conn_id_none(self, proc_mock, file_mock, exec_mock, get_conn): type( file_mock.return_value.__enter__.return_value).name = PropertyMock( side_effect=[FILE_NAME]) def assert_credentials(*args, **kwargs): # since we passed in keyfile_path we should get a file self.assertIn(CREDENTIALS, os.environ) self.assertEqual(os.environ[CREDENTIALS], '/tmp/local-creds') proc_mock.side_effect = assert_credentials self.gke_op.execute(None) # Assert Environment Variable is being set correctly self.assertIn(KUBE_ENV_VAR, os.environ) self.assertEqual(os.environ[KUBE_ENV_VAR], FILE_NAME) # Assert the gcloud command being called correctly proc_mock.assert_called_once_with( GCLOUD_COMMAND.format(CLUSTER_NAME, PROJECT_LOCATION, TEST_GCP_PROJECT_ID).split()) self.assertEqual(self.gke_op.config_file, FILE_NAME) # pylint: disable=unused-argument @mock.patch( "airflow.hooks.base_hook.BaseHook.get_connections", return_value=[ Connection(extra=json.dumps( {'extra__google_cloud_platform__key_path': '/path/to/file'})) ]) @mock.patch( 'airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.execute' ) @mock.patch('tempfile.NamedTemporaryFile') @mock.patch("subprocess.check_call") @mock.patch.dict(os.environ, {}) def test_execute_conn_id_path(self, proc_mock, file_mock, exec_mock, get_con_mock): type( file_mock.return_value.__enter__.return_value).name = PropertyMock( side_effect=[FILE_NAME]) def assert_credentials(*args, **kwargs): # since we passed in keyfile_path we should get a file self.assertIn(CREDENTIALS, os.environ) self.assertEqual(os.environ[CREDENTIALS], '/path/to/file') proc_mock.side_effect = assert_credentials self.gke_op.execute(None) # Assert Environment Variable is being set correctly self.assertIn(KUBE_ENV_VAR, os.environ) self.assertEqual(os.environ[KUBE_ENV_VAR], FILE_NAME) # Assert the gcloud command being called correctly proc_mock.assert_called_once_with( GCLOUD_COMMAND.format(CLUSTER_NAME, PROJECT_LOCATION, TEST_GCP_PROJECT_ID).split()) self.assertEqual(self.gke_op.config_file, FILE_NAME) # pylint: disable=unused-argument @mock.patch.dict(os.environ, {}) @mock.patch("airflow.hooks.base_hook.BaseHook.get_connections", return_value=[ Connection(extra=json.dumps({ "extra__google_cloud_platform__keyfile_dict": '{"private_key": "r4nd0m_k3y"}' })) ]) @mock.patch( 'airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.execute' ) @mock.patch('tempfile.NamedTemporaryFile') @mock.patch("subprocess.check_call") def test_execute_conn_id_dict(self, proc_mock, file_mock, exec_mock, get_con_mock): type( file_mock.return_value.__enter__.return_value).name = PropertyMock( side_effect=[FILE_NAME, '/path/to/new-file']) def assert_credentials(*args, **kwargs): # since we passed in keyfile_dict we should get a new file self.assertIn(CREDENTIALS, os.environ) self.assertEqual(os.environ[CREDENTIALS], '/path/to/new-file') proc_mock.side_effect = assert_credentials self.gke_op.execute(None) # Assert Environment Variable is being set correctly self.assertIn(KUBE_ENV_VAR, os.environ) self.assertEqual(os.environ[KUBE_ENV_VAR], FILE_NAME) # Assert the gcloud command being called correctly proc_mock.assert_called_once_with( GCLOUD_COMMAND.format(CLUSTER_NAME, PROJECT_LOCATION, TEST_GCP_PROJECT_ID).split()) self.assertEqual(self.gke_op.config_file, FILE_NAME)
def _create_connections(self, count): return [ Connection(conn_id='TEST_CONN_ID' + str(i), conn_type='TEST_CONN_TYPE' + str(i)) for i in range(1, count + 1) ]
def _create_connection(self, session): connection_model = Connection(conn_id='test-connection-id', conn_type='test_type') session.add(connection_model) session.commit()
def setUp(self): db.merge_conn( Connection(conn_id='imap_default', host='imap_server_address', login='******', password='******'))
def test_connection_test_hook_method_missing(self): conn = Connection(conn_id='test_uri_hook_method_mising', conn_type='ftp') res = conn.test_connection() assert res[0] is False assert res[1] == "Hook FTPHook doesn't implement or inherit test_connection method"
def test_connection_test_no_hook(self): conn = Connection(conn_id='test_uri_no_hook', conn_type='fs') res = conn.test_connection() assert res[0] is False assert res[1] == 'Unknown hook type "fs"'
# KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import unittest from unittest import mock from airflow import PY38 from airflow.models import Connection if not PY38: from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook PYMSSQL_CONN = Connection(host='ip', schema='share', login='******', password='******', port=8081) class TestMsSqlHook(unittest.TestCase): @unittest.skipIf(PY38, "Mssql package not available when Python >= 3.8.") @mock.patch( 'airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn') @mock.patch('airflow.hooks.dbapi.DbApiHook.get_connection') def test_get_conn_should_return_connection(self, get_connection, mssql_get_conn): get_connection.return_value = PYMSSQL_CONN mssql_get_conn.return_value = mock.Mock() hook = MsSqlHook()
def setUp(self): db.merge_conn( Connection(conn_id=self.conn_id, host='https://api.opsgenie.com/', password='******'))
def create_default_connections(session: Session = NEW_SESSION): """Create default Airflow connections.""" merge_conn( Connection( conn_id="airflow_db", conn_type="mysql", host="mysql", login="******", password="", schema="airflow", ), session, ) merge_conn( Connection( conn_id="aws_default", conn_type="aws", ), session, ) merge_conn( Connection( conn_id="azure_batch_default", conn_type="azure_batch", login="******", password="", extra='''{"account_url": "<ACCOUNT_URL>"}''', ) ) merge_conn( Connection( conn_id="azure_cosmos_default", conn_type="azure_cosmos", extra='{"database_name": "<DATABASE_NAME>", "collection_name": "<COLLECTION_NAME>" }', ), session, ) merge_conn( Connection( conn_id='azure_data_explorer_default', conn_type='azure_data_explorer', host='https://<CLUSTER>.kusto.windows.net', extra='''{"auth_method": "<AAD_APP | AAD_APP_CERT | AAD_CREDS | AAD_DEVICE>", "tenant": "<TENANT ID>", "certificate": "<APPLICATION PEM CERTIFICATE>", "thumbprint": "<APPLICATION CERTIFICATE THUMBPRINT>"}''', ), session, ) merge_conn( Connection( conn_id="azure_data_lake_default", conn_type="azure_data_lake", extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }', ), session, ) merge_conn( Connection( conn_id="azure_default", conn_type="azure", ), session, ) merge_conn( Connection( conn_id="cassandra_default", conn_type="cassandra", host="cassandra", port=9042, ), session, ) merge_conn( Connection( conn_id="databricks_default", conn_type="databricks", host="localhost", ), session, ) merge_conn( Connection( conn_id="dingding_default", conn_type="http", host="", password="", ), session, ) merge_conn( Connection( conn_id="drill_default", conn_type="drill", host="localhost", port=8047, extra='{"dialect_driver": "drill+sadrill", "storage_plugin": "dfs"}', ), session, ) merge_conn( Connection( conn_id="druid_broker_default", conn_type="druid", host="druid-broker", port=8082, extra='{"endpoint": "druid/v2/sql"}', ), session, ) merge_conn( Connection( conn_id="druid_ingest_default", conn_type="druid", host="druid-overlord", port=8081, extra='{"endpoint": "druid/indexer/v1/task"}', ), session, ) merge_conn( Connection( conn_id="elasticsearch_default", conn_type="elasticsearch", host="localhost", schema="http", port=9200, ), session, ) merge_conn( Connection( conn_id="emr_default", conn_type="emr", extra=""" { "Name": "default_job_flow_name", "LogUri": "s3://my-emr-log-bucket/default_job_flow_location", "ReleaseLabel": "emr-4.6.0", "Instances": { "Ec2KeyName": "mykey", "Ec2SubnetId": "somesubnet", "InstanceGroups": [ { "Name": "Master nodes", "Market": "ON_DEMAND", "InstanceRole": "MASTER", "InstanceType": "r3.2xlarge", "InstanceCount": 1 }, { "Name": "Core nodes", "Market": "ON_DEMAND", "InstanceRole": "CORE", "InstanceType": "r3.2xlarge", "InstanceCount": 1 } ], "TerminationProtected": false, "KeepJobFlowAliveWhenNoSteps": false }, "Applications":[ { "Name": "Spark" } ], "VisibleToAllUsers": true, "JobFlowRole": "EMR_EC2_DefaultRole", "ServiceRole": "EMR_DefaultRole", "Tags": [ { "Key": "app", "Value": "analytics" }, { "Key": "environment", "Value": "development" } ] } """, ), session, ) merge_conn( Connection( conn_id="facebook_default", conn_type="facebook_social", extra=""" { "account_id": "<AD_ACCOUNT_ID>", "app_id": "<FACEBOOK_APP_ID>", "app_secret": "<FACEBOOK_APP_SECRET>", "access_token": "<FACEBOOK_AD_ACCESS_TOKEN>" } """, ), session, ) merge_conn( Connection( conn_id="fs_default", conn_type="fs", extra='{"path": "/"}', ), session, ) merge_conn( Connection( conn_id="google_cloud_default", conn_type="google_cloud_platform", schema="default", ), session, ) merge_conn( Connection( conn_id="hive_cli_default", conn_type="hive_cli", port=10000, host="localhost", extra='{"use_beeline": true, "auth": ""}', schema="default", ), session, ) merge_conn( Connection( conn_id="hiveserver2_default", conn_type="hiveserver2", host="localhost", schema="default", port=10000, ), session, ) merge_conn( Connection( conn_id="http_default", conn_type="http", host="https://www.httpbin.org/", ), session, ) merge_conn( Connection( conn_id='kubernetes_default', conn_type='kubernetes', ), session, ) merge_conn( Connection( conn_id='kylin_default', conn_type='kylin', host='localhost', port=7070, login="******", password="******", ), session, ) merge_conn( Connection( conn_id="leveldb_default", conn_type="leveldb", host="localhost", ), session, ) merge_conn(Connection(conn_id="livy_default", conn_type="livy", host="livy", port=8998), session) merge_conn( Connection( conn_id="local_mysql", conn_type="mysql", host="localhost", login="******", password="******", schema="airflow", ), session, ) merge_conn( Connection( conn_id="metastore_default", conn_type="hive_metastore", host="localhost", extra='{"authMechanism": "PLAIN"}', port=9083, ), session, ) merge_conn(Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017), session) merge_conn( Connection( conn_id="mssql_default", conn_type="mssql", host="localhost", port=1433, ), session, ) merge_conn( Connection( conn_id="mysql_default", conn_type="mysql", login="******", schema="airflow", host="mysql", ), session, ) merge_conn( Connection( conn_id="opsgenie_default", conn_type="http", host="", password="", ), session, ) merge_conn( Connection( conn_id="oss_default", conn_type="oss", extra='''{ "auth_type": "AK", "access_key_id": "<ACCESS_KEY_ID>", "access_key_secret": "<ACCESS_KEY_SECRET>"} ''', ), session, ) merge_conn( Connection( conn_id="pig_cli_default", conn_type="pig_cli", schema="default", ), session, ) merge_conn( Connection( conn_id="pinot_admin_default", conn_type="pinot", host="localhost", port=9000, ), session, ) merge_conn( Connection( conn_id="pinot_broker_default", conn_type="pinot", host="localhost", port=9000, extra='{"endpoint": "/query", "schema": "http"}', ), session, ) merge_conn( Connection( conn_id="postgres_default", conn_type="postgres", login="******", password="******", schema="airflow", host="postgres", ), session, ) merge_conn( Connection( conn_id="presto_default", conn_type="presto", host="localhost", schema="hive", port=3400, ), session, ) merge_conn( Connection( conn_id="qubole_default", conn_type="qubole", host="localhost", ), session, ) merge_conn( Connection( conn_id="redis_default", conn_type="redis", host="redis", port=6379, extra='{"db": 0}', ), session, ) merge_conn( Connection( conn_id="segment_default", conn_type="segment", extra='{"write_key": "my-segment-write-key"}', ), session, ) merge_conn( Connection( conn_id="sftp_default", conn_type="sftp", host="localhost", port=22, login="******", extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}', ), session, ) merge_conn( Connection( conn_id="spark_default", conn_type="spark", host="yarn", extra='{"queue": "root.default"}', ), session, ) merge_conn( Connection( conn_id="sqlite_default", conn_type="sqlite", host=os.path.join(gettempdir(), "sqlite_default.db"), ), session, ) merge_conn( Connection( conn_id="sqoop_default", conn_type="sqoop", host="rdbms", ), session, ) merge_conn( Connection( conn_id="ssh_default", conn_type="ssh", host="localhost", ), session, ) merge_conn( Connection( conn_id="tableau_default", conn_type="tableau", host="https://tableau.server.url", login="******", password="******", extra='{"site_id": "my_site"}', ), session, ) merge_conn( Connection( conn_id="trino_default", conn_type="trino", host="localhost", schema="hive", port=3400, ), session, ) merge_conn( Connection( conn_id="vertica_default", conn_type="vertica", host="localhost", port=5433, ), session, ) merge_conn( Connection( conn_id="wasb_default", conn_type="wasb", extra='{"sas_token": null}', ), session, ) merge_conn( Connection( conn_id="webhdfs_default", conn_type="hdfs", host="localhost", port=50070, ), session, ) merge_conn( Connection( conn_id='yandexcloud_default', conn_type='yandexcloud', schema='default', ), session, )
class LivyBatchHook(BaseHook): """ Uses the Apache Livy `Batch API <https://livy.incubator.apache.org/docs/latest/rest-api.html>`_ to submit spark jobs to a livy server, get batch state, verify batch state by quering either the spark history server or yarn resource manager, spill the logs of the spark job post completion, etc. """ template_fields = ["file", "proxy_user", "class_name", "arguments", "jars", "py_files", "files", "driver_memory", "driver_cores", "executor_memory", "executor_cores", "num_executors", "archives", "queue", "name", "conf", "azure_conn_id", "cluster_name", "batch_id"] class LocalConnHttpHook(HttpHook): def __init__(self, batch_hook, *args, **kwargs): super().__init__(*args, **kwargs) self.batch_hook = batch_hook def get_connection(self, conn_id): if conn_id == 'livy_conn_id': return self.batch_hook.livy_conn if conn_id == 'spark_conn_id': return self.batch_hook.spark_conn if conn_id == 'yarn_conn_id': return self.batch_hook.yarn_conn def __init__(self, file=None, proxy_user=None, class_name=None, arguments=None, jars=None, py_files=None, files=None, driver_memory=None, driver_cores=None, executor_memory=None, executor_cores=None, num_executors=None, archives=None, queue=None, name=None, conf=None, azure_conn_id=None, cluster_name=None, batch_id=None, verify_in=None): """ A batch hook object represents a call to livy `POST /batches <https://livy.incubator.apache.org/docs/latest/rest-api.html>`_ :param file: File containing the application to execute :type file: string :param proxy_user: User to impersonate when running the job :type file: string :param class_name: Application Java/Spark main class :type class_name: string :param arguments: Command line arguments for the application :type arguments: list[string] :param jars: jars to be used in this session :type jars: list[string] :param py_files: Python files to be used in this session :type py_files: list[string] :param files: files to be used in this session :type files: list[string] :param driver_memory: Amount of memory to use for the driver process :type driver_memory: string :param driver_cores: Number of cores to use for the driver process :type driver_cores: int :param executor_memory: Amount of memory to use per executor process :type executor_memory: string :param executor_cores: Number of cores to use for each executor :type executor_cores: int :param num_executors: Number of executors to launch for this session :type num_executors: int :param archives: Archives to be used in this session :type archives: list[string] :param queue: The name of the YARN queue to which submitted :type queue: string :param name: The name of this session :type name: string :param conf: Spark configuration properties :type conf: dict :param azure_conn_id: Connection ID for this Azure HDInsight connection. :type azure_conn_id: string :param cluster_name: Unique cluster name of the HDInsight cluster :type cluster_name: string :param batch_id: Livy Batch ID as returned by the API :type batch_id: string :param verify_in: Specify the additional verification method. Either `spark` or `yarn` :type verify_in: string """ super().__init__(source=None) self.file = file self.proxy_user = proxy_user self.class_name = class_name self.arguments = arguments self.jars = jars self.py_files = py_files self.files = files self.driver_memory = driver_memory self.driver_cores = driver_cores self.executor_memory = executor_memory self.executor_cores = executor_cores self.num_executors = num_executors self.archives = archives self.queue = queue self.name = name self.conf = conf self.azure_conn_id = azure_conn_id self.cluster_name = cluster_name self.batch_id = batch_id self.verify_in = verify_in self.connections_created = False def create_livy_connections(self): """Creates a livy connection dynamically""" session = settings.Session() azure_conn = session.query(Connection).filter(Connection.conn_id == self.azure_conn_id).first() if not azure_conn: raise AirflowException(f"Azure connection not found: {self.azure_conn_id}") username = azure_conn.extra_dejson['CLUSTER_LOGIN_USER_NAME'] password = azure_conn.extra_dejson['CLUSTER_PASSWORD'] self.livy_conn = Connection(conn_id='livy_conn_id') self.livy_conn.login = username self.livy_conn.set_password(password) self.livy_conn.schema = 'https' self.livy_conn.extra = f"{{ \"X-Requested-By\": \"{username}\" }}" self.livy_conn.host = f"https://{self.cluster_name}.azurehdinsight.net/livy" self.spark_conn = Connection(conn_id='spark_conn_id') self.spark_conn.login = username self.spark_conn.set_password(password) self.spark_conn.schema = 'https' self.spark_conn.extra = f"{{ \"X-Requested-By\": \"{username}\" }}" self.spark_conn.host = f"https://{self.cluster_name}.azurehdinsight.net/sparkhistory" self.yarn_conn = Connection(conn_id='yarn_conn_id') self.yarn_conn.login = username self.yarn_conn.set_password(password) self.yarn_conn.schema = 'https' self.yarn_conn.extra = f"{{ \"X-Requested-By\": \"{username}\" }}" self.yarn_conn.host = f"https://{self.cluster_name}.azurehdinsight.net/yarnui" self.connections_created = True def submit_batch(self): """ Submit a livy batch :return: the batch id returned by the livy server :rtype: string """ if not self.connections_created: self.create_livy_connections() headers = {"X-Requested-By": "airflow", "Content-Type": "application/json"} unfiltered_payload = { "file": self.file, "proxyUser": self.proxy_user, "className": self.class_name, "args": self.arguments, "jars": self.jars, "pyFiles": self.py_files, "files": self.files, "driverMemory": self.driver_memory, "driverCores": self.driver_cores, "executorMemory": self.executor_memory, "executorCores": self.executor_cores, "numExecutors": self.num_executors, "archives": self.archives, "queue": self.queue, "name": self.name, "conf": self.conf, } payload = {k: v for k, v in unfiltered_payload.items() if v} self.log.info( f"Submitting the batch to Livy... " f"Payload:\n{json.dumps(payload, indent=2)}" ) response = self.LocalConnHttpHook(self, http_conn_id='livy_conn_id').run( LIVY_ENDPOINT, json.dumps(payload), headers ) try: batch_id = json.loads(response.content)["id"] except (JSONDecodeError, LookupError) as ex: self._log_response_error("$.id", response) raise AirflowBadRequest(ex) if not isinstance(batch_id, Number): raise AirflowException( f"ID of the created batch is not a number ({batch_id}). " "Are you sure we're calling Livy API?" ) self.batch_id = batch_id self.log.info(f"Batch successfully submitted with id %s", self.batch_id) return self.batch_id def get_batch_state(self): """ queries and gets the current livy batch state :return: the livy batch state :rtype: dict """ if not self.connections_created: self.create_livy_connections() self.log.info("Getting batch %s status...", self.batch_id) endpoint = f"{LIVY_ENDPOINT}/{self.batch_id}" response = self.LocalConnHttpHook(self, method="GET", http_conn_id='livy_conn_id').run(endpoint) try: return json.loads(response.content)["state"] except (JSONDecodeError, LookupError) as ex: self._log_response_error("$.state", response, self.batch_id) raise AirflowBadRequest(ex) def verify(self): """ does additional verification of a livy batch by either querying the yarn resource manager or the spark history server. :raises AirflowException: when the job is verified to have failed """ if not self.connections_created: self.create_livy_connections() app_id = self._get_spark_app_id(self.batch_id) if app_id is None: raise AirflowException(f"Spark appId was null for batch {self.batch_id}") self.log.info("Found app id '%s' for batch id %s.", app_id, self.batch_id) if self.verify_in == "spark": self._check_spark_app_status(app_id) else: self._check_yarn_app_status(app_id) self.log.info("App '%s' associated with batch %s completed!", app_id, self.batch_id) def _get_spark_app_id(self, batch_id): """ Gets the spark application ID of a livy batch job :param batch_id: the batch id of the livy batch job :return: spark application ID :rtype: string """ self.log.info("Getting Spark app id from Livy API for batch %s...", batch_id) endpoint = f"{LIVY_ENDPOINT}/{batch_id}" response = self.LocalConnHttpHook(self, method="GET", http_conn_id='livy_conn_id').run( endpoint ) try: return json.loads(response.content)["appId"] except (JSONDecodeError, LookupError, AirflowException) as ex: self._log_response_error("$.appId", response, batch_id) raise AirflowBadRequest(ex) def _check_spark_app_status(self, app_id): """ Verifies whether this spark job has succeeded or failed by querying the spark history server :param app_id: application ID of the spark job :raises AirflowException: when the job is verified to have failed """ self.log.info("Getting app status (id=%s) from Spark REST API...", app_id) endpoint = f"{SPARK_ENDPOINT}/{app_id}/jobs" response = self.LocalConnHttpHook(self, method="GET", http_conn_id='spark_conn_id').run( endpoint ) try: jobs = json.loads(response.content) expected_status = "SUCCEEDED" for job in jobs: job_id = job["jobId"] job_status = job["status"] self.log.info( "Job id %s associated with application '%s' is '%s'", job_id, app_id, job_status ) if job_status != expected_status: raise AirflowException( f"Job id '{job_id}' associated with application '{app_id}' " f"is '{job_status}', expected status is '{expected_status}'" ) except (JSONDecodeError, LookupError, TypeError) as ex: self._log_response_error("$.jobId, $.status", response) raise AirflowBadRequest(ex) def _check_yarn_app_status(self, app_id): """ Verifies whether this YARN job has succeeded or failed by querying the YARN Resource Manager :param app_id: the YARN application ID :raises AirflowException: when the job is verified to have failed """ self.log.info("Getting app status (id=%s) from YARN RM REST API...", app_id) endpoint = f"{YARN_ENDPOINT}/{app_id}" response = self.LocalConnHttpHook(self, method="GET", http_conn_id='yarn_conn_id').run( endpoint ) try: status = json.loads(response.content)["app"]["finalStatus"] except (JSONDecodeError, LookupError, TypeError) as ex: self._log_response_error("$.app.finalStatus", response) raise AirflowBadRequest(ex) expected_status = "SUCCEEDED" if status != expected_status: raise AirflowException( f"YARN app {app_id} is '{status}', expected status: '{expected_status}'" ) def spill_batch_logs(self): """Gets paginated batch logs from livy batch API and logs them""" if not self.connections_created: self.create_livy_connections() dashes = 50 self.log.info(f"{'-'*dashes}Full log for batch %s{'-'*dashes}", self.batch_id) endpoint = f"{LIVY_ENDPOINT}/{self.batch_id}/log" hook = self.LocalConnHttpHook(self, method="GET", http_conn_id='livy_conn_id') line_from = 0 line_to = LOG_PAGE_LINES while True: log_page = self._fetch_log_page(hook, endpoint, line_from, line_to) try: logs = log_page["log"] for log in logs: self.log.info(log.replace("\\n", "\n")) actual_line_from = log_page["from"] total_lines = log_page["total"] except LookupError as ex: self._log_response_error("$.log, $.from, $.total", log_page) raise AirflowBadRequest(ex) actual_lines = len(logs) if actual_line_from + actual_lines >= total_lines: self.log.info( f"{'-' * dashes}End of full log for batch %s" f"{'-' * dashes}", self.batch_id ) break line_from = actual_line_from + actual_lines def _fetch_log_page(self, hook: LocalConnHttpHook, endpoint, line_from, line_to): """fetch a paginated log page from the livy batch API""" prepd_endpoint = endpoint + f"?from={line_from}&size={line_to}" response = hook.run(prepd_endpoint) try: return json.loads(response.content) except JSONDecodeError as ex: self._log_response_error("$", response) raise AirflowBadRequest(ex) def close_batch(self): """close a livy batch""" self.log.info(f"Closing batch with id = %s", self.batch_id) batch_endpoint = f"{LIVY_ENDPOINT}/{self.batch_id}" self.LocalConnHttpHook(self, method="DELETE", http_conn_id='livy_conn_id').run( batch_endpoint ) self.log.info(f"Batch %s has been closed", self.batch_id) def _log_response_error(self, lookup_path, response, batch_id=None): """log an error response from the livy batch API""" msg = "Can not parse JSON response." if batch_id is not None: msg += f" Batch id={batch_id}." try: pp_response = ( json.dumps(json.loads(response.content), indent=2) if "application/json" in response.headers.get("Content-Type", "") else response.content ) except AttributeError: pp_response = json.dumps(response, indent=2) msg += f"\nTried to find JSON path: {lookup_path}, but response was:\n{pp_response}" self.log.error(msg)
def initdb(rbac=False): session = settings.Session() from airflow import models from airflow.models import Connection upgradedb() merge_conn( Connection(conn_id='airflow_db', conn_type='mysql', host='mysql', login='******', password='', schema='airflow')) merge_conn( Connection(conn_id='beeline_default', conn_type='beeline', port=10000, host='localhost', extra="{\"use_beeline\": true, \"auth\": \"\"}", schema='default')) merge_conn( Connection(conn_id='bigquery_default', conn_type='google_cloud_platform', schema='default')) merge_conn( Connection(conn_id='local_mysql', conn_type='mysql', host='localhost', login='******', password='******', schema='airflow')) merge_conn( Connection(conn_id='presto_default', conn_type='presto', host='localhost', schema='hive', port=3400)) merge_conn( Connection( conn_id='google_cloud_default', conn_type='google_cloud_platform', schema='default', )) merge_conn( Connection( conn_id='hive_cli_default', conn_type='hive_cli', schema='default', )) merge_conn( Connection(conn_id='hiveserver2_default', conn_type='hiveserver2', host='localhost', schema='default', port=10000)) merge_conn( Connection(conn_id='metastore_default', conn_type='hive_metastore', host='localhost', extra="{\"authMechanism\": \"PLAIN\"}", port=9083)) merge_conn( Connection(conn_id='mongo_default', conn_type='mongo', host='mongo', port=27017)) merge_conn( Connection(conn_id='mysql_default', conn_type='mysql', login='******', schema='airflow', host='mysql')) merge_conn( Connection(conn_id='postgres_default', conn_type='postgres', login='******', password='******', schema='airflow', host='postgres')) merge_conn( Connection(conn_id='sqlite_default', conn_type='sqlite', host='/tmp/sqlite_default.db')) merge_conn( Connection(conn_id='http_default', conn_type='http', host='https://www.google.com/')) merge_conn( Connection(conn_id='mssql_default', conn_type='mssql', host='localhost', port=1433)) merge_conn( Connection(conn_id='vertica_default', conn_type='vertica', host='localhost', port=5433)) merge_conn( Connection(conn_id='wasb_default', conn_type='wasb', extra='{"sas_token": null}')) merge_conn( Connection(conn_id='webhdfs_default', conn_type='hdfs', host='localhost', port=50070)) merge_conn( Connection(conn_id='ssh_default', conn_type='ssh', host='localhost')) merge_conn( Connection(conn_id='sftp_default', conn_type='sftp', host='localhost', port=22, login='******', extra=''' {"key_file": "~/.ssh/id_rsa", "no_host_key_check": true} ''')) merge_conn( Connection(conn_id='fs_default', conn_type='fs', extra='{"path": "/"}')) merge_conn( Connection(conn_id='aws_default', conn_type='aws', extra='{"region_name": "us-east-1"}')) merge_conn( Connection(conn_id='spark_default', conn_type='spark', host='yarn', extra='{"queue": "root.default"}')) merge_conn( Connection(conn_id='druid_broker_default', conn_type='druid', host='druid-broker', port=8082, extra='{"endpoint": "druid/v2/sql"}')) merge_conn( Connection(conn_id='druid_ingest_default', conn_type='druid', host='druid-overlord', port=8081, extra='{"endpoint": "druid/indexer/v1/task"}')) merge_conn( Connection(conn_id='redis_default', conn_type='redis', host='redis', port=6379, extra='{"db": 0}')) merge_conn( Connection(conn_id='sqoop_default', conn_type='sqoop', host='rmdbs', extra='')) merge_conn( Connection(conn_id='emr_default', conn_type='emr', extra=''' { "Name": "default_job_flow_name", "LogUri": "s3://my-emr-log-bucket/default_job_flow_location", "ReleaseLabel": "emr-4.6.0", "Instances": { "Ec2KeyName": "mykey", "Ec2SubnetId": "somesubnet", "InstanceGroups": [ { "Name": "Master nodes", "Market": "ON_DEMAND", "InstanceRole": "MASTER", "InstanceType": "r3.2xlarge", "InstanceCount": 1 }, { "Name": "Slave nodes", "Market": "ON_DEMAND", "InstanceRole": "CORE", "InstanceType": "r3.2xlarge", "InstanceCount": 1 } ], "TerminationProtected": false, "KeepJobFlowAliveWhenNoSteps": false }, "Applications":[ { "Name": "Spark" } ], "VisibleToAllUsers": true, "JobFlowRole": "EMR_EC2_DefaultRole", "ServiceRole": "EMR_DefaultRole", "Tags": [ { "Key": "app", "Value": "analytics" }, { "Key": "environment", "Value": "development" } ] } ''')) merge_conn( Connection(conn_id='databricks_default', conn_type='databricks', host='localhost')) merge_conn( Connection(conn_id='qubole_default', conn_type='qubole', host='localhost')) merge_conn( Connection(conn_id='segment_default', conn_type='segment', extra='{"write_key": "my-segment-write-key"}')), merge_conn( Connection( conn_id='azure_data_lake_default', conn_type='azure_data_lake', extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }')) merge_conn( Connection( conn_id='azure_cosmos_default', conn_type='azure_cosmos', extra= '{"database_name": "<DATABASE_NAME>", "collection_name": "<COLLECTION_NAME>" }' )) merge_conn( Connection( conn_id='azure_container_instances_default', conn_type='azure_container_instances', extra= '{"tenantId": "<TENANT>", "subscriptionId": "<SUBSCRIPTION ID>" }') ) merge_conn( Connection(conn_id='cassandra_default', conn_type='cassandra', host='cassandra', port=9042)) merge_conn( Connection(conn_id='dingding_default', conn_type='http', host='', password='')) merge_conn( Connection(conn_id='opsgenie_default', conn_type='http', host='', password='')) # Known event types KET = models.KnownEventType if not session.query(KET).filter(KET.know_event_type == 'Holiday').first(): session.add(KET(know_event_type='Holiday')) if not session.query(KET).filter(KET.know_event_type == 'Outage').first(): session.add(KET(know_event_type='Outage')) if not session.query(KET).filter( KET.know_event_type == 'Natural Disaster').first(): session.add(KET(know_event_type='Natural Disaster')) if not session.query(KET).filter( KET.know_event_type == 'Marketing Campaign').first(): session.add(KET(know_event_type='Marketing Campaign')) session.commit() dagbag = models.DagBag() # Save individual DAGs in the ORM for dag in dagbag.dags.values(): dag.sync_to_db() # Deactivate the unknown ones models.DAG.deactivate_unknown_dags(dagbag.dags.keys()) Chart = models.Chart chart_label = "Airflow task instance by type" chart = session.query(Chart).filter(Chart.label == chart_label).first() if not chart: chart = Chart( label=chart_label, conn_id='airflow_db', chart_type='bar', x_is_date=False, sql=("SELECT state, COUNT(1) as number " "FROM task_instance " "WHERE dag_id LIKE 'example%' " "GROUP BY state"), ) session.add(chart) session.commit() if rbac: from flask_appbuilder.security.sqla import models from flask_appbuilder.models.sqla import Base Base.metadata.create_all(settings.engine)
import configparser from airflow import settings from airflow.models import Connection, Variable # fetch AWS_KEY and AWS_SECRET config = configparser.ConfigParser() config.read('/Users/sathishkaliamoorthy/.aws/credentials') AWS_KEY = config['default']['aws_access_key_id'] AWS_SECRET = config['default']['aws_secret_access_key'] #AWS_KEY = os.environ.get('AWS_KEY') #AWS_SECRET = os.environ.get('AWS_SECRET') # inserting new connection object programmatically aws_conn = Connection(conn_id='aws_credentials', conn_type='Amazon Web Services', login=AWS_KEY, password=AWS_SECRET) redshift_conn = Connection( conn_id='aws_redshift', conn_type='Postgres', host='my-sparkify-dwh.cdc1pzfmi32k.us-west-2.redshift.amazonaws.com', port=5439, schema='sparkify_dwh', login='******', password='******') session = settings.Session() session.add(aws_conn) session.add(redshift_conn) session.commit()
def connections_add(args): """Adds new connection""" # Check that the conn_id and conn_uri args were passed to the command: missing_args = [] invalid_args = [] if args.conn_uri: if not _valid_uri(args.conn_uri): raise SystemExit( f'The URI provided to --conn-uri is invalid: {args.conn_uri}') for arg in alternative_conn_specs: if getattr(args, arg) is not None: invalid_args.append(arg) elif not args.conn_type: missing_args.append('conn-uri or conn-type') if missing_args: raise SystemExit( f'The following args are required to add a connection: {missing_args!r}' ) if invalid_args: raise SystemExit(f'The following args are not compatible with the ' f'add flag and --conn-uri flag: {invalid_args!r}') if args.conn_uri: new_conn = Connection(conn_id=args.conn_id, description=args.conn_description, uri=args.conn_uri) else: new_conn = Connection( conn_id=args.conn_id, conn_type=args.conn_type, description=args.conn_description, host=args.conn_host, login=args.conn_login, password=args.conn_password, schema=args.conn_schema, port=args.conn_port, ) if args.conn_extra is not None: new_conn.set_extra(args.conn_extra) with create_session() as session: if not session.query(Connection).filter( Connection.conn_id == new_conn.conn_id).first(): session.add(new_conn) msg = 'Successfully added `conn_id`={conn_id} : {uri}' msg = msg.format( conn_id=new_conn.conn_id, uri=args.conn_uri or urlunparse(( args.conn_type, '{login}:{password}@{host}:{port}'.format( login=args.conn_login or '', password='******' if args.conn_password else '', host=args.conn_host or '', port=args.conn_port or '', ), args.conn_schema or '', '', '', '', )), ) print(msg) else: msg = f'A connection with `conn_id`={new_conn.conn_id} already exists.' raise SystemExit(msg)
import datahub.emitter.mce_builder as builder from datahub.integrations.airflow.get_provider_info import get_provider_info from datahub.integrations.airflow.hooks import DatahubKafkaHook, DatahubRestHook from datahub.integrations.airflow.operators import DatahubEmitterOperator lineage_mce = builder.make_lineage_mce( [ builder.make_dataset_urn("bigquery", "upstream1"), builder.make_dataset_urn("bigquery", "upstream2"), ], builder.make_dataset_urn("bigquery", "downstream1"), ) datahub_rest_connection_config = Connection( conn_id="datahub_rest_test", conn_type="datahub_rest", host="http://test_host:8080/", extra=None, ) datahub_kafka_connection_config = Connection( conn_id="datahub_kafka_test", conn_type="datahub_kafka", host="test_broker:9092", extra=json.dumps({ "connection": { "producer_config": {}, "schema_registry_url": "http://localhost:8081", } }), )
def get_airflow_connection(conn_id=None): return Connection(conn_id='http_default', conn_type='http', host='test:8080/', extra='{"bareer": "test"}')
new_conn = Connection(conn_id='slack_token') new_conn.set_password(SLACK_LEGACY_TOKEN) if not (session.query(Connection).filter(Connection.conn_id == new_conn.conn_id).first()): session.add(new_conn) session.commit() else: msg = '\n\tA connection with `conn_id`={conn_id} already exists\n' msg = msg.format(conn_id=new_conn.conn_id) print(msg) creds = {"user": myservice.get_user(), "pwd": myservice.get_pwd() c = Connection(conn_id=f'your_airflow_connection_id_here', login=creds["user"], host=None) c.set_password(creds["pwd"]) merge_conn(c) def create_conn(username, password, host=None): conn = Connection(conn_id=f'{username}_connection', login=username, host=host if host else None) conn.set_password(password) dag = DAG( 'create_connection', default_args=default_args,
class TestWinRMHook(unittest.TestCase): @patch('airflow.providers.microsoft.winrm.hooks.winrm.Protocol') def test_get_conn_exists(self, mock_protocol): winrm_hook = WinRMHook() winrm_hook.client = mock_protocol.return_value.open_shell.return_value conn = winrm_hook.get_conn() self.assertEqual(conn, winrm_hook.client) def test_get_conn_missing_remote_host(self): with self.assertRaises(AirflowException): WinRMHook().get_conn() @patch('airflow.providers.microsoft.winrm.hooks.winrm.Protocol') def test_get_conn_error(self, mock_protocol): mock_protocol.side_effect = Exception('Error') with self.assertRaises(AirflowException): WinRMHook(remote_host='host').get_conn() @patch('airflow.providers.microsoft.winrm.hooks.winrm.Protocol', autospec=True) @patch( 'airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection', return_value=Connection(login='******', password='******', host='remote_host', extra="""{ "endpoint": "endpoint", "remote_port": 123, "transport": "plaintext", "service": "service", "keytab": "keytab", "ca_trust_path": "ca_trust_path", "cert_pem": "cert_pem", "cert_key_pem": "cert_key_pem", "server_cert_validation": "validate", "kerberos_delegation": "true", "read_timeout_sec": 124, "operation_timeout_sec": 123, "kerberos_hostname_override": "kerberos_hostname_override", "message_encryption": "auto", "credssp_disable_tlsv1_2": "true", "send_cbt": "false" }""")) def test_get_conn_from_connection(self, mock_get_connection, mock_protocol): connection = mock_get_connection.return_value winrm_hook = WinRMHook(ssh_conn_id='conn_id') winrm_hook.get_conn() mock_get_connection.assert_called_once_with(winrm_hook.ssh_conn_id) mock_protocol.assert_called_once_with( endpoint=str(connection.extra_dejson['endpoint']), transport=str(connection.extra_dejson['transport']), username=connection.login, password=connection.password, service=str(connection.extra_dejson['service']), keytab=str(connection.extra_dejson['keytab']), ca_trust_path=str(connection.extra_dejson['ca_trust_path']), cert_pem=str(connection.extra_dejson['cert_pem']), cert_key_pem=str(connection.extra_dejson['cert_key_pem']), server_cert_validation=str( connection.extra_dejson['server_cert_validation']), kerberos_delegation=str( connection.extra_dejson['kerberos_delegation']).lower() == 'true', read_timeout_sec=int(connection.extra_dejson['read_timeout_sec']), operation_timeout_sec=int( connection.extra_dejson['operation_timeout_sec']), kerberos_hostname_override=str( connection.extra_dejson['kerberos_hostname_override']), message_encryption=str( connection.extra_dejson['message_encryption']), credssp_disable_tlsv1_2=str( connection.extra_dejson['credssp_disable_tlsv1_2']).lower() == 'true', send_cbt=str( connection.extra_dejson['send_cbt']).lower() == 'true') @patch('airflow.providers.microsoft.winrm.hooks.winrm.getpass.getuser', return_value='user') @patch('airflow.providers.microsoft.winrm.hooks.winrm.Protocol') def test_get_conn_no_username(self, mock_protocol, mock_getuser): winrm_hook = WinRMHook(remote_host='host', password='******') winrm_hook.get_conn() self.assertEqual(mock_getuser.return_value, winrm_hook.username) @patch('airflow.providers.microsoft.winrm.hooks.winrm.Protocol') def test_get_conn_no_endpoint(self, mock_protocol): winrm_hook = WinRMHook(remote_host='host', password='******') winrm_hook.get_conn() self.assertEqual( 'http://{0}:{1}/wsman'.format(winrm_hook.remote_host, winrm_hook.remote_port), winrm_hook.endpoint)
def create_conn(username, password, host=None): conn = Connection(conn_id=f'{username}_connection', login=username, host=host if host else None)
def setUp(self): db.merge_conn( Connection(conn_id='spark_default', conn_type='spark', host='yarn://yarn-master'))
class TestRedisHook(unittest.TestCase): def test_get_conn(self): hook = RedisHook(redis_conn_id='redis_default') self.assertEqual(hook.redis, None) self.assertEqual(hook.host, None, 'host initialised as None.') self.assertEqual(hook.port, None, 'port initialised as None.') self.assertEqual(hook.password, None, 'password initialised as None.') self.assertEqual(hook.db, None, 'db initialised as None.') self.assertIs(hook.get_conn(), hook.get_conn(), 'Connection initialized only if None.') @mock.patch('airflow.providers.redis.hooks.redis.Redis') @mock.patch('airflow.providers.redis.hooks.redis.RedisHook.get_connection', return_value=Connection(password='******', host='remote_host', port=1234, extra="""{ "db": 2, "ssl": true, "ssl_cert_reqs": "required", "ssl_ca_certs": "/path/to/custom/ca-cert", "ssl_keyfile": "/path/to/key-file", "ssl_cert_file": "/path/to/cert-file", "ssl_check_hostname": true }""")) def test_get_conn_with_extra_config(self, mock_get_connection, mock_redis): connection = mock_get_connection.return_value hook = RedisHook() hook.get_conn() mock_redis.assert_called_once_with( host=connection.host, password=connection.password, port=connection.port, db=connection.extra_dejson["db"], ssl=connection.extra_dejson["ssl"], ssl_cert_reqs=connection.extra_dejson["ssl_cert_reqs"], ssl_ca_certs=connection.extra_dejson["ssl_ca_certs"], ssl_keyfile=connection.extra_dejson["ssl_keyfile"], ssl_cert_file=connection.extra_dejson["ssl_cert_file"], ssl_check_hostname=connection.extra_dejson["ssl_check_hostname"]) def test_get_conn_password_stays_none(self): hook = RedisHook(redis_conn_id='redis_default') hook.get_conn() self.assertEqual(hook.password, None) @pytest.mark.integration("redis") def test_real_ping(self): hook = RedisHook(redis_conn_id='redis_default') redis = hook.get_conn() self.assertTrue(redis.ping(), 'Connection to Redis with PING works.') @pytest.mark.integration("redis") def test_real_get_and_set(self): hook = RedisHook(redis_conn_id='redis_default') redis = hook.get_conn() self.assertTrue(redis.set('test_key', 'test_value'), 'Connection to Redis with SET works.') self.assertEqual(redis.get('test_key'), b'test_value', 'Connection to Redis with GET works.') self.assertEqual(redis.delete('test_key'), 1, 'Connection to Redis with DELETE works.')
def setUpClass(cls) -> None: db.merge_conn( Connection( conn_id=cls.CONN_SSH_WITH_EXTRA, host='localhost', conn_type='ssh', extra= '{"compress" : true, "no_host_key_check" : "true", "allow_host_key_change": false}', )) db.merge_conn( Connection( conn_id=cls.CONN_SSH_WITH_EXTRA_FALSE_LOOK_FOR_KEYS, host='localhost', conn_type='ssh', extra='{"compress" : true, "no_host_key_check" : "true", ' '"allow_host_key_change": false, "look_for_keys": false}', )) db.merge_conn( Connection( conn_id=cls.CONN_SSH_WITH_PRIVATE_KEY_EXTRA, host='localhost', conn_type='ssh', extra=json.dumps({ "private_key": TEST_PRIVATE_KEY, }), )) db.merge_conn( Connection( conn_id=cls.CONN_SSH_WITH_PRIVATE_KEY_PASSPHRASE_EXTRA, host='localhost', conn_type='ssh', extra=json.dumps({ "private_key": TEST_ENCRYPTED_PRIVATE_KEY, "private_key_passphrase": PASSPHRASE }), )) db.merge_conn( Connection( conn_id=cls.CONN_SSH_WITH_HOST_KEY_EXTRA, host='localhost', conn_type='ssh', extra=json.dumps({ "private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY }), )) db.merge_conn( Connection( conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE, host='remote_host', conn_type='ssh', extra=json.dumps({ "private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY, "no_host_key_check": False }), )) db.merge_conn( Connection( conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE, host='remote_host', conn_type='ssh', extra=json.dumps({ "private_key": TEST_PRIVATE_KEY, "host_key": TEST_HOST_KEY, "no_host_key_check": True }), )) db.merge_conn( Connection( conn_id=cls. CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE, host='remote_host', conn_type='ssh', extra=json.dumps({ "private_key": TEST_PRIVATE_KEY, "no_host_key_check": False }), ))
def test_get_connection_first_try(self, mock_env_get, mock_meta_get): mock_env_get.side_effect = ["something"] # returns something Connection.get_connection_from_secrets("fake_conn_id") mock_env_get.assert_called_once_with(conn_id="fake_conn_id") mock_meta_get.not_called()