예제 #1
0
def launch_container_component(
        component: base_node.BaseNode, component_launcher_class: Type[
            base_component_launcher.BaseComponentLauncher],
        component_config: base_component_config.BaseComponentConfig,
        pipeline: tfx_pipeline.Pipeline):
    """Use the kubernetes component launcher to launch the component.

  Args:
    component: Container component to be executed.
    component_launcher_class: The class of the launcher to launch the component.
    component_config: component config to launch the component.
    pipeline: Logical pipeline that contains pipeline related information.
  """
    driver_args = data_types.DriverArgs(enable_cache=pipeline.enable_cache)
    metadata_connection = metadata.Metadata(
        pipeline.metadata_connection_config)

    component_launcher = component_launcher_class.create(
        component=component,
        pipeline_info=pipeline.pipeline_info,
        driver_args=driver_args,
        metadata_connection=metadata_connection,
        beam_pipeline_args=pipeline.beam_pipeline_args,
        additional_pipeline_args=pipeline.additional_pipeline_args,
        component_config=component_config)
    logging.info('Component %s is running.', component.id)
    component_launcher.launch()
    logging.info('Component %s is finished.', component.id)
예제 #2
0
    def __init__(self, component: base_component.BaseComponent,
                 tfx_pipeline: pipeline.Pipeline):
        """Initialize the _ComponentAsDoFn.

    Args:
      component: Component that to be executed.
      tfx_pipeline: Logical pipeline that contains pipeline related information.
    """
        driver_args = data_types.DriverArgs(
            enable_cache=tfx_pipeline.enable_cache)
        self._additional_pipeline_args = tfx_pipeline.additional_pipeline_args.copy(
        )
        _job_name = re.sub(
            r'[^0-9a-zA-Z-]+', '-', '{pipeline_name}-{component}-{ts}'.format(
                pipeline_name=tfx_pipeline.pipeline_info.pipeline_name,
                component=component.component_id,
                ts=int(datetime.datetime.timestamp(
                    datetime.datetime.now()))).lower())
        self._additional_pipeline_args['beam_pipeline_args'] = [
            arg for arg in self._additional_pipeline_args.setdefault(
                'beam_pipeline_args', []) if not arg.startswith("--job_name")
        ]
        self._additional_pipeline_args['beam_pipeline_args'].append(
            '--job_name={}'.format(_job_name))

        self._component_launcher = component_launcher.ComponentLauncher(
            component=component,
            pipeline_info=tfx_pipeline.pipeline_info,
            driver_args=driver_args,
            metadata_connection_config=tfx_pipeline.metadata_connection_config,
            additional_pipeline_args=self._additional_pipeline_args)
        self._component_id = component.component_id
예제 #3
0
    def _create_launcher_context(self, component_config=None):
        test_dir = self.get_temp_dir()

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()

        pipeline_root = os.path.join(test_dir, 'Test')

        input_artifact = test_utils._InputArtifact()
        input_artifact.uri = os.path.join(test_dir, 'input')

        component = test_utils._FakeComponent(
            name='FakeComponent',
            input_channel=channel_utils.as_channel([input_artifact]),
            custom_executor_spec=executor_spec.ExecutorContainerSpec(
                image='gcr://test', args=['{{input_dict["input"][0].uri}}']))

        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(enable_cache=True)

        launcher = kubernetes_component_launcher.KubernetesComponentLauncher.create(
            component=component,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection_config=connection_config,
            beam_pipeline_args=[],
            additional_pipeline_args={},
            component_config=component_config)

        return {'launcher': launcher, 'input_artifact': input_artifact}
예제 #4
0
    def __init__(
            self, parent_dag: models.DAG,
            component: base_component.BaseComponent,
            pipeline_info: data_types.PipelineInfo, enable_cache: bool,
            metadata_connection_config: metadata_store_pb2.ConnectionConfig,
            additional_pipeline_args: Dict[Text, Any]):
        """Constructs an Airflow implementation of TFX component.

    Args:
      parent_dag: An AirflowPipeline instance as the pipeline DAG.
      component: An instance of base_component.BaseComponent that holds all
        properties of a logical component.
      pipeline_info: An instance of data_types.PipelineInfo that holds pipeline
        properties.
      enable_cache: Whether or not cache is enabled for this component run.
      metadata_connection_config: A config proto for metadata connection.
      additional_pipeline_args: Additional pipeline args.
    """
        # Prepare parameters to create TFX worker.
        driver_args = data_types.DriverArgs(enable_cache=enable_cache)

        super(AirflowComponent, self).__init__(
            task_id=component.component_id,
            provide_context=True,
            python_callable=functools.partial(
                _airflow_component_launcher,
                component=component,
                pipeline_info=pipeline_info,
                driver_args=driver_args,
                metadata_connection_config=metadata_connection_config,
                additional_pipeline_args=additional_pipeline_args),
            dag=parent_dag)
