Esempio n. 1
0
    def testDockerComponentLauncherInBeam(self):

        beam_dag_runner.BeamDagRunner().run(
            _create_pipeline(pipeline_name=self._pipeline_name,
                             pipeline_root=self._pipeline_root,
                             metadata_path=self._metadata_path,
                             name='docker_e2e_test_in_beam'))

        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            self.assertEqual(1, len(m.store.get_executions()))
Esempio n. 2
0
 def setUp(self):
   super(PipelineTest, self).setUp()
   tmp_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir())
   self._tmp_file = os.path.join(tmp_dir, self._testMethodName,
                                 tempfile.mkstemp(prefix='cli_tmp_')[1])
   self._tmp_dir = os.path.join(tmp_dir, self._testMethodName,
                                tempfile.mkdtemp(prefix='cli_tmp_')[1])
   # Back up the environmental variable.
   self._original_tmp_value = os.environ.get(
       'TFX_JSON_EXPORT_PIPELINE_ARGS_PATH')
   self._metadata_connection_config = metadata.sqlite_metadata_connection_config(
       os.path.join(self._tmp_dir, 'metadata'))
    def testPenguinPipelineLocal(self, model_framework):
        module_file = self._module_file_name(model_framework)
        pipeline = penguin_pipeline_local._create_pipeline(
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=module_file,
            accuracy_threshold=0.1,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._metadata_path,
            enable_tuning=False,
            examplegen_input_config=None,
            examplegen_range_config=None,
            resolver_range_config=None,
            beam_pipeline_args=[])

        logging.info('Starting the first pipeline run.')
        LocalDagRunner().run(pipeline)

        self.assertTrue(fileio.exists(self._serving_model_dir))
        self.assertTrue(fileio.exists(self._metadata_path))
        expected_execution_count = 9  # 8 components + 1 resolver
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(expected_execution_count, execution_count)

        self._assertPipelineExecution(False)

        logging.info('Starting the second pipeline run. All components except '
                     'Evaluator and Pusher will use cached results.')
        LocalDagRunner().run(pipeline)

        with metadata.Metadata(metadata_config) as m:
            # Artifact count is increased by 3 caused by Evaluator and Pusher.
            self.assertLen(m.store.get_artifacts(), artifact_count + 3)
            artifact_count = len(m.store.get_artifacts())
            self.assertLen(m.store.get_executions(),
                           expected_execution_count * 2)

        logging.info('Starting the third pipeline run. '
                     'All components will use cached results.')
        LocalDagRunner().run(pipeline)

        # Asserts cache execution.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is unchanged.
            self.assertLen(m.store.get_artifacts(), artifact_count)
            self.assertLen(m.store.get_executions(),
                           expected_execution_count * 3)
Esempio n. 4
0
def create_pipeline(pipeline_name,
                    pipeline_root,
                    model_name,
                    text_path,
                    train_config,
                    mlflow_tracking_url,
                    encoding='utf-8',
                    combine=50000,
                    enable_cache=False):
    for key, value in train_config.items():
        default_train_config[key] = value

    pretrained_model = DownloadPretrainedModel(model_name=model_name)

    create_dataset = CreateDataset(
        text_path=text_path,
        model_path=pretrained_model.outputs["model_path"],
        encoding=encoding,
        combine=combine)

    train_gpt2 = TrainGPT2(dataset_path=create_dataset.outputs["dataset_path"],
                           model_path=pretrained_model.outputs["model_path"],
                           model_name=model_name,
                           train_config=train_config,
                           combine=combine,
                           encoding=encoding)
    export_tfserving = ExportToTFServing(
        model_path=pretrained_model.outputs["model_path"],
        checkpoint_dir=train_gpt2.outputs["checkpoint_dir"],
        train_config=train_config)
    mlflow_import = MLFlowImport(
        model_name=model_name,
        mlflow_tracking_url=mlflow_tracking_url,
        artifact_dir=train_gpt2.outputs["sample_dir"],
        hyperparameter_dir=train_gpt2.outputs["hyperparameter_dir"],
        metric_dir=train_gpt2.outputs["metric_dir"],
        model_dir=export_tfserving.outputs["export_dir"])

    pipeline_root = os.path.join(pipeline_root, 'pipelines', pipeline_name)
    metadata_path = os.path.join(pipeline_root, 'metadata', pipeline_name,
                                 'metadata.db')

    tfx_pipeline = pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            create_dataset, pretrained_model, train_gpt2, export_tfserving,
            mlflow_import
        ],
        enable_cache=enable_cache,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path))
    return tfx_pipeline
