def testPartiallySubstituteRuntimeParameter(self):
     pipeline = pipeline_pb2.Pipeline()
     expected = pipeline_pb2.Pipeline()
     self.load_proto_from_text('pipeline_with_runtime_parameter.pbtxt',
                               pipeline)
     self.load_proto_from_text(
         'pipeline_with_runtime_parameter_partially_substituted.pbtxt',
         expected)
     runtime_parameter_utils.substitute_runtime_parameter(
         pipeline, {
             'context_name_rp': 'my_context',
         })
     self.assertProtoEquals(pipeline, expected)
Beispiel #2
0
  def setUp(self):
    super(BaseComponentWithPipelineParamTest, self).setUp()

    test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
    example_gen_buckets = data_types.RuntimeParameter(
        name='example-gen-buckets', ptype=int, default=10)

    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input=channel_utils.as_channel([examples]),
        output_config={
            'split_config': {
                'splits': [{
                    'name': 'examples',
                    'hash_buckets': example_gen_buckets
                }]
            }
        })
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'], instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name=self._test_pipeline_name,
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    self._tfx_ir = pipeline_pb2.Pipeline()
    with dsl.Pipeline('test_pipeline'):
      self.example_gen = base_component.BaseComponent(
          component=example_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
          tfx_ir=self._tfx_ir)
      self.statistics_gen = base_component.BaseComponent(
          component=statistics_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
          tfx_ir=self._tfx_ir
      )

    self.tfx_example_gen = example_gen
    self.tfx_statistics_gen = statistics_gen
Beispiel #3
0
    def setUp(self):
        super(LauncherTest, self).setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        # Make sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Setup pipelines
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text('pipeline_for_launcher_test.pbtxt', pipeline)
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
            pipeline_root)
        self._pipeline_runtime_spec.pipeline_run_id.field_value.string_value = (
            'test_run_0')

        # Extract components
        self._example_gen = pipeline.nodes[0].pipeline_node
        self._transform = pipeline.nodes[1].pipeline_node
        self._trainer = pipeline.nodes[2].pipeline_node

        # Fake an executor operator
        self._test_executor_operators = {
            pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec:
            FakeExecutorOperator
        }
Beispiel #4
0
 def testReplaceExecutorWithStub(self):
     pipeline = text_format.Parse(
         """
     deployment_config {
       [type.googleapis.com/tfx.orchestration.IntermediateDeploymentConfig] {
         executor_specs {
           key: "CsvExampleGen"
           value {
             [type.googleapis.com/tfx.orchestration.executable_spec.PythonClassExecutableSpec] {
               class_path: "tfx.components.example_gen.csv_example_gen.executor.Executor"
               extra_flags: "--my_testing_beam_pipeline_args=foo"
             }
           }
         }
       }
     }""", pipeline_pb2.Pipeline())
     expected = """
     deployment_config {
       [type.googleapis.com/tfx.orchestration.IntermediateDeploymentConfig] {
         executor_specs {
           key: "CsvExampleGen"
           value {
             [type.googleapis.com/tfx.orchestration.executable_spec.PythonClassExecutableSpec] {
               class_path: "tfx.experimental.pipeline_testing.base_stub_executor.BaseStubExecutor"
               extra_flags: "--test_data_dir=/dummy/a"
               extra_flags: "--component_id=CsvExampleGen"
             }
           }
         }
       }
     }"""
     pipeline_mock.replace_executor_with_stub(pipeline, '/dummy/a', [])
     self.assertProtoEquals(expected, pipeline)
Beispiel #5
0
    def setUp(self):
        super().setUp()
        example_gen = csv_example_gen_component.CsvExampleGen(
            input_base='data_input')
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples']).with_id('foo')

        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        self._tfx_ir = pipeline_pb2.Pipeline()
        with dsl.Pipeline('test_pipeline'):
            self.component = base_component.BaseComponent(
                component=statistics_gen,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                tfx_ir=self._tfx_ir,
                pod_labels_to_attach={},
                runtime_parameters=[])
        self.tfx_component = statistics_gen
