예제 #1
0
    def test_read_key(self):
        hook = S3Hook(aws_conn_id=None)
        conn = hook.get_conn()
        # We need to create the bucket since this is all in Moto's 'virtual'
        # AWS account
        conn.create_bucket(Bucket='mybucket')
        conn.put_object(Bucket='mybucket',
                        Key='my_key',
                        Body=b'Cont\xC3\xA9nt')

        self.assertEqual(hook.read_key('my_key', 'mybucket'), 'Contént')
예제 #2
0
    def test_load_bytes(self):
        hook = S3Hook(aws_conn_id=None)
        conn = hook.get_conn()
        # We need to create the bucket since this is all in Moto's 'virtual'
        # AWS account
        conn.create_bucket(Bucket="mybucket")

        hook.load_bytes(b"Content", "my_key", "mybucket")
        body = boto3.resource('s3').Object('mybucket',
                                           'my_key').get()['Body'].read()

        self.assertEqual(body, b'Content')
예제 #3
0
    def test_list_keys(self):
        hook = S3Hook(aws_conn_id=None)
        bucket = hook.get_bucket('bucket')
        bucket.create()
        bucket.put_object(Key='a', Body=b'a')
        bucket.put_object(Key='dir/b', Body=b'b')

        self.assertIsNone(hook.list_keys('bucket', prefix='non-existent/'))
        self.assertListEqual(['a', 'dir/b'], hook.list_keys('bucket'))
        self.assertListEqual(['a'], hook.list_keys('bucket', delimiter='/'))
        self.assertListEqual(['dir/b'], hook.list_keys('bucket',
                                                       prefix='dir/'))
예제 #4
0
    def test_list_keys_paged(self):
        hook = S3Hook(aws_conn_id=None)
        bucket = hook.get_bucket('bucket')
        bucket.create()

        keys = [str(i) for i in range(2)]
        for key in keys:
            bucket.put_object(Key=key, Body=b'a')

        self.assertListEqual(
            sorted(keys),
            sorted(hook.list_keys('bucket', delimiter='/', page_size=1)))
예제 #5
0
    def execute(self, context):
        self.s3_key = self.get_s3_key(self.s3_key)
        ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
        s3_hook = S3Hook(self.s3_conn_id)

        sftp_client = ssh_hook.get_conn().open_sftp()

        with NamedTemporaryFile("w") as f:
            sftp_client.get(self.sftp_path, f.name)

            s3_hook.load_file(filename=f.name,
                              key=self.s3_key,
                              bucket_name=self.s3_bucket,
                              replace=True)
예제 #6
0
    def test_load_fileobj(self):
        hook = S3Hook(aws_conn_id=None)
        conn = hook.get_conn()
        # We need to create the bucket since this is all in Moto's 'virtual'
        # AWS account
        conn.create_bucket(Bucket="mybucket")
        with tempfile.TemporaryFile() as temp_file:
            temp_file.write(b"Content")
            temp_file.seek(0)
            hook.load_file_obj(temp_file, "my_key", "mybucket")
            body = boto3.resource('s3').Object('mybucket',
                                               'my_key').get()['Body'].read()

            self.assertEqual(body, b'Content')
예제 #7
0
    def test_list_prefixes_paged(self):
        hook = S3Hook(aws_conn_id=None)
        bucket = hook.get_bucket('bucket')
        bucket.create()

        # we dont need to test the paginator
        # that's covered by boto tests
        keys = ["%s/b" % i for i in range(2)]
        dirs = ["%s/" % i for i in range(2)]
        for key in keys:
            bucket.put_object(Key=key, Body=b'a')

        self.assertListEqual(
            sorted(dirs),
            sorted(hook.list_prefixes('bucket', delimiter='/', page_size=1)))
