Ejemplo n.º 1
0
    def __init__(self, component: base_component.BaseComponent,
                 tfx_pipeline: pipeline.Pipeline):
        """Initialize the _ComponentAsDoFn.

    Args:
      component: Component that to be executed.
      tfx_pipeline: Logical pipeline that contains pipeline related information.
    """
        driver_args = data_types.DriverArgs(
            enable_cache=tfx_pipeline.enable_cache)
        self._additional_pipeline_args = tfx_pipeline.additional_pipeline_args.copy(
        )
        _job_name = re.sub(
            r'[^0-9a-zA-Z-]+', '-', '{pipeline_name}-{component}-{ts}'.format(
                pipeline_name=tfx_pipeline.pipeline_info.pipeline_name,
                component=component.component_id,
                ts=int(datetime.datetime.timestamp(
                    datetime.datetime.now()))).lower())
        self._additional_pipeline_args['beam_pipeline_args'] = [
            arg for arg in self._additional_pipeline_args.setdefault(
                'beam_pipeline_args', []) if not arg.startswith("--job_name")
        ]
        self._additional_pipeline_args['beam_pipeline_args'].append(
            '--job_name={}'.format(_job_name))

        self._component_launcher = component_launcher.ComponentLauncher(
            component=component,
            pipeline_info=tfx_pipeline.pipeline_info,
            driver_args=driver_args,
            metadata_connection_config=tfx_pipeline.metadata_connection_config,
            additional_pipeline_args=self._additional_pipeline_args)
        self._component_id = component.component_id
Ejemplo n.º 2
0
def _airflow_component_launcher(
        component: base_component.BaseComponent,
        pipeline_info: data_types.PipelineInfo,
        driver_args: data_types.DriverArgs,
        metadata_connection_config: metadata_store_pb2.ConnectionConfig,
        additional_pipeline_args: Dict[Text, Any], **kwargs) -> None:
    """Helper function to launch TFX component execution.

  This helper function will be called with Airflow env objects which contains
  run_id that we need to pass into TFX ComponentLauncher.

  Args:
    component: TFX BaseComponent instance. This instance holds all inputs and
      outputs placeholders as well as component properties.
    pipeline_info: a data_types.PipelineInfo instance that holds pipeline
      properties
    driver_args: component specific args for driver.
    metadata_connection_config: configuration for how to connect to metadata.
    additional_pipeline_args: a dict of additional pipeline args. Currently
      supporting following keys: beam_pipeline_args.
    **kwargs: Context arguments that will be passed in by Airflow, including:
      - ti: TaskInstance object from which we can get run_id of the running
            pipeline.
      For more details, please refer to the code:
      https://github.com/apache/airflow/blob/master/airflow/operators/python_operator.py
  """
    # Populate run id from Airflow task instance.
    pipeline_info.run_id = kwargs['ti'].get_dagrun().run_id
    launcher = component_launcher.ComponentLauncher(
        component=component,
        pipeline_info=pipeline_info,
        driver_args=driver_args,
        metadata_connection_config=metadata_connection_config,
        additional_pipeline_args=additional_pipeline_args)
    launcher.launch()
Ejemplo n.º 3
0
    def test_run(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        example_gen = FileBasedExampleGen(
            executor_class=avro_executor.Executor,
            input_base=external_input(self.avro_dir_path),
            input_config=self.input_config,
            output_config=self.output_config,
            name='AvroExampleGenComponent')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        pipeline_root = os.path.join(output_data_dir, 'Test')
        tf.gfile.MakeDirs(pipeline_root)
        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(enable_cache=True)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()

        launcher = component_launcher.ComponentLauncher(
            component=example_gen,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection_config=connection_config,
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type, '.'.join(
                [FileBasedExampleGen.__module__,
                 FileBasedExampleGen.__name__]))

        launcher.launch()
        mock_publisher.return_value.publish_execution.assert_called_once()

        # Get output paths.
        component_id = '.'.join([example_gen.component_name, example_gen.name])
        output_path = os.path.join(pipeline_root, component_id, 'examples/1')
        train_examples = types.TfxArtifact(type_name='ExamplesPath',
                                           split='train')
        train_examples.uri = os.path.join(output_path, 'train')
        eval_examples = types.TfxArtifact(type_name='ExamplesPath',
                                          split='eval')
        eval_examples.uri = os.path.join(output_path, 'eval')

        # Check Avro example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
Ejemplo n.º 4
0
  def __init__(self, component: base_component.BaseComponent,
               tfx_pipeline: pipeline.Pipeline):
    """Initialize the _ComponentAsDoFn.

    Args:
      component: Component that to be executed.
      tfx_pipeline: Logical pipeline that contains pipeline related information.
    """
    driver_args = data_types.DriverArgs(enable_cache=tfx_pipeline.enable_cache)
    self._component_launcher = component_launcher.ComponentLauncher(
        component=component,
        pipeline_info=tfx_pipeline.pipeline_info,
        driver_args=driver_args,
        metadata_connection_config=tfx_pipeline.metadata_connection_config,
        additional_pipeline_args=tfx_pipeline.additional_pipeline_args)
    self._component_id = component.component_id
Ejemplo n.º 5
0
    def test_run(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        test_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()

        pipeline_root = os.path.join(test_dir, 'Test')
        input_path = os.path.join(test_dir, 'input')
        tf.gfile.MakeDirs(os.path.dirname(input_path))
        file_io.write_string_to_file(input_path, 'test')

        input_artifact = types.TfxArtifact(type_name='InputPath')
        input_artifact.uri = input_path

        component = _FakeComponent(name='FakeComponent',
                                   input_channel=channel.as_channel(
                                       [input_artifact]))

        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(worker_name=component.component_id,
                                            base_output_dir=os.path.join(
                                                pipeline_root,
                                                component.component_id),
                                            enable_cache=True)

        launcher = component_launcher.ComponentLauncher(
            component=component,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection_config=connection_config,
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type,
            '.'.join([_FakeComponent.__module__, _FakeComponent.__name__]))
        launcher.launch()

        output_path = os.path.join(pipeline_root, 'output')
        self.assertTrue(tf.gfile.Exists(output_path))
        contents = file_io.read_file_to_string(output_path)
        self.assertEqual('test', contents)
Ejemplo n.º 6
0
  def __init__(self, component: base_component.BaseComponent,
               tfx_pipeline: pipeline.Pipeline):
    """Initialize the _ComponentAsDoFn.

    Args:
      component: Component that to be executed.
      tfx_pipeline: Logical pipeline that contains pipeline related information.
    """
    # TODO(ruoyu): refactor DriverArgs, drop duplicated base dir and name.
    driver_args = data_types.DriverArgs(
        worker_name=component.component_id,
        base_output_dir=os.path.join(tfx_pipeline.pipeline_info.pipeline_root,
                                     component.component_id, ''),
        enable_cache=tfx_pipeline.enable_cache)
    self._component_launcher = component_launcher.ComponentLauncher(
        component=component,
        pipeline_info=tfx_pipeline.pipeline_info,
        driver_args=driver_args,
        metadata_connection_config=tfx_pipeline.metadata_connection_config,
        additional_pipeline_args=tfx_pipeline.additional_pipeline_args)
    self._name = component.component_name