Beispiel #6
0
 def setUp(self):
     super(BeamDagRunnerTest, self).setUp()
     # Setup pipelines
     self._pipeline = pipeline_pb2.Pipeline()
     self.load_proto_from_text(
         os.path.join(os.path.dirname(__file__), 'testdata',
                      'pipeline_for_launcher_test.pbtxt'), self._pipeline)
Beispiel #7
0
    def test_stop_node_no_active_executions(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'async_pipeline.pbtxt'), pipeline)
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        node_uid = task_lib.NodeUid(node_id='my_trainer',
                                    pipeline_uid=pipeline_uid)
        with self._mlmd_connection as m:
            pstate.PipelineState.new(m, pipeline)
            pipeline_ops.stop_node(m, node_uid)
            pipeline_state = pstate.PipelineState.load(m, pipeline_uid)

            # The node should be stop-initiated even when node is inactive to prevent
            # future triggers.
            with pipeline_state:
                self.assertEqual(
                    status_lib.Code.CANCELLED,
                    pipeline_state.node_stop_initiated_reason(node_uid).code)

            # Restart node.
            pipeline_state = pipeline_ops.initiate_node_start(m, node_uid)
            with pipeline_state:
                self.assertIsNone(
                    pipeline_state.node_stop_initiated_reason(node_uid))
Beispiel #8
0
  def test_stop_node_wait_for_inactivation(self):
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'),
        pipeline)
    trainer = pipeline.nodes[2].pipeline_node
    test_utils.fake_component_output(
        self._mlmd_connection, trainer, active=True)
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid)
    with self._mlmd_connection as m:
      pstate.PipelineState.new(m, pipeline).commit()

      def _inactivate(execution):
        time.sleep(2.0)
        with pipeline_ops._PIPELINE_OPS_LOCK:
          execution.last_known_state = metadata_store_pb2.Execution.COMPLETE
          m.store.put_executions([execution])

      execution = task_gen_utils.get_executions(m, trainer)[0]
      thread = threading.Thread(
          target=_inactivate, args=(copy.deepcopy(execution),))
      thread.start()
      pipeline_ops.stop_node(m, node_uid, timeout_secs=5.0)
      thread.join()

      pipeline_state = pstate.PipelineState.load(m, pipeline_uid)
      self.assertEqual(status_lib.Code.CANCELLED,
                       pipeline_state.node_stop_initiated_reason(node_uid).code)

      # Restart node.
      pipeline_state = pipeline_ops.initiate_node_start(m, node_uid)
      self.assertIsNone(pipeline_state.node_stop_initiated_reason(node_uid))
Beispiel #9
0
def _test_pipeline(pipeline_id,
                   execution_mode: pipeline_pb2.Pipeline.ExecutionMode = (
                       pipeline_pb2.Pipeline.ASYNC)):
    pipeline = pipeline_pb2.Pipeline()
    pipeline.pipeline_info.id = pipeline_id
    pipeline.execution_mode = execution_mode
    return pipeline
Beispiel #10
0
    def compile(self,
                tfx_pipeline: pipeline.Pipeline) -> pipeline_pb2.Pipeline:
        """Compiles a tfx pipeline into uDSL proto.

    Args:
      tfx_pipeline: A TFX pipeline.

    Returns:
      A Pipeline proto that encodes all necessary information of the pipeline.
    """
        context = _CompilerContext(tfx_pipeline.pipeline_info)
        pipeline_pb = pipeline_pb2.Pipeline()
        pipeline_pb.pipeline_info.id = context.pipeline_info.pipeline_name
        compiler_utils.set_runtime_parameter_pb(
            pipeline_pb.runtime_spec.pipeline_root.runtime_parameter,
            constants.PIPELINE_ROOT_PARAMETER_NAME, str,
            context.pipeline_info.pipeline_root)
        compiler_utils.set_runtime_parameter_pb(
            pipeline_pb.runtime_spec.pipeline_run_id.runtime_parameter,
            constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

        assert compiler_utils.ensure_topological_order(
            tfx_pipeline.components), (
                "Pipeline components are not topologically sorted.")
        for node in tfx_pipeline.components:
            component_pb = self._compile_node(node, context)
            pipeline_or_node = pipeline_pb.PipelineOrNode()
            pipeline_or_node.pipeline_node.CopyFrom(component_pb)
            # TODO(b/158713812): Support sub-pipeline.
            pipeline_pb.nodes.append(pipeline_or_node)
            context.component_pbs[node.id] = component_pb

        # Currently only synchronous mode is supported
        pipeline_pb.execution_mode = pipeline_pb2.Pipeline.ExecutionMode.SYNC
        return pipeline_pb
Beispiel #11
0
    def setUp(self):
        super(AsyncPipelineTaskGeneratorTest, self).setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())
        self._pipeline_root = pipeline_root

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        self._metadata_path = metadata_path
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up the pipeline.
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'async_pipeline.pbtxt'), pipeline)
        self._pipeline = pipeline
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
            pipeline_root)

        # Extracts components.
        self._example_gen = pipeline.nodes[0].pipeline_node
        self._transform = pipeline.nodes[1].pipeline_node
        self._trainer = pipeline.nodes[2].pipeline_node

        self._task_queue = tq.TaskQueue()
        self._ignore_node_ids = set([self._example_gen.node_info.id])