Esempio n. 5
0
 def setUp(self):
     self._component = _FakeComponent(
         _FakeComponentSpec(input=channel.Channel(type_name='type_a'),
                            output=channel.Channel(type_name='type_b')))
     self._pipeline_info = data_types.PipelineInfo('name', 'root')
     self._driver_args = data_types.DriverArgs(True)
     self._metadata_connection_config = metadata.sqlite_metadata_connection_config(
         os.path.join(os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
                      'metadata'))
     self._parent_dag = models.DAG(dag_id=self._pipeline_info.pipeline_name,
                                   start_date=datetime.datetime(2018, 1, 1),
                                   schedule_interval=None)
Esempio n. 6
0
def run_pipeline():
    my_pipeline = create_pipeline(
        pipeline_name=PIPELINE_NAME,
        pipeline_root=PIPELINE_ROOT,
        data_root=DATA_ROOT,
        test_data_root=TEST_DATA_ROOT,
        module_file=MODULE_FILE,
        serving_model_dir=SERVING_MODEL_DIR,
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            METADATA_PATH))

    BeamDagRunner().run(my_pipeline)
Esempio n. 7
0
def run():
    LocalDagRunner().run(
        pipeline.create_pipeline(
            pipeline_name="fishing-classifier",
            data_path="data",
            outputs_path="outputs",
            output_model_path="outputs/model",
            train_args=trainer_pb2.TrainArgs(num_steps=100),
            eval_args=trainer_pb2.EvalArgs(num_steps=15),
            eval_accuracy_threshold=0.6,
            metadata_connection_config=sqlite_metadata_connection_config(
                f"outputs/metadata.db"),
        ))
Esempio n. 8
0
    def setUp(self):
        super().setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())
        self._pipeline_root = pipeline_root

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        self._metadata_path = metadata_path
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up the pipeline.
        pipeline = self._make_pipeline(self._pipeline_root, str(uuid.uuid4()))
        self._pipeline = pipeline

        # Extracts components.
        self._example_gen = test_utils.get_node(pipeline, 'my_example_gen')
        self._stats_gen = test_utils.get_node(pipeline, 'my_statistics_gen')
        self._schema_gen = test_utils.get_node(pipeline, 'my_schema_gen')
        self._transform = test_utils.get_node(pipeline, 'my_transform')
        self._example_validator = test_utils.get_node(pipeline,
                                                      'my_example_validator')
        self._trainer = test_utils.get_node(pipeline, 'my_trainer')
        self._evaluator = test_utils.get_node(pipeline, 'my_evaluator')
        self._chore_a = test_utils.get_node(pipeline, 'chore_a')
        self._chore_b = test_utils.get_node(pipeline, 'chore_b')

        self._task_queue = tq.TaskQueue()

        self._mock_service_job_manager = mock.create_autospec(
            service_jobs.ServiceJobManager, instance=True)

        self._mock_service_job_manager.is_pure_service_node.side_effect = (
            lambda _, node_id: node_id == self._example_gen.node_info.id)
        self._mock_service_job_manager.is_mixed_service_node.side_effect = (
            lambda _, node_id: node_id == self._transform.node_info.id)

        def _default_ensure_node_services(unused_pipeline_state, node_id):
            self.assertIn(
                node_id,
                (self._example_gen.node_info.id, self._transform.node_info.id))
            return service_jobs.ServiceStatus.SUCCESS

        self._mock_service_job_manager.ensure_node_services.side_effect = (
            _default_ensure_node_services)
