예제 #1
0
 def test_execute(self, mock_dataflow):
     start_flex_template = DataflowStartFlexTemplateOperator(
         task_id="start_flex_template_streaming_beam_sql",
         body={"launchParameter": TEST_FLEX_PARAMETERS},
         do_xcom_push=True,
         project_id=TEST_PROJECT_ID,
         location=TEST_LOCATION,
     )
     start_flex_template.execute(mock.MagicMock())
     mock_dataflow.return_value.start_flex_template.assert_called_once_with(
         body={"launchParameter": TEST_FLEX_PARAMETERS},
         location=TEST_LOCATION,
         project_id=TEST_PROJECT_ID,
         on_new_job_id_callback=mock.ANY,
     )
예제 #2
0
 def test_on_kill(self):
     start_flex_template = DataflowStartFlexTemplateOperator(
         task_id="start_flex_template_streaming_beam_sql",
         body={"launchParameter": TEST_FLEX_PARAMETERS},
         do_xcom_push=True,
         location=TEST_LOCATION,
         project_id=TEST_PROJECT_ID,
     )
     start_flex_template.hook = mock.MagicMock()
     start_flex_template.job_id = JOB_ID
     start_flex_template.on_kill()
     start_flex_template.hook.cancel_job.assert_called_once_with(
         job_id='test-dataflow-pipeline-id', project_id=TEST_PROJECT_ID)
예제 #3
0
			'outputTableSpec': 'rabbitmq-integration-301417:rabbitmq.payload_history',
			'javascriptTextTransformGcsPath': 'gs://rabbitmq-storage/udf-template/transform.js',
			'javascriptTextTransformFunctionName': 'transform'
}
template = "gs://rabbitmq-storage/templates/rabbitmq-to-bigquery-image-spec.json"

# Define a DAG (directed acyclic graph) of tasks.
# Any task you create within the context manager is automatically added to the
# DAG object.
with models.DAG(
    # The id you will see in the DAG airflow page
    "composer_dataflow_dag",
    default_args=default_args,
    # The interval with which to schedule the DAG
    schedule_interval=timedelta(days=1),  # Override to match your needs
) as dag:
    start_template_job = DataflowStartFlexTemplateOperator(
        # The task id of your job
        task_id="dataflow_operator_rabbitmq_to_bq",
        project_id="rabbitmq-integration-301417",
        location="us-central1",
        wait_until_finished=False,
        body={
				'launchParameter': {
					'jobName': job,
					'containerSpecGcsPath': template,
					'parameters': parameters
				}
			}
    )
GCS_FLEX_TEMPLATE_TEMPLATE_PATH = os.environ.get(
    'GCP_DATAFLOW_GCS_FLEX_TEMPLATE_TEMPLATE_PATH',
    "gs://INVALID BUCKET NAME/samples/dataflow/templates/streaming-beam-sql.json",
)
BQ_FLEX_TEMPLATE_DATASET = os.environ.get('GCP_DATAFLOW_BQ_FLEX_TEMPLATE_DATASET', 'airflow_dataflow_samples')
BQ_FLEX_TEMPLATE_LOCATION = os.environ.get('GCP_DATAFLOW_BQ_FLEX_TEMPLATE_LOCATION>', 'us-west1')

with models.DAG(
    dag_id="example_gcp_dataflow_flex_template_java",
    start_date=datetime(2021, 1, 1),
    catchup=False,
    schedule_interval='@once',  # Override to match your needs
) as dag_flex_template:
    # [START howto_operator_start_template_job]
    start_flex_template = DataflowStartFlexTemplateOperator(
        task_id="start_flex_template_streaming_beam_sql",
        body={
            "launchParameter": {
                "containerSpecGcsPath": GCS_FLEX_TEMPLATE_TEMPLATE_PATH,
                "jobName": DATAFLOW_FLEX_TEMPLATE_JOB_NAME,
                "parameters": {
                    "inputSubscription": PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION,
                    "outputTable": f"{GCP_PROJECT_ID}:{BQ_FLEX_TEMPLATE_DATASET}.streaming_beam_sql",
                },
            }
        },
        do_xcom_push=True,
        location=BQ_FLEX_TEMPLATE_LOCATION,
    )
    # [END howto_operator_start_template_job]
