def test_delete_multiple_blobs(self, mock_service): mock_instance = mock_service.return_value Blob = namedtuple('Blob', ['name']) mock_instance.list_blobs.return_value = iter( [Blob('blob_prefix/blob1'), Blob('blob_prefix/blob2')]) hook = WasbHook(wasb_conn_id='wasb_test_sas_token') hook.delete_file('container', 'blob_prefix', is_prefix=True) mock_instance.delete_blob.assert_any_call('container', 'blob_prefix/blob1', delete_snapshots='include') mock_instance.delete_blob.assert_any_call('container', 'blob_prefix/blob2', delete_snapshots='include')
def test_delete_multiple_blobs(self, mock_check, mock_get_blobslist, mock_delete_blobs): mock_check.return_value = False mock_get_blobslist.return_value = [ 'blob_prefix/blob1', 'blob_prefix/blob2' ] hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) hook.delete_file('container', 'blob_prefix', is_prefix=True) mock_get_blobslist.assert_called_once_with('container', prefix='blob_prefix') mock_delete_blobs.assert_any_call( 'container', 'blob_prefix/blob1', 'blob_prefix/blob2', )
def copy_files_to_wasb(self, sftp_files: List[SftpFile]) -> List[str]: """Upload a list of files from sftp_files to Azure Blob Storage with a new Blob Name.""" uploaded_files = [] wasb_hook = WasbHook(wasb_conn_id=self.wasb_conn_id) for file in sftp_files: with NamedTemporaryFile("w") as tmp: self.sftp_hook.retrieve_file(file.sftp_file_path, tmp.name) self.log.info( 'Uploading %s to wasb://%s as %s', file.sftp_file_path, self.container_name, file.blob_name, ) wasb_hook.load_file(tmp.name, self.container_name, file.blob_name, **self.load_options) uploaded_files.append(file.sftp_file_path) return uploaded_files
def _rank_movies(odbc_conn_id, wasb_conn_id, ratings_container, rankings_container, **context): year = context["execution_date"].year month = context["execution_date"].month # Determine storage account name, needed for query source URL. blob_account_name = WasbHook.get_connection(wasb_conn_id).login query = RANK_QUERY.format( year=year, month=month, blob_account_name=blob_account_name, blob_container=ratings_container, ) logging.info(f"Executing query: {query}") odbc_hook = OdbcHook(odbc_conn_id, driver="ODBC Driver 17 for SQL Server") with odbc_hook.get_conn() as conn: with conn.cursor() as cursor: cursor.execute(query) rows = cursor.fetchall() colnames = [field[0] for field in cursor.description] ranking = pd.DataFrame.from_records(rows, columns=colnames) logging.info(f"Retrieved {ranking.shape[0]} rows") # Write ranking to temp file. logging.info( f"Writing results to {rankings_container}/{year}/{month:02d}.csv") with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = path.join(tmp_dir, "ranking.csv") ranking.to_csv(tmp_path, index=False) # Upload file to Azure Blob. wasb_hook = WasbHook(wasb_conn_id) wasb_hook.load_file( tmp_path, container_name=rankings_container, blob_name=f"{year}/{month:02d}.csv", )
def _upload_ratings(wasb_conn_id, container, **context): year = context["execution_date"].year month = context["execution_date"].month logging.info(f"Fetching ratings for {year}/{month:02d}") ratings = fetch_ratings(year=year, month=month) logging.info(f"Fetched {ratings.shape[0]} rows") # Write ratings to temp file. with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = path.join(tmp_dir, "ratings.csv") ratings.to_csv(tmp_path, index=False) # Upload file to Azure Blob. logging.info(f"Writing results to {container}/{year}/{month:02d}.csv") hook = WasbHook(wasb_conn_id) hook.load_file( tmp_path, container_name=container, blob_name=f"{year}/{month:02d}.csv", )
def hook(self): """ Returns WasbHook. """ remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID') try: from airflow.providers.microsoft.azure.hooks.wasb import WasbHook return WasbHook(remote_conn_id) except AzureHttpError: self.log.error( 'Could not create an WasbHook with connection id "%s". ' 'Please make sure that airflow[azure] is installed and ' 'the Wasb connection exists.', remote_conn_id)
def poke(self, context: dict) -> bool: self.log.info('Poking for prefix: %s in wasb://%s', self.prefix, self.container_name) hook = WasbHook(wasb_conn_id=self.wasb_conn_id) return hook.check_for_prefix(self.container_name, self.prefix, **self.check_options)
def test_check_for_blob_empty(self, mock_service): mock_service.return_value.exists.return_value = False hook = WasbHook(wasb_conn_id='wasb_test_sas_token') self.assertFalse(hook.check_for_blob('container', 'blob'))
def test_delete_single_blob(self, mock_service): mock_instance = mock_service.return_value hook = WasbHook(wasb_conn_id='wasb_test_sas_token') hook.delete_file('container', 'blob', is_prefix=False) mock_instance.delete_blob.assert_called_once_with( 'container', 'blob', delete_snapshots='include')
def test_load_string(self, mock_service): mock_instance = mock_service.return_value hook = WasbHook(wasb_conn_id='wasb_test_sas_token') hook.load_string('big string', 'container', 'blob', max_connections=1) mock_instance.create_blob_from_text.assert_called_once_with( 'container', 'blob', 'big string', max_connections=1)
def test_public_read(self): hook = WasbHook(wasb_conn_id=self.public_read_conn_id, public_read=True) assert isinstance(hook.get_conn(), BlobServiceClient)
def test_delete_multiple_nonexisting_blobs_fails(self, mock_getblobs): mock_getblobs.return_value = [] with pytest.raises(Exception) as ctx: hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) hook.delete_file('container', 'nonexisting_blob_prefix', is_prefix=True, ignore_if_missing=False) assert isinstance(ctx.value, AirflowException)
def test_sas_token_connection(self): hook = WasbHook(wasb_conn_id=self.sas_conn_id) assert isinstance(hook.get_conn(), BlobServiceClient)
def get_hook(self): if self.conn_type == 'mysql': from airflow.providers.mysql.hooks.mysql import MySqlHook return MySqlHook(mysql_conn_id=self.conn_id) elif self.conn_type == 'google_cloud_platform': from airflow.gcp.hooks.bigquery import BigQueryHook return BigQueryHook(bigquery_conn_id=self.conn_id) elif self.conn_type == 'postgres': from airflow.providers.postgres.hooks.postgres import PostgresHook return PostgresHook(postgres_conn_id=self.conn_id) elif self.conn_type == 'pig_cli': from airflow.providers.apache.pig.hooks.pig import PigCliHook return PigCliHook(pig_cli_conn_id=self.conn_id) elif self.conn_type == 'hive_cli': from airflow.providers.apache.hive.hooks.hive import HiveCliHook return HiveCliHook(hive_cli_conn_id=self.conn_id) elif self.conn_type == 'presto': from airflow.providers.presto.hooks.presto import PrestoHook return PrestoHook(presto_conn_id=self.conn_id) elif self.conn_type == 'hiveserver2': from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) elif self.conn_type == 'sqlite': from airflow.providers.sqlite.hooks.sqlite import SqliteHook return SqliteHook(sqlite_conn_id=self.conn_id) elif self.conn_type == 'jdbc': from airflow.providers.jdbc.hooks.jdbc import JdbcHook return JdbcHook(jdbc_conn_id=self.conn_id) elif self.conn_type == 'mssql': from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook return MsSqlHook(mssql_conn_id=self.conn_id) elif self.conn_type == 'odbc': from airflow.providers.odbc.hooks.odbc import OdbcHook return OdbcHook(odbc_conn_id=self.conn_id) elif self.conn_type == 'oracle': from airflow.providers.oracle.hooks.oracle import OracleHook return OracleHook(oracle_conn_id=self.conn_id) elif self.conn_type == 'vertica': from airflow.providers.vertica.hooks.vertica import VerticaHook return VerticaHook(vertica_conn_id=self.conn_id) elif self.conn_type == 'cloudant': from airflow.providers.cloudant.hooks.cloudant import CloudantHook return CloudantHook(cloudant_conn_id=self.conn_id) elif self.conn_type == 'jira': from airflow.providers.jira.hooks.jira import JiraHook return JiraHook(jira_conn_id=self.conn_id) elif self.conn_type == 'redis': from airflow.providers.redis.hooks.redis import RedisHook return RedisHook(redis_conn_id=self.conn_id) elif self.conn_type == 'wasb': from airflow.providers.microsoft.azure.hooks.wasb import WasbHook return WasbHook(wasb_conn_id=self.conn_id) elif self.conn_type == 'docker': from airflow.providers.docker.hooks.docker import DockerHook return DockerHook(docker_conn_id=self.conn_id) elif self.conn_type == 'azure_data_lake': from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id) elif self.conn_type == 'azure_cosmos': from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id) elif self.conn_type == 'cassandra': from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook return CassandraHook(cassandra_conn_id=self.conn_id) elif self.conn_type == 'mongo': from airflow.providers.mongo.hooks.mongo import MongoHook return MongoHook(conn_id=self.conn_id) elif self.conn_type == 'gcpcloudsql': from airflow.gcp.hooks.cloud_sql import CloudSQLDatabaseHook return CloudSQLDatabaseHook(gcp_cloudsql_conn_id=self.conn_id) elif self.conn_type == 'grpc': from airflow.providers.grpc.hooks.grpc import GrpcHook return GrpcHook(grpc_conn_id=self.conn_id) raise AirflowException("Unknown hook type {}".format(self.conn_type))
def test_check_for_prefix(self, mock_service): mock_instance = mock_service.return_value mock_instance.list_blobs.return_value = iter(['blob_1']) hook = WasbHook(wasb_conn_id='wasb_test_sas_token') self.assertTrue(hook.check_for_prefix('container', 'prefix', timeout=3)) mock_instance.list_blobs.assert_called_once_with('container', 'prefix', num_results=1, timeout=3)
def test_delete_container(self, mock_service): hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) hook.delete_container('mycontainer') mock_service.return_value.get_container_client.assert_called_once_with('mycontainer') mock_service.return_value.get_container_client.return_value.delete_container.assert_called()
def test_delete_single_blob(self, delete_blobs, mock_service): hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) hook.delete_file('container', 'blob', is_prefix=False) delete_blobs.assert_called_once_with('container', 'blob')
def test_check_for_prefix_empty(self, get_blobs_list): get_blobs_list.return_value = [] hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) assert not hook.check_for_prefix('container', 'prefix', timeout=3) get_blobs_list.assert_called_once_with(container_name='container', prefix='prefix', timeout=3)
def test_key(self): hook = WasbHook(wasb_conn_id='wasb_test_key') self.assertEqual(hook.conn_id, 'wasb_test_key') self.assertIsInstance(hook.connection, BlobServiceClient)
def test_load_file(self, mock_upload): with mock.patch("builtins.open", mock.mock_open(read_data="data")): hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) hook.load_file('path', 'container', 'blob', max_connections=1) mock_upload.assert_called()
def test_connection_string(self): hook = WasbHook(wasb_conn_id=self.connection_string_id) assert hook.conn_id == self.connection_string_id assert isinstance(hook.get_conn(), BlobServiceClient)
def test_load_string(self, mock_upload): hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) hook.load_string('big string', 'container', 'blob', max_connections=1) mock_upload.assert_called_once_with('container', 'blob', 'big string', max_connections=1)
def test_read_file(self, mock_service): mock_instance = mock_service.return_value hook = WasbHook(wasb_conn_id='wasb_test_sas_token') hook.read_file('container', 'blob', max_connections=1) mock_instance.get_blob_to_text.assert_called_once_with( 'container', 'blob', max_connections=1)
def test_get_file(self, mock_download): with mock.patch("builtins.open", mock.mock_open(read_data="data")): hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) hook.get_file('path', 'container', 'blob', max_connections=1) mock_download.assert_called_once_with(container_name='container', blob_name='blob', max_connections=1) mock_download.return_value.readall.assert_called()
def test_sas_token(self): from azure.storage.blob import BlockBlobService hook = WasbHook(wasb_conn_id='wasb_test_sas_token') self.assertEqual(hook.conn_id, 'wasb_test_sas_token') self.assertIsInstance(hook.connection, BlockBlobService)
def test_read_file(self, mock_download, mock_service): hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) hook.read_file('container', 'blob', max_connections=1) mock_download.assert_called_once_with('container', 'blob', max_connections=1)
def test_check_for_prefix_empty(self, mock_service): mock_instance = mock_service.return_value mock_instance.list_blobs.return_value = iter([]) hook = WasbHook(wasb_conn_id='wasb_test_sas_token') self.assertFalse(hook.check_for_prefix('container', 'prefix'))
def test_get_blob_client(self, mock_service): hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) hook._get_blob_client(container_name='mycontainer', blob_name='myblob') mock_instance = mock_service.return_value.get_container_client mock_instance.assert_called_once_with('mycontainer') mock_instance.return_value.get_blob_client.assert_called_once_with('myblob')
def poke(self, context: dict): self.log.info('Poking for blob: %s\nin wasb://%s', self.blob_name, self.container_name) hook = WasbHook(wasb_conn_id=self.wasb_conn_id) return hook.check_for_blob(self.container_name, self.blob_name, **self.check_options)
def test_create_container(self, mock_service): hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) hook.create_container(container_name='mycontainer') mock_instance = mock_service.return_value.get_container_client mock_instance.assert_called_once_with('mycontainer') mock_instance.return_value.create_container.assert_called()