Esempio n. 9
0
    def __init__(self,
                 pipeline_name: Optional[Text] = None,
                 pipeline_root: Optional[Text] = None,
                 metadata_connection_config: Optional[
                     metadata_store_pb2.ConnectionConfig] = None,
                 beam_pipeline_args: Optional[List[Text]] = None):
        """Initialize an InteractiveContext.

    Args:
      pipeline_name: Optional name of the pipeline for ML Metadata tracking
        purposes. If not specified, a name will be generated for you.
      pipeline_root: Optional path to the root of the pipeline's outputs. If not
        specified, an ephemeral temporary directory will be created and used.
      metadata_connection_config: Optional metadata_store_pb2.ConnectionConfig
        instance used to configure connection to a ML Metadata connection. If
        not specified, an ephemeral SQLite MLMD connection contained in the
        pipeline_root directory with file name "metadata.sqlite" will be used.
      beam_pipeline_args: Optional Beam pipeline args for beam jobs within
        executor. Executor will use beam DirectRunner as Default.
    """

        if not pipeline_name:
            pipeline_name = (
                'interactive-%s' %
                datetime.datetime.now().isoformat().replace(':', '_'))
        if not pipeline_root:
            pipeline_root = tempfile.mkdtemp(prefix='tfx-%s-' % pipeline_name)
            absl.logging.warning(
                'InteractiveContext pipeline_root argument not provided: using '
                'temporary directory %s as root for pipeline outputs.',
                pipeline_root)
        if not metadata_connection_config:
            # TODO(ccy): consider reconciling similar logic here with other instances
            # in tfx/orchestration/...
            metadata_sqlite_path = os.path.join(pipeline_root,
                                                self._DEFAULT_SQLITE_FILENAME)
            metadata_connection_config = metadata.sqlite_metadata_connection_config(
                metadata_sqlite_path)
            absl.logging.warning(
                'InteractiveContext metadata_connection_config not provided: using '
                'SQLite ML Metadata database at %s.', metadata_sqlite_path)
        self.pipeline_name = pipeline_name
        self.pipeline_root = pipeline_root
        self.metadata_connection_config = metadata_connection_config
        self.beam_pipeline_args = beam_pipeline_args or []

        # Register IPython formatters.
        notebook_formatters.register_formatters()

        # Register artifact visualizations.
        standard_visualizations.register_standard_visualizations()
Esempio n. 10
0
    def setUp(self):
        super(LauncherTest, self).setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)
        self._testdata_dir = os.path.join(os.path.dirname(__file__),
                                          'testdata')

        # Sets up pipelines
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'pipeline_for_launcher_test.pbtxt'), pipeline)
        # Substitute the runtime parameter to be a concrete run_id
        runtime_parameter_utils.substitute_runtime_parameter(
            pipeline, {
                constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'test_run',
            })
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
            pipeline_root)
        self._pipeline_runtime_spec.pipeline_run_id.field_value.string_value = (
            'test_run_0')

        # Extracts components
        self._example_gen = pipeline.nodes[0].pipeline_node
        self._transform = pipeline.nodes[1].pipeline_node
        self._trainer = pipeline.nodes[2].pipeline_node
        self._importer = pipeline.nodes[3].pipeline_node
        self._resolver = pipeline.nodes[4].pipeline_node

        # Fakes an ExecutorSpec for Trainer
        self._trainer_executor_spec = _PYTHON_CLASS_EXECUTABLE_SPEC()
        # Fakes an executor operator
        self._test_executor_operators = {
            _PYTHON_CLASS_EXECUTABLE_SPEC: _FakeExecutorOperator
        }
        # Fakes an custom driver spec
        self._custom_driver_spec = _PYTHON_CLASS_EXECUTABLE_SPEC()
        self._custom_driver_spec.class_path = 'tfx.orchestration.portable.launcher_test._FakeExampleGenLikeDriver'
 def _get_latest_output_artifact(self, component_name, output_key):
   metadata_config = metadata.sqlite_metadata_connection_config(
       self._metadata_path)
   with metadata.Metadata(metadata_config) as m:
     [exec_type_name] = [
         exec_type.name for exec_type in m.store.get_execution_types()
         if component_name in exec_type.name]
     executions = m.store.get_executions_by_type(exec_type_name)
     events = m.store.get_events_by_execution_ids([e.id for e in executions])
     output_artifact_ids = [event.artifact_id for event in events
                            if event.type in _OUTPUT_EVENT_TYPES]
     output_artifacts = m.store.get_artifacts_by_id(output_artifact_ids)
     self.assertNotEmpty(output_artifacts)
     return max(output_artifacts, key=lambda a: a.create_time_since_epoch)
