예제 #1
0
    def gcs_to_psql_import(**kwargs):
        fd, tmp_filename = tempfile.mkstemp(text=True)

        # download file locally
        gcs_hook = GoogleCloudStorageHook(
            google_cloud_storage_conn_id=kwargs['gcp_conn_id'])
        gcs_hook.download(bucket=kwargs['bucket'],
                          object=kwargs['object'],
                          filename=tmp_filename)
        del gcs_hook

        # load the file into postgres
        pg_hook = PostgresHook(postgres_conn_id=kwargs['postgres_conn_id'],
                               schema=kwargs['database'])
        pg_hook.bulk_load(
            '{schema}.{table}'.format(schema=kwargs['schema'],
                                      table=kwargs['table']), tmp_filename)

        # output errors
        for output in pg_hook.conn.notices:
            print(output)

        # remove temp file
        os.close(fd)
        os.unlink(tmp_filename)
예제 #2
0
 def execute(self, context):
     gcs_hook = GoogleCloudStorageHook(
         google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
         delegate_to=self.delegate_to
     )
     s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify)
     if gcs_hook.exists(self.gcs_source_bucket, self.gcs_source_uri) is False:
         self.log.error('Skip object not found: gs://%s/%s', self.gcs_source_bucket, self.gcs_source_uri)
         raise AirflowException('Skip object not found: gs://%s/%s', self.gcs_source_bucket, self.gcs_source_uri)
     tmp = tempfile.NamedTemporaryFile()
     self.log.info('Download gs://%s/%s', self.gcs_source_bucket, self.gcs_source_uri)
     gcs_hook.download(
         bucket=self.gcs_source_bucket,
         object=self.gcs_source_uri,
         filename=tmp.name,
     )
     self.log.info('Upload s3://%s/%s', self.s3_destination_bucket, self.s3_destination_uri)
     s3_hook.load_file(
             filename=tmp.name,
         bucket_name=self.s3_destination_bucket,
         key=self.s3_destination_uri,
         replace=True,
         acl_policy=self.s3_acl_policy
     )
     tmp.close()
예제 #3
0
def _get_data_from_gcs(gcp_conn_id, bucket, input):
    hook = GoogleCloudStorageHook(google_cloud_storage_conn_id=gcp_conn_id)
    tmp_file = NamedTemporaryFile(delete=False)
    hook.download(bucket, input, tmp_file.name)
    filename = tmp_file.name

    return filename
예제 #4
0
def compression(**kwargs):
    ti = kwargs['ti']
    fileName = ti.xcom_pull(task_ids="censor")
    gcs = GoogleCloudStorageHook()
    gcs.download("workflowstorage", fileName, fileName)
    file = {"to_compress" : open(fileName, 'rb')}
    response = requests.post("https://us-central1-devops-218113.cloudfunctions.net/Compression", files=file)
    newFileName = str(uuid.uuid4())
    with open(newFileName, "wb") as outfile:
        outfile.write(response.content)
    gcs.upload("workflowstorage", newFileName, newFileName, mime_type='application/octet-stream')
    os.remove(newFileName)
    return newFileName
예제 #5
0
def _print_stats(ds, **context):
    gcloud_storage_hook = GoogleCloudStorageHook()
    tmp_file_handle = NamedTemporaryFile(delete=True)
    gcloud_storage_hook.download(bucket="nice_bucket",
                                 object=f"rocket_launches/ds={ds}",
                                 filename=tmp_file_handle.name)
    data = json.load(tmp_file_handle)
    rockets_launched = [launch["name"] for launch in data["launches"]]
    rockets_str = ""
    if rockets_launched:
        rockets_str = f" ({' & '.join(rockets_launched)})"
        print(
            f"{len(rockets_launched)} rocket launch(es) on {ds}{rockets_str}.")
예제 #6
0
    def execute(self, context):
        ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
        gcs_hook = GoogleCloudStorageHook(self.google_cloud_storage_conn_id)

        sftp_client = ssh_hook.get_conn().open_sftp()

        with NamedTemporaryFile("w") as f:
            filename = f.name
            gcs_hook.download(bucket=self.gcs_bucket,
                              object=self.gcs_dest,
                              filename=filename)
            file_msg = "from {0} to {1}".format(filename, self.sftp_dest_path)
            self.log.info("Starting to transfer file %s", file_msg)
            sftp_client.put(filename, self.sftp_dest_path, confirm=True)
예제 #7
0
def censor(**kwargs):
    ti = kwargs['ti']
    indexes = ti.xcom_pull(task_ids="profanity")
    fileName = ti.xcom_pull(task_ids="conversion")
    gcs = GoogleCloudStorageHook()
    gcs.download("workflowstorage", fileName, fileName)
    message = {"to_censor" : open(fileName, 'rb'), "indexes" : json.dumps(indexes)}
    response = requests.post("https://us-central1-devops-218113.cloudfunctions.net/Censor", files=message)
    newFileName = str(uuid.uuid4())
    with open(newFileName, "wb") as outfile:
        outfile.write(response.content)
    gcs.upload("workflowstorage", newFileName, newFileName, mime_type='application/octet-stream')
    os.remove(newFileName)
    return newFileName
예제 #8
0
 def loadKeyword(self, gcs_filepath):
     """
         Read a file from S3 and get list of keywords
     """
     gcshook = GoogleCloudStorageHook(self.gcp_conn_id)
     file_path = "tmp/" + gcs_filepath
     gcshook.download(self.gcs_bucket,
                      object=gcs_filepath,
                      filename=file_path)
     keyword = []
     with open(file_path, "r") as file:
         for line in file:
             word = line.split(",")
             keyword.append(word[-1].strip())
     return keyword[1:]
    def execute(self, context):
        gcs_hook = GoogleCloudStorageHook(google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
                                          delegate_to=self.delegate_to)
        bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
                               delegate_to=self.delegate_to)

        schema_fields = self.schema_fields if self.schema_fields else json.loads(gcs_hook.download(self.bucket, self.schema_object))
        source_uris = map(lambda schema_object: 'gs://{}/{}'.format(self.bucket, schema_object), self.source_objects)
        conn = bq_hook.get_conn()
        cursor = conn.cursor()
        cursor.run_load(
            destination_project_dataset_table=self.destination_project_dataset_table,
            schema_fields=schema_fields,
            source_uris=source_uris,
            source_format=self.source_format,
            create_disposition=self.create_disposition,
            skip_leading_rows=self.skip_leading_rows,
            write_disposition=self.write_disposition,
            field_delimiter=self.field_delimiter)

        if self.max_id_key:
            cursor.execute('SELECT MAX({}) FROM {}'.format(self.max_id_key, self.destination_project_dataset_table))
            row = cursor.fetchone()
            max_id = row[0] if row[0] else 0
            logging.info('Loaded BQ data with max {}.{}={}'.format(self.destination_project_dataset_table, self.max_id_key, max_id))
            return max_id
