def load(gcp_conn_id: str, combined_data: str, gcs_bucket: str, gcs_object: str): gcs_hook = GCSHook(gcp_conn_id=gcp_conn_id) gcs_hook.upload(bucket_name=gcs_bucket, data=combined_data, object_name=gcs_object) bq_hook = BigQueryHook(bigquery_conn_id=gcp_conn_id) bq_hook.run_load( destination_project_dataset_table= "augmented-works-297410.demo_dataset.sales_interactions2", source_uris="gs://{}/{}".format(gcs_bucket, gcs_object), write_disposition="WRITE_APPEND", source_format="CSV", skip_leading_rows=1, autodetect=False, schema_fields=[ bigquery.SchemaField("date", "DATETIME").to_api_repr(), bigquery.SchemaField("location_name", "STRING").to_api_repr(), bigquery.SchemaField("average_temp", "FLOAT").to_api_repr(), bigquery.SchemaField("fullVisitorId", "STRING").to_api_repr(), bigquery.SchemaField("city", "STRING").to_api_repr(), bigquery.SchemaField("country", "STRING").to_api_repr(), bigquery.SchemaField("region", "STRING").to_api_repr(), bigquery.SchemaField("productCategory", "STRING").to_api_repr(), bigquery.SchemaField("productName", "STRING").to_api_repr(), bigquery.SchemaField("action_type", "INTEGER").to_api_repr(), bigquery.SchemaField("action_step", "INTEGER").to_api_repr(), bigquery.SchemaField("quantity", "FLOAT").to_api_repr(), bigquery.SchemaField("price", "FLOAT").to_api_repr(), bigquery.SchemaField("revenue", "FLOAT").to_api_repr(), bigquery.SchemaField("isImpression", "BOOL").to_api_repr(), bigquery.SchemaField("transactionId", "STRING").to_api_repr(), bigquery.SchemaField("transactionRevenue", "FLOAT").to_api_repr(), bigquery.SchemaField("transactionTax", "FLOAT").to_api_repr(), bigquery.SchemaField("transactionShipping", "FLOAT").to_api_repr(), ])
def execute(self, context: 'Context'): bq_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, location=self.location, impersonation_chain=self.impersonation_chain, ) if not self.schema_fields: if self.schema_object and self.source_format != 'DATASTORE_BACKUP': gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) blob = gcs_hook.download( bucket_name=self.bucket, object_name=self.schema_object, ) schema_fields = json.loads(blob.decode("utf-8")) else: schema_fields = None else: schema_fields = self.schema_fields self.source_objects = (self.source_objects if isinstance( self.source_objects, list) else [self.source_objects]) source_uris = [ f'gs://{self.bucket}/{source_object}' for source_object in self.source_objects ] if self.external_table: with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) bq_hook.create_external_table( external_project_dataset_table=self. destination_project_dataset_table, schema_fields=schema_fields, source_uris=source_uris, source_format=self.source_format, autodetect=self.autodetect, compression=self.compression, skip_leading_rows=self.skip_leading_rows, field_delimiter=self.field_delimiter, max_bad_records=self.max_bad_records, quote_character=self.quote_character, ignore_unknown_values=self.ignore_unknown_values, allow_quoted_newlines=self.allow_quoted_newlines, allow_jagged_rows=self.allow_jagged_rows, encoding=self.encoding, src_fmt_configs=self.src_fmt_configs, encryption_configuration=self.encryption_configuration, labels=self.labels, description=self.description, ) else: with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) bq_hook.run_load( destination_project_dataset_table=self. destination_project_dataset_table, schema_fields=schema_fields, source_uris=source_uris, source_format=self.source_format, autodetect=self.autodetect, create_disposition=self.create_disposition, skip_leading_rows=self.skip_leading_rows, write_disposition=self.write_disposition, field_delimiter=self.field_delimiter, max_bad_records=self.max_bad_records, quote_character=self.quote_character, ignore_unknown_values=self.ignore_unknown_values, allow_quoted_newlines=self.allow_quoted_newlines, allow_jagged_rows=self.allow_jagged_rows, encoding=self.encoding, schema_update_options=self.schema_update_options, src_fmt_configs=self.src_fmt_configs, time_partitioning=self.time_partitioning, cluster_fields=self.cluster_fields, encryption_configuration=self.encryption_configuration, labels=self.labels, description=self.description, ) if self.max_id_key: select_command = f'SELECT MAX({self.max_id_key}) FROM `{self.destination_project_dataset_table}`' with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) job_id = bq_hook.run_query( sql=select_command, use_legacy_sql=False, ) row = list(bq_hook.get_job(job_id).result()) if row: max_id = row[0] if row[0] else 0 self.log.info( 'Loaded BQ data with max %s.%s=%s', self.destination_project_dataset_table, self.max_id_key, max_id, ) else: raise RuntimeError(f"The {select_command} returned no rows!")