Esempio n. 12
0
def create_pipeline(pipeline_name, pipeline_root, checkpoint_dir, model_name, text_dir, train_config, mlflow_tracking_url,
                    encoding='utf-8', text_token_size=50000, enable_cache=False, end_token="<|endoftext|>"):
    for key, value in train_config.items():
        default_train_config[key] = value

    create_merged_text = CreateMergedText(text_dir=text_dir,
                                          end_token=end_token,
                                          encoding=encoding)

    pretrained_model = CopyCheckpoint(checkpoint_dir=checkpoint_dir)

    create_dataset = CreateEncodedDataset(merged_text_dir=create_merged_text.outputs["merged_text_dir"],
                                          encoding_dir=pretrained_model.outputs["model_dir"],
                                          encoding=encoding,
                                          end_token=end_token)

    train_gpt2 = TrainGPT2(dataset_dir=create_dataset.outputs["dataset_dir"],
                           checkpoint_dir=pretrained_model.outputs["model_dir"],
                           encoding_dir=pretrained_model.outputs["model_dir"],
                           model_name=model_name,
                           train_config=train_config,
                           encoding=encoding)

    export_tfserving = ExportToTFServing(encoding_dir=pretrained_model.outputs["model_dir"],
                                         checkpoint_dir=train_gpt2.outputs["trained_checkpoint_dir"],
                                         train_config=train_config)

    mlflow_import = MLFlowImport(model_name=model_name,
                                 mlflow_tracking_url=mlflow_tracking_url,
                                 artifact_dir=train_gpt2.outputs["sample_dir"],
                                 hyperparameter_dir=train_gpt2.outputs["hyperparameter_dir"],
                                 metric_dir=train_gpt2.outputs["metric_dir"],
                                 model_dir=export_tfserving.outputs["export_dir"])

    pipeline_path = os.path.join(pipeline_root, 'pipelines', pipeline_name)
    metadata_path = os.path.join(pipeline_root, 'metadata', pipeline_name,
                                 'metadata.db')

    tfx_pipeline = pipeline.Pipeline(pipeline_name=pipeline_name,
                                     pipeline_root=pipeline_path,
                                     components=[create_merged_text,
                                                 pretrained_model,
                                                 create_dataset,
                                                 train_gpt2,
                                                 export_tfserving,
                                                 mlflow_import],
                                     enable_cache=enable_cache,
                                     metadata_connection_config=metadata.sqlite_metadata_connection_config(
                                         metadata_path))
    return tfx_pipeline
Esempio n. 13
0
def create_dag(name, url, output_dir, airflow_config):
    pipeline_name = name
    pipeline_root = os.path.join(output_dir, 'pipelines', pipeline_name)
    metadata_path = os.path.join(output_dir, 'metadata', pipeline_name,
                                 'metadata.db')

    crawler = NewsCrawler(url=url)
    tfx_pipeline = pipeline.Pipeline(pipeline_name=pipeline_name,
                                     pipeline_root=pipeline_root,
                                     components=[crawler],
                                     enable_cache=False,
                                     metadata_connection_config=metadata.sqlite_metadata_connection_config(
                                         metadata_path))

    return AirflowDagRunner(AirflowPipelineConfig(airflow_config)).run(tfx_pipeline)
Esempio n. 14
0
  def setUp(self):
    super(PipelineOpsTest, self).setUp()
    pipeline_root = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self.id())

    # Makes sure multiple connections within a test always connect to the same
    # MLMD instance.
    metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
    self._metadata_path = metadata_path
    connection_config = metadata.sqlite_metadata_connection_config(
        metadata_path)
    connection_config.sqlite.SetInParent()
    self._mlmd_connection = metadata.Metadata(
        connection_config=connection_config)