Beispiel #12
0
    def setUp(self):
        super(BaseComponentTest, self).setUp()
        example_gen = csv_example_gen_component.CsvExampleGen(
            input_base='data_input')
        statistics_gen = statistics_gen_component.StatisticsGen(
            examples=example_gen.outputs['examples'], instance_name='foo')

        pipeline = tfx_pipeline.Pipeline(
            pipeline_name=self._test_pipeline_name,
            pipeline_root='test_pipeline_root',
            metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
            components=[example_gen, statistics_gen],
        )

        test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')

        self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
        self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
        self._tfx_ir = pipeline_pb2.Pipeline()
        with dsl.Pipeline('test_pipeline'):
            self.component = base_component.BaseComponent(
                component=statistics_gen,
                component_launcher_class=in_process_component_launcher.
                InProcessComponentLauncher,
                depends_on=set(),
                pipeline=pipeline,
                pipeline_name=self._test_pipeline_name,
                pipeline_root=test_pipeline_root,
                tfx_image='container_image',
                kubeflow_metadata_config=self._metadata_config,
                component_config=None,
                tfx_ir=self._tfx_ir,
            )
        self.tfx_component = statistics_gen
Beispiel #13
0
 def _set_up_test_pipeline_pb(self):
   """Read expected pipeline pb from a text proto file."""
   test_pb_filepath = os.path.join(
       os.path.dirname(__file__), "testdata", "iris_pipeline_ir.pbtxt")
   with open(test_pb_filepath) as text_pb_file:
     self._pipeline_pb = text_format.ParseLines(text_pb_file,
                                                pipeline_pb2.Pipeline())
Beispiel #14
0
    def setUp(self):
        super().setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)
        self._testdata_dir = os.path.join(os.path.dirname(__file__),
                                          'testdata')

        # Sets up pipelines
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'pipeline_for_resolver_test.pbtxt'), pipeline)
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        runtime_parameter_utils.substitute_runtime_parameter(
            pipeline, {
                constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'my_pipeline_run',
            })

        # Extracts components
        self._my_trainer = pipeline.nodes[0].pipeline_node
        self._resolver_node = pipeline.nodes[1].pipeline_node
Beispiel #15
0
    def test_node_state_for_skipped_nodes_in_partial_pipeline_run(self):
        """Tests that nodes marked to be skipped in a partial pipeline run have the right node state."""
        with self._mlmd_connection as m:
            pipeline = pipeline_pb2.Pipeline()
            pipeline.pipeline_info.id = 'pipeline1'
            pipeline.execution_mode = pipeline_pb2.Pipeline.SYNC
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'

            # Mark ExampleGen and Transform to be skipped.
            pipeline.nodes[0].pipeline_node.execution_options.skip.SetInParent(
            )
            pipeline.nodes[1].pipeline_node.execution_options.skip.SetInParent(
            )

            eg_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen')
            transform_node_uid = task_lib.NodeUid(pipeline_uid, 'Transform')
            trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer')

            pstate.PipelineState.new(m, pipeline)
            with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
                self.assertEqual(
                    {
                        eg_node_uid:
                        pstate.NodeState(state=pstate.NodeState.COMPLETE),
                        transform_node_uid:
                        pstate.NodeState(state=pstate.NodeState.COMPLETE),
                        trainer_node_uid:
                        pstate.NodeState(state=pstate.NodeState.STARTED),
                    }, pipeline_state.get_node_states_dict())
