def _create_pipeline(self, pipeline_name: Text, components: List[BaseComponent]): """Creates a pipeline given name and list of components.""" return tfx_pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=self._pipeline_root(pipeline_name), metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=components, log_root='/var/tmp/tfx/logs', additional_pipeline_args={ # Use a fixed WORKFLOW_ID (which is used as run id) for testing, # for the purpose of making debugging easier. 'WORKFLOW_ID': pipeline_name, }, )
def setUp(self): super().setUp() self._connection_config = metadata_store_pb2.ConnectionConfig() self._connection_config.sqlite.SetInParent() self._metadata = self.enter_context( metadata.Metadata(connection_config=self._connection_config)) self._store = self._metadata.store self._pipeline_info = data_types.PipelineInfo( pipeline_name='my_pipeline', pipeline_root='/tmp', run_id='my_run_id') self._component_info = data_types.ComponentInfo( component_type='a.b.c', component_id='my_component', pipeline_info=self._pipeline_info)
def setUp(self): super(ResolverDriverTest, self).setUp() self.connection_config = metadata_store_pb2.ConnectionConfig() self.connection_config.sqlite.SetInParent() self.component_info = data_types.ComponentInfo(component_type='c_type', component_id='c_id') self.pipeline_info = data_types.PipelineInfo(pipeline_name='p_name', pipeline_root='p_root', run_id='run_id') self.driver_args = data_types.DriverArgs(enable_cache=True) self.source_channel_key = 'source_channel' self.source_channels = { self.source_channel_key: types.Channel(type=standard_artifacts.Examples) }
def _create_pipeline(): pipeline_name = _PIPELINE_NAME pipeline_root = os.path.join(_get_test_output_dir(), pipeline_name) components = test_utils.create_e2e_components(_get_csv_input_location()) return tfx_pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=components[: 2], # Run two components only to reduce overhead. log_root='/var/tmp/tfx/logs', additional_pipeline_args={ 'WORKFLOW_ID': pipeline_name, }, )
def testRun(self, mock_publisher): mock_publisher.return_value.publish_execution.return_value = {} test_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() metadata_connection = metadata.Metadata(connection_config) pipeline_root = os.path.join(test_dir, 'Test') input_path = os.path.join(test_dir, 'input') fileio.makedirs(os.path.dirname(input_path)) file_io.write_string_to_file(input_path, 'test') input_artifact = test_utils._InputArtifact() input_artifact.uri = input_path component = test_utils._FakeComponent( name='FakeComponent', input_channel=channel_utils.as_channel([input_artifact])) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(enable_cache=True) # We use InProcessComponentLauncher to test BaseComponentLauncher logics. launcher = in_process_component_launcher.InProcessComponentLauncher.create( component=component, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=[], additional_pipeline_args={}) self.assertEqual( launcher._component_info.component_type, '.'.join([ test_utils._FakeComponent.__module__, test_utils._FakeComponent.__name__ ])) launcher.launch() output_path = component.outputs['output'].get()[0].uri self.assertTrue(fileio.exists(output_path)) contents = file_io.read_file_to_string(output_path) self.assertEqual('test', contents)
def testRun(self): component_a = _FakeComponent( _FakeComponentSpecA(output=types.Channel(type=_ArtifactTypeA)), enable_cache=True) component_b = _FakeComponent(_FakeComponentSpecB( a=component_a.outputs['output'], output=types.Channel(type=_ArtifactTypeB)), enable_cache=False) component_c = _FakeComponent( _FakeComponentSpecC(a=component_a.outputs['output'], output=types.Channel(type=_ArtifactTypeC)), True) component_d = _FakeComponent( _FakeComponentSpecD(b=component_b.outputs['output'], c=component_c.outputs['output'], output=types.Channel(type=_ArtifactTypeD)), False) component_e = _FakeComponent( _FakeComponentSpecE(a=component_a.outputs['output'], b=component_b.outputs['output'], d=component_d.outputs['output'], output=types.Channel(type=_ArtifactTypeE))) test_pipeline = pipeline.Pipeline( pipeline_name='x', pipeline_root='y', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[ component_d, component_c, component_a, component_b, component_e ]) beam_dag_runner.BeamDagRunner().run(test_pipeline) self.assertCountEqual(_executed_components, [ '_FakeComponent.a', '_FakeComponent.b', '_FakeComponent.c', '_FakeComponent.d', '_FakeComponent.e' ]) self.assertEqual(_executed_components[0], '_FakeComponent.a') self.assertEqual(_executed_components[3], '_FakeComponent.d') self.assertEqual(_executed_components[4], '_FakeComponent.e') self.assertDictEqual( { '_FakeComponent.a': True, '_FakeComponent.b': False, '_FakeComponent.c': True, '_FakeComponent.d': False, '_FakeComponent.e': False, }, _executed_components_cached)
def sqlite_metadata_connection_config( metadata_db_uri: Text) -> metadata_store_pb2.ConnectionConfig: """Convenience function to create file based metadata connection config. Args: metadata_db_uri: uri to metadata db. Returns: A metadata_store_pb2.ConnectionConfig based on given metadata db uri. """ tf.io.gfile.makedirs(os.path.dirname(metadata_db_uri)) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.filename_uri = metadata_db_uri connection_config.sqlite.connection_mode = \ metadata_store_pb2.SqliteMetadataSourceConfig.READWRITE_OPENCREATE return connection_config
def get_default_kubernetes_metadata_config( ) -> metadata_store_pb2.ConnectionConfig: """Returns the default metadata connection config for a kubernetes cluster. Returns: A config proto that will be serialized as JSON and passed to the running container so the TFX component driver is able to communicate with MLMD in a kubernetes cluster. """ connection_config = metadata_store_pb2.ConnectionConfig() connection_config.mysql.host = 'mysql' connection_config.mysql.port = 3306 connection_config.mysql.database = 'mysql' connection_config.mysql.user = '******' connection_config.mysql.password = '' return connection_config
def setUp(self): super(KubernetesRemoteRunnerTest, self).setUp() self.component_a = _FakeComponent( _FakeComponentSpecA(output=types.Channel(type=_ArtifactTypeA))) self.component_b = _FakeComponent( _FakeComponentSpecB(a=self.component_a.outputs['output'], output=types.Channel(type=_ArtifactTypeB))) self.component_c = _FakeComponent( _FakeComponentSpecC(a=self.component_a.outputs['output'], b=self.component_b.outputs['output'], output=types.Channel(type=_ArtifactTypeC))) self.test_pipeline = tfx_pipeline.Pipeline( pipeline_name='x', pipeline_root='y', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[self.component_c, self.component_a, self.component_b])
def test_before_generating_trial_model(self): connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.filename_uri = os.path.join( FLAGS.test_tmpdir, "1") connection_config.sqlite.connection_mode = 3 handler = ml_metadata_db.MLMetaData( None, None, None, connection_config=connection_config) handler.before_generating_trial_model(trial_id=1, model_dir="/tmp/1") output = handler._store.get_executions_by_type("Trial") self.assertLen(output, 1) output = output[0] self.assertEqual(output.properties["id"].int_value, 1) self.assertEqual(output.properties["state"].string_value, "RUNNING") self.assertEqual(output.properties["serialized_data"].string_value, "") self.assertEqual(output.properties["model_dir"].string_value, "/tmp/1") self.assertEqual(output.properties["evaluation"].string_value, "")
def test_report(self): connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.filename_uri = os.path.join( FLAGS.test_tmpdir, "2") connection_config.sqlite.connection_mode = 3 handler = ml_metadata_db.MLMetaData( None, None, None, connection_config=connection_config) handler.before_generating_trial_model(trial_id=1, model_dir="/tmp/1") handler.report(eval_dictionary={"loss": 0.5}, model_dir="/tmp/1") output = handler.get_completed_trials() self.assertLen(output, 1) output = output[0] self.assertEqual(output.id, 1) self.assertEqual(output.status, "COMPLETED") self.assertEqual(output.model_dir, "/tmp/1") self.assertEqual(output.final_measurement.objective_value, 0.5)
def setUp(self): super(MetadataTest, self).setUp() self._connection_config = metadata_store_pb2.ConnectionConfig() self._connection_config.sqlite.SetInParent() self._component_info = data_types.ComponentInfo( component_type='a.b.c', component_id='my_component') self._component_info2 = data_types.ComponentInfo( component_type='a.b.d', component_id='my_component_2') self._pipeline_info = data_types.PipelineInfo( pipeline_name='my_pipeline', pipeline_root='/tmp', run_id='my_run_id') self._pipeline_info2 = data_types.PipelineInfo( pipeline_name='my_pipeline', pipeline_root='/tmp', run_id='my_run_id2')
def test_run(self, mock_publisher): mock_publisher.return_value.publish_execution.return_value = {} test_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() pipeline_root = os.path.join(test_dir, 'Test') input_path = os.path.join(test_dir, 'input') tf.gfile.MakeDirs(os.path.dirname(input_path)) file_io.write_string_to_file(input_path, 'test') input_artifact = types.TfxArtifact(type_name='InputPath') input_artifact.uri = input_path component = _FakeComponent(name='FakeComponent', input_channel=channel.as_channel( [input_artifact])) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(worker_name=component.component_id, base_output_dir=os.path.join( pipeline_root, component.component_id), enable_cache=True) launcher = component_launcher.ComponentLauncher( component=component, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection_config=connection_config, additional_pipeline_args={}) self.assertEqual( launcher._component_info.component_type, '.'.join([_FakeComponent.__module__, _FakeComponent.__name__])) launcher.launch() output_path = os.path.join(pipeline_root, 'output') self.assertTrue(tf.gfile.Exists(output_path)) contents = file_io.read_file_to_string(output_path) self.assertEqual('test', contents)
def testResolverWithResolverPolicy(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(self._testdata_dir, 'pipeline_for_input_resolver_test.pbtxt'), pipeline) my_example_gen = pipeline.nodes[0].pipeline_node my_transform = pipeline.nodes[2].pipeline_node connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() with metadata.Metadata(connection_config=connection_config) as m: # Publishes first ExampleGen with two output channels. `output_examples` # will be consumed by downstream Transform. output_example_1 = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example_1.uri = 'my_examples_uri_1' output_example_2 = types.Artifact( my_example_gen.outputs.outputs['output_examples']. artifact_spec.type) output_example_2.uri = 'my_examples_uri_2' contexts = context_lib.register_contexts_if_not_exists( m, my_example_gen.contexts) execution = execution_publish_utils.register_execution( m, my_example_gen.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'output_examples': [output_example_1, output_example_2], }) my_transform.inputs.resolver_config.resolver_policy = ( pipeline_pb2.ResolverConfig.LATEST_ARTIFACT) # Gets inputs for transform. Should get back what the first ExampleGen # published in the `output_examples` channel. transform_inputs = inputs_utils.resolve_input_artifacts( m, my_transform.inputs) self.assertEqual(len(transform_inputs), 1) self.assertEqual(len(transform_inputs['examples']), 1) self.assertProtoPartiallyEquals( transform_inputs['examples'][0].mlmd_artifact, output_example_2.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def _get_metadata_store(grpc_max_receive_message_length=None): if FLAGS.use_grpc_backend: grpc_connection_config = metadata_store_pb2.MetadataStoreClientConfig() if grpc_max_receive_message_length: (grpc_connection_config.channel_arguments.max_receive_message_length ) = grpc_max_receive_message_length if FLAGS.grpc_host is None: raise ValueError("grpc_host argument not set.") grpc_connection_config.host = FLAGS.grpc_host if not FLAGS.grpc_port: raise ValueError("grpc_port argument not set.") grpc_connection_config.port = FLAGS.grpc_port return metadata_store.MetadataStore(grpc_connection_config) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() return metadata_store.MetadataStore(connection_config)
def _create_pipeline(): pipeline_name = _PIPELINE_NAME test_output_dir = 'gs://{}/test_output'.format(test_utils.BUCKET_NAME) pipeline_root = os.path.join(test_output_dir, pipeline_name) components = test_utils.create_e2e_components(pipeline_root, test_utils.DATA_ROOT, test_utils.TAXI_MODULE_FILE) return tfx_pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=components[:2], log_root='/var/tmp/tfx/logs', additional_pipeline_args={ 'WORKFLOW_ID': pipeline_name, }, )
def setUp(self): super().setUp() self._connection_config = metadata_store_pb2.ConnectionConfig() self._connection_config.sqlite.SetInParent() self._module_file_path = os.path.join(self.tmp_dir, 'module_file') self._input_artifacts = {'input_examples': [standard_artifacts.Examples()]} self._output_artifacts = {'output_models': [standard_artifacts.Model()]} self._parameters = {'module_file': self._module_file_path} self._module_file_content = 'module content' self._pipeline_node = text_format.Parse( """ executor { python_class_executor_spec {class_path: 'a.b.c'} } """, pipeline_pb2.PipelineNode()) self._executor_class_path = 'a.b.c' self._pipeline_info = pipeline_pb2.PipelineInfo(id='pipeline_id')
def setUp(self): super().setUp() example_gen_output_config = data_types.RuntimeParameter( name='example-gen-output-config', ptype=str) example_gen = csv_example_gen_component.CsvExampleGen( input_base='data_root', output_config=example_gen_output_config) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples']).with_id('foo') test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' self._tfx_ir = pipeline_pb2.Pipeline() with dsl.Pipeline('test_pipeline'): self.example_gen = base_component.BaseComponent( component=example_gen, depends_on=set(), pipeline=pipeline, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, tfx_ir=self._tfx_ir, pod_labels_to_attach={}, runtime_parameters=[example_gen_output_config]) self.statistics_gen = base_component.BaseComponent( component=statistics_gen, depends_on=set(), pipeline=pipeline, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, tfx_ir=self._tfx_ir, pod_labels_to_attach={}, runtime_parameters=[]) self.tfx_example_gen = example_gen self.tfx_statistics_gen = statistics_gen
def setUp(self): super(ImporterDriverTest, self).setUp() self.connection_config = metadata_store_pb2.ConnectionConfig() self.connection_config.sqlite.SetInParent() self.artifact_type = 'Examples' self.output_dict = { importer_node.IMPORT_RESULT_KEY: types.Channel(type_name=self.artifact_type) } self.source_uri = 'm/y/u/r/i' self.existing_artifact = types.Artifact(type_name=self.artifact_type) self.existing_artifact.uri = self.source_uri self.component_info = data_types.ComponentInfo( component_type='c_type', component_id='c_id') self.pipeline_info = data_types.PipelineInfo( pipeline_name='p_name', pipeline_root='p_root', run_id='run_id') self.driver_args = data_types.DriverArgs(enable_cache=True)
def __init__(self, components_to_always_add: List[BaseComponent], benchmark_subpipelines: List[BenchmarkSubpipeline], pipeline_name: Optional[str], pipeline_root: Optional[str], metadata_connection_config: Optional[ metadata_store_pb2.ConnectionConfig] = None, beam_pipeline_args: Optional[List[str]] = None, **kwargs): if not benchmark_subpipelines and not components_to_always_add: raise ValueError( "Requires at least one benchmark subpipeline or component to run. " "You may want to call `self.add(..., always=True) in order " "to run Components, Subpipelines, or Pipeline even without requiring " "a call to `self.evaluate(...)`.") # Set defaults. if not pipeline_name: pipeline_name = "nitroml" if not pipeline_root: tmp_root_dir = os.path.join("/tmp", pipeline_name) tf.io.gfile.makedirs(tmp_root_dir) pipeline_root = tempfile.mkdtemp(dir=tmp_root_dir) logging.info("Creating tmp pipeline_root at %s", pipeline_root) if not metadata_connection_config: metadata_connection_config = metadata_store_pb2.ConnectionConfig( sqlite=metadata_store_pb2.SqliteMetadataSourceConfig( filename_uri=os.path.join(pipeline_root, "mlmd.sqlite"))) # Ensure that pipeline dirs are created. _make_pipeline_dirs(pipeline_root, metadata_connection_config) components = set(components_to_always_add) for benchmark_subpipeline in benchmark_subpipelines: for component in benchmark_subpipeline.components: components.add(component) super().__init__(pipeline_name=pipeline_name, pipeline_root=pipeline_root, metadata_connection_config=metadata_connection_config, components=list(components), beam_pipeline_args=beam_pipeline_args, **kwargs) self._subpipelines = benchmark_subpipelines
def testRun(self, mock_publisher): mock_publisher.return_value.publish_execution.return_value = {} example_gen = FileBasedExampleGen( custom_executor_spec=executor_spec.ExecutorClassSpec( avro_executor.Executor), input_base=self.avro_dir_path, input_config=self.input_config, output_config=self.output_config, instance_name='AvroExampleGen') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) pipeline_root = os.path.join(output_data_dir, 'Test') fileio.makedirs(pipeline_root) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(enable_cache=True) connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() metadata_connection = metadata.Metadata(connection_config) launcher = in_process_component_launcher.InProcessComponentLauncher.create( component=example_gen, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=[], additional_pipeline_args={}) self.assertEqual( launcher._component_info.component_type, '.'.join( [FileBasedExampleGen.__module__, FileBasedExampleGen.__name__])) launcher.launch() mock_publisher.return_value.publish_execution.assert_called_once() # Check output paths. self.assertTrue( fileio.exists(os.path.join(pipeline_root, example_gen.id)))
def _input_artifacts(self, pipeline_name: Text, input_artifacts: List[Artifact]) -> Channel: """Publish input artifacts for test to MLMD and return channel to them.""" connection_config = metadata_store_pb2.ConnectionConfig( mysql=metadata_store_pb2.MySQLDatabaseConfig( host='127.0.0.1', port=3306, database=self._get_mlmd_db_name(pipeline_name), user='******', password='')) dummy_artifact = (input_artifacts[0].type_name, self._random_id()) output_key = 'dummy_output_%s_%s' % dummy_artifact producer_component_id = 'dummy_producer_id_%s_%s' % dummy_artifact producer_component_type = 'dummy_producer_type_%s_%s' % dummy_artifact # Input artifacts must have a unique name and producer in MLMD. for artifact in input_artifacts: artifact.name = output_key artifact.pipeline_name = pipeline_name artifact.producer_component = producer_component_id with metadata.Metadata(connection_config=connection_config) as m: # Register a dummy execution to metadata store as producer execution. execution_id = m.register_execution( exec_properties={}, pipeline_info=data_types.PipelineInfo( pipeline_name=pipeline_name, pipeline_root='/dummy_pipeline_root', # test_utils uses pipeline_name as fixed WORKFLOW_ID. run_id=pipeline_name, ), component_info=data_types.ComponentInfo( component_type=producer_component_type, component_id=producer_component_id)) # Publish the test input artifact from the dummy execution. published_artifacts = m.publish_execution( execution_id=execution_id, input_dict={}, output_dict={output_key: input_artifacts}) return channel_utils.as_channel(published_artifacts[output_key])
def _create_pipeline(): pipeline_name = _PIPELINE_NAME pipeline_root = os.path.join(test_utils.get_test_output_dir(), pipeline_name) components = test_utils.create_e2e_components( pipeline_root, test_utils.get_csv_input_location(), test_utils.get_transform_module(), test_utils.get_trainer_module(), ) return tfx_pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=components[:4], log_root='/var/tmp/tfx/logs', additional_pipeline_args={ 'WORKFLOW_ID': pipeline_name, }, )
def test_get_completed_trials(self): connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.filename_uri = os.path.join( FLAGS.test_tmpdir, "3") connection_config.sqlite.connection_mode = 3 handler = ml_metadata_db.MLMetaData( None, None, None, connection_config=connection_config) handler.before_generating_trial_model(trial_id=1, model_dir="/tmp/1") handler.before_generating_trial_model(trial_id=2, model_dir="/tmp/2") handler.report(eval_dictionary={"loss": 0.1}, model_dir="/tmp/1") handler.before_generating_trial_model(trial_id=3, model_dir="/tmp/3") handler.report(eval_dictionary={"loss": 0.3}, model_dir="/tmp/3") handler.report(eval_dictionary={"loss": 0.2}, model_dir="/tmp/2") output = handler.get_completed_trials() self.assertLen(output, 3) for i in range(3): self.assertEqual(output[i].status, "COMPLETED") self.assertEqual(output[i].model_dir, "/tmp/" + str(output[i].id)) self.assertEqual(output[i].final_measurement.objective_value, float(output[i].id) / 10)
def setUp(self): super(KubeflowMetadataAdapterTest, self).setUp() self._connection_config = metadata_store_pb2.ConnectionConfig() self._connection_config.sqlite.SetInParent() self._pipeline_info = data_types.PipelineInfo( pipeline_name='fake_pipeline_name', pipeline_root='/fake_pipeline_root', run_id='fake_run_id') self._pipeline_info2 = data_types.PipelineInfo( pipeline_name='fake_pipeline_name', pipeline_root='/fake_pipeline_root', run_id='fake_run_id2') self._component_info = data_types.ComponentInfo( component_type='fake.component.type', component_id='fake_component_id', pipeline_info=self._pipeline_info) self._component_info2 = data_types.ComponentInfo( component_type='fake.component.type', component_id='fake_component_id', pipeline_info=self._pipeline_info2)
def testRunWithSameSpec(self, mock_kube_utils): _initialize_executed_components() mock_kube_utils.is_inside_cluster.return_value = True component_a = _FakeComponent(spec=_FakeComponentSpecA( output=types.Channel(type=_ArtifactTypeA))) component_f1 = _FakeComponent(spec=_FakeComponentSpecF( a=component_a.outputs['output'])).with_id('f1') component_f2 = _FakeComponent(spec=_FakeComponentSpecF( a=component_a.outputs['output'])).with_id('f2') component_f2.add_upstream_node(component_f1) test_pipeline = pipeline.Pipeline( pipeline_name='x', pipeline_root='y', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[component_f1, component_f2, component_a]) kubernetes_dag_runner.KubernetesDagRunner().run(test_pipeline) self.assertEqual(_executed_components, ['a.Wrapper', 'f1.Wrapper', 'f2.Wrapper'])
def _two_step_pipeline() -> tfx_pipeline.Pipeline: default_input_config = json.dumps({ 'splits': [{ 'name': 'single_split', 'pattern': 'SELECT * FROM default-table' }] }) input_config = data_types.RuntimeParameter(name='input_config', ptype=str, default=default_input_config) example_gen = big_query_example_gen_component.BigQueryExampleGen( input_config=input_config, output_config=example_gen_pb2.Output()) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples']) return tfx_pipeline.Pipeline( pipeline_name='two_step_pipeline', pipeline_root='pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], )
def testRun(self, mock_kube_utils): _initialize_executed_components() mock_kube_utils.is_inside_cluster.return_value = True component_a = _FakeComponent( spec=_FakeComponentSpecA(output=types.Channel(type=_ArtifactTypeA))) component_b = _FakeComponent( spec=_FakeComponentSpecB( a=component_a.outputs['output'], output=types.Channel(type=_ArtifactTypeB))) component_c = _FakeComponent( spec=_FakeComponentSpecC( a=component_a.outputs['output'], output=types.Channel(type=_ArtifactTypeC))) component_c.add_upstream_node(component_b) component_d = _FakeComponent( spec=_FakeComponentSpecD( b=component_b.outputs['output'], c=component_c.outputs['output'], output=types.Channel(type=_ArtifactTypeD))) component_e = _FakeComponent( spec=_FakeComponentSpecE( a=component_a.outputs['output'], b=component_b.outputs['output'], d=component_d.outputs['output'], output=types.Channel(type=_ArtifactTypeE))) test_pipeline = pipeline.Pipeline( pipeline_name='x', pipeline_root='y', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[ component_d, component_c, component_a, component_b, component_e ]) kubernetes_dag_runner.KubernetesDagRunner().run(test_pipeline) self.assertEqual(_executed_components, [ '_FakeComponent.a.Wrapper', '_FakeComponent.b.Wrapper', '_FakeComponent.c.Wrapper', '_FakeComponent.d.Wrapper', '_FakeComponent.e.Wrapper' ])
def mysql_metadata_connection_config( host: Text, port: int, database: Text, username: Text, password: Text) -> metadata_store_pb2.ConnectionConfig: """Convenience function to create mysql-based metadata connection config. Args: host: The name or network address of the instance of MySQL to connect to. port: The port MySQL is using to listen for connections. database: The name of the database to use. username: The MySQL login account being used. password: The password for the MySQL account being used. Returns: A metadata_store_pb2.ConnectionConfig based on given metadata db uri. """ return metadata_store_pb2.ConnectionConfig( mysql=metadata_store_pb2.MySQLDatabaseConfig(host=host, port=port, database=database, user=username, password=password))
def testRun(self): component_a = _FakeComponent( _FakeComponentSpecA(output=types.Channel(type_name='a'))) component_b = _FakeComponent( _FakeComponentSpecB( a=component_a.outputs['output'], output=types.Channel(type_name='b'))) component_c = _FakeComponent( _FakeComponentSpecC( a=component_a.outputs['output'], output=types.Channel(type_name='c'))) component_d = _FakeComponent( _FakeComponentSpecD( b=component_b.outputs['output'], c=component_c.outputs['output'], output=types.Channel(type_name='d'))) component_e = _FakeComponent( _FakeComponentSpecE( a=component_a.outputs['output'], b=component_b.outputs['output'], d=component_d.outputs['output'], output=types.Channel(type_name='e'))) test_pipeline = pipeline.Pipeline( pipeline_name='x', pipeline_root='y', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[ component_d, component_c, component_a, component_b, component_e ]) beam_dag_runner.BeamDagRunner().run(test_pipeline) self.assertItemsEqual(_executed_components, [ '_FakeComponent.a', '_FakeComponent.b', '_FakeComponent.c', '_FakeComponent.d', '_FakeComponent.e' ]) self.assertEqual(_executed_components[0], '_FakeComponent.a') self.assertEqual(_executed_components[3], '_FakeComponent.d') self.assertEqual(_executed_components[4], '_FakeComponent.e')