예제 #10
0
    def execute(self, context):
        bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
                               delegate_to=self.delegate_to)

        if not self.schema_fields and self.gcs_schema_object:

            gcs_bucket, gcs_object = _parse_gcs_url(self.gcs_schema_object)

            gcs_hook = GoogleCloudStorageHook(
                google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
                delegate_to=self.delegate_to)
            schema_fields = json.loads(gcs_hook.download(
                gcs_bucket,
                gcs_object).decode("utf-8"))
        else:
            schema_fields = self.schema_fields

        conn = bq_hook.get_conn()
        cursor = conn.cursor()

        cursor.create_empty_table(
            project_id=self.project_id,
            dataset_id=self.dataset_id,
            table_id=self.table_id,
            schema_fields=schema_fields,
            time_partitioning=self.time_partitioning,
            labels=self.labels
        )
예제 #11
0
    def execute(self, context):
        bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
                               delegate_to=self.delegate_to)

        if not self.schema_fields and self.schema_object \
                and self.source_format != 'DATASTORE_BACKUP':
            gcs_hook = GoogleCloudStorageHook(
                google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
                delegate_to=self.delegate_to)
            schema_fields = json.loads(gcs_hook.download(
                self.bucket,
                self.schema_object).decode("utf-8"))
        else:
            schema_fields = self.schema_fields

        source_uris = ['gs://{}/{}'.format(self.bucket, source_object)
                       for source_object in self.source_objects]
        conn = bq_hook.get_conn()
        cursor = conn.cursor()

        cursor.create_external_table(
            external_project_dataset_table=self.destination_project_dataset_table,
            schema_fields=schema_fields,
            source_uris=source_uris,
            source_format=self.source_format,
            compression=self.compression,
            skip_leading_rows=self.skip_leading_rows,
            field_delimiter=self.field_delimiter,
            max_bad_records=self.max_bad_records,
            quote_character=self.quote_character,
            allow_quoted_newlines=self.allow_quoted_newlines,
            allow_jagged_rows=self.allow_jagged_rows,
            src_fmt_configs=self.src_fmt_configs,
            labels=self.labels
        )
예제 #12
0
파일: gcs_to_bq.py 프로젝트: yunhen/airflow
    def execute(self, context):
        gcs_hook = GoogleCloudStorageHook(
            google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
            delegate_to=self.delegate_to)
        bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
                               delegate_to=self.delegate_to)

        schema_fields = self.schema_fields if self.schema_fields else json.loads(
            gcs_hook.download(self.bucket, self.schema_object))
        source_uris = map(
            lambda schema_object: 'gs://{}/{}'.format(
                self.bucket, schema_object), self.source_objects)
        conn = bq_hook.get_conn()
        cursor = conn.cursor()
        cursor.run_load(
            destination_dataset_table=self.destination_dataset_table,
            schema_fields=schema_fields,
            source_uris=source_uris,
            source_format=self.source_format,
            create_disposition=self.create_disposition,
            skip_leading_rows=self.skip_leading_rows,
            write_disposition=self.write_disposition,
            field_delimiter=self.field_delimiter)

        if self.max_id_key:
            cursor.execute('SELECT MAX({}) FROM {}'.format(
                self.max_id_key, self.destination_dataset_table))
            row = cursor.fetchone()
            max_id = row[0] if row[0] else 0
            logging.info('Loaded BQ data with max {}.{}={}'.format(
                self.destination_dataset_table, self.max_id_key, max_id))
            return max_id
예제 #13
0
    def execute(self, context):
        bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
                               delegate_to=self.delegate_to)

        if not self.schema_fields and self.gcs_schema_object:

            gcs_bucket, gcs_object = _parse_gcs_url(self.gcs_schema_object)

            gcs_hook = GoogleCloudStorageHook(
                google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
                delegate_to=self.delegate_to)
            schema_fields = json.loads(gcs_hook.download(
                gcs_bucket,
                gcs_object).decode("utf-8"))
        else:
            schema_fields = self.schema_fields

        conn = bq_hook.get_conn()
        cursor = conn.cursor()

        cursor.create_empty_table(
            project_id=self.project_id,
            dataset_id=self.dataset_id,
            table_id=self.table_id,
            schema_fields=schema_fields,
            time_partitioning=self.time_partitioning
        )
예제 #14
0
    def execute(self, context):
        bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
                               delegate_to=self.delegate_to)

        if not self.schema_fields and self.schema_object \
                and self.source_format != 'DATASTORE_BACKUP':
            gcs_hook = GoogleCloudStorageHook(
                google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
                delegate_to=self.delegate_to)
            schema_fields = json.loads(gcs_hook.download(
                self.bucket,
                self.schema_object).decode("utf-8"))
        else:
            schema_fields = self.schema_fields

        source_uris = ['gs://{}/{}'.format(self.bucket, source_object)
                       for source_object in self.source_objects]
        conn = bq_hook.get_conn()
        cursor = conn.cursor()

        cursor.create_external_table(
            external_project_dataset_table=self.destination_project_dataset_table,
            schema_fields=schema_fields,
            source_uris=source_uris,
            source_format=self.source_format,
            compression=self.compression,
            skip_leading_rows=self.skip_leading_rows,
            field_delimiter=self.field_delimiter,
            max_bad_records=self.max_bad_records,
            quote_character=self.quote_character,
            allow_quoted_newlines=self.allow_quoted_newlines,
            allow_jagged_rows=self.allow_jagged_rows,
            src_fmt_configs=self.src_fmt_configs
        )
예제 #15
0
    def schema(self):
        hook = GoogleCloudStorageHook()
        objs = hook.download(
            self.config['bucket_name'],
            '{}/{}.json'.format(self.config['schemas_clean_path'], self.table))

        return json.loads(objs)