Esempio n. 15
0
def init_pipeline(components, pipeline_root: Text,
                  direct_num_workers: int) -> pipeline.Pipeline:

    beam_arg = [
        f'--direct_num_workers={direct_num_workers}',
    ]
    p = pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_arg)
    return p
    def testMNISTPipelineNativeKeras(self):
        if not tf.executing_eagerly():
            self.skipTest('The test requires TF2.')
        BeamDagRunner().run(
            mnist_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                module_file_lite=self._module_file_lite,
                serving_model_dir=self._serving_model_dir,
                serving_model_dir_lite=self._serving_model_dir_lite,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        self.assertTrue(fileio.exists(self._serving_model_dir))
        self.assertTrue(fileio.exists(self._serving_model_dir_lite))
        self.assertTrue(fileio.exists(self._metadata_path))
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        expected_execution_count = 11
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(execution_count, expected_execution_count)

        self.assertPipelineExecution()

        # Runs pipeline the second time.
        BeamDagRunner().run(
            mnist_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                module_file_lite=self._module_file_lite,
                serving_model_dir=self._serving_model_dir,
                serving_model_dir_lite=self._serving_model_dir_lite,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        # Asserts cache execution.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is unchanged.
            self.assertLen(m.store.get_artifacts(), artifact_count)
            self.assertLen(m.store.get_executions(),
                           expected_execution_count * 2)
Esempio n. 17
0
  def setUp(self):
    super(SyncPipelineTaskGeneratorTest, self).setUp()
    pipeline_root = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self.id())
    self._pipeline_root = pipeline_root

    # Makes sure multiple connections within a test always connect to the same
    # MLMD instance.
    metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
    self._metadata_path = metadata_path
    connection_config = metadata.sqlite_metadata_connection_config(
        metadata_path)
    connection_config.sqlite.SetInParent()
    self._mlmd_connection = metadata.Metadata(
        connection_config=connection_config)

    # Sets up the pipeline.
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'sync_pipeline.pbtxt'),
        pipeline)
    self._pipeline_run_id = str(uuid.uuid4())
    runtime_parameter_utils.substitute_runtime_parameter(
        pipeline, {
            'pipeline_root': pipeline_root,
            'pipeline_run_id': self._pipeline_run_id
        })
    self._pipeline = pipeline

    # Extracts components.
    self._example_gen = _get_node(pipeline, 'my_example_gen')
    self._stats_gen = _get_node(pipeline, 'my_statistics_gen')
    self._schema_gen = _get_node(pipeline, 'my_schema_gen')
    self._transform = _get_node(pipeline, 'my_transform')
    self._example_validator = _get_node(pipeline, 'my_example_validator')
    self._trainer = _get_node(pipeline, 'my_trainer')

    self._task_queue = tq.TaskQueue()

    self._mock_service_job_manager = mock.create_autospec(
        service_jobs.ServiceJobManager, instance=True)

    self._mock_service_job_manager.is_pure_service_node.side_effect = (
        lambda _, node_id: node_id == self._example_gen.node_info.id)
    self._mock_service_job_manager.is_mixed_service_node.side_effect = (
        lambda _, node_id: node_id == self._transform.node_info.id)
Esempio n. 18
0
    def assertComponentExecutionCount(self, count: int) -> None:
        """Checks the number of component executions recorded in MLMD.

    Args:
      count: Number of components that should have succeeded and produced
        artifacts recorded in MLMD.
    """

        self.assertTrue(tf.io.gfile.exists(self.metadata_path))
        metadata_config = metadata.sqlite_metadata_connection_config(
            self.metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(count, execution_count)
Esempio n. 19
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     metadata_path: Text) -> pipeline.Pipeline:
    """Implements the chicago taxi pipeline with TFX."""

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input_base=data_root)

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[example_gen],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        additional_pipeline_args={},
    )
Esempio n. 20
0
 def setUp(self):
   super(AirflowComponentTest, self).setUp()
   self._component = _FakeComponent(
       _FakeComponentSpec(
           input=types.Channel(type=_ArtifactTypeA),
           output=types.Channel(type=_ArtifactTypeB)),
       enable_cache=True)
   self._pipeline_info = data_types.PipelineInfo('name', 'root')
   self._driver_args = data_types.DriverArgs(True)
   self._metadata_connection_config = metadata.sqlite_metadata_connection_config(
       os.path.join(
           os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), 'metadata'))
   self._parent_dag = models.DAG(
       dag_id=self._pipeline_info.pipeline_name,
       start_date=datetime.datetime(2018, 1, 1),
       schedule_interval=None)