def build_load_worldpop_dag(
    dag_id,
    output_bucket,
    countries,
    large_countries,
    destination_dataset_project_id,
    destination_dataset_name,
    destination_table_name,
    staging_dataset_project_id,
    staging_dataset_name,
    dataflow_template_path,
    dataflow_environment,
    notification_emails=None,
    load_start_date=datetime(2000, 1, 1),
    load_schedule_interval="0 0 * * *",
    load_max_active_runs=None,
    load_concurrency=None,
    load_retries=5,
    load_retry_delay=300,
    output_path_prefix="export",
    **kwargs,
):
    if not output_bucket:
        raise ValueError("output_bucket is required")
    if not destination_dataset_project_id:
        raise ValueError("destination_dataset_project_id is required")
    if not destination_dataset_name:
        raise ValueError("destination_dataset_name is required")
    if not destination_table_name:
        raise ValueError("destination_table_name is required")

    default_dag_args = {
        "depends_on_past": False,
        "start_date": load_start_date,
        "end_date": None,
        "email_on_failure": True,
        "email_on_retry": False,
        "retries": load_retries,
        "retry_delay": timedelta(seconds=load_retry_delay),
    }

    if notification_emails and len(notification_emails) > 0:
        default_dag_args["email"] = [
            email.strip() for email in notification_emails.split(",")
        ]

    if load_max_active_runs is None:
        load_max_active_runs = configuration.conf.getint(
            "core", "max_active_runs_per_dag")

    dag = models.DAG(
        dag_id,
        schedule_interval=load_schedule_interval,
        max_active_runs=load_max_active_runs,
        concurrency=load_concurrency,
        default_args=default_dag_args,
        is_paused_upon_creation=True,
    )

    dags_folder = os.environ.get("DAGS_FOLDER", "/home/airflow/gcs/dags")

    def read_bigquery_schema(schema):
        schema_path = os.path.join(
            dags_folder,
            "resources/stages/load/schemas/{schema}.json".format(
                schema=schema),
        )
        return read_bigquery_schema_from_file(schema_path)

    def load_task(country, **context):
        client = bigquery.Client()
        job_config = bigquery.LoadJobConfig()
        job_config.schema = read_bigquery_schema("world_pop")
        job_config.source_format = bigquery.SourceFormat.PARQUET
        job_config.write_disposition = "WRITE_TRUNCATE"
        job_config.ignore_unknown_values = True
        job_config.range_partitioning = RangePartitioning(
            field="year",
            range_=PartitionRange(start=1900, end=2100, interval=1),
        )

        execution_date = context["execution_date"]
        load_table_name = "{table}_{country}_{year}".format(
            table=destination_table_name,
            country=country,
            year=execution_date.strftime("%Y"),
        )
        table_ref = create_dataset(
            client,
            staging_dataset_name,
            project=staging_dataset_project_id,
        ).table(load_table_name)

        load_uri = "gs://{bucket}/{prefix}/world_pop/year={year}/parquet/{country}_{year}.parquet".format(
            bucket=output_bucket,
            prefix=output_path_prefix,
            country=country,
            year=execution_date.strftime("%Y"),
        )
        load_job = client.load_table_from_uri(
            load_uri,
            table_ref,
            job_config=job_config,
        )
        submit_bigquery_job(load_job, job_config)
        assert load_job.state == "DONE"

    def merge_task(country, **context):
        client = bigquery.Client()

        table_ref = create_dataset(
            client,
            destination_dataset_name,
            project=destination_dataset_project_id,
        ).table(destination_table_name)
        if not does_table_exist(client, table_ref):
            table = bigquery.Table(table_ref,
                                   schema=read_bigquery_schema("world_pop"))
            table.range_partitioning = RangePartitioning(
                field="year",
                range_=PartitionRange(start=1900, end=2100, interval=1),
            )
            table.clustering_fields = [
                "geography",
                "geography_polygon",
                "country",
            ]
            client.create_table(table)

        job_config = bigquery.QueryJobConfig()
        job_config.priority = bigquery.QueryPriority.INTERACTIVE

        sql_path = os.path.join(
            dags_folder, "resources/stages/load/sqls/merge_worldpop.sql")
        sql_template = read_file(sql_path)

        execution_date = context["execution_date"]
        year = execution_date.strftime("%Y")
        staging_table_name = "{table}_{country}_{year}".format(
            table=destination_table_name, country=country, year=year)

        template_context = {
            "year": year,
            "country": country,
            "source_table": staging_table_name,
            "source_project_id": staging_dataset_project_id,
            "source_dataset_name": staging_dataset_name,
            "destination_table": destination_table_name,
            "destination_dataset_project_id": destination_dataset_project_id,
            "destination_dataset_name": destination_dataset_name,
        }

        sql = context["task"].render_template(sql_template, template_context)
        job = client.query(sql, location="US", job_config=job_config)
        submit_bigquery_job(job, job_config)
        assert job.state == "DONE"

    priority = len(countries)
    for country in countries.split(","):
        c = country.lower()
        wait_uri = (
            "{prefix}/world_pop/year={year}/parquet/{country}_{year}.parquet".
            format(
                prefix=output_path_prefix,
                country=country,
                year='{{execution_date.strftime("%Y")}}',
            ))
        wait_gcs = GoogleCloudStorageObjectSensor(
            task_id=f"wait_{c}",
            timeout=60 * 60,
            poke_interval=60,
            bucket=output_bucket,
            object=wait_uri,
            weight_rule="upstream",
            priority_weight=priority,
            dag=dag,
        )
        if country in large_countries:
            input_file = "gs://{bucket}/{prefix}/world_pop/year={year}/parquet/{country}_{year}.parquet".format(
                bucket=output_bucket,
                prefix=output_path_prefix,
                country=country,
                year='{{ execution_date.strftime("%Y") }}',
            )
            output_table = "{table}_{country}_{year}".format(
                table=destination_table_name,
                country=country,
                year='{{ execution_date.strftime("%Y") }}',
            )
            load_operator = DataflowStartFlexTemplateOperator(
                task_id=f"run_dataflow_load_{c}",
                body={
                    "launchParameter": {
                        "containerSpecGcsPath":
                        f"{dataflow_template_path}/load-parquet-0.1.0.json",
                        "jobName": "load-parquet-{country}".format(country=c) +
                        '-{{ execution_date.strftime("%Y%m%d-%H%M%S") }}',
                        "parameters": {
                            "input-file":
                            input_file,
                            "output-table":
                            f"{staging_dataset_project_id}:{staging_dataset_name}.{output_table}",
                            "output-schema":
                            "/dataflow/template/schemas/world_pop.json",
                        },
                        "environment": dataflow_environment,
                    }
                },
                location="us-central1",
                wait_until_finished=True,
                dag=dag,
            )
        else:
            load_operator = PythonOperator(
                task_id=f"load_{c}",
                python_callable=load_task,
                execution_timeout=timedelta(minutes=600),
                provide_context=True,
                op_kwargs={"country": country},
                retries=1,
                retry_delay=timedelta(seconds=300),
                weight_rule="upstream",
                priority_weight=priority,
                dag=dag,
            )
        merge_operator = PythonOperator(
            task_id=f"merge_{c}",
            python_callable=merge_task,
            execution_timeout=timedelta(minutes=600),
            provide_context=True,
            op_kwargs={"country": country},
            retries=1,
            retry_delay=timedelta(seconds=300),
            weight_rule="upstream",
            priority_weight=priority,
            dag=dag,
        )
        priority -= 1
        wait_gcs >> load_operator >> merge_operator

    return dag