예제 #16
0
    def execute(self, context):
        self.log.info('Executing download: %s, %s, %s', self.bucket,
                      self.prefix, self.bucket)
        hook = GoogleCloudStorageHook(
            google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
            delegate_to=self.delegate_to)

        downloaded_files = []
        for object in hook.list(bucket=self.bucket, prefix=self.prefix):
            self.log.info('Downloading object: %s', object)
            filename = os.path.join(self.directory, object.replace('/', '_'))
            hook.download(bucket=self.bucket, object=object, filename=filename)
            downloaded_files.append(filename)

        task_instance = context['task_instance']
        task_instance.xcom_push('downloaded_files', downloaded_files)
예제 #17
0
    def execute(self, context):
        # use the super to list all files in an Google Cloud Storage bucket
        files = super(GoogleCloudStorageToS3Operator, self).execute(context)
        s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify)

        if not self.replace:
            # if we are not replacing -> list all files in the S3 bucket
            # and only keep those files which are present in
            # Google Cloud Storage and not in S3
            bucket_name, _ = S3Hook.parse_s3_url(self.dest_s3_key)
            existing_files = s3_hook.list_keys(bucket_name)
            files = set(files) - set(existing_files)

        if files:
            hook = GoogleCloudStorageHook(
                google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
                delegate_to=self.delegate_to
            )

            for file in files:
                file_bytes = hook.download(self.bucket, file)

                dest_key = self.dest_s3_key + file
                self.log.info("Saving file to %s", dest_key)

                s3_hook.load_bytes(file_bytes,
                                   key=dest_key,
                                   replace=self.replace)

            self.log.info("All done, uploaded %d files to S3", len(files))
        else:
            self.log.info("In sync, no files needed to be uploaded to S3")

        return files
예제 #18
0
    def execute(self, context):
        # use the super to list all files in an Google Cloud Storage bucket
        files = super().execute(context)
        s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id,
                         verify=self.dest_verify)

        if not self.replace:
            # if we are not replacing -> list all files in the S3 bucket
            # and only keep those files which are present in
            # Google Cloud Storage and not in S3
            bucket_name, _ = S3Hook.parse_s3_url(self.dest_s3_key)
            existing_files = s3_hook.list_keys(bucket_name)
            files = list(set(files) - set(existing_files))

        if files:
            hook = GoogleCloudStorageHook(
                google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
                delegate_to=self.delegate_to)

            for file in files:
                file_bytes = hook.download(self.bucket, file)

                dest_key = self.dest_s3_key + file
                self.log.info("Saving file to %s", dest_key)

                s3_hook.load_bytes(file_bytes,
                                   key=dest_key,
                                   replace=self.replace)

            self.log.info("All done, uploaded %d files to S3", len(files))
        else:
            self.log.info("In sync, no files needed to be uploaded to S3")

        return files
예제 #19
0
 def execute(self, context):
     logging.info('Executing download: %s, %s, %s', self.bucket,
                  self.object, self.filename)
     hook = GoogleCloudStorageHook(
         google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
         delegate_to=self.delegate_to)
     print(hook.download(self.bucket, self.object, self.filename))
예제 #20
0
class GoogleCloudBucketHelper(object):
    """GoogleCloudStorageHook helper class to download GCS object."""
    GCS_PREFIX_LENGTH = 5

    def __init__(self,
                 gcp_conn_id='google_cloud_default',
                 delegate_to=None):
        self._gcs_hook = GoogleCloudStorageHook(gcp_conn_id, delegate_to)

    def google_cloud_to_local(self, file_name):
        """
        Checks whether the file specified by file_name is stored in Google Cloud
        Storage (GCS), if so, downloads the file and saves it locally. The full
        path of the saved file will be returned. Otherwise the local file_name
        will be returned immediately.

        :param file_name: The full path of input file.
        :type file_name: str
        :return: The full path of local file.
        :rtype: str
        """
        if not file_name.startswith('gs://'):
            return file_name

        # Extracts bucket_id and object_id by first removing 'gs://' prefix and
        # then split the remaining by path delimiter '/'.
        path_components = file_name[self.GCS_PREFIX_LENGTH:].split('/')
        if len(path_components) < 2:
            raise Exception(
                'Invalid Google Cloud Storage (GCS) object path: {}'
                .format(file_name))

        bucket_id = path_components[0]
        object_id = '/'.join(path_components[1:])
        local_file = os.path.join(
            tempfile.gettempdir(),
            'dataflow{}-{}'.format(str(uuid.uuid4())[:8], path_components[-1])
        )
        self._gcs_hook.download(bucket_id, object_id, local_file)

        if os.stat(local_file).st_size > 0:
            return local_file
        raise Exception(
            'Failed to download Google Cloud Storage (GCS) object: {}'
            .format(file_name))
 def apply_validate_fn(*args, **kwargs):
     prediction_path = kwargs["templates_dict"]["prediction_path"]
     scheme, bucket, obj, _, _ = urlsplit(prediction_path)
     if scheme != "gs" or not bucket or not obj:
         raise ValueError("Wrong format prediction_path: {}".format(prediction_path))
     summary = os.path.join(obj.strip("/"),
                            "prediction.summary.json")
     gcs_hook = GoogleCloudStorageHook()
     summary = json.loads(gcs_hook.download(bucket, summary))
     return validate_fn(summary)
class GoogleCloudBucketHelper(object):
    """GoogleCloudStorageHook helper class to download GCS object."""
    GCS_PREFIX_LENGTH = 5

    def __init__(self,
                 gcp_conn_id='google_cloud_default',
                 delegate_to=None):
        self._gcs_hook = GoogleCloudStorageHook(gcp_conn_id, delegate_to)

    def google_cloud_to_local(self, file_name):
        """
        Checks whether the file specified by file_name is stored in Google Cloud
        Storage (GCS), if so, downloads the file and saves it locally. The full
        path of the saved file will be returned. Otherwise the local file_name
        will be returned immediately.

        :param file_name: The full path of input file.
        :type file_name: str
        :return: The full path of local file.
        :type: str
        """
        if not file_name.startswith('gs://'):
            return file_name

        # Extracts bucket_id and object_id by first removing 'gs://' prefix and
        # then split the remaining by path delimiter '/'.
        path_components = file_name[self.GCS_PREFIX_LENGTH:].split('/')
        if len(path_components) < 2:
            raise Exception(
                'Invalid Google Cloud Storage (GCS) object path: {}'
                .format(file_name))

        bucket_id = path_components[0]
        object_id = '/'.join(path_components[1:])
        local_file = '/tmp/dataflow{}-{}'.format(str(uuid.uuid4())[:8],
                                                 path_components[-1])
        self._gcs_hook.download(bucket_id, object_id, local_file)

        if os.stat(local_file).st_size > 0:
            return local_file
        raise Exception(
            'Failed to download Google Cloud Storage (GCS) object: {}'
            .format(file_name))