Beispiel #16
0
def _get_pipeline_from_orchestrator_execution(
        execution: metadata_store_pb2.Execution) -> pipeline_pb2.Pipeline:
    pipeline_ir_b64 = data_types_utils.get_metadata_value(
        execution.properties[_PIPELINE_IR])
    pipeline = pipeline_pb2.Pipeline()
    pipeline.ParseFromString(base64.b64decode(pipeline_ir_b64))
    return pipeline
Beispiel #17
0
 def _get_test_pipeline_pb(self, file_name: str) -> pipeline_pb2.Pipeline:
     """Reads expected pipeline pb from a text proto file."""
     test_pb_filepath = os.path.join(os.path.dirname(__file__), "testdata",
                                     file_name)
     with open(test_pb_filepath) as text_pb_file:
         return text_format.ParseLines(text_pb_file,
                                       pipeline_pb2.Pipeline())
Beispiel #18
0
  def test_stop_node_wait_for_inactivation_timeout(self):
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'),
        pipeline)
    trainer = pipeline.nodes[2].pipeline_node
    test_utils.fake_component_output(
        self._mlmd_connection, trainer, active=True)
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid)
    with self._mlmd_connection as m:
      pstate.PipelineState.new(m, pipeline).commit()
      with self.assertRaisesRegex(
          status_lib.StatusNotOkError,
          'Timed out.*waiting for execution inactivation.'
      ) as exception_context:
        pipeline_ops.stop_node(m, node_uid, timeout_secs=1.0)
      self.assertEqual(status_lib.Code.DEADLINE_EXCEEDED,
                       exception_context.exception.code)

      # Even if `wait_for_inactivation` times out, the node should be stop
      # initiated to prevent future triggers.
      pipeline_state = pstate.PipelineState.load(m, pipeline_uid)
      self.assertEqual(status_lib.Code.CANCELLED,
                       pipeline_state.node_stop_initiated_reason(node_uid).code)
Beispiel #19
0
def _get_pipeline_details(mlmd_handle: metadata.Metadata,
                          task_queue: tq.TaskQueue) -> List[_PipelineDetail]:
    """Scans MLMD and returns pipeline details."""
    result = []

    contexts = mlmd_handle.store.get_contexts_by_type(
        _ORCHESTRATOR_RESERVED_ID)

    for context in contexts:
        active_executions = [
            e for e in mlmd_handle.store.get_executions_by_context(context.id)
            if execution_lib.is_execution_active(e)
        ]
        if len(active_executions) > 1:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.INTERNAL,
                message=(
                    f'Expected 1 but found {len(active_executions)} active '
                    f'executions for context named: {context.name}'))
        if not active_executions:
            continue
        execution = active_executions[0]

        # TODO(goutham): Instead of parsing the pipeline IR each time, we could
        # cache the parsed pipeline IR in `initiate_pipeline_start` and reuse it.
        pipeline_ir_b64 = common_utils.get_metadata_value(
            execution.properties[_PIPELINE_IR])
        pipeline = pipeline_pb2.Pipeline()
        pipeline.ParseFromString(base64.b64decode(pipeline_ir_b64))

        stop_initiated = _is_stop_initiated(execution)

        if stop_initiated:
            generator = None
        else:
            if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
                generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
                    mlmd_handle, pipeline, task_queue.contains_task_id)
            elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
                generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
                    mlmd_handle, pipeline, task_queue.contains_task_id)
            else:
                raise status_lib.StatusNotOkError(
                    code=status_lib.Code.FAILED_PRECONDITION,
                    message=
                    (f'Only SYNC and ASYNC pipeline execution modes supported; '
                     f'found pipeline with execution mode: {pipeline.execution_mode}'
                     ))

        result.append(
            _PipelineDetail(
                context=context,
                execution=execution,
                pipeline=pipeline,
                pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
                stop_initiated=stop_initiated,
                generator=generator))

    return result