Esempio n. 21
0
def create_dag(name,
               url,
               airflow_config,
               backup_dir="pipelines_backup",
               mongo_ip=None,
               mongo_port=None,
               dag_type="default",
               output_dir="/output",
               updated_collections=[],
               update_collections=[]):
    pipeline_name = name.replace(".py", "")
    pipeline_root = os.path.join(output_dir, 'pipelines', pipeline_name)
    metadata_path = os.path.join(output_dir, 'metadata', pipeline_name,
                                 'metadata.db')

    components = []
    if dag_type == "default":
        crawler = NewsCrawler(url=url)
        mongo = MongoImport(ip=mongo_ip,
                            port=mongo_port,
                            rss_feed=crawler.outputs["rss_feed"],
                            colname=pipeline_name)
        components = components + [crawler, mongo]
    elif dag_type == "backup":
        load_news = OldNewsImport(backup_dir=os.path.join(
            "/output", backup_dir),
                                  ip=mongo_ip,
                                  port=mongo_port)
        components = components + [load_news]
    elif dag_type == "update":
        update_news = UpdateMongoNews(ip=mongo_ip,
                                      port=mongo_port,
                                      updated_collections=updated_collections,
                                      update_collections=update_collections)
        components = components + [update_news]

    airflow_config["catchup"] = False
    tfx_pipeline = pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=False,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path))

    return AirflowDagRunner(
        AirflowPipelineConfig(airflow_config)).run(tfx_pipeline)
def _create_pipeline(
    pipeline_name,
    pipeline_root,
    metadata_path,
    name,
):
    hello_world = _HelloWorldComponent(name=name)

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[hello_world],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        additional_pipeline_args={},
    )
Esempio n. 23
0
    def testCIFAR10PipelineNativeKeras(self):
        pipeline = cifar10_pipeline_native_keras._create_pipeline(
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir_lite=self._serving_model_dir_lite,
            pipeline_root=self._pipeline_root,
            metadata_path=self._metadata_path,
            labels_path=self._labels_path,
            beam_pipeline_args=[])

        BeamDagRunner().run(pipeline)

        self.assertTrue(fileio.exists(self._serving_model_dir_lite))
        self.assertTrue(fileio.exists(self._metadata_path))
        expected_execution_count = 9  # 8 components + 1 resolver
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(expected_execution_count, execution_count)

        self.assertPipelineExecution()

        # Runs pipeline the second time.
        BeamDagRunner().run(pipeline)

        # All executions but Evaluator and Pusher are cached.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is increased by 3 caused by Evaluator and Pusher.
            self.assertEqual(artifact_count + 3, len(m.store.get_artifacts()))
            artifact_count = len(m.store.get_artifacts())
            self.assertEqual(expected_execution_count * 2,
                             len(m.store.get_executions()))

        # Runs pipeline the third time.
        BeamDagRunner().run(pipeline)

        # Asserts cache execution.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is unchanged.
            self.assertEqual(artifact_count, len(m.store.get_artifacts()))
            self.assertEqual(expected_execution_count * 3,
                             len(m.store.get_executions()))