예제 #23
0
    def execute(self, context):
        gcs_hook = GoogleCloudStorageHook()

        # splitting the file path to extract the desired parts (which should be a path like gs://bucket/path/file.csv)
        file_parts = self.gcs_file_path.split('/')
        # gets the bucket
        bucket = file_parts[2]
        # getting the path to the file
        file_path = '/'.join(file_parts[3:-1])
        # getting the file name
        file_name = file_parts[-1]

        # setting the local path and preparing a "Ready" path for the prepared file
        local_file_path = '/home/airflow/gcs/data' + self.gcs_file_path[
            self.gcs_file_path.rindex('/'):]
        prepared_file_path = '/home/airflow/gcs/data/Ready_{}'.format(
            self.gcs_file_path[self.gcs_file_path.rindex('/') + 1:])

        gcs_hook.download(bucket, '{}/{}'.format(file_path, file_name),
                          local_file_path)

        # adding the version header
        # replacing the _ to / to conform to the required pattern of the Geocode Job
        with open(local_file_path, 'r') as rf:
            with open(prepared_file_path, 'w') as wf:
                wf.write('Bing Spatial Data Services, 2.0\n')
                for num, line in enumerate(rf, 1):
                    wf.write(line if num > 1 else line.replace('_', '/'))

        # preparing and making the call
        bm_hook = BingMapsHook(bing_maps_conn_id=self.bing_maps_conn_id)
        api_params = {
            'description': file_name,
            'input': 'pipe',
            'output': 'json'
        }
        response = bm_hook.call(method='',
                                api_params=api_params,
                                operation='POST',
                                file_path=prepared_file_path)

        return response.json()
예제 #24
0
    def execute(self, context):
        self.log.info('Executing download: %s, %s, %s', self.bucket,
                      self.object, self.filename)
        hook = GoogleCloudStorageHook(
            google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
            delegate_to=self.delegate_to)

        if self.store_to_xcom_key:
            file_bytes = hook.download(bucket=self.bucket, object=self.object)
            if sys.getsizeof(file_bytes) < MAX_XCOM_SIZE:
                context['ti'].xcom_push(key=self.store_to_xcom_key,
                                        value=file_bytes)
            else:
                raise RuntimeError(
                    'The size of the downloaded file is too large to push to XCom!'
                )
        else:
            hook.download(bucket=self.bucket,
                          object=self.object,
                          filename=self.filename)
예제 #25
0
 def execute(self, context):
     logging.info('Executing download: %s, %s, %s', self.bucket, self.object, self.filename)
     hook = GoogleCloudStorageHook(google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
                                   delegate_to=self.delegate_to)
     file_bytes = hook.download(self.bucket, self.object, self.filename)
     if self.store_to_xcom_key:
         if sys.getsizeof(file_bytes) < 48000:
             context['ti'].xcom_push(key=self.store_to_xcom_key, value=file_bytes)
         else:
             raise RuntimeError('The size of the downloaded file is too large to push to XCom!')
     print(file_bytes)
 def apply_validate_fn(*args, **kwargs):
     prediction_path = kwargs["templates_dict"]["prediction_path"]
     scheme, bucket, obj, _, _ = urlsplit(prediction_path)
     if scheme != "gs" or not bucket or not obj:
         raise ValueError("Wrong format prediction_path: %s",
                          prediction_path)
     summary = os.path.join(obj.strip("/"),
                            "prediction.summary.json")
     gcs_hook = GoogleCloudStorageHook()
     summary = json.loads(gcs_hook.download(bucket, summary))
     return validate_fn(summary)
예제 #27
0
    def execute(self, context):
        gcs_hook = GoogleCloudStorageHook(
            google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
            delegate_to=self.delegate_to
        )
        s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify)

        gcs_source_objects = gcs_hook.list(bucket=self.gcs_source_bucket, prefix=self.gcs_source_prefix, maxResults=1000)
        if gcs_source_objects is None or len(gcs_source_objects) == 0:
            self.log.warn('SKIP: No objects found matching the prefix "%s"', self.gcs_source_prefix)
            return

        self.log.info('Number of object to compose: %d', len(gcs_source_objects))

        for gcs_uri in gcs_source_objects:
            tmp = tempfile.NamedTemporaryFile()
            if gcs_hook.exists(self.gcs_source_bucket, gcs_uri) is False:
                if self.fail_on_missing is True:
                    self.log.error('Execution will fail Object not found: gs://%s/%s', self.gcs_source_bucket, gcs_uri)
                    self.is_failed = True
                else:
                    self.log.warning('Skipping. Object not found: gs://%s/%s', self.gcs_source_bucket, gcs_uri)
                continue

            self.log.info('Download gs://%s/%s', self.gcs_source_bucket, gcs_uri)
            gcs_hook.download(
                bucket=self.gcs_source_bucket,
                object=gcs_uri,
                filename=tmp.name
            )
            self.log.info('Upload s3://%s/%s', self.s3_destination_bucket, gcs_uri)
            s3_hook.load_file(
                filename=tmp.name,
                bucket_name=self.s3_destination_bucket,
                key=gcs_uri,
                replace=True,
                acl_policy=self.s3_acl_policy
            )
            tmp.close()
            if self.is_failed:
                raise AirflowException('Some object were not found at the source.')
