コード例 #1
0
  def test_cached_execution(self):
    """Tests that cached execution is used if one is available."""

    # Fake ExampleGen run.
    example_gen_exec = otu.fake_example_gen_run(self._mlmd_connection,
                                                self._example_gen, 1, 1)

    # Invoking generator should produce an ExecNodeTask for StatsGen.
    [stats_gen_task] = self._generate_and_test(
        False,
        num_initial_executions=1,
        num_tasks_generated=1,
        num_new_executions=1,
        num_active_executions=1)
    self.assertEqual('my_statistics_gen', stats_gen_task.node_uid.node_id)

    # Finish StatsGen execution.
    otu.fake_execute_node(self._mlmd_connection, stats_gen_task)

    # Prepare another pipeline with a new pipeline_run_id.
    pipeline_run_id = str(uuid.uuid4())
    new_pipeline = self._make_pipeline(self._pipeline_root, pipeline_run_id)

    with self._mlmd_connection as m:
      contexts = m.store.get_contexts_by_execution(example_gen_exec.id)
      # We use node context as cache context for ease of testing.
      cache_context = [c for c in contexts if c.name == 'my_example_gen'][0]
    # Fake example_gen cached execution.
    otu.fake_cached_execution(self._mlmd_connection, cache_context,
                              otu.get_node(new_pipeline, 'my_example_gen'))

    stats_gen = otu.get_node(new_pipeline, 'my_statistics_gen')

    # Invoking generator for the new pipeline should result in:
    # 1. StatsGen execution succeeds with state "CACHED" but no ExecNodeTask
    #    generated.
    # 2. An ExecNodeTask is generated for SchemaGen (component downstream of
    #    StatsGen) with an active execution in MLMD.
    [schema_gen_task] = self._generate_and_test(
        False,
        pipeline=new_pipeline,
        num_initial_executions=3,
        num_tasks_generated=1,
        num_new_executions=2,
        num_active_executions=1)
    self.assertEqual('my_schema_gen', schema_gen_task.node_uid.node_id)

    # Check that StatsGen execution is successful in state "CACHED".
    with self._mlmd_connection as m:
      executions = task_gen_utils.get_executions(m, stats_gen)
      self.assertLen(executions, 1)
      execution = executions[0]
      self.assertTrue(execution_lib.is_execution_successful(execution))
      self.assertEqual(metadata_store_pb2.Execution.CACHED,
                       execution.last_known_state)
コード例 #2
0
    def setUp(self):
        super().setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())
        self._pipeline_root = pipeline_root

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        self._metadata_path = metadata_path
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up the pipeline.
        pipeline = self._make_pipeline(self._pipeline_root, str(uuid.uuid4()))
        self._pipeline = pipeline

        # Extracts components.
        self._example_gen = test_utils.get_node(pipeline, 'my_example_gen')
        self._stats_gen = test_utils.get_node(pipeline, 'my_statistics_gen')
        self._schema_gen = test_utils.get_node(pipeline, 'my_schema_gen')
        self._transform = test_utils.get_node(pipeline, 'my_transform')
        self._example_validator = test_utils.get_node(pipeline,
                                                      'my_example_validator')
        self._trainer = test_utils.get_node(pipeline, 'my_trainer')
        self._evaluator = test_utils.get_node(pipeline, 'my_evaluator')
        self._chore_a = test_utils.get_node(pipeline, 'chore_a')
        self._chore_b = test_utils.get_node(pipeline, 'chore_b')

        self._task_queue = tq.TaskQueue()

        self._mock_service_job_manager = mock.create_autospec(
            service_jobs.ServiceJobManager, instance=True)

        self._mock_service_job_manager.is_pure_service_node.side_effect = (
            lambda _, node_id: node_id == self._example_gen.node_info.id)
        self._mock_service_job_manager.is_mixed_service_node.side_effect = (
            lambda _, node_id: node_id == self._transform.node_info.id)

        def _default_ensure_node_services(unused_pipeline_state, node_id):
            self.assertIn(
                node_id,
                (self._example_gen.node_info.id, self._transform.node_info.id))
            return service_jobs.ServiceStatus.SUCCESS

        self._mock_service_job_manager.ensure_node_services.side_effect = (
            _default_ensure_node_services)
コード例 #3
0
ファイル: sample_mlmd_creator.py プロジェクト: htahir1/tfx
def _execute_nodes(handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline,
                   version: int):
    """Creates fake execution of nodes."""
    example_gen = test_utils.get_node(pipeline, 'my_example_gen')
    stats_gen = test_utils.get_node(pipeline, 'my_statistics_gen')
    schema_gen = test_utils.get_node(pipeline, 'my_schema_gen')
    transform = test_utils.get_node(pipeline, 'my_transform')
    example_validator = test_utils.get_node(pipeline, 'my_example_validator')
    trainer = test_utils.get_node(pipeline, 'my_trainer')

    test_utils.fake_example_gen_run_with_handle(handle, example_gen, 1,
                                                version)
    test_utils.fake_component_output_with_handle(handle,
                                                 stats_gen,
                                                 active=False)
    test_utils.fake_component_output_with_handle(handle,
                                                 schema_gen,
                                                 active=False)
    test_utils.fake_component_output_with_handle(handle,
                                                 transform,
                                                 active=False)
    test_utils.fake_component_output_with_handle(handle,
                                                 example_validator,
                                                 active=False)
    test_utils.fake_component_output_with_handle(handle, trainer, active=False)