def launch_container_component( component: base_node.BaseNode, component_launcher_class: Type[ base_component_launcher.BaseComponentLauncher], component_config: base_component_config.BaseComponentConfig, pipeline: tfx_pipeline.Pipeline): """Use the kubernetes component launcher to launch the component. Args: component: Container component to be executed. component_launcher_class: The class of the launcher to launch the component. component_config: component config to launch the component. pipeline: Logical pipeline that contains pipeline related information. """ driver_args = data_types.DriverArgs(enable_cache=pipeline.enable_cache) metadata_connection = metadata.Metadata( pipeline.metadata_connection_config) component_launcher = component_launcher_class.create( component=component, pipeline_info=pipeline.pipeline_info, driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=pipeline.beam_pipeline_args, additional_pipeline_args=pipeline.additional_pipeline_args, component_config=component_config) logging.info('Component %s is running.', component.id) component_launcher.launch() logging.info('Component %s is finished.', component.id)
def __init__(self, component: base_component.BaseComponent, tfx_pipeline: pipeline.Pipeline): """Initialize the _ComponentAsDoFn. Args: component: Component that to be executed. tfx_pipeline: Logical pipeline that contains pipeline related information. """ driver_args = data_types.DriverArgs( enable_cache=tfx_pipeline.enable_cache) self._additional_pipeline_args = tfx_pipeline.additional_pipeline_args.copy( ) _job_name = re.sub( r'[^0-9a-zA-Z-]+', '-', '{pipeline_name}-{component}-{ts}'.format( pipeline_name=tfx_pipeline.pipeline_info.pipeline_name, component=component.component_id, ts=int(datetime.datetime.timestamp( datetime.datetime.now()))).lower()) self._additional_pipeline_args['beam_pipeline_args'] = [ arg for arg in self._additional_pipeline_args.setdefault( 'beam_pipeline_args', []) if not arg.startswith("--job_name") ] self._additional_pipeline_args['beam_pipeline_args'].append( '--job_name={}'.format(_job_name)) self._component_launcher = component_launcher.ComponentLauncher( component=component, pipeline_info=tfx_pipeline.pipeline_info, driver_args=driver_args, metadata_connection_config=tfx_pipeline.metadata_connection_config, additional_pipeline_args=self._additional_pipeline_args) self._component_id = component.component_id
def _create_launcher_context(self, component_config=None): test_dir = self.get_temp_dir() connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() pipeline_root = os.path.join(test_dir, 'Test') input_artifact = test_utils._InputArtifact() input_artifact.uri = os.path.join(test_dir, 'input') component = test_utils._FakeComponent( name='FakeComponent', input_channel=channel_utils.as_channel([input_artifact]), custom_executor_spec=executor_spec.ExecutorContainerSpec( image='gcr://test', args=['{{input_dict["input"][0].uri}}'])) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(enable_cache=True) launcher = kubernetes_component_launcher.KubernetesComponentLauncher.create( component=component, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection_config=connection_config, beam_pipeline_args=[], additional_pipeline_args={}, component_config=component_config) return {'launcher': launcher, 'input_artifact': input_artifact}
def __init__( self, parent_dag: models.DAG, component: base_component.BaseComponent, pipeline_info: data_types.PipelineInfo, enable_cache: bool, metadata_connection_config: metadata_store_pb2.ConnectionConfig, additional_pipeline_args: Dict[Text, Any]): """Constructs an Airflow implementation of TFX component. Args: parent_dag: An AirflowPipeline instance as the pipeline DAG. component: An instance of base_component.BaseComponent that holds all properties of a logical component. pipeline_info: An instance of data_types.PipelineInfo that holds pipeline properties. enable_cache: Whether or not cache is enabled for this component run. metadata_connection_config: A config proto for metadata connection. additional_pipeline_args: Additional pipeline args. """ # Prepare parameters to create TFX worker. driver_args = data_types.DriverArgs(enable_cache=enable_cache) super(AirflowComponent, self).__init__( task_id=component.component_id, provide_context=True, python_callable=functools.partial( _airflow_component_launcher, component=component, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection_config=metadata_connection_config, additional_pipeline_args=additional_pipeline_args), dag=parent_dag)
def __init__(self, component: base_component.BaseComponent, component_launcher_class: Type[ base_component_launcher.BaseComponentLauncher], component_config: base_component_config.BaseComponentConfig, tfx_pipeline: pipeline.Pipeline): """Initialize the _ComponentAsDoFn. Args: component: Component that to be executed. component_launcher_class: The class of the launcher to launch the component. component_config: component config to launch the component. tfx_pipeline: Logical pipeline that contains pipeline related information. """ enable_cache = (component.enable_cache if component.enable_cache is not None else tfx_pipeline.enable_cache) driver_args = data_types.DriverArgs(enable_cache=enable_cache) metadata_connection = metadata.Metadata( tfx_pipeline.metadata_connection_config) self._component_launcher = component_launcher_class.create( component=component, pipeline_info=tfx_pipeline.pipeline_info, driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=tfx_pipeline.beam_pipeline_args, additional_pipeline_args=tfx_pipeline.additional_pipeline_args, component_config=component_config) self._component_id = component.id
def run(self, component: base_component.BaseComponent, enable_cache: bool = True) -> execution_result.ExecutionResult: """Run a given TFX component in the interactive context. Args: component: Component instance to be run. enable_cache: whether caching logic should be enabled in the driver. Returns: execution_result.ExecutionResult object. """ run_id = datetime.datetime.now().isoformat() pipeline_info = data_types.PipelineInfo( pipeline_name=self.pipeline_name, pipeline_root=self.pipeline_root, run_id=run_id) driver_args = data_types.DriverArgs(enable_cache=enable_cache, interactive_resolution=True) additional_pipeline_args = {} for name, output in component.outputs.get_all().items(): for artifact in output.get(): artifact.pipeline_name = self.pipeline_name artifact.producer_component = component.component_id artifact.run_id = run_id artifact.name = name # TODO(hongyes): figure out how to resolve launcher class in the interactive # context. launcher = in_process_component_launcher.InProcessComponentLauncher.create( component, pipeline_info, driver_args, self.metadata_connection_config, additional_pipeline_args) execution_id = launcher.launch() return execution_result.ExecutionResult(component=component, execution_id=execution_id)
def setUp(self): super(StubComponentLauncherTest, self).setUp() test_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() self.metadata_connection = metadata.Metadata(connection_config) self.pipeline_root = os.path.join(test_dir, 'Test') self.input_dir = os.path.join(test_dir, 'input') self.output_dir = os.path.join(test_dir, 'output') self.record_dir = os.path.join(test_dir, 'record') tf.io.gfile.makedirs(self.input_dir) tf.io.gfile.makedirs(self.output_dir) tf.io.gfile.makedirs(self.record_dir) input_artifact = test_utils._InputArtifact() # pylint: disable=protected-access input_artifact.uri = os.path.join(self.input_dir, 'result.txt') output_artifact = test_utils._OutputArtifact() # pylint: disable=protected-access output_artifact.uri = os.path.join(self.output_dir, 'result.txt') self.component = test_utils._FakeComponent( # pylint: disable=protected-access name='FakeComponent', input_channel=channel_utils.as_channel([input_artifact]), output_channel=channel_utils.as_channel([output_artifact])) self.driver_args = data_types.DriverArgs(enable_cache=True) self.pipeline_info = data_types.PipelineInfo( pipeline_name='Test', pipeline_root=self.pipeline_root, run_id='123')
def setUp(self): super(ImporterDriverTest, self).setUp() self.connection_config = metadata_store_pb2.ConnectionConfig() self.connection_config.sqlite.SetInParent() self.output_dict = { importer_node.IMPORT_RESULT_KEY: types.Channel(type=standard_artifacts.Examples) } self.source_uri = 'm/y/u/r/i' self.properties = { 'split_names': artifact_utils.encode_split_names(['train', 'eval']) } self.custom_properties = { 'string_custom_property': 'abc', 'int_custom_property': 123, } self.existing_artifacts = [] existing_artifact = standard_artifacts.Examples() existing_artifact.uri = self.source_uri existing_artifact.split_names = self.properties['split_names'] self.existing_artifacts.append(existing_artifact) self.pipeline_info = data_types.PipelineInfo(pipeline_name='p_name', pipeline_root='p_root', run_id='run_id') self.component_info = data_types.ComponentInfo( component_type='c_type', component_id='c_id', pipeline_info=self.pipeline_info) self.driver_args = data_types.DriverArgs(enable_cache=True)
def setUp(self): self._mock_metadata = tf.test.mock.Mock() self._input_dict = { 'input_data': [types.TfxArtifact(type_name='InputType')], } input_dir = os.path.join( os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), self._testMethodName, 'input_dir') # valid input artifacts must have a uri pointing to an existing directory. for key, input_list in self._input_dict.items(): for index, artifact in enumerate(input_list): artifact.id = index + 1 uri = os.path.join(input_dir, key, str(artifact.id), '') artifact.uri = uri tf.gfile.MakeDirs(uri) self._output_dict = { 'output_data': [types.TfxArtifact(type_name='OutputType')], } self._exec_properties = { 'key': 'value', } self._base_output_dir = os.path.join( os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), self._testMethodName, 'base_output_dir') self._driver_args = data_types.DriverArgs( worker_name='worker_name', base_output_dir=self._base_output_dir, enable_cache=True) self._execution_id = 100
def setUp(self): super(ImporterDriverTest, self).setUp() self.connection_config = metadata_store_pb2.ConnectionConfig() self.connection_config.sqlite.SetInParent() self.output_dict = { importer_node.IMPORT_RESULT_KEY: types.Channel(type=standard_artifacts.Examples) } self.source_uri = ['m/y/u/r/i/1', 'm/y/u/r/i/2'] self.split = ['train', 'eval'] self.existing_artifacts = [] for uri, split in zip(self.source_uri, self.split): existing_artifact = standard_artifacts.Examples() existing_artifact.uri = uri existing_artifact.split_names = artifact_utils.encode_split_names( [split]) self.existing_artifacts.append(existing_artifact) self.component_info = data_types.ComponentInfo(component_type='c_type', component_id='c_id') self.pipeline_info = data_types.PipelineInfo(pipeline_name='p_name', pipeline_root='p_root', run_id='run_id') self.driver_args = data_types.DriverArgs(enable_cache=True)
def testRun(self, mock_publisher): mock_publisher.return_value.publish_execution.return_value = {} example_gen = FileBasedExampleGen( custom_executor_spec=executor_spec.ExecutorClassSpec( avro_executor.Executor), input=external_input(self.avro_dir_path), input_config=self.input_config, output_config=self.output_config, instance_name='AvroExampleGen') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) pipeline_root = os.path.join(output_data_dir, 'Test') tf.io.gfile.makedirs(pipeline_root) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(enable_cache=True) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() metadata_connection = metadata.Metadata(connection_config) launcher = in_process_component_launcher.InProcessComponentLauncher.create( component=example_gen, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=[], additional_pipeline_args={}) self.assertEqual( launcher._component_info.component_type, '.'.join( [FileBasedExampleGen.__module__, FileBasedExampleGen.__name__])) launcher.launch() mock_publisher.return_value.publish_execution.assert_called_once() # Get output paths. component_id = example_gen.id output_path = os.path.join(pipeline_root, component_id, 'examples/1') examples = standard_artifacts.Examples() examples.uri = output_path examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) # Check Avro example gen outputs. train_output_file = os.path.join(examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.io.gfile.exists(train_output_file)) self.assertTrue(tf.io.gfile.exists(eval_output_file)) self.assertGreater( tf.io.gfile.GFile(train_output_file).size(), tf.io.gfile.GFile(eval_output_file).size())
def test_run(self, mock_publisher): mock_publisher.return_value.publish_execution.return_value = {} example_gen = FileBasedExampleGen( executor_class=avro_executor.Executor, input_base=external_input(self.avro_dir_path), input_config=self.input_config, output_config=self.output_config, name='AvroExampleGenComponent') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) pipeline_root = os.path.join(output_data_dir, 'Test') tf.gfile.MakeDirs(pipeline_root) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(enable_cache=True) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() launcher = component_launcher.ComponentLauncher( component=example_gen, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection_config=connection_config, additional_pipeline_args={}) self.assertEqual( launcher._component_info.component_type, '.'.join( [FileBasedExampleGen.__module__, FileBasedExampleGen.__name__])) launcher.launch() mock_publisher.return_value.publish_execution.assert_called_once() # Get output paths. component_id = '.'.join([example_gen.component_name, example_gen.name]) output_path = os.path.join(pipeline_root, component_id, 'examples/1') train_examples = types.TfxArtifact(type_name='ExamplesPath', split='train') train_examples.uri = os.path.join(output_path, 'train') eval_examples = types.TfxArtifact(type_name='ExamplesPath', split='eval') eval_examples.uri = os.path.join(output_path, 'eval') # Check Avro example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())
def run( self, component: base_node.BaseNode, enable_cache: bool = True, beam_pipeline_args: Optional[List[Text]] = None ) -> execution_result.ExecutionResult: """Run a given TFX component in the interactive context. Args: component: Component instance to be run. enable_cache: whether caching logic should be enabled in the driver. beam_pipeline_args: Optional Beam pipeline args for beam jobs within executor. Executor will use beam DirectRunner as Default. If provided, will override beam_pipeline_args specified in constructor. Returns: execution_result.ExecutionResult object. """ run_id = datetime.datetime.now().isoformat() pipeline_info = data_types.PipelineInfo( pipeline_name=self.pipeline_name, pipeline_root=self.pipeline_root, run_id=run_id) driver_args = data_types.DriverArgs(enable_cache=enable_cache, interactive_resolution=True) metadata_connection = metadata.Metadata( self.metadata_connection_config) beam_pipeline_args = beam_pipeline_args or self.beam_pipeline_args additional_pipeline_args = {} for name, output in component.outputs.items(): for artifact in output.get(): artifact.pipeline_name = self.pipeline_name artifact.producer_component = component.id artifact.name = name # Special treatment for pip dependencies. # TODO(b/187122662): Pass through pip dependencies as a first-class # component flag. if isinstance(component, base_component.BaseComponent): component._resolve_pip_dependencies(self.pipeline_root) # pylint: disable=protected-access # TODO(hongyes): figure out how to resolve launcher class in the interactive # context. launcher = in_process_component_launcher.InProcessComponentLauncher.create( component, pipeline_info, driver_args, metadata_connection, beam_pipeline_args, additional_pipeline_args) try: import colab # pytype: disable=import-error # pylint: disable=g-import-not-at-top, unused-import, unused-variable runner_label = 'interactivecontext-colab' except ImportError: runner_label = 'interactivecontext' with telemetry_utils.scoped_labels({ telemetry_utils.LABEL_TFX_RUNNER: runner_label, }): execution_id = launcher.launch().execution_id return execution_result.ExecutionResult(component=component, execution_id=execution_id)
def main(): # Log to the container's stdout so it can be streamed by the orchestrator. logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.getLogger().setLevel(logging.INFO) parser = argparse.ArgumentParser() parser.add_argument('--pipeline_name', type=str, required=True) parser.add_argument('--pipeline_root', type=str, required=True) parser.add_argument('--run_id', type=str, required=True) parser.add_argument('--metadata_config', type=str, required=True) parser.add_argument('--beam_pipeline_args', type=str, required=True) parser.add_argument('--additional_pipeline_args', type=str, required=True) parser.add_argument('--component_launcher_class_path', type=str, required=True) parser.add_argument('--enable_cache', action='store_true') parser.add_argument('--serialized_component', type=str, required=True) parser.add_argument('--component_config', type=str, required=True) args = parser.parse_args() component = json_utils.loads(args.serialized_component) component_config = json_utils.loads(args.component_config) component_launcher_class = import_utils.import_class_by_path( args.component_launcher_class_path) if not issubclass(component_launcher_class, base_component_launcher.BaseComponentLauncher): raise TypeError( 'component_launcher_class "%s" is not subclass of base_component_launcher.BaseComponentLauncher' % component_launcher_class) metadata_config = metadata_store_pb2.ConnectionConfig() json_format.Parse(args.metadata_config, metadata_config) driver_args = data_types.DriverArgs(enable_cache=args.enable_cache) beam_pipeline_args = json.loads(args.beam_pipeline_args) additional_pipeline_args = json.loads(args.additional_pipeline_args) launcher = component_launcher_class.create( component=component, pipeline_info=data_types.PipelineInfo( pipeline_name=args.pipeline_name, pipeline_root=args.pipeline_root, run_id=args.run_id, ), driver_args=driver_args, metadata_connection=metadata.Metadata( connection_config=metadata_config), beam_pipeline_args=beam_pipeline_args, additional_pipeline_args=additional_pipeline_args, component_config=component_config) # Attach necessary labels to distinguish different runner and DSL. with telemetry_utils.scoped_labels({ telemetry_utils.LABEL_TFX_RUNNER: 'kubernetes', }): launcher.launch()
def main(): # Log to the container's stdout so Kubeflow Pipelines UI can display logs to # the user. logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.getLogger().setLevel(logging.INFO) parser = argparse.ArgumentParser() parser.add_argument('--pipeline_name', type=str, required=True) parser.add_argument('--pipeline_root', type=str, required=True) parser.add_argument('--kubeflow_metadata_config', type=str, required=True) parser.add_argument('--beam_pipeline_args', type=str, required=True) parser.add_argument('--additional_pipeline_args', type=str, required=True) parser.add_argument('--component_launcher_class_path', type=str, required=True) parser.add_argument('--enable_cache', action='store_true') parser.add_argument('--serialized_component', type=str, required=True) parser.add_argument('--component_config', type=str, required=True) args = parser.parse_args() component = json_utils.loads(args.serialized_component) component_config = json_utils.loads(args.component_config) component_launcher_class = import_utils.import_class_by_path( args.component_launcher_class_path) if not issubclass(component_launcher_class, base_component_launcher.BaseComponentLauncher): raise TypeError( 'component_launcher_class "%s" is not subclass of base_component_launcher.BaseComponentLauncher' % component_launcher_class) kubeflow_metadata_config = kubeflow_pb2.KubeflowMetadataConfig() json_format.Parse(args.kubeflow_metadata_config, kubeflow_metadata_config) metadata_connection = kubeflow_metadata_adapter.KubeflowMetadataAdapter( _get_metadata_connection_config(kubeflow_metadata_config)) driver_args = data_types.DriverArgs(enable_cache=args.enable_cache) beam_pipeline_args = _make_beam_pipeline_args(args.beam_pipeline_args) additional_pipeline_args = json.loads(args.additional_pipeline_args) launcher = component_launcher_class.create( component=component, pipeline_info=data_types.PipelineInfo( pipeline_name=args.pipeline_name, pipeline_root=args.pipeline_root, run_id=os.environ['WORKFLOW_ID']), driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=beam_pipeline_args, additional_pipeline_args=additional_pipeline_args, component_config=component_config) execution_info = launcher.launch() # Dump the UI metadata. _dump_ui_metadata(component, execution_info)
def setUp(self): super(BaseDriverTest, self).setUp() self._mock_metadata = tf.compat.v1.test.mock.Mock() self._input_dict = { 'input_data': types.Channel( type=_InputArtifact, artifacts=[_InputArtifact()], producer_component_id='c', output_key='k'), 'input_string': types.Channel( type=standard_artifacts.String, artifacts=[ standard_artifacts.String(), standard_artifacts.String() ], producer_component_id='c2', output_key='k2'), } input_dir = os.path.join( os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), self._testMethodName, 'input_dir') # valid input artifacts must have a uri pointing to an existing directory. for key, input_channel in self._input_dict.items(): for index, artifact in enumerate(input_channel.get()): artifact.id = index + 1 uri = os.path.join(input_dir, key, str(artifact.id)) artifact.uri = uri tf.io.gfile.makedirs(uri) self._output_dict = { 'output_data': types.Channel(type=_OutputArtifact, artifacts=[_OutputArtifact()]), 'output_multi_data': types.Channel( type=_OutputArtifact, matching_channel_name='input_string') } self._input_artifacts = channel_utils.unwrap_channel_dict(self._input_dict) self._output_artifacts = channel_utils.unwrap_channel_dict( self._output_dict) self._exec_properties = { 'key': 'value', } self._execution_id = 100 self._execution = metadata_store_pb2.Execution() self._execution.id = self._execution_id self._context_id = 123 self._driver_args = data_types.DriverArgs(enable_cache=True) self._pipeline_info = data_types.PipelineInfo( pipeline_name='my_pipeline_name', pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), run_id='my_run_id') self._component_info = data_types.ComponentInfo( component_type='a.b.c', component_id='my_component_id', pipeline_info=self._pipeline_info)
def setUp(self): self._component = _FakeComponent( _FakeComponentSpec(input=types.Channel(type_name='type_a'), output=types.Channel(type_name='type_b'))) self._pipeline_info = data_types.PipelineInfo('name', 'root') self._driver_args = data_types.DriverArgs(True) self._metadata_connection_config = metadata.sqlite_metadata_connection_config( os.path.join(os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), 'metadata')) self._parent_dag = models.DAG(dag_id=self._pipeline_info.pipeline_name, start_date=datetime.datetime(2018, 1, 1), schedule_interval=None)
def setUp(self): super(ResolverDriverTest, self).setUp() self.connection_config = metadata_store_pb2.ConnectionConfig() self.connection_config.sqlite.SetInParent() self.component_info = data_types.ComponentInfo( component_type='c_type', component_id='c_id') self.pipeline_info = data_types.PipelineInfo( pipeline_name='p_name', pipeline_root='p_root', run_id='run_id') self.driver_args = data_types.DriverArgs(enable_cache=True) self.source_channel_key = 'source_channel' self.source_channels = { self.source_channel_key: types.Channel(type=standard_artifacts.Examples) }
def __init__(self, parent_dag, component_name, unique_name, driver, executor, input_dict, output_dict, exec_properties): # Prepare parameters to create TFX worker. if unique_name: worker_name = component_name + '.' + unique_name else: worker_name = component_name task_id = parent_dag.dag_id + '.' + worker_name # Create output object of appropriate type output_dir = self._get_working_dir(parent_dag.project_path, component_name, unique_name or '') # Update the output dict before providing to downstream componentsget_ for k, output_list in output_dict.items(): for single_output in output_list: single_output.source = _OrchestrationSource( key=k, component_id=task_id) my_logger_config = logging_utils.LoggerConfig( log_root=parent_dag.logger_config.log_root, log_level=parent_dag.logger_config.log_level, pipeline_name=parent_dag.logger_config.pipeline_name, worker_name=worker_name) driver_args = data_types.DriverArgs( worker_name=worker_name, base_output_dir=output_dir, enable_cache=parent_dag.enable_cache) worker = _TfxWorker( component_name=component_name, task_id=task_id, parent_dag=parent_dag, input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, driver_args=driver_args, driver_class=driver, executor_class=executor, additional_pipeline_args=parent_dag.additional_pipeline_args, metadata_connection_config=parent_dag.metadata_connection_config, logger_config=my_logger_config) subdag = subdag_operator.SubDagOperator(subdag=worker, task_id=worker_name, dag=parent_dag) parent_dag.add_node_to_graph(node=subdag, consumes=input_dict.values(), produces=output_dict.values())
def testPreExecutionNewExecution(self, mock_verify_input_artifacts_fn): input_dict = { 'input_a': types.Channel(type_name='input_a', artifacts=[types.Artifact(type_name='input_a')]) } output_dict = { 'output_a': types.Channel(type_name='output_a', artifacts=[ types.Artifact(type_name='output_a', split='split') ]) } execution_id = 1 context_id = 123 exec_properties = copy.deepcopy(self._exec_properties) driver_args = data_types.DriverArgs(enable_cache=True) pipeline_info = data_types.PipelineInfo( pipeline_name='my_pipeline_name', pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), run_id='my_run_id') component_info = data_types.ComponentInfo( component_type='a.b.c', component_id='my_component_id') self._mock_metadata.get_artifacts_by_info.side_effect = list( input_dict['input_a'].get()) self._mock_metadata.register_execution.side_effect = [execution_id] self._mock_metadata.previous_execution.side_effect = [None] self._mock_metadata.register_run_context_if_not_exists.side_effect = [ context_id ] driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) execution_decision = driver.pre_execution( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, driver_args=driver_args, pipeline_info=pipeline_info, component_info=component_info) self.assertFalse(execution_decision.use_cached_results) self.assertEqual(execution_decision.execution_id, 1) self.assertItemsEqual(execution_decision.exec_properties, exec_properties) self.assertEqual( execution_decision.output_dict['output_a'][0].uri, os.path.join(pipeline_info.pipeline_root, component_info.component_id, 'output_a', str(execution_id), 'split', ''))
def __init__( self, parent_dag: models.DAG, component: base_node.BaseNode, component_launcher_class: Type[ base_component_launcher.BaseComponentLauncher], pipeline_info: data_types.PipelineInfo, enable_cache: bool, metadata_connection_config: metadata_store_pb2.ConnectionConfig, beam_pipeline_args: List[Text], additional_pipeline_args: Dict[Text, Any], component_config: base_component_config.BaseComponentConfig): """Constructs an Airflow implementation of TFX component. Args: parent_dag: An AirflowPipeline instance as the pipeline DAG. component: An instance of base_node.BaseNode that holds all properties of a logical component. component_launcher_class: The class of the launcher to launch the component. pipeline_info: An instance of data_types.PipelineInfo that holds pipeline properties. enable_cache: Whether or not cache is enabled for this component run. metadata_connection_config: A config proto for metadata connection. beam_pipeline_args: Pipeline arguments for Beam powered Components. additional_pipeline_args: Additional pipeline args. component_config: Component config to launch the component. """ # Prepare parameters to create TFX worker. driver_args = data_types.DriverArgs(enable_cache=enable_cache) exec_properties = component.exec_properties super(AirflowComponent, self).__init__( task_id=component.id, # TODO(b/183172663): Delete `provide_context` when we drop support of # airflow 1.x. provide_context=True, python_callable=functools.partial( _airflow_component_launcher, component=component, component_launcher_class=component_launcher_class, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection_config=metadata_connection_config, beam_pipeline_args=beam_pipeline_args, additional_pipeline_args=additional_pipeline_args, component_config=component_config), # op_kwargs is a templated field for PythonOperator, which means Airflow # will inspect the dictionary and resolve any templated fields. op_kwargs={'exec_properties': exec_properties}, dag=parent_dag)
def testRun(self, mock_publisher): mock_publisher.return_value.publish_execution.return_value = {} test_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() metadata_connection = metadata.Metadata(connection_config) pipeline_root = os.path.join(test_dir, 'Test') input_path = os.path.join(test_dir, 'input') fileio.makedirs(os.path.dirname(input_path)) file_io.write_string_to_file(input_path, 'test') input_artifact = test_utils._InputArtifact() input_artifact.uri = input_path component = test_utils._FakeComponent( name='FakeComponent', input_channel=channel_utils.as_channel([input_artifact])) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(enable_cache=True) # We use InProcessComponentLauncher to test BaseComponentLauncher logics. launcher = in_process_component_launcher.InProcessComponentLauncher.create( component=component, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=[], additional_pipeline_args={}) self.assertEqual( launcher._component_info.component_type, '.'.join([ test_utils._FakeComponent.__module__, test_utils._FakeComponent.__name__ ])) launcher.launch() output_path = component.outputs['output'].get()[0].uri self.assertTrue(fileio.exists(output_path)) contents = file_io.read_file_to_string(output_path) self.assertEqual('test', contents)
def __init__(self, component: base_component.BaseComponent, tfx_pipeline: pipeline.Pipeline): """Initialize the _ComponentAsDoFn. Args: component: Component that to be executed. tfx_pipeline: Logical pipeline that contains pipeline related information. """ driver_args = data_types.DriverArgs(enable_cache=tfx_pipeline.enable_cache) self._component_launcher = component_launcher.ComponentLauncher( component=component, pipeline_info=tfx_pipeline.pipeline_info, driver_args=driver_args, metadata_connection_config=tfx_pipeline.metadata_connection_config, additional_pipeline_args=tfx_pipeline.additional_pipeline_args) self._component_id = component.component_id
def test_pre_execution_cached(self): input_dict = { 'input_a': channel.Channel(type_name='input_a', artifacts=[types.TfxArtifact(type_name='input_a')]) } output_dict = { 'output_a': channel.Channel(type_name='output_a', artifacts=[ types.TfxArtifact(type_name='output_a', split='split') ]) } execution_id = 1 exec_properties = copy.deepcopy(self._exec_properties) driver_args = data_types.DriverArgs(worker_name='worker_name', base_output_dir='base', enable_cache=True) pipeline_info = data_types.PipelineInfo( pipeline_name='my_pipeline_name', pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), run_id='my_run_id') component_info = data_types.ComponentInfo( component_type='a.b.c', component_id='my_component_id') self._mock_metadata.get_artifacts_by_info.side_effect = list( input_dict['input_a'].get()) self._mock_metadata.register_execution.side_effect = [execution_id] self._mock_metadata.previous_execution.side_effect = [2] self._mock_metadata.fetch_previous_result_artifacts.side_effect = [ self._output_dict ] driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) execution_decision = driver.pre_execution( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, driver_args=driver_args, pipeline_info=pipeline_info, component_info=component_info) self.assertTrue(execution_decision.use_cached_results) self.assertEqual(execution_decision.execution_id, 1) self.assertItemsEqual(execution_decision.exec_properties, exec_properties) self.assertItemsEqual(execution_decision.output_dict, self._output_dict)
def test_run(self, mock_publisher): mock_publisher.return_value.publish_execution.return_value = {} test_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() pipeline_root = os.path.join(test_dir, 'Test') input_path = os.path.join(test_dir, 'input') tf.gfile.MakeDirs(os.path.dirname(input_path)) file_io.write_string_to_file(input_path, 'test') input_artifact = types.TfxArtifact(type_name='InputPath') input_artifact.uri = input_path component = _FakeComponent(name='FakeComponent', input_channel=channel.as_channel( [input_artifact])) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(worker_name=component.component_id, base_output_dir=os.path.join( pipeline_root, component.component_id), enable_cache=True) launcher = component_launcher.ComponentLauncher( component=component, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection_config=connection_config, additional_pipeline_args={}) self.assertEqual( launcher._component_info.component_type, '.'.join([_FakeComponent.__module__, _FakeComponent.__name__])) launcher.launch() output_path = os.path.join(pipeline_root, 'output') self.assertTrue(tf.gfile.Exists(output_path)) contents = file_io.read_file_to_string(output_path) self.assertEqual('test', contents)
def setUp(self): super(ImporterDriverTest, self).setUp() self.connection_config = metadata_store_pb2.ConnectionConfig() self.connection_config.sqlite.SetInParent() self.artifact_type = 'Examples' self.output_dict = { importer_node.IMPORT_RESULT_KEY: types.Channel(type_name=self.artifact_type) } self.source_uri = 'm/y/u/r/i' self.existing_artifact = types.Artifact(type_name=self.artifact_type) self.existing_artifact.uri = self.source_uri self.component_info = data_types.ComponentInfo( component_type='c_type', component_id='c_id') self.pipeline_info = data_types.PipelineInfo( pipeline_name='p_name', pipeline_root='p_root', run_id='run_id') self.driver_args = data_types.DriverArgs(enable_cache=True)
def __init__( self, parent_dag: models.DAG, component: base_component.BaseComponent, component_launcher_class: Type[ base_component_launcher.BaseComponentLauncher], pipeline_info: data_types.PipelineInfo, enable_cache: bool, metadata_connection_config: metadata_store_pb2.ConnectionConfig, beam_pipeline_args: List[Text], additional_pipeline_args: Dict[Text, Any], component_config: base_component_config.BaseComponentConfig): """Constructs an Airflow implementation of TFX component. Args: parent_dag: An AirflowPipeline instance as the pipeline DAG. component: An instance of base_component.BaseComponent that holds all properties of a logical component. component_launcher_class: the class of the launcher to launch the component. pipeline_info: An instance of data_types.PipelineInfo that holds pipeline properties. enable_cache: Whether or not cache is enabled for this component run. metadata_connection_config: A config proto for metadata connection. beam_pipeline_args: Beam pipeline args for beam jobs within executor. additional_pipeline_args: Additional pipeline args. component_config: component config to launch the component. """ # Prepare parameters to create TFX worker. enable_cache = (component.enable_cache if component.enable_cache is not None else enable_cache) driver_args = data_types.DriverArgs(enable_cache=enable_cache) super(AirflowComponent, self).__init__( task_id=component.id, provide_context=True, python_callable=functools.partial( _airflow_component_launcher, component=component, component_launcher_class=component_launcher_class, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection_config=metadata_connection_config, beam_pipeline_args=beam_pipeline_args, additional_pipeline_args=additional_pipeline_args, component_config=component_config), dag=parent_dag)
def testRun(self, mock_publisher): mock_publisher.return_value.publish_execution.return_value = {} example_gen = FileBasedExampleGen( custom_executor_spec=executor_spec.ExecutorClassSpec( avro_executor.Executor), input_base=self.avro_dir_path, input_config=self.input_config, output_config=self.output_config, instance_name='AvroExampleGen') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) pipeline_root = os.path.join(output_data_dir, 'Test') fileio.makedirs(pipeline_root) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(enable_cache=True) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() metadata_connection = metadata.Metadata(connection_config) launcher = in_process_component_launcher.InProcessComponentLauncher.create( component=example_gen, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=[], additional_pipeline_args={}) self.assertEqual( launcher._component_info.component_type, '.'.join( [FileBasedExampleGen.__module__, FileBasedExampleGen.__name__])) launcher.launch() mock_publisher.return_value.publish_execution.assert_called_once() # Check output paths. self.assertTrue( fileio.exists(os.path.join(pipeline_root, example_gen.id)))
def run(self, tfx_pipeline: pipeline.Pipeline) -> None: """Runs given logical pipeline locally. Args: tfx_pipeline: Logical pipeline containing pipeline args and components. """ # For CLI, while creating or updating pipeline, pipeline_args are extracted # and hence we avoid executing the pipeline. if 'TFX_JSON_EXPORT_PIPELINE_ARGS_PATH' in os.environ: return tfx_pipeline.pipeline_info.run_id = datetime.datetime.now().isoformat() with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_RUNNER: 'local'}): # Run each component. Note that the pipeline.components list is in # topological order. # # TODO(b/171319478): After IR-based execution is used, used multi-threaded # execution so that independent components can be run in parallel. for component in tfx_pipeline.components: # TODO(b/187122662): Pass through pip dependencies as a first-class # component flag. if isinstance(component, base_component.BaseComponent): component._resolve_pip_dependencies( # pylint: disable=protected-access tfx_pipeline.pipeline_info.pipeline_root) (component_launcher_class, component_config) = ( config_utils.find_component_launch_info(self._config, component)) driver_args = data_types.DriverArgs( enable_cache=tfx_pipeline.enable_cache) metadata_connection = metadata.Metadata( tfx_pipeline.metadata_connection_config) node_launcher = component_launcher_class.create( component=component, pipeline_info=tfx_pipeline.pipeline_info, driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=tfx_pipeline.beam_pipeline_args, additional_pipeline_args=tfx_pipeline.additional_pipeline_args, component_config=component_config) logging.info('Component %s is running.', component.id) node_launcher.launch() logging.info('Component %s is finished.', component.id)
def run(self, component: base_component.BaseComponent, enable_cache: bool = True) -> execution_result.ExecutionResult: """Run a given TFX component in the interactive context. Args: component: Component instance to be run. enable_cache: whether caching logic should be enabled in the driver. Returns: execution_result.ExecutionResult object. """ run_id = datetime.datetime.now().isoformat() pipeline_info = data_types.PipelineInfo( pipeline_name=self.pipeline_name, pipeline_root=self.pipeline_root, run_id=run_id) driver_args = data_types.DriverArgs( enable_cache=enable_cache, interactive_resolution=True) try: parallelism = multiprocessing.cpu_count() except NotImplementedError: absl.logging.info('Using a single process for Beam pipeline execution.') parallelism = 1 beam_pipeline_args = ['--direct_num_workers=%d' % parallelism] additional_pipeline_args = {} for name, output in component.outputs.get_all().items(): for artifact in output.get(): artifact.pipeline_name = self.pipeline_name artifact.producer_component = component.id artifact.run_id = run_id artifact.name = name # TODO(hongyes): figure out how to resolve launcher class in the interactive # context. launcher = in_process_component_launcher.InProcessComponentLauncher.create( component, pipeline_info, driver_args, self.metadata_connection_config, beam_pipeline_args, additional_pipeline_args) execution_id = launcher.launch() return execution_result.ExecutionResult( component=component, execution_id=execution_id)