Exemple #1
0
    def execute(self, context):
        gcs_hook = GoogleCloudStorageHook(gcp_conn_id=self.gcp_conn_id,
                                          delegate_to=self.delegate_to)

        sftp_hook = SFTPHook(self.sftp_conn_id)

        if WILDCARD in self.source_path:
            total_wildcards = self.source_path.count(WILDCARD)
            if total_wildcards > 1:
                raise AirflowException(
                    "Only one wildcard '*' is allowed in source_path parameter. "
                    "Found {} in {}.".format(total_wildcards,
                                             self.source_path))

            prefix, delimiter = self.source_path.split(WILDCARD, 1)
            base_path = os.path.dirname(prefix)

            files, _, _ = sftp_hook.get_tree_map(base_path,
                                                 prefix=prefix,
                                                 delimiter=delimiter)

            for file in files:
                destination_path = file.replace(base_path,
                                                self.destination_path, 1)
                self._copy_single_object(gcs_hook, sftp_hook, file,
                                         destination_path)

        else:
            destination_object = (self.destination_path
                                  if self.destination_path else
                                  self.source_path.rsplit("/", 1)[1])
            self._copy_single_object(gcs_hook, sftp_hook, self.source_path,
                                     destination_object)
Exemple #2
0
class SFTPSensor(BaseSensorOperator):
    """
    Waits for a file or directory to be present on SFTP.

    :param path: Remote file or directory path
    :type path: str
    :param sftp_conn_id: The connection to run the sensor against
    :type sftp_conn_id: str
    """
    template_fields = ('path', )

    @apply_defaults
    def __init__(self, path, sftp_conn_id='sftp_default', *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.path = path
        self.hook = None
        self.sftp_conn_id = sftp_conn_id

    def poke(self, context):
        self.hook = SFTPHook(self.sftp_conn_id)
        self.log.info('Poking for %s', self.path)
        try:
            self.hook.get_mod_time(self.path)
        except OSError as e:
            if e.errno != SFTP_NO_SUCH_FILE:
                raise e
            return False
        self.hook.close_conn()
        return True
Exemple #3
0
    def _copy_single_object(
        self,
        gcs_hook: GoogleCloudStorageHook,
        sftp_hook: SFTPHook,
        source_path: str,
        destination_object: str,
    ) -> None:
        """
        Helper function to copy single object.
        """
        self.log.info(
            "Executing copy of %s to gs://%s/%s",
            source_path,
            self.destination_bucket,
            destination_object,
        )

        with NamedTemporaryFile("w") as tmp:
            sftp_hook.retrieve_file(source_path, tmp.name)

            gcs_hook.upload(
                bucket_name=self.destination_bucket,
                object_name=destination_object,
                filename=tmp.name,
                mime_type=self.mime_type,
            )

        if self.move_object:
            self.log.info("Executing delete of %s", source_path)
            sftp_hook.delete_file(source_path)
Exemple #4
0
    def _copy_single_object(
        self,
        gcs_hook: GCSHook,
        sftp_hook: SFTPHook,
        source_object: str,
        destination_path: str,
    ) -> None:
        """
        Helper function to copy single object.
        """
        self.log.info(
            "Executing copy of gs://%s/%s to %s",
            self.source_bucket,
            source_object,
            destination_path,
        )

        dir_path = os.path.dirname(destination_path)
        sftp_hook.create_directory(dir_path)

        with NamedTemporaryFile("w") as tmp:
            gcs_hook.download(
                bucket_name=self.source_bucket,
                object_name=source_object,
                filename=tmp.name,
            )
            sftp_hook.store_file(destination_path, tmp.name)

        if self.move_object:
            self.log.info(
                "Executing delete of gs://%s/%s", self.source_bucket, source_object
            )
            gcs_hook.delete(self.source_bucket, source_object)
Exemple #5
0
    def setUp(self):
        self.old_login = self.update_connection(SFTP_CONNECTION_USER)
        self.hook = SFTPHook()
        os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))

        with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
            file.write('Test file')
        with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
            file.write('Test file')