예제 #5
0
    def __init__(self, component: base_component.BaseComponent,
                 component_launcher_class: Type[
                     base_component_launcher.BaseComponentLauncher],
                 component_config: base_component_config.BaseComponentConfig,
                 tfx_pipeline: pipeline.Pipeline):
        """Initialize the _ComponentAsDoFn.

    Args:
      component: Component that to be executed.
      component_launcher_class: The class of the launcher to launch the
        component.
      component_config: component config to launch the component.
      tfx_pipeline: Logical pipeline that contains pipeline related information.
    """
        enable_cache = (component.enable_cache if component.enable_cache
                        is not None else tfx_pipeline.enable_cache)
        driver_args = data_types.DriverArgs(enable_cache=enable_cache)
        metadata_connection = metadata.Metadata(
            tfx_pipeline.metadata_connection_config)
        self._component_launcher = component_launcher_class.create(
            component=component,
            pipeline_info=tfx_pipeline.pipeline_info,
            driver_args=driver_args,
            metadata_connection=metadata_connection,
            beam_pipeline_args=tfx_pipeline.beam_pipeline_args,
            additional_pipeline_args=tfx_pipeline.additional_pipeline_args,
            component_config=component_config)
        self._component_id = component.id
예제 #6
0
    def run(self,
            component: base_component.BaseComponent,
            enable_cache: bool = True) -> execution_result.ExecutionResult:
        """Run a given TFX component in the interactive context.

    Args:
      component: Component instance to be run.
      enable_cache: whether caching logic should be enabled in the driver.

    Returns:
      execution_result.ExecutionResult object.
    """
        run_id = datetime.datetime.now().isoformat()
        pipeline_info = data_types.PipelineInfo(
            pipeline_name=self.pipeline_name,
            pipeline_root=self.pipeline_root,
            run_id=run_id)
        driver_args = data_types.DriverArgs(enable_cache=enable_cache,
                                            interactive_resolution=True)
        additional_pipeline_args = {}
        for name, output in component.outputs.get_all().items():
            for artifact in output.get():
                artifact.pipeline_name = self.pipeline_name
                artifact.producer_component = component.component_id
                artifact.run_id = run_id
                artifact.name = name
        # TODO(hongyes): figure out how to resolve launcher class in the interactive
        # context.
        launcher = in_process_component_launcher.InProcessComponentLauncher.create(
            component, pipeline_info, driver_args,
            self.metadata_connection_config, additional_pipeline_args)
        execution_id = launcher.launch()

        return execution_result.ExecutionResult(component=component,
                                                execution_id=execution_id)
예제 #7
0
    def setUp(self):
        super(StubComponentLauncherTest, self).setUp()
        test_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        self.metadata_connection = metadata.Metadata(connection_config)

        self.pipeline_root = os.path.join(test_dir, 'Test')
        self.input_dir = os.path.join(test_dir, 'input')
        self.output_dir = os.path.join(test_dir, 'output')
        self.record_dir = os.path.join(test_dir, 'record')
        tf.io.gfile.makedirs(self.input_dir)
        tf.io.gfile.makedirs(self.output_dir)
        tf.io.gfile.makedirs(self.record_dir)

        input_artifact = test_utils._InputArtifact()  # pylint: disable=protected-access
        input_artifact.uri = os.path.join(self.input_dir, 'result.txt')
        output_artifact = test_utils._OutputArtifact()  # pylint: disable=protected-access
        output_artifact.uri = os.path.join(self.output_dir, 'result.txt')
        self.component = test_utils._FakeComponent(  # pylint: disable=protected-access
            name='FakeComponent',
            input_channel=channel_utils.as_channel([input_artifact]),
            output_channel=channel_utils.as_channel([output_artifact]))
        self.driver_args = data_types.DriverArgs(enable_cache=True)

        self.pipeline_info = data_types.PipelineInfo(
            pipeline_name='Test',
            pipeline_root=self.pipeline_root,
            run_id='123')
예제 #8
0
    def setUp(self):
        super(ImporterDriverTest, self).setUp()
        self.connection_config = metadata_store_pb2.ConnectionConfig()
        self.connection_config.sqlite.SetInParent()
        self.output_dict = {
            importer_node.IMPORT_RESULT_KEY:
            types.Channel(type=standard_artifacts.Examples)
        }
        self.source_uri = 'm/y/u/r/i'
        self.properties = {
            'split_names': artifact_utils.encode_split_names(['train', 'eval'])
        }
        self.custom_properties = {
            'string_custom_property': 'abc',
            'int_custom_property': 123,
        }

        self.existing_artifacts = []
        existing_artifact = standard_artifacts.Examples()
        existing_artifact.uri = self.source_uri
        existing_artifact.split_names = self.properties['split_names']
        self.existing_artifacts.append(existing_artifact)

        self.pipeline_info = data_types.PipelineInfo(pipeline_name='p_name',
                                                     pipeline_root='p_root',
                                                     run_id='run_id')
        self.component_info = data_types.ComponentInfo(
            component_type='c_type',
            component_id='c_id',
            pipeline_info=self.pipeline_info)
        self.driver_args = data_types.DriverArgs(enable_cache=True)
예제 #9
0
 def setUp(self):
     self._mock_metadata = tf.test.mock.Mock()
     self._input_dict = {
         'input_data': [types.TfxArtifact(type_name='InputType')],
     }
     input_dir = os.path.join(
         os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
         self._testMethodName, 'input_dir')
     # valid input artifacts must have a uri pointing to an existing directory.
     for key, input_list in self._input_dict.items():
         for index, artifact in enumerate(input_list):
             artifact.id = index + 1
             uri = os.path.join(input_dir, key, str(artifact.id), '')
             artifact.uri = uri
             tf.gfile.MakeDirs(uri)
     self._output_dict = {
         'output_data': [types.TfxArtifact(type_name='OutputType')],
     }
     self._exec_properties = {
         'key': 'value',
     }
     self._base_output_dir = os.path.join(
         os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
         self._testMethodName, 'base_output_dir')
     self._driver_args = data_types.DriverArgs(
         worker_name='worker_name',
         base_output_dir=self._base_output_dir,
         enable_cache=True)
     self._execution_id = 100