Beispiel #20
0
 def pipeline(self) -> pipeline_pb2.Pipeline:
     if not self._pipeline:
         pipeline_ir_b64 = data_types_utils.get_metadata_value(
             self.execution.properties[_PIPELINE_IR])
         pipeline = pipeline_pb2.Pipeline()
         pipeline.ParseFromString(base64.b64decode(pipeline_ir_b64))
         self._pipeline = pipeline
     return self._pipeline
Beispiel #21
0
def _test_pipeline(pipeline_id,
                   execution_mode: pipeline_pb2.Pipeline.ExecutionMode = (
                       pipeline_pb2.Pipeline.ASYNC)):
    pipeline = pipeline_pb2.Pipeline()
    pipeline.pipeline_info.id = pipeline_id
    pipeline.execution_mode = execution_mode
    pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
    return pipeline
Beispiel #22
0
 def testFullySubstituteRuntimeParameter(self):
     pipeline = pipeline_pb2.Pipeline()
     expected = pipeline_pb2.Pipeline()
     self.load_proto_from_text(
         os.path.join(self._testdata_dir,
                      'pipeline_with_runtime_parameter.pbtxt'), pipeline)
     self.load_proto_from_text(
         os.path.join(self._testdata_dir,
                      'pipeline_with_runtime_parameter_substituted.pbtxt'),
         expected)
     runtime_parameter_utils.substitute_runtime_parameter(
         pipeline, {
             'context_name_rp': 'my_context',
             'prop_one_rp': 2,
             'prop_two_rp': 'X'
         })
     self.assertProtoEquals(pipeline, expected)
  def setUp(self):
    super(AsyncPipelineTaskGeneratorTest, self).setUp()
    pipeline_root = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self.id())
    self._pipeline_root = pipeline_root

    # Makes sure multiple connections within a test always connect to the same
    # MLMD instance.
    metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
    self._metadata_path = metadata_path
    connection_config = metadata.sqlite_metadata_connection_config(
        metadata_path)
    connection_config.sqlite.SetInParent()
    self._mlmd_connection = metadata.Metadata(
        connection_config=connection_config)

    # Sets up the pipeline.
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'),
        pipeline)
    self._pipeline = pipeline
    self._pipeline_info = pipeline.pipeline_info
    self._pipeline_runtime_spec = pipeline.runtime_spec
    self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
        pipeline_root)

    # Extracts components.
    self._example_gen = pipeline.nodes[0].pipeline_node
    self._transform = pipeline.nodes[1].pipeline_node
    self._trainer = pipeline.nodes[2].pipeline_node

    self._task_queue = tq.TaskQueue()

    self._mock_service_job_manager = mock.create_autospec(
        service_jobs.ServiceJobManager, instance=True)

    def _is_pure_service_node(unused_pipeline_state, node_id):
      return node_id == self._example_gen.node_info.id

    def _is_mixed_service_node(unused_pipeline_state, node_id):
      return node_id == self._transform.node_info.id

    self._mock_service_job_manager.is_pure_service_node.side_effect = (
        _is_pure_service_node)
    self._mock_service_job_manager.is_mixed_service_node.side_effect = (
        _is_mixed_service_node)

    def _default_ensure_node_services(unused_pipeline_state, node_id):
      self.assertIn(
          node_id,
          (self._example_gen.node_info.id, self._transform.node_info.id))
      return service_jobs.ServiceStatus.RUNNING

    self._mock_service_job_manager.ensure_node_services.side_effect = (
        _default_ensure_node_services)
Beispiel #24
0
def _test_pipeline(pipeline_id,
                   execution_mode: pipeline_pb2.Pipeline.ExecutionMode = (
                       pipeline_pb2.Pipeline.ASYNC)):
  pipeline = pipeline_pb2.Pipeline()
  pipeline.pipeline_info.id = pipeline_id
  pipeline.execution_mode = execution_mode
  if execution_mode == pipeline_pb2.Pipeline.SYNC:
    pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run0'
  return pipeline