def init_pipeline(components, pipeline_root: Text,
                  direct_num_workers: int) -> pipeline.Pipeline:

    beam_arg = [
        f"--direct_num_workers={direct_num_workers}",
        f"--direct_running_mode=multi_processing",
    ]
    tfx_pipeline = pipeline.Pipeline(
        pipeline_name=config.PIPELINE_NAME,
        pipeline_root=config.PIPELINE_ROOT,
        components=components,
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            config.METADATA_PATH),
        beam_pipeline_args=beam_arg,
    )
    return tfx_pipeline
  def setUp(self):
    super(SyncPipelineTaskGeneratorTest, self).setUp()
    pipeline_root = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self.id())
    self._pipeline_root = pipeline_root

    # Makes sure multiple connections within a test always connect to the same
    # MLMD instance.
    metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
    self._metadata_path = metadata_path
    connection_config = metadata.sqlite_metadata_connection_config(
        metadata_path)
    connection_config.sqlite.SetInParent()
    self._mlmd_connection = metadata.Metadata(
        connection_config=connection_config)

    # Sets up the pipeline.
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'sync_pipeline.pbtxt'),
        pipeline)
    self._pipeline = pipeline
    self._pipeline_info = pipeline.pipeline_info
    self._pipeline_runtime_spec = pipeline.runtime_spec
    self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
        pipeline_root)
    self._pipeline_runtime_spec.pipeline_run_id.field_value.string_value = (
        'run_0')

    # Extracts components.
    self._example_gen = pipeline.nodes[0].pipeline_node
    self._transform = pipeline.nodes[1].pipeline_node
    self._trainer = pipeline.nodes[2].pipeline_node

    self._task_queue = tq.TaskQueue()

    self._mock_service_job_manager = mock.create_autospec(
        service_jobs.ServiceJobManager, instance=True)

    def _is_pure_service_node(unused_pipeline_state, node_id):
      return node_id == self._example_gen.node_info.id

    self._mock_service_job_manager.is_pure_service_node.side_effect = (
        _is_pure_service_node)
Esempio n. 26
0
    def testIrisPipelineNativeKeras(self):
        pipeline = iris_pipeline_native_keras._create_pipeline(
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._metadata_path,
            enable_tuning=False,
            direct_num_workers=1)

        BeamDagRunner().run(pipeline)

        self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
        self.assertTrue(tf.io.gfile.exists(self._metadata_path))
        expected_execution_count = 9  # 8 components + 1 resolver
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(expected_execution_count, execution_count)

        self.assertPipelineExecution(False)

        # Runs pipeline the second time.
        BeamDagRunner().run(pipeline)

        # All executions but Evaluator and Pusher are cached.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is increased by 3 caused by Evaluator and Pusher.
            self.assertEqual(artifact_count + 3, len(m.store.get_artifacts()))
            artifact_count = len(m.store.get_artifacts())
            self.assertEqual(expected_execution_count * 2,
                             len(m.store.get_executions()))

        # Runs pipeline the third time.
        BeamDagRunner().run(pipeline)

        # Asserts cache execution.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is unchanged.
            self.assertEqual(artifact_count, len(m.store.get_artifacts()))
            self.assertEqual(expected_execution_count * 3,
                             len(m.store.get_executions()))
def init_beam_pipeline(components, pipeline_root: Text,
                       direct_num_workers: int) -> pipeline.Pipeline:

    absl.logging.info(f'Pipeline root set to: {pipeline_root}')
    beam_arg = [
        f'--direct_num_workers={direct_num_workers}',
        f'--requirements_file={requirement_file}'  # optional
    ]

    p = pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=False,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_arg)
    return p
def init_pipeline(components, pipeline_root: Text,
                  direct_num_workers: int) -> pipeline.Pipeline:

    beam_arg = (
        f"--direct_num_workers={direct_num_workers}",
        "--direct_running_mode=multi_processing",
    )

    p = pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_arg,
    )
    return p
Esempio n. 29
0
    def testDockerComponentLauncherInBeam(self):

        beam_dag_runner.BeamDagRunner(config=pipeline_config.PipelineConfig(
            supported_launcher_classes=[
                docker_component_launcher.DockerComponentLauncher
            ],
            default_component_configs=[
                docker_component_config.DockerComponentConfig()
            ])).run(
                _create_pipeline(pipeline_name=self._pipeline_name,
                                 pipeline_root=self._pipeline_root,
                                 metadata_path=self._metadata_path,
                                 name='docker_e2e_test_in_beam'))

        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            self.assertEqual(1, len(m.store.get_executions()))
Esempio n. 30
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     metadata_path: Text) -> pipeline.Pipeline:
    """Implements the chicago taxi pipeline with TFX."""
    examples = external_input(data_root)

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input_base=examples)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(input_data=example_gen.outputs['examples'])

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[example_gen, statistics_gen],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path))