예제 #1
0
    def test_load_from_orchestrator_context(self):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            with pstate.PipelineState.new(m, pipeline):
                pass

            mlmd_contexts = pstate.get_orchestrator_contexts(m)
            self.assertLen(mlmd_contexts, 1)
            pipeline_state = pstate.PipelineState.load_from_orchestrator_context(
                m, mlmd_contexts[0])

            mlmd_contexts = pstate.get_orchestrator_contexts(m)
            self.assertLen(mlmd_contexts, 1)
            self.assertProtoPartiallyEquals(mlmd_contexts[0],
                                            pipeline_state.context)

            mlmd_executions = m.store.get_executions_by_context(
                mlmd_contexts[0].id)
            self.assertLen(mlmd_executions, 1)
            self.assertProtoPartiallyEquals(mlmd_executions[0],
                                            pipeline_state.execution)

            self.assertEqual(pipeline, pipeline_state.pipeline)
            self.assertEqual(task_lib.PipelineUid.from_pipeline(pipeline),
                             pipeline_state.pipeline_uid)
예제 #2
0
    def test_new_pipeline_state(self):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            with pstate.PipelineState.new(m, pipeline) as pipeline_state:
                pass

            mlmd_contexts = pstate.get_orchestrator_contexts(m)
            self.assertLen(mlmd_contexts, 1)
            self.assertProtoPartiallyEquals(mlmd_contexts[0],
                                            pipeline_state.context,
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])

            mlmd_executions = m.store.get_executions_by_context(
                mlmd_contexts[0].id)
            self.assertLen(mlmd_executions, 1)
            self.assertProtoPartiallyEquals(mlmd_executions[0],
                                            pipeline_state.execution,
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])

            self.assertEqual(pipeline, pipeline_state.pipeline)
            self.assertEqual(task_lib.PipelineUid.from_pipeline(pipeline),
                             pipeline_state.pipeline_uid)
예제 #3
0
    def test_save_and_remove_property(self):
        property_key = 'key'
        property_value = 'value'
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            with pstate.PipelineState.new(m, pipeline) as pipeline_state:
                pipeline_state.save_property(property_key, property_value)

            mlmd_contexts = pstate.get_orchestrator_contexts(m)
            mlmd_executions = m.store.get_executions_by_context(
                mlmd_contexts[0].id)
            self.assertLen(mlmd_executions, 1)
            self.assertIsNotNone(
                mlmd_executions[0].custom_properties.get(property_key))
            self.assertEqual(
                mlmd_executions[0].custom_properties.get(
                    property_key).string_value, property_value)

            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                pipeline_state.remove_property(property_key)

            mlmd_executions = m.store.get_executions_by_context(
                mlmd_contexts[0].id)
            self.assertLen(mlmd_executions, 1)
            self.assertIsNone(
                mlmd_executions[0].custom_properties.get(property_key))
예제 #4
0
def _get_pipeline_states(
        mlmd_handle: metadata.Metadata) -> List[pstate.PipelineState]:
    """Scans MLMD and returns pipeline states."""
    contexts = pstate.get_orchestrator_contexts(mlmd_handle)
    result = []
    for context in contexts:
        try:
            pipeline_state = pstate.PipelineState.load_from_orchestrator_context(
                mlmd_handle, context)
        except status_lib.StatusNotOkError as e:
            if e.code == status_lib.Code.NOT_FOUND:
                # Ignore any old contexts with no associated active pipelines.
                logging.info(e.message)
                continue
            else:
                raise
        result.append(pipeline_state)
    return result
예제 #5
0
def _get_pipeline_details(mlmd_handle: metadata.Metadata,
                          task_queue: tq.TaskQueue) -> List[_PipelineDetail]:
  """Scans MLMD and returns pipeline details."""
  result = []

  contexts = pstate.get_orchestrator_contexts(mlmd_handle)

  for context in contexts:
    try:
      pipeline_state = pstate.PipelineState.load_from_orchestrator_context(
          mlmd_handle, context)
    except status_lib.StatusNotOkError as e:
      if e.code == status_lib.Code.NOT_FOUND:
        continue

    if pipeline_state.is_stop_initiated():
      generator = None
    else:
      pipeline = pipeline_state.pipeline
      if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle, pipeline, task_queue.contains_task_id)
      elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, pipeline, task_queue.contains_task_id)
      else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    result.append(
        _PipelineDetail(pipeline_state=pipeline_state, generator=generator))

  return result