예제 #8
0
    def test_check_for_prefix(self):
        hook = S3Hook(aws_conn_id=None)
        bucket = hook.get_bucket('bucket')
        bucket.create()
        bucket.put_object(Key='a', Body=b'a')
        bucket.put_object(Key='dir/b', Body=b'b')

        self.assertTrue(
            hook.check_for_prefix(bucket_name='bucket',
                                  prefix='dir/',
                                  delimiter='/'))
        self.assertFalse(
            hook.check_for_prefix(bucket_name='bucket',
                                  prefix='a',
                                  delimiter='/'))
예제 #9
0
    def test_check_for_wildcard_key(self):
        hook = S3Hook(aws_conn_id=None)
        bucket = hook.get_bucket('bucket')
        bucket.create()
        bucket.put_object(Key='abc', Body=b'a')
        bucket.put_object(Key='a/b', Body=b'a')

        self.assertTrue(hook.check_for_wildcard_key('a*', 'bucket'))
        self.assertTrue(hook.check_for_wildcard_key('s3://bucket//a*'))
        self.assertTrue(hook.check_for_wildcard_key('abc', 'bucket'))
        self.assertTrue(hook.check_for_wildcard_key('s3://bucket//abc'))
        self.assertFalse(hook.check_for_wildcard_key('a', 'bucket'))
        self.assertFalse(hook.check_for_wildcard_key('s3://bucket//a'))
        self.assertFalse(hook.check_for_wildcard_key('b', 'bucket'))
        self.assertFalse(hook.check_for_wildcard_key('s3://bucket//b'))
예제 #10
0
    def setUp(self):
        hook = SSHHook(ssh_conn_id='ssh_default')
        s3_hook = S3Hook('aws_default')
        hook.no_host_key_check = True
        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE,
        }
        dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
        dag.schedule_interval = '@once'

        self.hook = hook
        self.s3_hook = s3_hook

        self.ssh_client = self.hook.get_conn()
        self.sftp_client = self.ssh_client.open_sftp()

        self.dag = dag
        self.s3_bucket = BUCKET
        self.sftp_path = SFTP_PATH
        self.s3_key = S3_KEY
예제 #11
0
    def check_s3_url(self, s3url):
        """
        Check if an S3 URL exists

        :param s3url: S3 url
        :type s3url: str
        :rtype: bool
        """
        bucket, key = S3Hook.parse_s3_url(s3url)
        if not self.s3_hook.check_for_bucket(bucket_name=bucket):
            raise AirflowException(
                "The input S3 Bucket {} does not exist ".format(bucket))
        if key and not self.s3_hook.check_for_key(key=key, bucket_name=bucket)\
           and not self.s3_hook.check_for_prefix(
                prefix=key, bucket_name=bucket, delimiter='/'):
            # check if s3 key exists in the case user provides a single file
            # or if s3 prefix exists in the case user provides multiple files in
            # a prefix
            raise AirflowException(
                "The input S3 Key "
                "or Prefix {} does not exist in the Bucket {}".format(
                    s3url, bucket))
        return True
예제 #12
0
    def execute(self, context):
        self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        credentials = self.s3.get_credentials()
        copy_options = '\n\t\t\t'.join(self.copy_options)

        copy_query = """
            COPY {schema}.{table}
            FROM 's3://{s3_bucket}/{s3_key}/{table}'
            with credentials
            'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
            {copy_options};
        """.format(schema=self.schema,
                   table=self.table,
                   s3_bucket=self.s3_bucket,
                   s3_key=self.s3_key,
                   access_key=credentials.access_key,
                   secret_key=credentials.secret_key,
                   copy_options=copy_options)

        self.log.info('Executing COPY command...')
        self.hook.run(copy_query, self.autocommit)
        self.log.info("COPY command complete...")