예제 #28
0
def download_and_transform_erf(self, partner_id=None):
  """Load and Transform ERF files to Newline Delimeted JSON.

  Then upload this file to the project GCS.

  Args:
    self: The operator this is being used in.
    partner_id: A string of the DCM id of the partner.

  Returns:
    entity_read_file_ndj: The filename for the converted entity read file.
  """
  if partner_id:
    self.erf_bucket = 'gdbm-%s' % partner_id
  else:
    self.erf_bucket = 'gdbm-public'

  gcs_hook = GoogleCloudStorageHook(google_cloud_storage_conn_id=self.gcp_conn_id)
  entity_read_file = tempfile.NamedTemporaryFile(delete=False)
  gcs_hook.download(self.erf_bucket, self.erf_object, entity_read_file.name)
  temp_file = None
  # Creating temp file. Not using the delete-on-close functionality
  # as opening the file for reading while still open for writing
  # will not work on all platform
  # https://docs.python.org/2/library/tempfile.html#tempfile.NamedTemporaryFile
  try:
    temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False)
    temp_file.writelines(json_to_jsonlines(entity_read_file.name))
    temp_file.close()
    # Random here used as a nonce for writing multiple files at once.
    filename = '%s_%s_%d.json' % (randint(1, 1000000), self.entity_type,
                                  time.time() * 1e+9)
    gcs_hook.upload(self.gcs_bucket, filename, temp_file.name)

  finally:
    if temp_file:
      temp_file.close()
    os.unlink(temp_file.name)

  return filename
예제 #29
0
def transform(**kwargs):
    """
    Clean and transform Fitbit and weather data from staging area and load into
    postgres table.
    """
    ds = kwargs.get('ds')
    pg_hook = PostgresHook(postgres_conn_id='sleep_dw')
    gcs_hook = GoogleCloudStorageHook(google_cloud_storage_conn_id='sleep-gcp')

    # load from GCS
    sleep = json.loads(gcs_hook.download('sleep-staging', f'{ds}/sleep.json'))
    weather = json.loads(
        gcs_hook.download('sleep-staging', f'{ds}/weather.json'))

    # clean staged data
    sleep = process_sleep(sleep)
    if not sleep:
        logging.info(f'No sleep data recorded for {ds}')
        return

    summary = sleep['levels']['summary']
    weather = weather['data'][0]

    # load into datawarehouse
    sleep_query = """INSERT INTO daily_sleep_data
    (ds, efficiency, startTime, endTime, events, deep, light, rem, wake,
    minAfterWakeup, minAsleep, minAwake, minInBed, temp, maxTemp, minTemp,
    precip)
    VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s);
    """
    data = (ds, sleep['efficiency'], sleep['startTime'], sleep['endTime'],
            sleep['events'], summary['deep']['minutes'],
            summary['light']['minutes'], summary['rem']['minutes'],
            summary['wake']['minutes'], sleep['minutesAfterWakeup'],
            sleep['minutesAsleep'], sleep['minutesAwake'], sleep['timeInBed'],
            weather['temp'], weather['max_temp'], weather['min_temp'],
            weather['precip'])

    pg_hook.run(sleep_query, parameters=data)
    logging.info('Done!!')
예제 #30
0
    def execute(self, context):
        hook = GoogleCloudStorageHook()
        objs = hook.download(
            self.config['bucket_name'],
            '{}/{}.json'.format(self.config['schemas_raw_path'], self.table))
        objs = json.loads(objs)

        fields = [obj['name'] for obj in objs]

        query = "SELECT {} FROM {}".format(', '.join(fields), self.table)
        self.parameters['query'] = query

        ExtractorTemplateOperator.execute(self, context)
예제 #31
0
    def execute(self, context):
        gcs_hook = GoogleCloudStorageHook(
            google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
            delegate_to=self.delegate_to
        )
        s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify)

        for i in range(0, len(self.gcs_source_uris), 1):
            tmp = tempfile.NamedTemporaryFile()
            gcs_obj = self.gcs_source_uris[i]
            s3_obj = self.s3_destination_uris[i]
            if gcs_hook.exists(self.gcs_source_bucket, gcs_obj) is False:
                if self.fail_on_missing is True:
                    self.log.error('Execution will fail Object not found: gs://%s/%s', self.gcs_source_bucket, gcs_obj)
                    self.is_failed = True
                else:
                    self.log.warning('Skipping. Object not found: gs://%s/%s', self.gcs_source_bucket, gcs_obj)
                continue

            self.log.info('Download gs://%s/%s', self.gcs_source_bucket, gcs_obj)
            gcs_hook.download(
                bucket=self.gcs_source_bucket,
                object=gcs_obj,
                filename=tmp.name
            )
            self.log.info('Upload s3://%s/%s', self.s3_destination_bucket, s3_obj)
            s3_hook.load_file(
                filename=tmp.name,
                bucket_name=self.s3_destination_bucket,
                key=s3_obj,
                replace=True,
                acl_policy=self.s3_acl_policy
            )
            tmp.close()
            if self.is_failed:
                raise AirflowException('Some object were not found at the source.')
예제 #32
0
    def execute(self, context):
        bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
                               delegate_to=self.delegate_to)

        if not self.schema_fields and self.schema_object \
                                  and self.source_format != 'DATASTORE_BACKUP':
            gcs_hook = GoogleCloudStorageHook(
                google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
                delegate_to=self.delegate_to)
            schema_fields = json.loads(
                gcs_hook.download(self.bucket,
                                  self.schema_object).decode("utf-8"))
        else:
            schema_fields = self.schema_fields

        source_uris = [
            'gs://{}/{}'.format(self.bucket, source_object)
            for source_object in self.source_objects
        ]
        conn = bq_hook.get_conn()
        cursor = conn.cursor()
        cursor.run_load(destination_project_dataset_table=self.
                        destination_project_dataset_table,
                        schema_fields=schema_fields,
                        source_uris=source_uris,
                        source_format=self.source_format,
                        create_disposition=self.create_disposition,
                        skip_leading_rows=self.skip_leading_rows,
                        write_disposition=self.write_disposition,
                        field_delimiter=self.field_delimiter,
                        max_bad_records=self.max_bad_records,
                        quote_character=self.quote_character,
                        allow_quoted_newlines=self.allow_quoted_newlines,
                        allow_jagged_rows=self.allow_jagged_rows,
                        schema_update_options=self.schema_update_options,
                        src_fmt_configs=self.src_fmt_configs,
                        time_partitioning=self.time_partitioning)

        if self.max_id_key:
            cursor.execute('SELECT MAX({}) FROM {}'.format(
                self.max_id_key, self.destination_project_dataset_table))
            row = cursor.fetchone()
            max_id = row[0] if row[0] else 0
            self.log.info('Loaded BQ data with max %s.%s=%s',
                          self.destination_project_dataset_table,
                          self.max_id_key, max_id)
            return max_id