Exemple #6
0
 def poke(self, context):
     self.hook = SFTPHook(self.sftp_conn_id)
     self.log.info('Poking for %s', self.path)
     try:
         self.hook.get_mod_time(self.path)
     except OSError as e:
         if e.errno != SFTP_NO_SUCH_FILE:
             raise e
         return False
     self.hook.close_conn()
     return True
    def execute(self, context):
        gcs_hook = GoogleCloudStorageHook(gcp_conn_id=self.gcp_conn_id,
                                          delegate_to=self.delegate_to)

        sftp_hook = SFTPHook(self.sftp_conn_id)

        if WILDCARD in self.source_object:
            total_wildcards = self.source_object.count(WILDCARD)
            if total_wildcards > 1:
                raise AirflowException(
                    "Only one wildcard '*' is allowed in source_object parameter. "
                    "Found {} in {}.".format(total_wildcards,
                                             self.source_object))

            prefix, delimiter = self.source_object.split(WILDCARD, 1)
            objects = gcs_hook.list(self.source_bucket,
                                    prefix=prefix,
                                    delimiter=delimiter)

            for source_object in objects:
                destination_path = os.path.join(self.destination_path,
                                                source_object)
                self._copy_single_object(gcs_hook, sftp_hook, source_object,
                                         destination_path)

            self.log.info("Done. Uploaded '%d' files to %s", len(objects),
                          self.destination_path)
        else:
            destination_path = os.path.join(self.destination_path,
                                            self.source_object)
            self._copy_single_object(gcs_hook, sftp_hook, self.source_object,
                                     destination_path)
            self.log.info("Done. Uploaded '%s' file to %s", self.source_object,
                          destination_path)
    def test_no_host_key_check_no_ignore(self, get_connection):
        connection = Connection(login='******',
                                host='host',
                                extra='{"ignore_hostkey_verification": false}')

        get_connection.return_value = connection
        hook = SFTPHook()
        self.assertEqual(hook.no_host_key_check, False)
    def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
        connection = Connection(login='******',
                                host='host',
                                extra='{"no_host_key_check": "foo"}')

        get_connection.return_value = connection
        hook = SFTPHook()
        self.assertEqual(hook.no_host_key_check, False)
    def test_no_host_key_check_enabled(self, get_connection):
        connection = Connection(login='******',
                                host='host',
                                extra='{"no_host_key_check": true}')

        get_connection.return_value = connection
        hook = SFTPHook()
        self.assertEqual(hook.no_host_key_check, True)
