def test_load_from_orchestrator_context(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') with pstate.PipelineState.new(m, pipeline): pass mlmd_contexts = pstate.get_orchestrator_contexts(m) self.assertLen(mlmd_contexts, 1) pipeline_state = pstate.PipelineState.load_from_orchestrator_context( m, mlmd_contexts[0]) mlmd_contexts = pstate.get_orchestrator_contexts(m) self.assertLen(mlmd_contexts, 1) self.assertProtoPartiallyEquals(mlmd_contexts[0], pipeline_state.context) mlmd_executions = m.store.get_executions_by_context( mlmd_contexts[0].id) self.assertLen(mlmd_executions, 1) self.assertProtoPartiallyEquals(mlmd_executions[0], pipeline_state.execution) self.assertEqual(pipeline, pipeline_state.pipeline) self.assertEqual(task_lib.PipelineUid.from_pipeline(pipeline), pipeline_state.pipeline_uid)
def test_new_pipeline_state(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') with pstate.PipelineState.new(m, pipeline) as pipeline_state: pass mlmd_contexts = pstate.get_orchestrator_contexts(m) self.assertLen(mlmd_contexts, 1) self.assertProtoPartiallyEquals(mlmd_contexts[0], pipeline_state.context, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) mlmd_executions = m.store.get_executions_by_context( mlmd_contexts[0].id) self.assertLen(mlmd_executions, 1) self.assertProtoPartiallyEquals(mlmd_executions[0], pipeline_state.execution, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) self.assertEqual(pipeline, pipeline_state.pipeline) self.assertEqual(task_lib.PipelineUid.from_pipeline(pipeline), pipeline_state.pipeline_uid)
def test_save_and_remove_property(self): property_key = 'key' property_value = 'value' with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') with pstate.PipelineState.new(m, pipeline) as pipeline_state: pipeline_state.save_property(property_key, property_value) mlmd_contexts = pstate.get_orchestrator_contexts(m) mlmd_executions = m.store.get_executions_by_context( mlmd_contexts[0].id) self.assertLen(mlmd_executions, 1) self.assertIsNotNone( mlmd_executions[0].custom_properties.get(property_key)) self.assertEqual( mlmd_executions[0].custom_properties.get( property_key).string_value, property_value) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: pipeline_state.remove_property(property_key) mlmd_executions = m.store.get_executions_by_context( mlmd_contexts[0].id) self.assertLen(mlmd_executions, 1) self.assertIsNone( mlmd_executions[0].custom_properties.get(property_key))
def _get_pipeline_states( mlmd_handle: metadata.Metadata) -> List[pstate.PipelineState]: """Scans MLMD and returns pipeline states.""" contexts = pstate.get_orchestrator_contexts(mlmd_handle) result = [] for context in contexts: try: pipeline_state = pstate.PipelineState.load_from_orchestrator_context( mlmd_handle, context) except status_lib.StatusNotOkError as e: if e.code == status_lib.Code.NOT_FOUND: # Ignore any old contexts with no associated active pipelines. logging.info(e.message) continue else: raise result.append(pipeline_state) return result
def _get_pipeline_details(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue) -> List[_PipelineDetail]: """Scans MLMD and returns pipeline details.""" result = [] contexts = pstate.get_orchestrator_contexts(mlmd_handle) for context in contexts: try: pipeline_state = pstate.PipelineState.load_from_orchestrator_context( mlmd_handle, context) except status_lib.StatusNotOkError as e: if e.code == status_lib.Code.NOT_FOUND: continue if pipeline_state.is_stop_initiated(): generator = None else: pipeline = pipeline_state.pipeline if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, pipeline, task_queue.contains_task_id) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, pipeline, task_queue.contains_task_id) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) result.append( _PipelineDetail(pipeline_state=pipeline_state, generator=generator)) return result