예제 #33
0
    def execute(self, context):
        bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
                               delegate_to=self.delegate_to)

        if not self.schema_fields and self.schema_object \
                                  and self.source_format != 'DATASTORE_BACKUP':
            gcs_hook = GoogleCloudStorageHook(
                google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
                delegate_to=self.delegate_to)
            schema_fields = json.loads(gcs_hook.download(
                self.bucket,
                self.schema_object).decode("utf-8"))
        else:
            schema_fields = self.schema_fields

        source_uris = ['gs://{}/{}'.format(self.bucket, source_object)
                       for source_object in self.source_objects]
        conn = bq_hook.get_conn()
        cursor = conn.cursor()
        cursor.run_load(
            destination_project_dataset_table=self.destination_project_dataset_table,
            schema_fields=schema_fields,
            source_uris=source_uris,
            source_format=self.source_format,
            create_disposition=self.create_disposition,
            skip_leading_rows=self.skip_leading_rows,
            write_disposition=self.write_disposition,
            field_delimiter=self.field_delimiter,
            max_bad_records=self.max_bad_records,
            quote_character=self.quote_character,
            allow_quoted_newlines=self.allow_quoted_newlines,
            allow_jagged_rows=self.allow_jagged_rows,
            schema_update_options=self.schema_update_options,
            src_fmt_configs=self.src_fmt_configs,
            time_partitioning=self.time_partitioning)

        if self.max_id_key:
            cursor.execute('SELECT MAX({}) FROM {}'.format(
                self.max_id_key,
                self.destination_project_dataset_table))
            row = cursor.fetchone()
            max_id = row[0] if row[0] else 0
            self.log.info(
                'Loaded BQ data with max %s.%s=%s',
                self.destination_project_dataset_table, self.max_id_key, max_id
            )
            return max_id
예제 #34
0
파일: gcs_to_s3.py 프로젝트: jxiao0/airflow
    def execute(self, context):
        # use the super to list all files in an Google Cloud Storage bucket
        files = super(GoogleCloudStorageToS3Operator, self).execute(context)
        s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id,
                         verify=self.dest_verify)

        if not self.replace:
            # if we are not replacing -> list all files in the S3 bucket
            # and only keep those files which are present in
            # Google Cloud Storage and not in S3
            bucket_name, prefix = S3Hook.parse_s3_url(self.dest_s3_key)
            # look for the bucket and the prefix to avoid look into
            # parent directories/keys
            existing_files = s3_hook.list_keys(bucket_name, prefix=prefix)
            # in case that no files exists, return an empty array to avoid errors
            existing_files = existing_files if existing_files is not None else []
            # remove the prefix for the existing files to allow the match
            existing_files = [
                file.replace(prefix, '', 1) for file in existing_files
            ]
            files = list(set(files) - set(existing_files))

        if files:
            hook = GoogleCloudStorageHook(
                google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
                delegate_to=self.delegate_to)

            for file in files:
                file_bytes = hook.download(self.bucket, file)

                dest_key = self.dest_s3_key + file
                self.log.info("Saving file to %s", dest_key)

                s3_hook.load_bytes(file_bytes,
                                   key=dest_key,
                                   replace=self.replace)

            self.log.info("All done, uploaded %d files to S3", len(files))
        else:
            self.log.info("In sync, no files needed to be uploaded to S3")

        return files
예제 #35
0
    def execute(self, context):
        # use the super to list all files in an Google Cloud Storage bucket
        files = super().execute(context)
        s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify)

        if not self.replace:
            # if we are not replacing -> list all files in the S3 bucket
            # and only keep those files which are present in
            # Google Cloud Storage and not in S3
            bucket_name, prefix = S3Hook.parse_s3_url(self.dest_s3_key)
            # look for the bucket and the prefix to avoid look into
            # parent directories/keys
            existing_files = s3_hook.list_keys(bucket_name, prefix=prefix)
            # in case that no files exists, return an empty array to avoid errors
            existing_files = existing_files if existing_files is not None else []
            # remove the prefix for the existing files to allow the match
            existing_files = [file.replace(prefix, '', 1) for file in existing_files]
            files = list(set(files) - set(existing_files))

        if files:
            hook = GoogleCloudStorageHook(
                google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
                delegate_to=self.delegate_to
            )

            for file in files:
                file_bytes = hook.download(self.bucket, file)

                dest_key = self.dest_s3_key + file
                self.log.info("Saving file to %s", dest_key)

                s3_hook.load_bytes(file_bytes,
                                   key=dest_key,
                                   replace=self.replace)

            self.log.info("All done, uploaded %d files to S3", len(files))
        else:
            self.log.info("In sync, no files needed to be uploaded to S3")

        return files
예제 #36
0
def compare_events(df_final, gcp_bucket, gs_path_eventslist):
    '''
    this function compares the list of events scrapped from web with existing one stored in cloud storage bucket. any new events
    will have new dags scheduled accordingly. The events stored in the storage will also be updated if updates did occur.
    :param df_final: dataframe with list of events scrapped from web
    :param gcp_bucket: this is the gcs bucket where the events list file is stored
    :param gs_path_eventslist: this is the (relative) path to the events file in the cloud storage bucket
    :return:
    '''

    # establish connection with cloudstorage bucket using default Airflow Settings
    conn = GoogleCloudStorageHook()
    prev_events_file = conn.download(bucket=gcp_bucket, object=gs_path_eventslist)

    # reconstruct dataframe
    df_previous_events = pd.read_pickle(prev_events_file)

    # find out the unique rows in the latest index
    unique_events = set(df_final.index) - set(df_previous_events.index)

    # return unique events
    return df_final.loc[unique_events]