예제 #10
0
    def setUp(self):
        super(ImporterDriverTest, self).setUp()
        self.connection_config = metadata_store_pb2.ConnectionConfig()
        self.connection_config.sqlite.SetInParent()
        self.output_dict = {
            importer_node.IMPORT_RESULT_KEY:
            types.Channel(type=standard_artifacts.Examples)
        }
        self.source_uri = ['m/y/u/r/i/1', 'm/y/u/r/i/2']
        self.split = ['train', 'eval']

        self.existing_artifacts = []
        for uri, split in zip(self.source_uri, self.split):
            existing_artifact = standard_artifacts.Examples()
            existing_artifact.uri = uri
            existing_artifact.split_names = artifact_utils.encode_split_names(
                [split])
            self.existing_artifacts.append(existing_artifact)

        self.component_info = data_types.ComponentInfo(component_type='c_type',
                                                       component_id='c_id')
        self.pipeline_info = data_types.PipelineInfo(pipeline_name='p_name',
                                                     pipeline_root='p_root',
                                                     run_id='run_id')
        self.driver_args = data_types.DriverArgs(enable_cache=True)
예제 #11
0
    def testRun(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        example_gen = FileBasedExampleGen(
            custom_executor_spec=executor_spec.ExecutorClassSpec(
                avro_executor.Executor),
            input=external_input(self.avro_dir_path),
            input_config=self.input_config,
            output_config=self.output_config,
            instance_name='AvroExampleGen')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        pipeline_root = os.path.join(output_data_dir, 'Test')
        tf.io.gfile.makedirs(pipeline_root)
        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(enable_cache=True)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        metadata_connection = metadata.Metadata(connection_config)

        launcher = in_process_component_launcher.InProcessComponentLauncher.create(
            component=example_gen,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection=metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type, '.'.join(
                [FileBasedExampleGen.__module__,
                 FileBasedExampleGen.__name__]))

        launcher.launch()
        mock_publisher.return_value.publish_execution.assert_called_once()

        # Get output paths.
        component_id = example_gen.id
        output_path = os.path.join(pipeline_root, component_id, 'examples/1')
        examples = standard_artifacts.Examples()
        examples.uri = output_path
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])

        # Check Avro example gen outputs.
        train_output_file = os.path.join(examples.uri, 'train',
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(examples.uri, 'eval',
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.io.gfile.exists(train_output_file))
        self.assertTrue(tf.io.gfile.exists(eval_output_file))
        self.assertGreater(
            tf.io.gfile.GFile(train_output_file).size(),
            tf.io.gfile.GFile(eval_output_file).size())
예제 #12
0
    def test_run(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        example_gen = FileBasedExampleGen(
            executor_class=avro_executor.Executor,
            input_base=external_input(self.avro_dir_path),
            input_config=self.input_config,
            output_config=self.output_config,
            name='AvroExampleGenComponent')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        pipeline_root = os.path.join(output_data_dir, 'Test')
        tf.gfile.MakeDirs(pipeline_root)
        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(enable_cache=True)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()

        launcher = component_launcher.ComponentLauncher(
            component=example_gen,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection_config=connection_config,
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type, '.'.join(
                [FileBasedExampleGen.__module__,
                 FileBasedExampleGen.__name__]))

        launcher.launch()
        mock_publisher.return_value.publish_execution.assert_called_once()

        # Get output paths.
        component_id = '.'.join([example_gen.component_name, example_gen.name])
        output_path = os.path.join(pipeline_root, component_id, 'examples/1')
        train_examples = types.TfxArtifact(type_name='ExamplesPath',
                                           split='train')
        train_examples.uri = os.path.join(output_path, 'train')
        eval_examples = types.TfxArtifact(type_name='ExamplesPath',
                                          split='eval')
        eval_examples.uri = os.path.join(output_path, 'eval')

        # Check Avro example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
예제 #13
0
    def run(
        self,
        component: base_node.BaseNode,
        enable_cache: bool = True,
        beam_pipeline_args: Optional[List[Text]] = None
    ) -> execution_result.ExecutionResult:
        """Run a given TFX component in the interactive context.

    Args:
      component: Component instance to be run.
      enable_cache: whether caching logic should be enabled in the driver.
      beam_pipeline_args: Optional Beam pipeline args for beam jobs within
        executor. Executor will use beam DirectRunner as Default. If provided,
        will override beam_pipeline_args specified in constructor.

    Returns:
      execution_result.ExecutionResult object.
    """
        run_id = datetime.datetime.now().isoformat()
        pipeline_info = data_types.PipelineInfo(
            pipeline_name=self.pipeline_name,
            pipeline_root=self.pipeline_root,
            run_id=run_id)
        driver_args = data_types.DriverArgs(enable_cache=enable_cache,
                                            interactive_resolution=True)
        metadata_connection = metadata.Metadata(
            self.metadata_connection_config)
        beam_pipeline_args = beam_pipeline_args or self.beam_pipeline_args
        additional_pipeline_args = {}
        for name, output in component.outputs.items():
            for artifact in output.get():
                artifact.pipeline_name = self.pipeline_name
                artifact.producer_component = component.id
                artifact.name = name
        # Special treatment for pip dependencies.
        # TODO(b/187122662): Pass through pip dependencies as a first-class
        # component flag.
        if isinstance(component, base_component.BaseComponent):
            component._resolve_pip_dependencies(self.pipeline_root)  # pylint: disable=protected-access
        # TODO(hongyes): figure out how to resolve launcher class in the interactive
        # context.
        launcher = in_process_component_launcher.InProcessComponentLauncher.create(
            component, pipeline_info, driver_args, metadata_connection,
            beam_pipeline_args, additional_pipeline_args)
        try:
            import colab  # pytype: disable=import-error # pylint: disable=g-import-not-at-top, unused-import, unused-variable
            runner_label = 'interactivecontext-colab'
        except ImportError:
            runner_label = 'interactivecontext'
        with telemetry_utils.scoped_labels({
                telemetry_utils.LABEL_TFX_RUNNER:
                runner_label,
        }):
            execution_id = launcher.launch().execution_id

        return execution_result.ExecutionResult(component=component,
                                                execution_id=execution_id)
예제 #14
0
def main():
    # Log to the container's stdout so it can be streamed by the orchestrator.
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    logging.getLogger().setLevel(logging.INFO)

    parser = argparse.ArgumentParser()
    parser.add_argument('--pipeline_name', type=str, required=True)
    parser.add_argument('--pipeline_root', type=str, required=True)
    parser.add_argument('--run_id', type=str, required=True)
    parser.add_argument('--metadata_config', type=str, required=True)
    parser.add_argument('--beam_pipeline_args', type=str, required=True)
    parser.add_argument('--additional_pipeline_args', type=str, required=True)
    parser.add_argument('--component_launcher_class_path',
                        type=str,
                        required=True)
    parser.add_argument('--enable_cache', action='store_true')
    parser.add_argument('--serialized_component', type=str, required=True)
    parser.add_argument('--component_config', type=str, required=True)

    args = parser.parse_args()

    component = json_utils.loads(args.serialized_component)
    component_config = json_utils.loads(args.component_config)
    component_launcher_class = import_utils.import_class_by_path(
        args.component_launcher_class_path)
    if not issubclass(component_launcher_class,
                      base_component_launcher.BaseComponentLauncher):
        raise TypeError(
            'component_launcher_class "%s" is not subclass of base_component_launcher.BaseComponentLauncher'
            % component_launcher_class)

    metadata_config = metadata_store_pb2.ConnectionConfig()
    json_format.Parse(args.metadata_config, metadata_config)
    driver_args = data_types.DriverArgs(enable_cache=args.enable_cache)
    beam_pipeline_args = json.loads(args.beam_pipeline_args)
    additional_pipeline_args = json.loads(args.additional_pipeline_args)

    launcher = component_launcher_class.create(
        component=component,
        pipeline_info=data_types.PipelineInfo(
            pipeline_name=args.pipeline_name,
            pipeline_root=args.pipeline_root,
            run_id=args.run_id,
        ),
        driver_args=driver_args,
        metadata_connection=metadata.Metadata(
            connection_config=metadata_config),
        beam_pipeline_args=beam_pipeline_args,
        additional_pipeline_args=additional_pipeline_args,
        component_config=component_config)

    # Attach necessary labels to distinguish different runner and DSL.
    with telemetry_utils.scoped_labels({
            telemetry_utils.LABEL_TFX_RUNNER:
            'kubernetes',
    }):
        launcher.launch()
예제 #15
0
def main():
    # Log to the container's stdout so Kubeflow Pipelines UI can display logs to
    # the user.
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    logging.getLogger().setLevel(logging.INFO)

    parser = argparse.ArgumentParser()
    parser.add_argument('--pipeline_name', type=str, required=True)
    parser.add_argument('--pipeline_root', type=str, required=True)
    parser.add_argument('--kubeflow_metadata_config', type=str, required=True)
    parser.add_argument('--beam_pipeline_args', type=str, required=True)
    parser.add_argument('--additional_pipeline_args', type=str, required=True)
    parser.add_argument('--component_launcher_class_path',
                        type=str,
                        required=True)
    parser.add_argument('--enable_cache', action='store_true')
    parser.add_argument('--serialized_component', type=str, required=True)
    parser.add_argument('--component_config', type=str, required=True)

    args = parser.parse_args()

    component = json_utils.loads(args.serialized_component)
    component_config = json_utils.loads(args.component_config)
    component_launcher_class = import_utils.import_class_by_path(
        args.component_launcher_class_path)
    if not issubclass(component_launcher_class,
                      base_component_launcher.BaseComponentLauncher):
        raise TypeError(
            'component_launcher_class "%s" is not subclass of base_component_launcher.BaseComponentLauncher'
            % component_launcher_class)

    kubeflow_metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    json_format.Parse(args.kubeflow_metadata_config, kubeflow_metadata_config)
    metadata_connection = kubeflow_metadata_adapter.KubeflowMetadataAdapter(
        _get_metadata_connection_config(kubeflow_metadata_config))
    driver_args = data_types.DriverArgs(enable_cache=args.enable_cache)

    beam_pipeline_args = _make_beam_pipeline_args(args.beam_pipeline_args)

    additional_pipeline_args = json.loads(args.additional_pipeline_args)

    launcher = component_launcher_class.create(
        component=component,
        pipeline_info=data_types.PipelineInfo(
            pipeline_name=args.pipeline_name,
            pipeline_root=args.pipeline_root,
            run_id=os.environ['WORKFLOW_ID']),
        driver_args=driver_args,
        metadata_connection=metadata_connection,
        beam_pipeline_args=beam_pipeline_args,
        additional_pipeline_args=additional_pipeline_args,
        component_config=component_config)

    execution_info = launcher.launch()

    # Dump the UI metadata.
    _dump_ui_metadata(component, execution_info)
예제 #16
0
 def setUp(self):
   super(BaseDriverTest, self).setUp()
   self._mock_metadata = tf.compat.v1.test.mock.Mock()
   self._input_dict = {
       'input_data':
           types.Channel(
               type=_InputArtifact,
               artifacts=[_InputArtifact()],
               producer_component_id='c',
               output_key='k'),
       'input_string':
           types.Channel(
               type=standard_artifacts.String,
               artifacts=[
                   standard_artifacts.String(),
                   standard_artifacts.String()
               ],
               producer_component_id='c2',
               output_key='k2'),
   }
   input_dir = os.path.join(
       os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
       self._testMethodName, 'input_dir')
   # valid input artifacts must have a uri pointing to an existing directory.
   for key, input_channel in self._input_dict.items():
     for index, artifact in enumerate(input_channel.get()):
       artifact.id = index + 1
       uri = os.path.join(input_dir, key, str(artifact.id))
       artifact.uri = uri
       tf.io.gfile.makedirs(uri)
   self._output_dict = {
       'output_data':
           types.Channel(type=_OutputArtifact, artifacts=[_OutputArtifact()]),
       'output_multi_data':
           types.Channel(
               type=_OutputArtifact, matching_channel_name='input_string')
   }
   self._input_artifacts = channel_utils.unwrap_channel_dict(self._input_dict)
   self._output_artifacts = channel_utils.unwrap_channel_dict(
       self._output_dict)
   self._exec_properties = {
       'key': 'value',
   }
   self._execution_id = 100
   self._execution = metadata_store_pb2.Execution()
   self._execution.id = self._execution_id
   self._context_id = 123
   self._driver_args = data_types.DriverArgs(enable_cache=True)
   self._pipeline_info = data_types.PipelineInfo(
       pipeline_name='my_pipeline_name',
       pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
       run_id='my_run_id')
   self._component_info = data_types.ComponentInfo(
       component_type='a.b.c',
       component_id='my_component_id',
       pipeline_info=self._pipeline_info)
예제 #17
0
 def setUp(self):
     self._component = _FakeComponent(
         _FakeComponentSpec(input=types.Channel(type_name='type_a'),
                            output=types.Channel(type_name='type_b')))
     self._pipeline_info = data_types.PipelineInfo('name', 'root')
     self._driver_args = data_types.DriverArgs(True)
     self._metadata_connection_config = metadata.sqlite_metadata_connection_config(
         os.path.join(os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
                      'metadata'))
     self._parent_dag = models.DAG(dag_id=self._pipeline_info.pipeline_name,
                                   start_date=datetime.datetime(2018, 1, 1),
                                   schedule_interval=None)
예제 #18
0
 def setUp(self):
   super(ResolverDriverTest, self).setUp()
   self.connection_config = metadata_store_pb2.ConnectionConfig()
   self.connection_config.sqlite.SetInParent()
   self.component_info = data_types.ComponentInfo(
       component_type='c_type', component_id='c_id')
   self.pipeline_info = data_types.PipelineInfo(
       pipeline_name='p_name', pipeline_root='p_root', run_id='run_id')
   self.driver_args = data_types.DriverArgs(enable_cache=True)
   self.source_channel_key = 'source_channel'
   self.source_channels = {
       self.source_channel_key: types.Channel(type=standard_artifacts.Examples)
   }
예제 #19
0
    def __init__(self, parent_dag, component_name, unique_name, driver,
                 executor, input_dict, output_dict, exec_properties):
        # Prepare parameters to create TFX worker.
        if unique_name:
            worker_name = component_name + '.' + unique_name
        else:
            worker_name = component_name
        task_id = parent_dag.dag_id + '.' + worker_name

        # Create output object of appropriate type
        output_dir = self._get_working_dir(parent_dag.project_path,
                                           component_name, unique_name or '')

        # Update the output dict before providing to downstream componentsget_
        for k, output_list in output_dict.items():
            for single_output in output_list:
                single_output.source = _OrchestrationSource(
                    key=k, component_id=task_id)

        my_logger_config = logging_utils.LoggerConfig(
            log_root=parent_dag.logger_config.log_root,
            log_level=parent_dag.logger_config.log_level,
            pipeline_name=parent_dag.logger_config.pipeline_name,
            worker_name=worker_name)
        driver_args = data_types.DriverArgs(
            worker_name=worker_name,
            base_output_dir=output_dir,
            enable_cache=parent_dag.enable_cache)

        worker = _TfxWorker(
            component_name=component_name,
            task_id=task_id,
            parent_dag=parent_dag,
            input_dict=input_dict,
            output_dict=output_dict,
            exec_properties=exec_properties,
            driver_args=driver_args,
            driver_class=driver,
            executor_class=executor,
            additional_pipeline_args=parent_dag.additional_pipeline_args,
            metadata_connection_config=parent_dag.metadata_connection_config,
            logger_config=my_logger_config)
        subdag = subdag_operator.SubDagOperator(subdag=worker,
                                                task_id=worker_name,
                                                dag=parent_dag)

        parent_dag.add_node_to_graph(node=subdag,
                                     consumes=input_dict.values(),
                                     produces=output_dict.values())
예제 #20
0
    def testPreExecutionNewExecution(self, mock_verify_input_artifacts_fn):
        input_dict = {
            'input_a':
            types.Channel(type_name='input_a',
                          artifacts=[types.Artifact(type_name='input_a')])
        }
        output_dict = {
            'output_a':
            types.Channel(type_name='output_a',
                          artifacts=[
                              types.Artifact(type_name='output_a',
                                             split='split')
                          ])
        }
        execution_id = 1
        context_id = 123
        exec_properties = copy.deepcopy(self._exec_properties)
        driver_args = data_types.DriverArgs(enable_cache=True)
        pipeline_info = data_types.PipelineInfo(
            pipeline_name='my_pipeline_name',
            pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
            run_id='my_run_id')
        component_info = data_types.ComponentInfo(
            component_type='a.b.c', component_id='my_component_id')
        self._mock_metadata.get_artifacts_by_info.side_effect = list(
            input_dict['input_a'].get())
        self._mock_metadata.register_execution.side_effect = [execution_id]
        self._mock_metadata.previous_execution.side_effect = [None]
        self._mock_metadata.register_run_context_if_not_exists.side_effect = [
            context_id
        ]

        driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
        execution_decision = driver.pre_execution(
            input_dict=input_dict,
            output_dict=output_dict,
            exec_properties=exec_properties,
            driver_args=driver_args,
            pipeline_info=pipeline_info,
            component_info=component_info)
        self.assertFalse(execution_decision.use_cached_results)
        self.assertEqual(execution_decision.execution_id, 1)
        self.assertItemsEqual(execution_decision.exec_properties,
                              exec_properties)
        self.assertEqual(
            execution_decision.output_dict['output_a'][0].uri,
            os.path.join(pipeline_info.pipeline_root,
                         component_info.component_id, 'output_a',
                         str(execution_id), 'split', ''))
예제 #21
0
    def __init__(
            self, parent_dag: models.DAG, component: base_node.BaseNode,
            component_launcher_class: Type[
                base_component_launcher.BaseComponentLauncher],
            pipeline_info: data_types.PipelineInfo, enable_cache: bool,
            metadata_connection_config: metadata_store_pb2.ConnectionConfig,
            beam_pipeline_args: List[Text],
            additional_pipeline_args: Dict[Text, Any],
            component_config: base_component_config.BaseComponentConfig):
        """Constructs an Airflow implementation of TFX component.

    Args:
      parent_dag: An AirflowPipeline instance as the pipeline DAG.
      component: An instance of base_node.BaseNode that holds all
        properties of a logical component.
      component_launcher_class: The class of the launcher to launch the
        component.
      pipeline_info: An instance of data_types.PipelineInfo that holds pipeline
        properties.
      enable_cache: Whether or not cache is enabled for this component run.
      metadata_connection_config: A config proto for metadata connection.
      beam_pipeline_args: Pipeline arguments for Beam powered Components.
      additional_pipeline_args: Additional pipeline args.
      component_config: Component config to launch the component.
    """
        # Prepare parameters to create TFX worker.
        driver_args = data_types.DriverArgs(enable_cache=enable_cache)

        exec_properties = component.exec_properties

        super(AirflowComponent, self).__init__(
            task_id=component.id,
            # TODO(b/183172663): Delete `provide_context` when we drop support of
            # airflow 1.x.
            provide_context=True,
            python_callable=functools.partial(
                _airflow_component_launcher,
                component=component,
                component_launcher_class=component_launcher_class,
                pipeline_info=pipeline_info,
                driver_args=driver_args,
                metadata_connection_config=metadata_connection_config,
                beam_pipeline_args=beam_pipeline_args,
                additional_pipeline_args=additional_pipeline_args,
                component_config=component_config),
            # op_kwargs is a templated field for PythonOperator, which means Airflow
            # will inspect the dictionary and resolve any templated fields.
            op_kwargs={'exec_properties': exec_properties},
            dag=parent_dag)
예제 #22
0
    def testRun(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        test_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        metadata_connection = metadata.Metadata(connection_config)

        pipeline_root = os.path.join(test_dir, 'Test')
        input_path = os.path.join(test_dir, 'input')
        fileio.makedirs(os.path.dirname(input_path))
        file_io.write_string_to_file(input_path, 'test')

        input_artifact = test_utils._InputArtifact()
        input_artifact.uri = input_path

        component = test_utils._FakeComponent(
            name='FakeComponent',
            input_channel=channel_utils.as_channel([input_artifact]))

        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(enable_cache=True)

        # We use InProcessComponentLauncher to test BaseComponentLauncher logics.
        launcher = in_process_component_launcher.InProcessComponentLauncher.create(
            component=component,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection=metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type, '.'.join([
                test_utils._FakeComponent.__module__,
                test_utils._FakeComponent.__name__
            ]))
        launcher.launch()

        output_path = component.outputs['output'].get()[0].uri
        self.assertTrue(fileio.exists(output_path))
        contents = file_io.read_file_to_string(output_path)
        self.assertEqual('test', contents)
예제 #23
0
  def __init__(self, component: base_component.BaseComponent,
               tfx_pipeline: pipeline.Pipeline):
    """Initialize the _ComponentAsDoFn.

    Args:
      component: Component that to be executed.
      tfx_pipeline: Logical pipeline that contains pipeline related information.
    """
    driver_args = data_types.DriverArgs(enable_cache=tfx_pipeline.enable_cache)
    self._component_launcher = component_launcher.ComponentLauncher(
        component=component,
        pipeline_info=tfx_pipeline.pipeline_info,
        driver_args=driver_args,
        metadata_connection_config=tfx_pipeline.metadata_connection_config,
        additional_pipeline_args=tfx_pipeline.additional_pipeline_args)
    self._component_id = component.component_id
예제 #24
0
    def test_pre_execution_cached(self):
        input_dict = {
            'input_a':
            channel.Channel(type_name='input_a',
                            artifacts=[types.TfxArtifact(type_name='input_a')])
        }
        output_dict = {
            'output_a':
            channel.Channel(type_name='output_a',
                            artifacts=[
                                types.TfxArtifact(type_name='output_a',
                                                  split='split')
                            ])
        }
        execution_id = 1
        exec_properties = copy.deepcopy(self._exec_properties)
        driver_args = data_types.DriverArgs(worker_name='worker_name',
                                            base_output_dir='base',
                                            enable_cache=True)
        pipeline_info = data_types.PipelineInfo(
            pipeline_name='my_pipeline_name',
            pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
            run_id='my_run_id')
        component_info = data_types.ComponentInfo(
            component_type='a.b.c', component_id='my_component_id')
        self._mock_metadata.get_artifacts_by_info.side_effect = list(
            input_dict['input_a'].get())
        self._mock_metadata.register_execution.side_effect = [execution_id]
        self._mock_metadata.previous_execution.side_effect = [2]
        self._mock_metadata.fetch_previous_result_artifacts.side_effect = [
            self._output_dict
        ]

        driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
        execution_decision = driver.pre_execution(
            input_dict=input_dict,
            output_dict=output_dict,
            exec_properties=exec_properties,
            driver_args=driver_args,
            pipeline_info=pipeline_info,
            component_info=component_info)
        self.assertTrue(execution_decision.use_cached_results)
        self.assertEqual(execution_decision.execution_id, 1)
        self.assertItemsEqual(execution_decision.exec_properties,
                              exec_properties)
        self.assertItemsEqual(execution_decision.output_dict,
                              self._output_dict)
예제 #25
0
    def test_run(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        test_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()

        pipeline_root = os.path.join(test_dir, 'Test')
        input_path = os.path.join(test_dir, 'input')
        tf.gfile.MakeDirs(os.path.dirname(input_path))
        file_io.write_string_to_file(input_path, 'test')

        input_artifact = types.TfxArtifact(type_name='InputPath')
        input_artifact.uri = input_path

        component = _FakeComponent(name='FakeComponent',
                                   input_channel=channel.as_channel(
                                       [input_artifact]))

        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(worker_name=component.component_id,
                                            base_output_dir=os.path.join(
                                                pipeline_root,
                                                component.component_id),
                                            enable_cache=True)

        launcher = component_launcher.ComponentLauncher(
            component=component,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection_config=connection_config,
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type,
            '.'.join([_FakeComponent.__module__, _FakeComponent.__name__]))
        launcher.launch()

        output_path = os.path.join(pipeline_root, 'output')
        self.assertTrue(tf.gfile.Exists(output_path))
        contents = file_io.read_file_to_string(output_path)
        self.assertEqual('test', contents)
예제 #26
0
 def setUp(self):
   super(ImporterDriverTest, self).setUp()
   self.connection_config = metadata_store_pb2.ConnectionConfig()
   self.connection_config.sqlite.SetInParent()
   self.artifact_type = 'Examples'
   self.output_dict = {
       importer_node.IMPORT_RESULT_KEY:
           types.Channel(type_name=self.artifact_type)
   }
   self.source_uri = 'm/y/u/r/i'
   self.existing_artifact = types.Artifact(type_name=self.artifact_type)
   self.existing_artifact.uri = self.source_uri
   self.component_info = data_types.ComponentInfo(
       component_type='c_type', component_id='c_id')
   self.pipeline_info = data_types.PipelineInfo(
       pipeline_name='p_name', pipeline_root='p_root', run_id='run_id')
   self.driver_args = data_types.DriverArgs(enable_cache=True)
예제 #27
0
    def __init__(
            self, parent_dag: models.DAG,
            component: base_component.BaseComponent,
            component_launcher_class: Type[
                base_component_launcher.BaseComponentLauncher],
            pipeline_info: data_types.PipelineInfo, enable_cache: bool,
            metadata_connection_config: metadata_store_pb2.ConnectionConfig,
            beam_pipeline_args: List[Text],
            additional_pipeline_args: Dict[Text, Any],
            component_config: base_component_config.BaseComponentConfig):
        """Constructs an Airflow implementation of TFX component.

    Args:
      parent_dag: An AirflowPipeline instance as the pipeline DAG.
      component: An instance of base_component.BaseComponent that holds all
        properties of a logical component.
      component_launcher_class: the class of the launcher to launch the
        component.
      pipeline_info: An instance of data_types.PipelineInfo that holds pipeline
        properties.
      enable_cache: Whether or not cache is enabled for this component run.
      metadata_connection_config: A config proto for metadata connection.
      beam_pipeline_args: Beam pipeline args for beam jobs within executor.
      additional_pipeline_args: Additional pipeline args.
      component_config: component config to launch the component.
    """
        # Prepare parameters to create TFX worker.
        enable_cache = (component.enable_cache if component.enable_cache
                        is not None else enable_cache)
        driver_args = data_types.DriverArgs(enable_cache=enable_cache)

        super(AirflowComponent, self).__init__(
            task_id=component.id,
            provide_context=True,
            python_callable=functools.partial(
                _airflow_component_launcher,
                component=component,
                component_launcher_class=component_launcher_class,
                pipeline_info=pipeline_info,
                driver_args=driver_args,
                metadata_connection_config=metadata_connection_config,
                beam_pipeline_args=beam_pipeline_args,
                additional_pipeline_args=additional_pipeline_args,
                component_config=component_config),
            dag=parent_dag)
예제 #28
0
    def testRun(self, mock_publisher):
        mock_publisher.return_value.publish_execution.return_value = {}

        example_gen = FileBasedExampleGen(
            custom_executor_spec=executor_spec.ExecutorClassSpec(
                avro_executor.Executor),
            input_base=self.avro_dir_path,
            input_config=self.input_config,
            output_config=self.output_config,
            instance_name='AvroExampleGen')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        pipeline_root = os.path.join(output_data_dir, 'Test')
        fileio.makedirs(pipeline_root)
        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(enable_cache=True)

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        metadata_connection = metadata.Metadata(connection_config)

        launcher = in_process_component_launcher.InProcessComponentLauncher.create(
            component=example_gen,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection=metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type, '.'.join(
                [FileBasedExampleGen.__module__,
                 FileBasedExampleGen.__name__]))

        launcher.launch()
        mock_publisher.return_value.publish_execution.assert_called_once()

        # Check output paths.
        self.assertTrue(
            fileio.exists(os.path.join(pipeline_root, example_gen.id)))
예제 #29
0
  def run(self, tfx_pipeline: pipeline.Pipeline) -> None:
    """Runs given logical pipeline locally.

    Args:
      tfx_pipeline: Logical pipeline containing pipeline args and components.
    """
    # For CLI, while creating or updating pipeline, pipeline_args are extracted
    # and hence we avoid executing the pipeline.
    if 'TFX_JSON_EXPORT_PIPELINE_ARGS_PATH' in os.environ:
      return

    tfx_pipeline.pipeline_info.run_id = datetime.datetime.now().isoformat()

    with telemetry_utils.scoped_labels(
        {telemetry_utils.LABEL_TFX_RUNNER: 'local'}):
      # Run each component. Note that the pipeline.components list is in
      # topological order.
      #
      # TODO(b/171319478): After IR-based execution is used, used multi-threaded
      # execution so that independent components can be run in parallel.
      for component in tfx_pipeline.components:
        # TODO(b/187122662): Pass through pip dependencies as a first-class
        # component flag.
        if isinstance(component, base_component.BaseComponent):
          component._resolve_pip_dependencies(  # pylint: disable=protected-access
              tfx_pipeline.pipeline_info.pipeline_root)
        (component_launcher_class, component_config) = (
            config_utils.find_component_launch_info(self._config, component))
        driver_args = data_types.DriverArgs(
            enable_cache=tfx_pipeline.enable_cache)
        metadata_connection = metadata.Metadata(
            tfx_pipeline.metadata_connection_config)
        node_launcher = component_launcher_class.create(
            component=component,
            pipeline_info=tfx_pipeline.pipeline_info,
            driver_args=driver_args,
            metadata_connection=metadata_connection,
            beam_pipeline_args=tfx_pipeline.beam_pipeline_args,
            additional_pipeline_args=tfx_pipeline.additional_pipeline_args,
            component_config=component_config)
        logging.info('Component %s is running.', component.id)
        node_launcher.launch()
        logging.info('Component %s is finished.', component.id)
예제 #30
0
  def run(self,
          component: base_component.BaseComponent,
          enable_cache: bool = True) -> execution_result.ExecutionResult:
    """Run a given TFX component in the interactive context.

    Args:
      component: Component instance to be run.
      enable_cache: whether caching logic should be enabled in the driver.

    Returns:
      execution_result.ExecutionResult object.
    """
    run_id = datetime.datetime.now().isoformat()
    pipeline_info = data_types.PipelineInfo(
        pipeline_name=self.pipeline_name,
        pipeline_root=self.pipeline_root,
        run_id=run_id)
    driver_args = data_types.DriverArgs(
        enable_cache=enable_cache,
        interactive_resolution=True)
    try:
      parallelism = multiprocessing.cpu_count()
    except NotImplementedError:
      absl.logging.info('Using a single process for Beam pipeline execution.')
      parallelism = 1
    beam_pipeline_args = ['--direct_num_workers=%d' % parallelism]
    additional_pipeline_args = {}
    for name, output in component.outputs.get_all().items():
      for artifact in output.get():
        artifact.pipeline_name = self.pipeline_name
        artifact.producer_component = component.id
        artifact.run_id = run_id
        artifact.name = name
    # TODO(hongyes): figure out how to resolve launcher class in the interactive
    # context.
    launcher = in_process_component_launcher.InProcessComponentLauncher.create(
        component, pipeline_info, driver_args, self.metadata_connection_config,
        beam_pipeline_args, additional_pipeline_args)
    execution_id = launcher.launch()

    return execution_result.ExecutionResult(
        component=component,
        execution_id=execution_id)