def __enter__(self) -> 'PipelineState': mlmd_execution_atomic_op_context = mlmd_state.mlmd_execution_atomic_op( self.mlmd_handle, self.execution_id) execution = mlmd_execution_atomic_op_context.__enter__() self._mlmd_execution_atomic_op_context = mlmd_execution_atomic_op_context self._execution = execution return self
def test_mlmd_execution_update(self): with self._mlmd_connection as m: expected_execution = _write_test_execution(m) # Mutate execution. with mlmd_state.mlmd_execution_atomic_op( m, expected_execution.id) as execution: self.assertEqual(expected_execution, execution) execution.last_known_state = metadata_store_pb2.Execution.CANCELED # Test that updated execution is committed to MLMD. [execution] = m.store.get_executions_by_id([execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state) # Test that in-memory state is also in sync. with mlmd_state.mlmd_execution_atomic_op( m, expected_execution.id) as execution: self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state)
def test_pipeline_failure_strategies(self, fail_fast): """Tests pipeline failure strategies.""" test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) self._run_next(False, expect_nodes=[self._stats_gen], fail_fast=fail_fast) self._run_next(False, expect_nodes=[self._schema_gen], fail_fast=fail_fast) # Both example-validator and transform are ready to execute. [example_validator_task, transform_task] = self._generate(False, True, fail_fast=fail_fast) self.assertEqual(self._example_validator.node_info.id, example_validator_task.node_uid.node_id) self.assertEqual(self._transform.node_info.id, transform_task.node_uid.node_id) # Simulate Transform success. self._finish_node_execution(False, transform_task) # But fail example-validator. with self._mlmd_connection as m: with mlmd_state.mlmd_execution_atomic_op( m, example_validator_task.execution_id) as ev_exec: # Fail stats-gen execution. ev_exec.last_known_state = metadata_store_pb2.Execution.FAILED data_types_utils.set_metadata_value( ev_exec.custom_properties[ constants.EXECUTION_ERROR_MSG_KEY], 'example-validator error') if fail_fast: # Pipeline run should immediately fail because example-validator failed. [finalize_task] = self._generate(False, True, fail_fast=fail_fast) self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task)) self.assertEqual(status_lib.Code.ABORTED, finalize_task.status.code) else: # Trainer and downstream nodes can execute as transform has finished. # example-validator failure does not impact them as it is not upstream. # Pipeline run will still fail but when no more progress can be made. self._run_next(False, expect_nodes=[self._trainer], fail_fast=fail_fast) self._run_next(False, expect_nodes=[self._chore_a], fail_fast=fail_fast) self._run_next(False, expect_nodes=[self._chore_b], fail_fast=fail_fast) [finalize_task] = self._generate(False, True, fail_fast=fail_fast) self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task)) self.assertEqual(status_lib.Code.ABORTED, finalize_task.status.code)
def test_triggering_upon_exec_properties_change(self): test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) [exec_transform_task] = self._generate_and_test( False, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._transform], ignore_update_node_state_tasks=True) # Fail the registered execution. with self._mlmd_connection as m: with mlmd_state.mlmd_execution_atomic_op( m, exec_transform_task.execution_id) as execution: execution.last_known_state = metadata_store_pb2.Execution.FAILED # Try to generate with same execution properties. This should not trigger # as there are no changes since last run. self._generate_and_test( False, num_initial_executions=2, num_tasks_generated=0, num_new_executions=0, num_active_executions=0, ignore_update_node_state_tasks=True) # Change execution properties of last run. with self._mlmd_connection as m: with mlmd_state.mlmd_execution_atomic_op( m, exec_transform_task.execution_id) as execution: execution.custom_properties['a_param'].int_value = 20 # Generating with different execution properties should trigger. self._generate_and_test( False, num_initial_executions=2, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._transform], ignore_update_node_state_tasks=True)
def _update_execution_state_in_mlmd( mlmd_handle: metadata.Metadata, execution_id: int, new_state: metadata_store_pb2.Execution.State, error_msg: str) -> None: with mlmd_state.mlmd_execution_atomic_op(mlmd_handle, execution_id) as execution: execution.last_known_state = new_state if error_msg: data_types_utils.set_metadata_value( execution.custom_properties[constants.EXECUTION_ERROR_MSG_KEY], error_msg)
def resume_manual_node(mlmd_handle: metadata.Metadata, node_uid: task_lib.NodeUid) -> None: """Resumes a manual node. Args: mlmd_handle: A handle to the MLMD db. node_uid: Uid of the manual node to be resumed. Raises: status_lib.StatusNotOkError: Failure to resume a manual node. """ logging.info('Received request to resume manual node; node uid: %s', node_uid) with pstate.PipelineState.load(mlmd_handle, node_uid.pipeline_uid) as pipeline_state: nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline) filtered_nodes = [ n for n in nodes if n.node_info.id == node_uid.node_id ] if len(filtered_nodes) != 1: raise status_lib.StatusNotOkError( code=status_lib.Code.NOT_FOUND, message=(f'Unable to find manual node to resume: {node_uid}')) node = filtered_nodes[0] node_type = node.node_info.type.name if node_type != constants.MANUAL_NODE_TYPE: raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message=('Unable to resume a non-manual node. ' f'Got non-manual node id: {node_uid}')) executions = task_gen_utils.get_executions(mlmd_handle, node) active_executions = [ e for e in executions if execution_lib.is_execution_active(e) ] if not active_executions: raise status_lib.StatusNotOkError( code=status_lib.Code.NOT_FOUND, message=( f'Unable to find active manual node to resume: {node_uid}')) if len(active_executions) > 1: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message=(f'Unexpected multiple active executions for manual node: ' f'{node_uid}')) with mlmd_state.mlmd_execution_atomic_op( mlmd_handle=mlmd_handle, execution_id=active_executions[0].id) as execution: completed_state = manual_task_scheduler.ManualNodeState( state=manual_task_scheduler.ManualNodeState.COMPLETED) completed_state.set_mlmd_value( execution.custom_properties.get_or_create( manual_task_scheduler.NODE_STATE_PROPERTY_KEY))
def test_restart_node_cancelled_due_to_stopping(self): """Tests that a node previously cancelled due to stopping can be restarted.""" test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) [stats_gen_task ] = self._generate_and_test(False, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, ignore_update_node_state_tasks=True) node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, self._stats_gen) self.assertEqual(node_uid, stats_gen_task.node_uid) # Simulate stopping the node while it is under execution, which leads to # the node execution being cancelled. with self._mlmd_connection as m: with mlmd_state.mlmd_execution_atomic_op( m, stats_gen_task.execution_id) as stats_gen_exec: stats_gen_exec.last_known_state = metadata_store_pb2.Execution.CANCELED data_types_utils.set_metadata_value( stats_gen_exec.custom_properties[ constants.EXECUTION_ERROR_MSG_KEY], 'manually stopped') # Change state of node to STARTING. with self._mlmd_connection as m: pipeline_state = test_utils.get_or_create_pipeline_state( m, self._pipeline) with pipeline_state: with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STARTING) # New execution should be created for any previously canceled node when the # node state is STARTING. [update_node_state_task, stats_gen_task] = self._generate_and_test(False, num_initial_executions=2, num_tasks_generated=2, num_new_executions=1, num_active_executions=1) self.assertTrue( task_lib.is_update_node_state_task(update_node_state_task)) self.assertEqual(node_uid, update_node_state_task.node_uid) self.assertEqual(pstate.NodeState.RUNNING, update_node_state_task.state) self.assertEqual(node_uid, stats_gen_task.node_uid)
def test_mlmd_execution_update(self): event_on_commit = threading.Event() with self._mlmd_connection as m: expected_execution = _write_test_execution(m) # Mutate execution. with mlmd_state.mlmd_execution_atomic_op( m, expected_execution.id, on_commit=event_on_commit.set) as execution: self.assertEqual(expected_execution, execution) execution.last_known_state = metadata_store_pb2.Execution.CANCELED self.assertFalse(event_on_commit.is_set()) # not yet invoked. # Test that updated execution is committed to MLMD. [execution] = m.store.get_executions_by_id([execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state) # Test that in-memory state is also in sync. self.assertEqual(execution, mlmd_state._execution_cache._cache[execution.id]) # Test that on_commit callback was invoked. self.assertTrue(event_on_commit.is_set()) # Sanity checks that the updated execution is yielded in the next call. with mlmd_state.mlmd_execution_atomic_op( m, expected_execution.id) as execution2: self.assertEqual(execution, execution2)
def test_triggering_upon_executor_spec_change(self): test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) with mock.patch.object(task_gen_utils, 'get_executor_spec') as mock_get_executor_spec: mock_get_executor_spec.side_effect = _fake_executor_spec(1) [exec_transform_task] = self._generate_and_test( False, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._transform], ignore_update_node_state_tasks=True) # Fail the registered execution. with self._mlmd_connection as m: with mlmd_state.mlmd_execution_atomic_op( m, exec_transform_task.execution_id) as execution: execution.last_known_state = metadata_store_pb2.Execution.FAILED # Try to generate with same executor spec. This should not trigger as # there are no changes since last run. with mock.patch.object(task_gen_utils, 'get_executor_spec') as mock_get_executor_spec: mock_get_executor_spec.side_effect = _fake_executor_spec(1) self._generate_and_test( False, num_initial_executions=2, num_tasks_generated=0, num_new_executions=0, num_active_executions=0, ignore_update_node_state_tasks=True) # Generating with a different executor spec should trigger. with mock.patch.object(task_gen_utils, 'get_executor_spec') as mock_get_executor_spec: mock_get_executor_spec.side_effect = _fake_executor_spec(2) self._generate_and_test( False, num_initial_executions=2, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._transform], ignore_update_node_state_tasks=True)
def schedule(self) -> task_scheduler.TaskSchedulerResult: while not self._cancel.wait(_POLLING_INTERVAL_SECS): with mlmd_state.mlmd_execution_atomic_op( mlmd_handle=self.mlmd_handle, execution_id=self.task.execution_id) as execution: node_state_mlmd_value = execution.custom_properties.get( NODE_STATE_PROPERTY_KEY) node_state = ManualNodeState.from_mlmd_value( node_state_mlmd_value) if node_state.state == ManualNodeState.COMPLETED: return task_scheduler.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output=task_scheduler.ExecutorNodeOutput()) return task_scheduler.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.CANCELLED), output=task_scheduler.ExecutorNodeOutput())
def test_node_failed(self, fail_fast): """Tests task generation when a node registers a failed execution.""" test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) [stats_gen_task ] = self._generate_and_test(False, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, ignore_update_node_state_tasks=True, fail_fast=fail_fast) self.assertEqual( task_lib.NodeUid.from_pipeline_node(self._pipeline, self._stats_gen), stats_gen_task.node_uid) with self._mlmd_connection as m: with mlmd_state.mlmd_execution_atomic_op( m, stats_gen_task.execution_id) as stats_gen_exec: # Fail stats-gen execution. stats_gen_exec.last_known_state = metadata_store_pb2.Execution.FAILED data_types_utils.set_metadata_value( stats_gen_exec.custom_properties[ constants.EXECUTION_ERROR_MSG_KEY], 'foobar error') # Test generation of FinalizePipelineTask. [update_node_state_task, finalize_task] = self._generate_and_test(True, num_initial_executions=2, num_tasks_generated=2, num_new_executions=0, num_active_executions=0, fail_fast=fail_fast) self.assertTrue( task_lib.is_update_node_state_task(update_node_state_task)) self.assertEqual('my_statistics_gen', update_node_state_task.node_uid.node_id) self.assertEqual(pstate.NodeState.FAILED, update_node_state_task.state) self.assertRegexMatch(update_node_state_task.status.message, ['foobar error']) self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task)) self.assertEqual(status_lib.Code.ABORTED, finalize_task.status.code) self.assertRegexMatch(finalize_task.status.message, ['foobar error'])
def test_node_failed(self, use_task_queue): """Tests task generation when a node registers a failed execution.""" otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) def _ensure_node_services(unused_pipeline_state, node_id): self.assertEqual(self._example_gen.node_info.id, node_id) return service_jobs.ServiceStatus.SUCCESS self._mock_service_job_manager.ensure_node_services.side_effect = ( _ensure_node_services) [stats_gen_task] = self._generate_and_test( use_task_queue, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1) self.assertEqual( task_lib.NodeUid.from_pipeline_node(self._pipeline, self._stats_gen), stats_gen_task.node_uid) with self._mlmd_connection as m: with mlmd_state.mlmd_execution_atomic_op( m, stats_gen_task.execution_id) as stats_gen_exec: # Fail stats-gen execution. stats_gen_exec.last_known_state = metadata_store_pb2.Execution.FAILED data_types_utils.set_metadata_value( stats_gen_exec.custom_properties[constants.EXECUTION_ERROR_MSG_KEY], 'foobar error') if use_task_queue: task = self._task_queue.dequeue() self._task_queue.task_done(task) # Test generation of FinalizePipelineTask. [finalize_task] = self._generate_and_test( True, num_initial_executions=2, num_tasks_generated=1, num_new_executions=0, num_active_executions=0) self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task)) self.assertEqual(status_lib.Code.ABORTED, finalize_task.status.code) self.assertRegexMatch(finalize_task.status.message, ['foobar error'])
def test_mlmd_execution_absent(self): with self._mlmd_connection as m: with self.assertRaisesRegex( ValueError, 'Execution not found for execution id'): with mlmd_state.mlmd_execution_atomic_op(m, 1): pass
def test_mlmd_execution_absent(self): with self._mlmd_connection as m: with mlmd_state.mlmd_execution_atomic_op(m, 1) as execution: self.assertIsNone(execution)