Beispiel #25
0
 def test_node_uid_from_pipeline_node(self):
   pipeline = pipeline_pb2.Pipeline()
   pipeline.pipeline_info.id = 'pipeline'
   node = pipeline_pb2.PipelineNode()
   node.node_info.id = 'Trainer'
   self.assertEqual(
       task_lib.NodeUid(
           pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'),
           node_id='Trainer'),
       task_lib.NodeUid.from_pipeline_node(pipeline, node))
Beispiel #26
0
 def setUp(self):
     super(BeamDagRunnerTest, self).setUp()
     # Setup pipelines
     self._pipeline = pipeline_pb2.Pipeline()
     self.load_proto_from_text(
         os.path.join(os.path.dirname(__file__), 'testdata',
                      'pipeline_for_launcher_test.pbtxt'), self._pipeline)
     _executed_components.clear()
     _component_executors.clear()
     _component_drivers.clear()
Beispiel #27
0
 def test_node_uid_from_pipeline_node(self):
     pipeline = pipeline_pb2.Pipeline()
     pipeline.pipeline_info.id = 'pipeline'
     pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run0'
     node = pipeline_pb2.PipelineNode()
     node.node_info.id = 'Trainer'
     self.assertEqual(
         task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid(
             pipeline_id='pipeline', pipeline_run_id='run0'),
                          node_id='Trainer'),
         task_lib.NodeUid.from_pipeline_node(pipeline, node))
Beispiel #28
0
def _test_pipeline(pipeline_id,
                   execution_mode: pipeline_pb2.Pipeline.ExecutionMode = (
                       pipeline_pb2.Pipeline.ASYNC),
                   param=1):
    pipeline = pipeline_pb2.Pipeline()
    pipeline.pipeline_info.id = pipeline_id
    pipeline.execution_mode = execution_mode
    pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
    pipeline.nodes[0].pipeline_node.parameters.parameters[
        'param'].field_value.int_value = param
    return pipeline
Beispiel #29
0
    def setUp(self):
        super(LauncherTest, self).setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        # Set a constant version for artifact version tag.
        patcher = mock.patch('tfx.version.__version__')
        patcher.start()
        tfx_version.__version__ = '0.123.4.dev'
        self.addCleanup(patcher.stop)

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up pipelines
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'pipeline_for_launcher_test.pbtxt'), pipeline)
        # Substitute the runtime parameter to be a concrete run_id
        runtime_parameter_utils.substitute_runtime_parameter(
            pipeline, {
                constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'test_run',
            })
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
            pipeline_root)
        self._pipeline_runtime_spec.pipeline_run_id.field_value.string_value = (
            'test_run_0')

        # Extracts components
        self._example_gen = pipeline.nodes[0].pipeline_node
        self._transform = pipeline.nodes[1].pipeline_node
        self._trainer = pipeline.nodes[2].pipeline_node
        self._importer = pipeline.nodes[3].pipeline_node
        self._resolver = pipeline.nodes[4].pipeline_node

        # Fakes an ExecutorSpec for Trainer
        self._trainer_executor_spec = _PYTHON_CLASS_EXECUTABLE_SPEC()
        # Fakes an executor operator
        self._test_executor_operators = {
            _PYTHON_CLASS_EXECUTABLE_SPEC: _FakeExecutorOperator
        }
        # Fakes an custom driver spec
        self._custom_driver_spec = _PYTHON_CLASS_EXECUTABLE_SPEC()
        self._custom_driver_spec.class_path = 'tfx.orchestration.portable.launcher_test._FakeExampleGenLikeDriver'
 def _make_pipeline(self, pipeline_root, pipeline_run_id):
     pipeline = pipeline_pb2.Pipeline()
     self.load_proto_from_text(
         os.path.join(os.path.dirname(__file__), 'testdata',
                      'pipeline_with_importer.pbtxt'), pipeline)
     runtime_parameter_utils.substitute_runtime_parameter(
         pipeline, {
             'pipeline_root': pipeline_root,
             'pipeline_run_id': pipeline_run_id,
         })
     return pipeline