class TestSFTPHook(unittest.TestCase):
    @provide_session
    def update_connection(self, login, session=None):
        connection = (session.query(Connection).filter(
            Connection.conn_id == "sftp_default").first())
        old_login = connection.login
        connection.login = login
        session.commit()
        return old_login

    def setUp(self):
        self.old_login = self.update_connection(SFTP_CONNECTION_USER)
        self.hook = SFTPHook()
        os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))

        with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
            file.write('Test file')
        with open(
                os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR,
                             TMP_FILE_FOR_TESTS), 'a') as file:
            file.write('Test file')

    def test_get_conn(self):
        output = self.hook.get_conn()
        self.assertEqual(type(output), pysftp.Connection)

    def test_close_conn(self):
        self.hook.conn = self.hook.get_conn()
        self.assertTrue(self.hook.conn is not None)
        self.hook.close_conn()
        self.assertTrue(self.hook.conn is None)

    def test_describe_directory(self):
        output = self.hook.describe_directory(TMP_PATH)
        self.assertTrue(TMP_DIR_FOR_TESTS in output)

    def test_list_directory(self):
        output = self.hook.list_directory(
            path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
        self.assertEqual(output, [SUB_DIR])

    def test_create_and_delete_directory(self):
        new_dir_name = 'new_dir'
        self.hook.create_directory(
            os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
        output = self.hook.describe_directory(
            os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
        self.assertTrue(new_dir_name in output)
        self.hook.delete_directory(
            os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
        output = self.hook.describe_directory(
            os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
        self.assertTrue(new_dir_name not in output)

    def test_create_and_delete_directories(self):
        base_dir = "base_dir"
        sub_dir = "sub_dir"
        new_dir_path = os.path.join(base_dir, sub_dir)
        self.hook.create_directory(
            os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
        output = self.hook.describe_directory(
            os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
        self.assertTrue(base_dir in output)
        output = self.hook.describe_directory(
            os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
        self.assertTrue(sub_dir in output)
        self.hook.delete_directory(
            os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
        self.hook.delete_directory(
            os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
        output = self.hook.describe_directory(
            os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
        self.assertTrue(new_dir_path not in output)
        self.assertTrue(base_dir not in output)

    def test_store_retrieve_and_delete_file(self):
        self.hook.store_file(
            remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS,
                                          TMP_FILE_FOR_TESTS),
            local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
        output = self.hook.list_directory(
            path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
        self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS])
        retrieved_file_name = 'retrieved.txt'
        self.hook.retrieve_file(
            remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS,
                                          TMP_FILE_FOR_TESTS),
            local_full_path=os.path.join(TMP_PATH, retrieved_file_name))
        self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH))
        os.remove(os.path.join(TMP_PATH, retrieved_file_name))
        self.hook.delete_file(
            path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
        output = self.hook.list_directory(
            path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
        self.assertEqual(output, [SUB_DIR])

    def test_get_mod_time(self):
        self.hook.store_file(
            remote_full_path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS,
                                          TMP_FILE_FOR_TESTS),
            local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
        output = self.hook.get_mod_time(
            path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
        self.assertEqual(len(output), 14)

    @mock.patch(
        'airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
    def test_no_host_key_check_default(self, get_connection):
        connection = Connection(login='******', host='host')
        get_connection.return_value = connection
        hook = SFTPHook()
        self.assertEqual(hook.no_host_key_check, False)

    @mock.patch(
        'airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
    def test_no_host_key_check_enabled(self, get_connection):
        connection = Connection(login='******',
                                host='host',
                                extra='{"no_host_key_check": true}')

        get_connection.return_value = connection
        hook = SFTPHook()
        self.assertEqual(hook.no_host_key_check, True)

    @mock.patch(
        'airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
    def test_no_host_key_check_disabled(self, get_connection):
        connection = Connection(login='******',
                                host='host',
                                extra='{"no_host_key_check": false}')

        get_connection.return_value = connection
        hook = SFTPHook()
        self.assertEqual(hook.no_host_key_check, False)

    @mock.patch(
        'airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
    def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
        connection = Connection(login='******',
                                host='host',
                                extra='{"no_host_key_check": "foo"}')

        get_connection.return_value = connection
        hook = SFTPHook()
        self.assertEqual(hook.no_host_key_check, False)

    @mock.patch(
        'airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
    def test_no_host_key_check_ignore(self, get_connection):
        connection = Connection(login='******',
                                host='host',
                                extra='{"ignore_hostkey_verification": true}')

        get_connection.return_value = connection
        hook = SFTPHook()
        self.assertEqual(hook.no_host_key_check, True)

    @mock.patch(
        'airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
    def test_no_host_key_check_no_ignore(self, get_connection):
        connection = Connection(login='******',
                                host='host',
                                extra='{"ignore_hostkey_verification": false}')

        get_connection.return_value = connection
        hook = SFTPHook()
        self.assertEqual(hook.no_host_key_check, False)

    @parameterized.expand([
        (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
        (os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True),
        (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False),
        (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False),
    ])
    def test_path_exists(self, path, exists):
        result = self.hook.path_exists(path)
        self.assertEqual(result, exists)

    @parameterized.expand([
        ("test/path/file.bin", None, None, True),
        ("test/path/file.bin", "test", None, True),
        ("test/path/file.bin", "test/", None, True),
        ("test/path/file.bin", None, "bin", True),
        ("test/path/file.bin", "test", "bin", True),
        ("test/path/file.bin", "test/", "file.bin", True),
        ("test/path/file.bin", None, "file.bin", True),
        ("test/path/file.bin", "diff", None, False),
        ("test/path/file.bin", "test//", None, False),
        ("test/path/file.bin", None, ".txt", False),
        ("test/path/file.bin", "diff", ".txt", False),
    ])
    def test_path_match(self, path, prefix, delimiter, match):
        result = self.hook._is_path_match(path=path,
                                          prefix=prefix,
                                          delimiter=delimiter)
        self.assertEqual(result, match)

    def test_get_tree_map(self):
        tree_map = self.hook.get_tree_map(
            path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
        files, dirs, unknowns = tree_map

        self.assertEqual(files, [
            os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR,
                         TMP_FILE_FOR_TESTS)
        ])
        self.assertEqual(dirs,
                         [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)])
        self.assertEqual(unknowns, [])

    def tearDown(self):
        shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
        os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
        self.update_connection(self.old_login)
 def test_no_host_key_check_default(self, get_connection):
     connection = Connection(login='******', host='host')
     get_connection.return_value = connection
     hook = SFTPHook()
     self.assertEqual(hook.no_host_key_check, False)