예제 #13
0
    def test_execute(self, mock_hook, mock_hook2):
        mock_hook.return_value.list.return_value = MOCK_FILES
        mock_hook.return_value.download.return_value = b"testing"
        mock_hook2.return_value.list.return_value = MOCK_FILES

        operator = GoogleCloudStorageToS3Operator(task_id=TASK_ID,
                                                  bucket=GCS_BUCKET,
                                                  prefix=PREFIX,
                                                  delimiter=DELIMITER,
                                                  dest_aws_conn_id=None,
                                                  dest_s3_key=S3_BUCKET,
                                                  replace=False)
        # create dest bucket without files
        hook = S3Hook(aws_conn_id=None)
        b = hook.get_bucket('bucket')
        b.create()

        # we expect all MOCK_FILES to be uploaded
        # and all MOCK_FILES to be present at the S3 bucket
        uploaded_files = operator.execute(None)
        self.assertEqual(sorted(MOCK_FILES),
                         sorted(uploaded_files))
        self.assertEqual(sorted(MOCK_FILES),
                         sorted(hook.list_keys('bucket', delimiter='/')))
예제 #14
0
    def execute(self, context):
        """
        This function executes the transfer from the email server (via imap) into s3.

        :param context: The context while executing.
        :type context: dict
        """
        self.log.info(
            'Transferring mail attachment %s from mail server via imap to s3 key %s...',
            self.imap_attachment_name, self.s3_key)

        with ImapHook(imap_conn_id=self.imap_conn_id) as imap_hook:
            imap_mail_attachments = imap_hook.retrieve_mail_attachments(
                name=self.imap_attachment_name,
                check_regex=self.imap_check_regex,
                latest_only=True,
                mail_folder=self.imap_mail_folder,
                mail_filter=self.imap_mail_filter,
            )

        s3_hook = S3Hook(aws_conn_id=self.s3_conn_id)
        s3_hook.load_bytes(bytes_data=imap_mail_attachments[0][1],
                           key=self.s3_key,
                           replace=self.s3_overwrite)
예제 #15
0
 def _load_data_to_s3(self, data):
     s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
     s3_hook.load_string(string_data=json.dumps(data),
                         key=self.s3_destination_key,
                         replace=self.s3_overwrite)
예제 #16
0
    def execute(self, context):
        # Downloading file from S3
        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Downloading S3 file")

        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}".format(
                    self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException("The key {0} does not exists".format(
                    self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)

        _, file_ext = os.path.splitext(s3_key_object.key)
        if (self.select_expression and self.input_compressed
                and file_ext.lower() != '.gz'):
            raise AirflowException("GZIP is the only compression " +
                                   "format Amazon S3 Select supports")

        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
                NamedTemporaryFile(mode="wb",
                                   dir=tmp_dir,
                                   suffix=file_ext) as f:
            self.log.info("Dumping S3 key %s contents to local file %s",
                          s3_key_object.key, f.name)
            if self.select_expression:
                option = {}
                if self.headers:
                    option['FileHeaderInfo'] = 'USE'
                if self.delimiter:
                    option['FieldDelimiter'] = self.delimiter

                input_serialization = {'CSV': option}
                if self.input_compressed:
                    input_serialization['CompressionType'] = 'GZIP'

                content = self.s3.select_key(
                    bucket_name=s3_key_object.bucket_name,
                    key=s3_key_object.key,
                    expression=self.select_expression,
                    input_serialization=input_serialization)
                f.write(content.encode("utf-8"))
            else:
                s3_key_object.download_fileobj(f)
            f.flush()

            if self.select_expression or not self.headers:
                self.log.info("Loading file %s into Hive", f.name)
                self.hive.load_file(f.name,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)
            else:
                # Decompressing file
                if self.input_compressed:
                    self.log.info("Uncompressing file %s", f.name)
                    fn_uncompressed = uncompress_file(f.name, file_ext,
                                                      tmp_dir)
                    self.log.info("Uncompressed to %s", fn_uncompressed)
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    self.log.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                self.log.info("Removing header from file %s", fn_uncompressed)
                headless_file = (self._delete_top_row_and_compress(
                    fn_uncompressed, file_ext, tmp_dir))
                self.log.info("Headless file %s", headless_file)
                self.log.info("Loading file %s into Hive", headless_file)
                self.hive.load_file(headless_file,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)
예제 #17
0
 def test_create_bucket_default_region(self):
     hook = S3Hook(aws_conn_id=None)
     hook.create_bucket(bucket_name='new_bucket')
     bucket = hook.get_bucket('new_bucket')
     self.assertIsNotNone(bucket)
예제 #18
0
 def test_get_bucket(self):
     hook = S3Hook(aws_conn_id=None)
     bucket = hook.get_bucket('bucket')
     self.assertIsNotNone(bucket)
예제 #19
0
    def test_check_for_bucket_raises_error_with_invalid_conn_id(self):
        hook = S3Hook(aws_conn_id="does_not_exist")

        with self.assertRaises(NoCredentialsError):
            hook.check_for_bucket('bucket')
예제 #20
0
 def test_parse_s3_url(self):
     parsed = S3Hook.parse_s3_url(self.s3_test_url)
     self.assertEqual(parsed, ("test", "this/is/not/a-real-key.txt"),
                      "Incorrect parsing of the s3 url")
예제 #21
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
     self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id)
