def execute(self, context): gcs_hook = GoogleCloudStorageHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to) sftp_hook = SFTPHook(self.sftp_conn_id) if WILDCARD in self.source_path: total_wildcards = self.source_path.count(WILDCARD) if total_wildcards > 1: raise AirflowException( "Only one wildcard '*' is allowed in source_path parameter. " "Found {} in {}.".format(total_wildcards, self.source_path)) prefix, delimiter = self.source_path.split(WILDCARD, 1) base_path = os.path.dirname(prefix) files, _, _ = sftp_hook.get_tree_map(base_path, prefix=prefix, delimiter=delimiter) for file in files: destination_path = file.replace(base_path, self.destination_path, 1) self._copy_single_object(gcs_hook, sftp_hook, file, destination_path) else: destination_object = (self.destination_path if self.destination_path else self.source_path.rsplit("/", 1)[1]) self._copy_single_object(gcs_hook, sftp_hook, self.source_path, destination_object)
class SFTPSensor(BaseSensorOperator): """ Waits for a file or directory to be present on SFTP. :param path: Remote file or directory path :type path: str :param sftp_conn_id: The connection to run the sensor against :type sftp_conn_id: str """ template_fields = ('path', ) @apply_defaults def __init__(self, path, sftp_conn_id='sftp_default', *args, **kwargs): super().__init__(*args, **kwargs) self.path = path self.hook = None self.sftp_conn_id = sftp_conn_id def poke(self, context): self.hook = SFTPHook(self.sftp_conn_id) self.log.info('Poking for %s', self.path) try: self.hook.get_mod_time(self.path) except OSError as e: if e.errno != SFTP_NO_SUCH_FILE: raise e return False self.hook.close_conn() return True
def _copy_single_object( self, gcs_hook: GoogleCloudStorageHook, sftp_hook: SFTPHook, source_path: str, destination_object: str, ) -> None: """ Helper function to copy single object. """ self.log.info( "Executing copy of %s to gs://%s/%s", source_path, self.destination_bucket, destination_object, ) with NamedTemporaryFile("w") as tmp: sftp_hook.retrieve_file(source_path, tmp.name) gcs_hook.upload( bucket_name=self.destination_bucket, object_name=destination_object, filename=tmp.name, mime_type=self.mime_type, ) if self.move_object: self.log.info("Executing delete of %s", source_path) sftp_hook.delete_file(source_path)
def _copy_single_object( self, gcs_hook: GCSHook, sftp_hook: SFTPHook, source_object: str, destination_path: str, ) -> None: """ Helper function to copy single object. """ self.log.info( "Executing copy of gs://%s/%s to %s", self.source_bucket, source_object, destination_path, ) dir_path = os.path.dirname(destination_path) sftp_hook.create_directory(dir_path) with NamedTemporaryFile("w") as tmp: gcs_hook.download( bucket_name=self.source_bucket, object_name=source_object, filename=tmp.name, ) sftp_hook.store_file(destination_path, tmp.name) if self.move_object: self.log.info( "Executing delete of gs://%s/%s", self.source_bucket, source_object ) gcs_hook.delete(self.source_bucket, source_object)
def setUp(self): self.old_login = self.update_connection(SFTP_CONNECTION_USER) self.hook = SFTPHook() os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)) with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file: file.write('Test file') with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file: file.write('Test file')
def poke(self, context): self.hook = SFTPHook(self.sftp_conn_id) self.log.info('Poking for %s', self.path) try: self.hook.get_mod_time(self.path) except OSError as e: if e.errno != SFTP_NO_SUCH_FILE: raise e return False self.hook.close_conn() return True
def execute(self, context): gcs_hook = GoogleCloudStorageHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to) sftp_hook = SFTPHook(self.sftp_conn_id) if WILDCARD in self.source_object: total_wildcards = self.source_object.count(WILDCARD) if total_wildcards > 1: raise AirflowException( "Only one wildcard '*' is allowed in source_object parameter. " "Found {} in {}.".format(total_wildcards, self.source_object)) prefix, delimiter = self.source_object.split(WILDCARD, 1) objects = gcs_hook.list(self.source_bucket, prefix=prefix, delimiter=delimiter) for source_object in objects: destination_path = os.path.join(self.destination_path, source_object) self._copy_single_object(gcs_hook, sftp_hook, source_object, destination_path) self.log.info("Done. Uploaded '%d' files to %s", len(objects), self.destination_path) else: destination_path = os.path.join(self.destination_path, self.source_object) self._copy_single_object(gcs_hook, sftp_hook, self.source_object, destination_path) self.log.info("Done. Uploaded '%s' file to %s", self.source_object, destination_path)
def test_no_host_key_check_no_ignore(self, get_connection): connection = Connection(login='******', host='host', extra='{"ignore_hostkey_verification": false}') get_connection.return_value = connection hook = SFTPHook() self.assertEqual(hook.no_host_key_check, False)
def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): connection = Connection(login='******', host='host', extra='{"no_host_key_check": "foo"}') get_connection.return_value = connection hook = SFTPHook() self.assertEqual(hook.no_host_key_check, False)
def test_no_host_key_check_enabled(self, get_connection): connection = Connection(login='******', host='host', extra='{"no_host_key_check": true}') get_connection.return_value = connection hook = SFTPHook() self.assertEqual(hook.no_host_key_check, True)
class TestSFTPHook(unittest.TestCase): @provide_session def update_connection(self, login, session=None): connection = (session.query(Connection).filter( Connection.conn_id == "sftp_default").first()) old_login = connection.login connection.login = login session.commit() return old_login def setUp(self): self.old_login = self.update_connection(SFTP_CONNECTION_USER) self.hook = SFTPHook() os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)) with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file: file.write('Test file') with open( os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file: file.write('Test file') def test_get_conn(self): output = self.hook.get_conn() self.assertEqual(type(output), pysftp.Connection) def test_close_conn(self): self.hook.conn = self.hook.get_conn() self.assertTrue(self.hook.conn is not None) self.hook.close_conn() self.assertTrue(self.hook.conn is None) def test_describe_directory(self): output = self.hook.describe_directory(TMP_PATH) self.assertTrue(TMP_DIR_FOR_TESTS in output) def test_list_directory(self): output = self.hook.list_directory( path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertEqual(output, [SUB_DIR]) def test_create_and_delete_directory(self): new_dir_name = 'new_dir' self.hook.create_directory( os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name)) output = self.hook.describe_directory( os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertTrue(new_dir_name in output) self.hook.delete_directory( os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name)) output = self.hook.describe_directory( os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertTrue(new_dir_name not in output) def test_create_and_delete_directories(self): base_dir = "base_dir" sub_dir = "sub_dir" new_dir_path = os.path.join(base_dir, sub_dir) self.hook.create_directory( os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path)) output = self.hook.describe_directory( os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertTrue(base_dir in output) output = self.hook.describe_directory( os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir)) self.assertTrue(sub_dir in output) self.hook.delete_directory( os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path)) self.hook.delete_directory( os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir)) output = self.hook.describe_directory( os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertTrue(new_dir_path not in output) self.assertTrue(base_dir not in output) def test_store_retrieve_and_delete_file(self): self.hook.store_file( remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)) output = self.hook.list_directory( path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS]) retrieved_file_name = 'retrieved.txt' self.hook.retrieve_file( remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), local_full_path=os.path.join(TMP_PATH, retrieved_file_name)) self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH)) os.remove(os.path.join(TMP_PATH, retrieved_file_name)) self.hook.delete_file( path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) output = self.hook.list_directory( path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) self.assertEqual(output, [SUB_DIR]) def test_get_mod_time(self): self.hook.store_file( remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)) output = self.hook.get_mod_time( path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) self.assertEqual(len(output), 14) @mock.patch( 'airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection') def test_no_host_key_check_default(self, get_connection): connection = Connection(login='******', host='host') get_connection.return_value = connection hook = SFTPHook() self.assertEqual(hook.no_host_key_check, False) @mock.patch( 'airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection') def test_no_host_key_check_enabled(self, get_connection): connection = Connection(login='******', host='host', extra='{"no_host_key_check": true}') get_connection.return_value = connection hook = SFTPHook() self.assertEqual(hook.no_host_key_check, True) @mock.patch( 'airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection') def test_no_host_key_check_disabled(self, get_connection): connection = Connection(login='******', host='host', extra='{"no_host_key_check": false}') get_connection.return_value = connection hook = SFTPHook() self.assertEqual(hook.no_host_key_check, False) @mock.patch( 'airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection') def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): connection = Connection(login='******', host='host', extra='{"no_host_key_check": "foo"}') get_connection.return_value = connection hook = SFTPHook() self.assertEqual(hook.no_host_key_check, False) @mock.patch( 'airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection') def test_no_host_key_check_ignore(self, get_connection): connection = Connection(login='******', host='host', extra='{"ignore_hostkey_verification": true}') get_connection.return_value = connection hook = SFTPHook() self.assertEqual(hook.no_host_key_check, True) @mock.patch( 'airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection') def test_no_host_key_check_no_ignore(self, get_connection): connection = Connection(login='******', host='host', extra='{"ignore_hostkey_verification": false}') get_connection.return_value = connection hook = SFTPHook() self.assertEqual(hook.no_host_key_check, False) @parameterized.expand([ (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True), (os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True), (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False), (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False), ]) def test_path_exists(self, path, exists): result = self.hook.path_exists(path) self.assertEqual(result, exists) @parameterized.expand([ ("test/path/file.bin", None, None, True), ("test/path/file.bin", "test", None, True), ("test/path/file.bin", "test/", None, True), ("test/path/file.bin", None, "bin", True), ("test/path/file.bin", "test", "bin", True), ("test/path/file.bin", "test/", "file.bin", True), ("test/path/file.bin", None, "file.bin", True), ("test/path/file.bin", "diff", None, False), ("test/path/file.bin", "test//", None, False), ("test/path/file.bin", None, ".txt", False), ("test/path/file.bin", "diff", ".txt", False), ]) def test_path_match(self, path, prefix, delimiter, match): result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter) self.assertEqual(result, match) def test_get_tree_map(self): tree_map = self.hook.get_tree_map( path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) files, dirs, unknowns = tree_map self.assertEqual(files, [ os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS) ]) self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)]) self.assertEqual(unknowns, []) def tearDown(self): shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)) self.update_connection(self.old_login)
def test_no_host_key_check_default(self, get_connection): connection = Connection(login='******', host='host') get_connection.return_value = connection hook = SFTPHook() self.assertEqual(hook.no_host_key_check, False)