def testDockerComponentLauncherInBeam(self): beam_dag_runner.BeamDagRunner().run( _create_pipeline(pipeline_name=self._pipeline_name, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, name='docker_e2e_test_in_beam')) metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: self.assertEqual(1, len(m.store.get_executions()))
def setUp(self): super(PipelineTest, self).setUp() tmp_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()) self._tmp_file = os.path.join(tmp_dir, self._testMethodName, tempfile.mkstemp(prefix='cli_tmp_')[1]) self._tmp_dir = os.path.join(tmp_dir, self._testMethodName, tempfile.mkdtemp(prefix='cli_tmp_')[1]) # Back up the environmental variable. self._original_tmp_value = os.environ.get( 'TFX_JSON_EXPORT_PIPELINE_ARGS_PATH') self._metadata_connection_config = metadata.sqlite_metadata_connection_config( os.path.join(self._tmp_dir, 'metadata'))
def testPenguinPipelineLocal(self, model_framework): module_file = self._module_file_name(model_framework) pipeline = penguin_pipeline_local._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=module_file, accuracy_threshold=0.1, serving_model_dir=self._serving_model_dir, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, enable_tuning=False, examplegen_input_config=None, examplegen_range_config=None, resolver_range_config=None, beam_pipeline_args=[]) logging.info('Starting the first pipeline run.') LocalDagRunner().run(pipeline) self.assertTrue(fileio.exists(self._serving_model_dir)) self.assertTrue(fileio.exists(self._metadata_path)) expected_execution_count = 9 # 8 components + 1 resolver metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) self.assertGreaterEqual(artifact_count, execution_count) self.assertEqual(expected_execution_count, execution_count) self._assertPipelineExecution(False) logging.info('Starting the second pipeline run. All components except ' 'Evaluator and Pusher will use cached results.') LocalDagRunner().run(pipeline) with metadata.Metadata(metadata_config) as m: # Artifact count is increased by 3 caused by Evaluator and Pusher. self.assertLen(m.store.get_artifacts(), artifact_count + 3) artifact_count = len(m.store.get_artifacts()) self.assertLen(m.store.get_executions(), expected_execution_count * 2) logging.info('Starting the third pipeline run. ' 'All components will use cached results.') LocalDagRunner().run(pipeline) # Asserts cache execution. with metadata.Metadata(metadata_config) as m: # Artifact count is unchanged. self.assertLen(m.store.get_artifacts(), artifact_count) self.assertLen(m.store.get_executions(), expected_execution_count * 3)
def create_pipeline(pipeline_name, pipeline_root, model_name, text_path, train_config, mlflow_tracking_url, encoding='utf-8', combine=50000, enable_cache=False): for key, value in train_config.items(): default_train_config[key] = value pretrained_model = DownloadPretrainedModel(model_name=model_name) create_dataset = CreateDataset( text_path=text_path, model_path=pretrained_model.outputs["model_path"], encoding=encoding, combine=combine) train_gpt2 = TrainGPT2(dataset_path=create_dataset.outputs["dataset_path"], model_path=pretrained_model.outputs["model_path"], model_name=model_name, train_config=train_config, combine=combine, encoding=encoding) export_tfserving = ExportToTFServing( model_path=pretrained_model.outputs["model_path"], checkpoint_dir=train_gpt2.outputs["checkpoint_dir"], train_config=train_config) mlflow_import = MLFlowImport( model_name=model_name, mlflow_tracking_url=mlflow_tracking_url, artifact_dir=train_gpt2.outputs["sample_dir"], hyperparameter_dir=train_gpt2.outputs["hyperparameter_dir"], metric_dir=train_gpt2.outputs["metric_dir"], model_dir=export_tfserving.outputs["export_dir"]) pipeline_root = os.path.join(pipeline_root, 'pipelines', pipeline_name) metadata_path = os.path.join(pipeline_root, 'metadata', pipeline_name, 'metadata.db') tfx_pipeline = pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=[ create_dataset, pretrained_model, train_gpt2, export_tfserving, mlflow_import ], enable_cache=enable_cache, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path)) return tfx_pipeline
def setUp(self): self._component = _FakeComponent( _FakeComponentSpec(input=channel.Channel(type_name='type_a'), output=channel.Channel(type_name='type_b'))) self._pipeline_info = data_types.PipelineInfo('name', 'root') self._driver_args = data_types.DriverArgs(True) self._metadata_connection_config = metadata.sqlite_metadata_connection_config( os.path.join(os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), 'metadata')) self._parent_dag = models.DAG(dag_id=self._pipeline_info.pipeline_name, start_date=datetime.datetime(2018, 1, 1), schedule_interval=None)
def run_pipeline(): my_pipeline = create_pipeline( pipeline_name=PIPELINE_NAME, pipeline_root=PIPELINE_ROOT, data_root=DATA_ROOT, test_data_root=TEST_DATA_ROOT, module_file=MODULE_FILE, serving_model_dir=SERVING_MODEL_DIR, enable_cache=True, metadata_connection_config=metadata.sqlite_metadata_connection_config( METADATA_PATH)) BeamDagRunner().run(my_pipeline)
def run(): LocalDagRunner().run( pipeline.create_pipeline( pipeline_name="fishing-classifier", data_path="data", outputs_path="outputs", output_model_path="outputs/model", train_args=trainer_pb2.TrainArgs(num_steps=100), eval_args=trainer_pb2.EvalArgs(num_steps=15), eval_accuracy_threshold=0.6, metadata_connection_config=sqlite_metadata_connection_config( f"outputs/metadata.db"), ))
def setUp(self): super().setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) self._pipeline_root = pipeline_root # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = self._make_pipeline(self._pipeline_root, str(uuid.uuid4())) self._pipeline = pipeline # Extracts components. self._example_gen = test_utils.get_node(pipeline, 'my_example_gen') self._stats_gen = test_utils.get_node(pipeline, 'my_statistics_gen') self._schema_gen = test_utils.get_node(pipeline, 'my_schema_gen') self._transform = test_utils.get_node(pipeline, 'my_transform') self._example_validator = test_utils.get_node(pipeline, 'my_example_validator') self._trainer = test_utils.get_node(pipeline, 'my_trainer') self._evaluator = test_utils.get_node(pipeline, 'my_evaluator') self._chore_a = test_utils.get_node(pipeline, 'chore_a') self._chore_b = test_utils.get_node(pipeline, 'chore_b') self._task_queue = tq.TaskQueue() self._mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) self._mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == self._example_gen.node_info.id) self._mock_service_job_manager.is_mixed_service_node.side_effect = ( lambda _, node_id: node_id == self._transform.node_info.id) def _default_ensure_node_services(unused_pipeline_state, node_id): self.assertIn( node_id, (self._example_gen.node_info.id, self._transform.node_info.id)) return service_jobs.ServiceStatus.SUCCESS self._mock_service_job_manager.ensure_node_services.side_effect = ( _default_ensure_node_services)
def __init__(self, pipeline_name: Optional[Text] = None, pipeline_root: Optional[Text] = None, metadata_connection_config: Optional[ metadata_store_pb2.ConnectionConfig] = None, beam_pipeline_args: Optional[List[Text]] = None): """Initialize an InteractiveContext. Args: pipeline_name: Optional name of the pipeline for ML Metadata tracking purposes. If not specified, a name will be generated for you. pipeline_root: Optional path to the root of the pipeline's outputs. If not specified, an ephemeral temporary directory will be created and used. metadata_connection_config: Optional metadata_store_pb2.ConnectionConfig instance used to configure connection to a ML Metadata connection. If not specified, an ephemeral SQLite MLMD connection contained in the pipeline_root directory with file name "metadata.sqlite" will be used. beam_pipeline_args: Optional Beam pipeline args for beam jobs within executor. Executor will use beam DirectRunner as Default. """ if not pipeline_name: pipeline_name = ( 'interactive-%s' % datetime.datetime.now().isoformat().replace(':', '_')) if not pipeline_root: pipeline_root = tempfile.mkdtemp(prefix='tfx-%s-' % pipeline_name) absl.logging.warning( 'InteractiveContext pipeline_root argument not provided: using ' 'temporary directory %s as root for pipeline outputs.', pipeline_root) if not metadata_connection_config: # TODO(ccy): consider reconciling similar logic here with other instances # in tfx/orchestration/... metadata_sqlite_path = os.path.join(pipeline_root, self._DEFAULT_SQLITE_FILENAME) metadata_connection_config = metadata.sqlite_metadata_connection_config( metadata_sqlite_path) absl.logging.warning( 'InteractiveContext metadata_connection_config not provided: using ' 'SQLite ML Metadata database at %s.', metadata_sqlite_path) self.pipeline_name = pipeline_name self.pipeline_root = pipeline_root self.metadata_connection_config = metadata_connection_config self.beam_pipeline_args = beam_pipeline_args or [] # Register IPython formatters. notebook_formatters.register_formatters() # Register artifact visualizations. standard_visualizations.register_standard_visualizations()
def setUp(self): super(LauncherTest, self).setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) self._testdata_dir = os.path.join(os.path.dirname(__file__), 'testdata') # Sets up pipelines pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(os.path.dirname(__file__), 'testdata', 'pipeline_for_launcher_test.pbtxt'), pipeline) # Substitute the runtime parameter to be a concrete run_id runtime_parameter_utils.substitute_runtime_parameter( pipeline, { constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'test_run', }) self._pipeline_info = pipeline.pipeline_info self._pipeline_runtime_spec = pipeline.runtime_spec self._pipeline_runtime_spec.pipeline_root.field_value.string_value = ( pipeline_root) self._pipeline_runtime_spec.pipeline_run_id.field_value.string_value = ( 'test_run_0') # Extracts components self._example_gen = pipeline.nodes[0].pipeline_node self._transform = pipeline.nodes[1].pipeline_node self._trainer = pipeline.nodes[2].pipeline_node self._importer = pipeline.nodes[3].pipeline_node self._resolver = pipeline.nodes[4].pipeline_node # Fakes an ExecutorSpec for Trainer self._trainer_executor_spec = _PYTHON_CLASS_EXECUTABLE_SPEC() # Fakes an executor operator self._test_executor_operators = { _PYTHON_CLASS_EXECUTABLE_SPEC: _FakeExecutorOperator } # Fakes an custom driver spec self._custom_driver_spec = _PYTHON_CLASS_EXECUTABLE_SPEC() self._custom_driver_spec.class_path = 'tfx.orchestration.portable.launcher_test._FakeExampleGenLikeDriver'
def _get_latest_output_artifact(self, component_name, output_key): metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: [exec_type_name] = [ exec_type.name for exec_type in m.store.get_execution_types() if component_name in exec_type.name] executions = m.store.get_executions_by_type(exec_type_name) events = m.store.get_events_by_execution_ids([e.id for e in executions]) output_artifact_ids = [event.artifact_id for event in events if event.type in _OUTPUT_EVENT_TYPES] output_artifacts = m.store.get_artifacts_by_id(output_artifact_ids) self.assertNotEmpty(output_artifacts) return max(output_artifacts, key=lambda a: a.create_time_since_epoch)
def create_pipeline(pipeline_name, pipeline_root, checkpoint_dir, model_name, text_dir, train_config, mlflow_tracking_url, encoding='utf-8', text_token_size=50000, enable_cache=False, end_token="<|endoftext|>"): for key, value in train_config.items(): default_train_config[key] = value create_merged_text = CreateMergedText(text_dir=text_dir, end_token=end_token, encoding=encoding) pretrained_model = CopyCheckpoint(checkpoint_dir=checkpoint_dir) create_dataset = CreateEncodedDataset(merged_text_dir=create_merged_text.outputs["merged_text_dir"], encoding_dir=pretrained_model.outputs["model_dir"], encoding=encoding, end_token=end_token) train_gpt2 = TrainGPT2(dataset_dir=create_dataset.outputs["dataset_dir"], checkpoint_dir=pretrained_model.outputs["model_dir"], encoding_dir=pretrained_model.outputs["model_dir"], model_name=model_name, train_config=train_config, encoding=encoding) export_tfserving = ExportToTFServing(encoding_dir=pretrained_model.outputs["model_dir"], checkpoint_dir=train_gpt2.outputs["trained_checkpoint_dir"], train_config=train_config) mlflow_import = MLFlowImport(model_name=model_name, mlflow_tracking_url=mlflow_tracking_url, artifact_dir=train_gpt2.outputs["sample_dir"], hyperparameter_dir=train_gpt2.outputs["hyperparameter_dir"], metric_dir=train_gpt2.outputs["metric_dir"], model_dir=export_tfserving.outputs["export_dir"]) pipeline_path = os.path.join(pipeline_root, 'pipelines', pipeline_name) metadata_path = os.path.join(pipeline_root, 'metadata', pipeline_name, 'metadata.db') tfx_pipeline = pipeline.Pipeline(pipeline_name=pipeline_name, pipeline_root=pipeline_path, components=[create_merged_text, pretrained_model, create_dataset, train_gpt2, export_tfserving, mlflow_import], enable_cache=enable_cache, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path)) return tfx_pipeline
def create_dag(name, url, output_dir, airflow_config): pipeline_name = name pipeline_root = os.path.join(output_dir, 'pipelines', pipeline_name) metadata_path = os.path.join(output_dir, 'metadata', pipeline_name, 'metadata.db') crawler = NewsCrawler(url=url) tfx_pipeline = pipeline.Pipeline(pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=[crawler], enable_cache=False, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path)) return AirflowDagRunner(AirflowPipelineConfig(airflow_config)).run(tfx_pipeline)
def setUp(self): super(PipelineOpsTest, self).setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config)
def init_pipeline(components, pipeline_root: Text, direct_num_workers: int) -> pipeline.Pipeline: beam_arg = [ f'--direct_num_workers={direct_num_workers}', ] p = pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=components, enable_cache=True, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path), beam_pipeline_args=beam_arg) return p
def testMNISTPipelineNativeKeras(self): if not tf.executing_eagerly(): self.skipTest('The test requires TF2.') BeamDagRunner().run( mnist_pipeline_native_keras._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=self._module_file, module_file_lite=self._module_file_lite, serving_model_dir=self._serving_model_dir, serving_model_dir_lite=self._serving_model_dir_lite, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, beam_pipeline_args=[])) self.assertTrue(fileio.exists(self._serving_model_dir)) self.assertTrue(fileio.exists(self._serving_model_dir_lite)) self.assertTrue(fileio.exists(self._metadata_path)) metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) expected_execution_count = 11 with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) self.assertGreaterEqual(artifact_count, execution_count) self.assertEqual(execution_count, expected_execution_count) self.assertPipelineExecution() # Runs pipeline the second time. BeamDagRunner().run( mnist_pipeline_native_keras._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=self._module_file, module_file_lite=self._module_file_lite, serving_model_dir=self._serving_model_dir, serving_model_dir_lite=self._serving_model_dir_lite, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, beam_pipeline_args=[])) # Asserts cache execution. with metadata.Metadata(metadata_config) as m: # Artifact count is unchanged. self.assertLen(m.store.get_artifacts(), artifact_count) self.assertLen(m.store.get_executions(), expected_execution_count * 2)
def setUp(self): super(SyncPipelineTaskGeneratorTest, self).setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) self._pipeline_root = pipeline_root # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( os.path.dirname(__file__), 'testdata', 'sync_pipeline.pbtxt'), pipeline) self._pipeline_run_id = str(uuid.uuid4()) runtime_parameter_utils.substitute_runtime_parameter( pipeline, { 'pipeline_root': pipeline_root, 'pipeline_run_id': self._pipeline_run_id }) self._pipeline = pipeline # Extracts components. self._example_gen = _get_node(pipeline, 'my_example_gen') self._stats_gen = _get_node(pipeline, 'my_statistics_gen') self._schema_gen = _get_node(pipeline, 'my_schema_gen') self._transform = _get_node(pipeline, 'my_transform') self._example_validator = _get_node(pipeline, 'my_example_validator') self._trainer = _get_node(pipeline, 'my_trainer') self._task_queue = tq.TaskQueue() self._mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) self._mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == self._example_gen.node_info.id) self._mock_service_job_manager.is_mixed_service_node.side_effect = ( lambda _, node_id: node_id == self._transform.node_info.id)
def assertComponentExecutionCount(self, count: int) -> None: """Checks the number of component executions recorded in MLMD. Args: count: Number of components that should have succeeded and produced artifacts recorded in MLMD. """ self.assertTrue(tf.io.gfile.exists(self.metadata_path)) metadata_config = metadata.sqlite_metadata_connection_config( self.metadata_path) with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) self.assertGreaterEqual(artifact_count, execution_count) self.assertEqual(count, execution_count)
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text, metadata_path: Text) -> pipeline.Pipeline: """Implements the chicago taxi pipeline with TFX.""" # Brings data into the pipeline or otherwise joins/converts training data. example_gen = CsvExampleGen(input_base=data_root) return pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=[example_gen], enable_cache=True, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path), additional_pipeline_args={}, )
def setUp(self): super(AirflowComponentTest, self).setUp() self._component = _FakeComponent( _FakeComponentSpec( input=types.Channel(type=_ArtifactTypeA), output=types.Channel(type=_ArtifactTypeB)), enable_cache=True) self._pipeline_info = data_types.PipelineInfo('name', 'root') self._driver_args = data_types.DriverArgs(True) self._metadata_connection_config = metadata.sqlite_metadata_connection_config( os.path.join( os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), 'metadata')) self._parent_dag = models.DAG( dag_id=self._pipeline_info.pipeline_name, start_date=datetime.datetime(2018, 1, 1), schedule_interval=None)
def create_dag(name, url, airflow_config, backup_dir="pipelines_backup", mongo_ip=None, mongo_port=None, dag_type="default", output_dir="/output", updated_collections=[], update_collections=[]): pipeline_name = name.replace(".py", "") pipeline_root = os.path.join(output_dir, 'pipelines', pipeline_name) metadata_path = os.path.join(output_dir, 'metadata', pipeline_name, 'metadata.db') components = [] if dag_type == "default": crawler = NewsCrawler(url=url) mongo = MongoImport(ip=mongo_ip, port=mongo_port, rss_feed=crawler.outputs["rss_feed"], colname=pipeline_name) components = components + [crawler, mongo] elif dag_type == "backup": load_news = OldNewsImport(backup_dir=os.path.join( "/output", backup_dir), ip=mongo_ip, port=mongo_port) components = components + [load_news] elif dag_type == "update": update_news = UpdateMongoNews(ip=mongo_ip, port=mongo_port, updated_collections=updated_collections, update_collections=update_collections) components = components + [update_news] airflow_config["catchup"] = False tfx_pipeline = pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=components, enable_cache=False, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path)) return AirflowDagRunner( AirflowPipelineConfig(airflow_config)).run(tfx_pipeline)
def _create_pipeline( pipeline_name, pipeline_root, metadata_path, name, ): hello_world = _HelloWorldComponent(name=name) return pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=[hello_world], enable_cache=True, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path), additional_pipeline_args={}, )
def testCIFAR10PipelineNativeKeras(self): pipeline = cifar10_pipeline_native_keras._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=self._module_file, serving_model_dir_lite=self._serving_model_dir_lite, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, labels_path=self._labels_path, beam_pipeline_args=[]) BeamDagRunner().run(pipeline) self.assertTrue(fileio.exists(self._serving_model_dir_lite)) self.assertTrue(fileio.exists(self._metadata_path)) expected_execution_count = 9 # 8 components + 1 resolver metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) self.assertGreaterEqual(artifact_count, execution_count) self.assertEqual(expected_execution_count, execution_count) self.assertPipelineExecution() # Runs pipeline the second time. BeamDagRunner().run(pipeline) # All executions but Evaluator and Pusher are cached. with metadata.Metadata(metadata_config) as m: # Artifact count is increased by 3 caused by Evaluator and Pusher. self.assertEqual(artifact_count + 3, len(m.store.get_artifacts())) artifact_count = len(m.store.get_artifacts()) self.assertEqual(expected_execution_count * 2, len(m.store.get_executions())) # Runs pipeline the third time. BeamDagRunner().run(pipeline) # Asserts cache execution. with metadata.Metadata(metadata_config) as m: # Artifact count is unchanged. self.assertEqual(artifact_count, len(m.store.get_artifacts())) self.assertEqual(expected_execution_count * 3, len(m.store.get_executions()))
def init_pipeline(components, pipeline_root: Text, direct_num_workers: int) -> pipeline.Pipeline: beam_arg = [ f"--direct_num_workers={direct_num_workers}", f"--direct_running_mode=multi_processing", ] tfx_pipeline = pipeline.Pipeline( pipeline_name=config.PIPELINE_NAME, pipeline_root=config.PIPELINE_ROOT, components=components, enable_cache=True, metadata_connection_config=metadata.sqlite_metadata_connection_config( config.METADATA_PATH), beam_pipeline_args=beam_arg, ) return tfx_pipeline
def setUp(self): super(SyncPipelineTaskGeneratorTest, self).setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) self._pipeline_root = pipeline_root # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( os.path.dirname(__file__), 'testdata', 'sync_pipeline.pbtxt'), pipeline) self._pipeline = pipeline self._pipeline_info = pipeline.pipeline_info self._pipeline_runtime_spec = pipeline.runtime_spec self._pipeline_runtime_spec.pipeline_root.field_value.string_value = ( pipeline_root) self._pipeline_runtime_spec.pipeline_run_id.field_value.string_value = ( 'run_0') # Extracts components. self._example_gen = pipeline.nodes[0].pipeline_node self._transform = pipeline.nodes[1].pipeline_node self._trainer = pipeline.nodes[2].pipeline_node self._task_queue = tq.TaskQueue() self._mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) def _is_pure_service_node(unused_pipeline_state, node_id): return node_id == self._example_gen.node_info.id self._mock_service_job_manager.is_pure_service_node.side_effect = ( _is_pure_service_node)
def testIrisPipelineNativeKeras(self): pipeline = iris_pipeline_native_keras._create_pipeline( pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=self._module_file, serving_model_dir=self._serving_model_dir, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, enable_tuning=False, direct_num_workers=1) BeamDagRunner().run(pipeline) self.assertTrue(tf.io.gfile.exists(self._serving_model_dir)) self.assertTrue(tf.io.gfile.exists(self._metadata_path)) expected_execution_count = 9 # 8 components + 1 resolver metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) self.assertGreaterEqual(artifact_count, execution_count) self.assertEqual(expected_execution_count, execution_count) self.assertPipelineExecution(False) # Runs pipeline the second time. BeamDagRunner().run(pipeline) # All executions but Evaluator and Pusher are cached. with metadata.Metadata(metadata_config) as m: # Artifact count is increased by 3 caused by Evaluator and Pusher. self.assertEqual(artifact_count + 3, len(m.store.get_artifacts())) artifact_count = len(m.store.get_artifacts()) self.assertEqual(expected_execution_count * 2, len(m.store.get_executions())) # Runs pipeline the third time. BeamDagRunner().run(pipeline) # Asserts cache execution. with metadata.Metadata(metadata_config) as m: # Artifact count is unchanged. self.assertEqual(artifact_count, len(m.store.get_artifacts())) self.assertEqual(expected_execution_count * 3, len(m.store.get_executions()))
def init_beam_pipeline(components, pipeline_root: Text, direct_num_workers: int) -> pipeline.Pipeline: absl.logging.info(f'Pipeline root set to: {pipeline_root}') beam_arg = [ f'--direct_num_workers={direct_num_workers}', f'--requirements_file={requirement_file}' # optional ] p = pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=components, enable_cache=False, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path), beam_pipeline_args=beam_arg) return p
def init_pipeline(components, pipeline_root: Text, direct_num_workers: int) -> pipeline.Pipeline: beam_arg = ( f"--direct_num_workers={direct_num_workers}", "--direct_running_mode=multi_processing", ) p = pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=components, enable_cache=True, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path), beam_pipeline_args=beam_arg, ) return p
def testDockerComponentLauncherInBeam(self): beam_dag_runner.BeamDagRunner(config=pipeline_config.PipelineConfig( supported_launcher_classes=[ docker_component_launcher.DockerComponentLauncher ], default_component_configs=[ docker_component_config.DockerComponentConfig() ])).run( _create_pipeline(pipeline_name=self._pipeline_name, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, name='docker_e2e_test_in_beam')) metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) with metadata.Metadata(metadata_config) as m: self.assertEqual(1, len(m.store.get_executions()))
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text, metadata_path: Text) -> pipeline.Pipeline: """Implements the chicago taxi pipeline with TFX.""" examples = external_input(data_root) # Brings data into the pipeline or otherwise joins/converts training data. example_gen = CsvExampleGen(input_base=examples) # Computes statistics over data for visualization and example validation. statistics_gen = StatisticsGen(input_data=example_gen.outputs['examples']) return pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=[example_gen, statistics_gen], enable_cache=True, metadata_connection_config=metadata.sqlite_metadata_connection_config( metadata_path))