예제 #22
0
    def execute(self, context):
        # use the super method to list all the files in an S3 bucket/key
        files = super().execute(context)

        gcs_hook = GoogleCloudStorageHook(
            google_cloud_storage_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to)

        if not self.replace:
            # if we are not replacing -> list all files in the GCS bucket
            # and only keep those files which are present in
            # S3 and not in Google Cloud Storage
            bucket_name, object_prefix = _parse_gcs_url(self.dest_gcs)
            existing_files_prefixed = gcs_hook.list(bucket_name,
                                                    prefix=object_prefix)

            existing_files = []

            if existing_files_prefixed:
                # Remove the object prefix itself, an empty directory was found
                if object_prefix in existing_files_prefixed:
                    existing_files_prefixed.remove(object_prefix)

                # Remove the object prefix from all object string paths
                for f in existing_files_prefixed:
                    if f.startswith(object_prefix):
                        existing_files.append(f[len(object_prefix):])
                    else:
                        existing_files.append(f)

            files = list(set(files) - set(existing_files))
            if len(files) > 0:
                self.log.info('%s files are going to be synced: %s.',
                              len(files), files)
            else:
                self.log.info(
                    'There are no new files to sync. Have a nice day!')

        if files:
            hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)

            for file in files:
                # GCS hook builds its own in-memory file so we have to create
                # and pass the path
                file_object = hook.get_key(file, self.bucket)
                with NamedTemporaryFile(mode='wb', delete=True) as f:
                    file_object.download_fileobj(f)
                    f.flush()

                    dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url(
                        self.dest_gcs)
                    # There will always be a '/' before file because it is
                    # enforced at instantiation time
                    dest_gcs_object = dest_gcs_object_prefix + file

                    # Sync is sequential and the hook already logs too much
                    # so skip this for now
                    # self.log.info(
                    #     'Saving file {0} from S3 bucket {1} in GCS bucket {2}'
                    #     ' as object {3}'.format(file, self.bucket,
                    #                             dest_gcs_bucket,
                    #                             dest_gcs_object))

                    gcs_hook.upload(dest_gcs_bucket,
                                    dest_gcs_object,
                                    f.name,
                                    gzip=self.gzip)

            self.log.info(
                "All done, uploaded %d files to Google Cloud Storage",
                len(files))
        else:
            self.log.info(
                'In sync, no files needed to be uploaded to Google Cloud'
                'Storage')

        return files
예제 #23
0
 def execute(self, context):
     s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
     s3_hook.copy_object(self.source_bucket_key, self.dest_bucket_key,
                         self.source_bucket_name, self.dest_bucket_name,
                         self.source_version_id)
예제 #24
0
 def test_select_key(self, mock_get_client_type):
     mock_get_client_type.return_value.select_object_content.return_value = \
         {'Payload': [{'Records': {'Payload': b'Cont\xC3\xA9nt'}}]}
     hook = S3Hook(aws_conn_id=None)
     self.assertEqual(hook.select_key('my_key', 'mybucket'), 'Contént')