Exemple #1
0
def _airflow_component_launcher(
        component: base_component.BaseComponent,
        pipeline_info: data_types.PipelineInfo,
        driver_args: data_types.DriverArgs,
        metadata_connection_config: metadata_store_pb2.ConnectionConfig,
        additional_pipeline_args: Dict[Text, Any], **kwargs) -> None:
    """Helper function to launch TFX component execution.

  This helper function will be called with Airflow env objects which contains
  run_id that we need to pass into TFX ComponentLauncher.

  Args:
    component: TFX BaseComponent instance. This instance holds all inputs and
      outputs placeholders as well as component properties.
    pipeline_info: a data_types.PipelineInfo instance that holds pipeline
      properties
    driver_args: component specific args for driver.
    metadata_connection_config: configuration for how to connect to metadata.
    additional_pipeline_args: a dict of additional pipeline args. Currently
      supporting following keys: beam_pipeline_args.
    **kwargs: Context arguments that will be passed in by Airflow, including:
      - ti: TaskInstance object from which we can get run_id of the running
            pipeline.
      For more details, please refer to the code:
      https://github.com/apache/airflow/blob/master/airflow/operators/python_operator.py
  """
    # Populate run id from Airflow task instance.
    pipeline_info.run_id = kwargs['ti'].get_dagrun().run_id
    launcher = component_launcher.ComponentLauncher(
        component=component,
        pipeline_info=pipeline_info,
        driver_args=driver_args,
        metadata_connection_config=metadata_connection_config,
        additional_pipeline_args=additional_pipeline_args)
    launcher.launch()
Exemple #2
0
def _airflow_component_launcher(
        component: base_node.BaseNode, component_launcher_class: Type[
            base_component_launcher.BaseComponentLauncher],
        pipeline_info: data_types.PipelineInfo,
        driver_args: data_types.DriverArgs,
        metadata_connection_config: metadata_store_pb2.ConnectionConfig,
        beam_pipeline_args: List[Text], additional_pipeline_args: Dict[Text,
                                                                       Any],
        component_config: base_component_config.BaseComponentConfig,
        exec_properties: Dict[Text, Any], **kwargs) -> None:
    """Helper function to launch TFX component execution.

  This helper function will be called with Airflow env objects which contains
  run_id that we need to pass into TFX ComponentLauncher.

  Args:
    component: TFX BaseComponent instance. This instance holds all inputs and
      outputs placeholders as well as component properties.
    component_launcher_class: The class of the launcher to launch the component.
    pipeline_info: A data_types.PipelineInfo instance that holds pipeline
      properties
    driver_args: Component specific args for driver.
    metadata_connection_config: Configuration for how to connect to metadata.
    beam_pipeline_args: Pipeline arguments for Beam powered Components.
    additional_pipeline_args: A dict of additional pipeline args.
    component_config: Component config to launch the component.
    exec_properties: Execution properties from the ComponentSpec.
    **kwargs: Context arguments that will be passed in by Airflow, including:
      - ti: TaskInstance object from which we can get run_id of the running
        pipeline.
      For more details, please refer to the code:
      https://github.com/apache/airflow/blob/master/airflow/operators/python_operator.py
  """
    component.exec_properties.update(exec_properties)

    # Populate run id from Airflow task instance.
    pipeline_info.run_id = kwargs['ti'].get_dagrun().run_id
    launcher = component_launcher_class.create(
        component=component,
        pipeline_info=pipeline_info,
        driver_args=driver_args,
        metadata_connection=metadata.Metadata(metadata_connection_config),
        beam_pipeline_args=beam_pipeline_args,
        additional_pipeline_args=additional_pipeline_args,
        component_config=component_config)
    with telemetry_utils.scoped_labels(
        {telemetry_utils.LABEL_TFX_RUNNER: 'airflow'}):
        launcher.launch()