예제 #37
0
class GCSLog(object):
    """
    Utility class for reading and writing logs in GCS. Requires
    airflow[gcp_api] and setting the REMOTE_BASE_LOG_FOLDER and
    REMOTE_LOG_CONN_ID configuration options in airflow.cfg.
    """
    def __init__(self):
        """
        Attempt to create hook with airflow[gcp_api].
        """
        remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID')
        self.hook = None

        try:
            from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
            self.hook = GoogleCloudStorageHook(
                google_cloud_storage_conn_id=remote_conn_id)
        except:
            logging.error(
                'Could not create a GoogleCloudStorageHook with connection id '
                '"{}". Please make sure that airflow[gcp_api] is installed '
                'and the GCS connection exists.'.format(remote_conn_id))

    def read(self, remote_log_location, return_error=False):
        """
        Returns the log found at the remote_log_location.

        :param remote_log_location: the log's location in remote storage
        :type remote_log_location: string (path)
        :param return_error: if True, returns a string error message if an
            error occurs. Otherwise returns '' when an error occurs.
        :type return_error: bool
        """
        if self.hook:
            try:
                bkt, blob = self.parse_gcs_url(remote_log_location)
                return self.hook.download(bkt, blob).decode()
            except:
                pass

        # raise/return error if we get here
        err = 'Could not read logs from {}'.format(remote_log_location)
        logging.error(err)
        return err if return_error else ''

    def write(self, log, remote_log_location, append=False):
        """
        Writes the log to the remote_log_location. Fails silently if no hook
        was created.

        :param log: the log to write to the remote_log_location
        :type log: string
        :param remote_log_location: the log's location in remote storage
        :type remote_log_location: string (path)
        :param append: if False, any existing log file is overwritten. If True,
            the new log is appended to any existing logs.
        :type append: bool

        """
        if self.hook:
            if append:
                old_log = self.read(remote_log_location)
                log = old_log + '\n' + log

            try:
                bkt, blob = self.parse_gcs_url(remote_log_location)
                from tempfile import NamedTemporaryFile
                with NamedTemporaryFile(mode='w+') as tmpfile:
                    tmpfile.write(log)
                    # Force the file to be flushed, since we're doing the
                    # upload from within the file context (it hasn't been
                    # closed).
                    tmpfile.flush()
                    self.hook.upload(bkt, blob, tmpfile.name)
            except:
                # raise/return error if we get here
                logging.error('Could not write logs to {}'.format(remote_log_location))

    def parse_gcs_url(self, gsurl):
        """
        Given a Google Cloud Storage URL (gs://<bucket>/<blob>), returns a
        tuple containing the corresponding bucket and blob.
        """
        # Python 3
        try:
            from urllib.parse import urlparse
        # Python 2
        except ImportError:
            from urlparse import urlparse

        parsed_url = urlparse(gsurl)
        if not parsed_url.netloc:
            raise AirflowException('Please provide a bucket name')
        else:
            bucket = parsed_url.netloc
            blob = parsed_url.path.strip('/')
            return (bucket, blob)
예제 #38
0
파일: gcs_to_bq.py 프로젝트: zzmg/airflow
    def execute(self, context):
        bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
                               delegate_to=self.delegate_to)

        if not self.schema_fields:
            if self.schema_object and self.source_format != 'DATASTORE_BACKUP':
                gcs_hook = GoogleCloudStorageHook(
                    google_cloud_storage_conn_id=self.
                    google_cloud_storage_conn_id,
                    delegate_to=self.delegate_to)
                schema_fields = json.loads(
                    gcs_hook.download(self.bucket,
                                      self.schema_object).decode("utf-8"))
            elif self.schema_object is None and self.autodetect is False:
                raise AirflowException(
                    'At least one of `schema_fields`, '
                    '`schema_object`, or `autodetect` must be passed.')
            else:
                schema_fields = None

        else:
            schema_fields = self.schema_fields

        source_uris = [
            'gs://{}/{}'.format(self.bucket, source_object)
            for source_object in self.source_objects
        ]
        conn = bq_hook.get_conn()
        cursor = conn.cursor()

        if self.external_table:
            cursor.create_external_table(
                external_project_dataset_table=self.
                destination_project_dataset_table,
                schema_fields=schema_fields,
                source_uris=source_uris,
                source_format=self.source_format,
                compression=self.compression,
                skip_leading_rows=self.skip_leading_rows,
                field_delimiter=self.field_delimiter,
                max_bad_records=self.max_bad_records,
                quote_character=self.quote_character,
                ignore_unknown_values=self.ignore_unknown_values,
                allow_quoted_newlines=self.allow_quoted_newlines,
                allow_jagged_rows=self.allow_jagged_rows,
                src_fmt_configs=self.src_fmt_configs,
                encryption_configuration=self.encryption_configuration)
        else:
            cursor.run_load(
                destination_project_dataset_table=self.
                destination_project_dataset_table,
                schema_fields=schema_fields,
                source_uris=source_uris,
                source_format=self.source_format,
                autodetect=self.autodetect,
                create_disposition=self.create_disposition,
                skip_leading_rows=self.skip_leading_rows,
                write_disposition=self.write_disposition,
                field_delimiter=self.field_delimiter,
                max_bad_records=self.max_bad_records,
                quote_character=self.quote_character,
                ignore_unknown_values=self.ignore_unknown_values,
                allow_quoted_newlines=self.allow_quoted_newlines,
                allow_jagged_rows=self.allow_jagged_rows,
                schema_update_options=self.schema_update_options,
                src_fmt_configs=self.src_fmt_configs,
                time_partitioning=self.time_partitioning,
                cluster_fields=self.cluster_fields,
                encryption_configuration=self.encryption_configuration)

        if self.max_id_key:
            cursor.execute('SELECT MAX({}) FROM {}'.format(
                self.max_id_key, self.destination_project_dataset_table))
            row = cursor.fetchone()
            max_id = row[0] if row[0] else 0
            self.log.info('Loaded BQ data with max %s.%s=%s',
                          self.destination_project_dataset_table,
                          self.max_id_key, max_id)
예제 #39
0
class GCSLog(object):
    """
    Utility class for reading and writing logs in GCS. Requires
    airflow[gcp_api] and setting the REMOTE_BASE_LOG_FOLDER and
    REMOTE_LOG_CONN_ID configuration options in airflow.cfg.
    """
    def __init__(self):
        """
        Attempt to create hook with airflow[gcp_api].
        """
        remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID')
        self.hook = None

        try:
            from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
            self.hook = GoogleCloudStorageHook(
                google_cloud_storage_conn_id=remote_conn_id)
        except:
            logging.error(
                'Could not create a GoogleCloudStorageHook with connection id '
                '"{}". Please make sure that airflow[gcp_api] is installed '
                'and the GCS connection exists.'.format(remote_conn_id))

    def read(self, remote_log_location, return_error=False):
        """
        Returns the log found at the remote_log_location.

        :param remote_log_location: the log's location in remote storage
        :type remote_log_location: string (path)
        :param return_error: if True, returns a string error message if an
            error occurs. Otherwise returns '' when an error occurs.
        :type return_error: bool
        """
        if self.hook:
            try:
                bkt, blob = self.parse_gcs_url(remote_log_location)
                return self.hook.download(bkt, blob).decode()
            except:
                pass

        # raise/return error if we get here
        err = 'Could not read logs from {}'.format(remote_log_location)
        logging.error(err)
        return err if return_error else ''

    def write(self, log, remote_log_location, append=True):
        """
        Writes the log to the remote_log_location. Fails silently if no hook
        was created.

        :param log: the log to write to the remote_log_location
        :type log: string
        :param remote_log_location: the log's location in remote storage
        :type remote_log_location: string (path)
        :param append: if False, any existing log file is overwritten. If True,
            the new log is appended to any existing logs.
        :type append: bool

        """
        if self.hook:
            if append:
                old_log = self.read(remote_log_location)
                log = old_log + '\n' + log

            try:
                bkt, blob = self.parse_gcs_url(remote_log_location)
                from tempfile import NamedTemporaryFile
                with NamedTemporaryFile(mode='w+') as tmpfile:
                    tmpfile.write(log)
                    # Force the file to be flushed, since we're doing the
                    # upload from within the file context (it hasn't been
                    # closed).
                    tmpfile.flush()
                    self.hook.upload(bkt, blob, tmpfile.name)
            except:
                # raise/return error if we get here
                logging.error('Could not write logs to {}'.format(remote_log_location))

    def parse_gcs_url(self, gsurl):
        """
        Given a Google Cloud Storage URL (gs://<bucket>/<blob>), returns a
        tuple containing the corresponding bucket and blob.
        """
        # Python 3
        try:
            from urllib.parse import urlparse
        # Python 2
        except ImportError:
            from urlparse import urlparse

        parsed_url = urlparse(gsurl)
        if not parsed_url.netloc:
            raise AirflowException('Please provide a bucket name')
        else:
            bucket = parsed_url.netloc
            blob = parsed_url.path.strip('/')
            return (bucket, blob)
예제 #40
0
 def execute(self, context):
     logging.info('Executing download: %s, %s, %s', self.bucket, self.object, self.filename)
     hook = GoogleCloudStorageHook(google_cloud_storage_conn_id=self.google_cloud_storage_conn_id)
     print(hook.download(self.bucket, self.object, self.filename))
예제 #41
0
class GcsToGDriveOperator(BaseOperator):
    """
    Copies objects from a Google Cloud Storage service service to Google Drive service, with renaming
    if requested.

    Using this operator requires the following OAuth 2.0 scope:

    .. code-block:: none

        https://www.googleapis.com/auth/drive

    :param source_bucket: The source Google Cloud Storage bucket where the object is. (templated)
    :type source_bucket: str
    :param source_object: The source name of the object to copy in the Google cloud
        storage bucket. (templated)
        You can use only one wildcard for objects (filenames) within your bucket. The wildcard can appear
        inside the object name or at the end of the object name. Appending a wildcard to the bucket name
        is unsupported.
    :type source_object: str
    :param destination_object: The destination name of the object in the destination Google Drive
        service. (templated)
        If a wildcard is supplied in the source_object argument, this is the prefix that will be prepended
        to the final destination objects' paths.
        Note that the source path's part before the wildcard will be removed;
        if it needs to be retained it should be appended to destination_object.
        For example, with prefix ``foo/*`` and destination_object ``blah/``, the file ``foo/baz`` will be
        copied to ``blah/baz``; to retain the prefix write the destination_object as e.g. ``blah/foo``, in
        which case the copied file will be named ``blah/foo/baz``.
    :type destination_object: str
    :param move_object: When move object is True, the object is moved instead of copied to the new location.
        This is the equivalent of a mv command as opposed to a cp command.
    :type move_object: bool
    :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform.
    :type gcp_conn_id: str
    :param delegate_to: The account to impersonate, if any.
        For this to work, the service account making the request must have domain-wide delegation enabled.
    :type delegate_to: str
    """

    template_fields = ("source_bucket", "source_object", "destination_object")
    ui_color = "#f0eee4"

    @apply_defaults
    def __init__(self,
                 source_bucket,
                 source_object,
                 destination_object=None,
                 move_object=False,
                 gcp_conn_id="google_cloud_default",
                 delegate_to=None,
                 *args,
                 **kwargs):
        super(GcsToGDriveOperator, self).__init__(*args, **kwargs)

        self.source_bucket = source_bucket
        self.source_object = source_object
        self.destination_object = destination_object
        self.move_object = move_object
        self.gcp_conn_id = gcp_conn_id
        self.delegate_to = delegate_to
        self.gcs_hook = None  # type: Optional[GoogleCloudStorageHook]
        self.gdrive_hook = None  # type: Optional[GoogleDriveHook]

    def execute(self, context):

        self.gcs_hook = GoogleCloudStorageHook(
            google_cloud_storage_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to)
        self.gdrive_hook = GoogleDriveHook(gcp_conn_id=self.gcp_conn_id,
                                           delegate_to=self.delegate_to)

        if WILDCARD in self.source_object:
            total_wildcards = self.source_object.count(WILDCARD)
            if total_wildcards > 1:
                error_msg = (
                    "Only one wildcard '*' is allowed in source_object parameter. "
                    "Found {} in {}.".format(total_wildcards,
                                             self.source_object))

                raise AirflowException(error_msg)

            prefix, delimiter = self.source_object.split(WILDCARD, 1)
            objects = self.gcs_hook.list(self.source_bucket,
                                         prefix=prefix,
                                         delimiter=delimiter)

            for source_object in objects:
                if self.destination_object is None:
                    destination_object = source_object
                else:
                    destination_object = source_object.replace(
                        prefix, self.destination_object, 1)

                self._copy_single_object(source_object=source_object,
                                         destination_object=destination_object)
        else:
            self._copy_single_object(
                source_object=self.source_object,
                destination_object=self.destination_object)

    def _copy_single_object(self, source_object, destination_object):
        self.log.info(
            "Executing copy of gs://%s/%s to gdrive://%s",
            self.source_bucket,
            source_object,
            destination_object,
        )

        with tempfile.NamedTemporaryFile() as file:
            filename = file.name
            self.gcs_hook.download(bucket=self.source_bucket,
                                   object=source_object,
                                   filename=filename)
            self.gdrive_hook.upload_file(local_location=filename,
                                         remote_location=destination_object)

        if self.move_object:
            self.gcs_hook.delete(self.source_bucket, source_object)