Пример #1
0
def register_execution(
    metadata_handler: metadata.Metadata,
    execution_type: metadata_store_pb2.ExecutionType,
    contexts: Sequence[metadata_store_pb2.Context],
    input_artifacts: Optional[MutableMapping[str,
                                             Sequence[types.Artifact]]] = None,
    exec_properties: Optional[Mapping[str, types.Property]] = None,
) -> metadata_store_pb2.Execution:
    """Registers a new execution in MLMD.

  Along with the execution:
  -  the input artifacts will be linked to the execution.
  -  the contexts will be linked to both the execution and its input artifacts.

  Args:
    metadata_handler: A handler to access MLMD.
    execution_type: The type of the execution.
    contexts: MLMD contexts to associated with the execution.
    input_artifacts: Input artifacts of the execution. Each artifact will be
      linked with the execution through an event.
    exec_properties: Execution properties. Will be attached to the execution.

  Returns:
    An MLMD execution that is registered in MLMD, with id populated.
  """
    execution = execution_lib.prepare_execution(
        metadata_handler, execution_type, metadata_store_pb2.Execution.RUNNING,
        exec_properties)
    return execution_lib.put_execution(metadata_handler,
                                       execution,
                                       contexts,
                                       input_artifacts=input_artifacts)
Пример #2
0
    def testPutExecutionGraph(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Prepares an input artifact. The artifact should be registered in MLMD
            # before the put_execution call.
            input_example = standard_artifacts.Examples()
            input_example.uri = 'example'
            input_example.type_id = common_utils.register_type_if_not_exist(
                m, input_example.artifact_type).id
            [input_example.id
             ] = m.store.put_artifacts([input_example.mlmd_artifact])
            # Prepares an output artifact.
            output_model = standard_artifacts.Model()
            output_model.uri = 'model'
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                exec_properties={
                    'p1': 1,
                    'p2': '2'
                },
                state=metadata_store_pb2.Execution.COMPLETE)
            contexts = self._generate_contexts(m)
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts={'example': [input_example]},
                output_artifacts={'model': [output_model]})

            self.assertProtoPartiallyEquals(
                output_model.mlmd_artifact,
                m.store.get_artifacts_by_id([output_model.id])[0],
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
            # Verifies edges between artifacts and execution.
            [input_event
             ] = m.store.get_events_by_artifact_ids([input_example.id])
            self.assertEqual(input_event.execution_id, execution.id)
            self.assertEqual(input_event.type, metadata_store_pb2.Event.INPUT)
            [output_event
             ] = m.store.get_events_by_artifact_ids([output_model.id])
            self.assertEqual(output_event.execution_id, execution.id)
            self.assertEqual(output_event.type,
                             metadata_store_pb2.Event.OUTPUT)
            # Verifies edges connecting contexts and {artifacts, execution}.
            context_ids = [context.id for context in contexts]
            self.assertCountEqual([
                c.id
                for c in m.store.get_contexts_by_artifact(input_example.id)
            ], context_ids)
            self.assertCountEqual([
                c.id for c in m.store.get_contexts_by_artifact(output_model.id)
            ], context_ids)
            self.assertCountEqual([
                c.id for c in m.store.get_contexts_by_execution(execution.id)
            ], context_ids)
Пример #3
0
    def testGetExecutionsAssociatedWithAllContexts(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            self.assertLen(contexts, 2)

            # Create 2 executions and associate with one context each.
            execution1 = execution_lib.prepare_execution(
                m, metadata_store_pb2.ExecutionType(name='my_execution_type'),
                metadata_store_pb2.Execution.RUNNING)
            execution1 = execution_lib.put_execution(m, execution1,
                                                     [contexts[0]])
            execution2 = execution_lib.prepare_execution(
                m, metadata_store_pb2.ExecutionType(name='my_execution_type'),
                metadata_store_pb2.Execution.COMPLETE)
            execution2 = execution_lib.put_execution(m, execution2,
                                                     [contexts[1]])

            # Create another execution and associate with both contexts.
            execution3 = execution_lib.prepare_execution(
                m, metadata_store_pb2.ExecutionType(name='my_execution_type'),
                metadata_store_pb2.Execution.NEW)
            execution3 = execution_lib.put_execution(m, execution3, contexts)

            # Verify that the right executions are returned.
            with self.subTest(for_contexts=(0, )):
                executions = execution_lib.get_executions_associated_with_all_contexts(
                    m, [contexts[0]])
                self.assertCountEqual([execution1.id, execution3.id],
                                      [e.id for e in executions])
            with self.subTest(for_contexts=(1, )):
                executions = execution_lib.get_executions_associated_with_all_contexts(
                    m, [contexts[1]])
                self.assertCountEqual([execution2.id, execution3.id],
                                      [e.id for e in executions])
            with self.subTest(for_contexts=(0, 1)):
                executions = execution_lib.get_executions_associated_with_all_contexts(
                    m, contexts)
                self.assertCountEqual([execution3.id],
                                      [e.id for e in executions])
Пример #4
0
    def new(cls, mlmd_handle: metadata.Metadata,
            pipeline: pipeline_pb2.Pipeline) -> 'PipelineState':
        """Creates a `PipelineState` object for a new pipeline.

    No active pipeline with the same pipeline uid should exist for the call to
    be successful.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: IR of the pipeline.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: If a pipeline with same UID already exists.
    """
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        context = context_lib.register_context_if_not_exists(
            mlmd_handle,
            context_type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))

        executions = mlmd_handle.store.get_executions_by_context(context.id)
        if any(e for e in executions if execution_lib.is_execution_active(e)):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=f'Pipeline with uid {pipeline_uid} already active.')

        execution = execution_lib.prepare_execution(
            mlmd_handle,
            _ORCHESTRATOR_EXECUTION_TYPE,
            metadata_store_pb2.Execution.NEW,
            exec_properties={
                _PIPELINE_IR:
                base64.b64encode(pipeline.SerializeToString()).decode('utf-8')
            },
        )
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            data_types_utils.set_metadata_value(
                execution.custom_properties[_PIPELINE_RUN_ID],
                pipeline.runtime_spec.pipeline_run_id.field_value.string_value)

        execution = execution_lib.put_execution(mlmd_handle, execution,
                                                [context])
        record_state_change_time()

        return cls(mlmd_handle=mlmd_handle,
                   pipeline=pipeline,
                   execution_id=execution.id)
Пример #5
0
    def testGetArtifactsDict(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Create and shuffle a few artifacts. The shuffled order should be
            # retained in the output of `execution_lib.get_artifacts_dict`.
            input_examples = []
            for i in range(10):
                input_example = standard_artifacts.Examples()
                input_example.uri = 'example{}'.format(i)
                input_example.type_id = common_utils.register_type_if_not_exist(
                    m, input_example.artifact_type).id
                input_examples.append(input_example)
            random.shuffle(input_examples)
            output_models = []
            for i in range(8):
                output_model = standard_artifacts.Model()
                output_model.uri = 'model{}'.format(i)
                output_model.type_id = common_utils.register_type_if_not_exist(
                    m, output_model.artifact_type).id
                output_models.append(output_model)
            random.shuffle(output_models)
            m.store.put_artifacts([
                a.mlmd_artifact
                for a in itertools.chain(input_examples, output_models)
            ])
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                state=metadata_store_pb2.Execution.RUNNING)
            contexts = self._generate_contexts(m)
            input_artifacts_dict = {'examples': input_examples}
            output_artifacts_dict = {'model': output_models}
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts=input_artifacts_dict,
                output_artifacts=output_artifacts_dict)

            # Verify that the same artifacts are returned in the correct order.
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, metadata_store_pb2.Event.INPUT)
            self.assertCountEqual(['examples'], list(artifacts_dict.keys()))
            self.assertEqual([ex.uri for ex in input_examples],
                             [a.uri for a in artifacts_dict['examples']])
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, metadata_store_pb2.Event.OUTPUT)
            self.assertCountEqual(['model'], list(artifacts_dict.keys()))
            self.assertEqual([model.uri for model in output_models],
                             [a.uri for a in artifacts_dict['model']])
Пример #6
0
def initiate_pipeline_start(
        mlmd_handle: metadata.Metadata,
        pipeline: pipeline_pb2.Pipeline) -> metadata_store_pb2.Execution:
    """Initiates a pipeline start operation.

  Upon success, MLMD is updated to signal that the given pipeline must be
  started.

  Args:
    mlmd_handle: A handle to the MLMD db.
    pipeline: IR of the pipeline to start.

  Returns:
    The pipeline-level MLMD execution proto upon success.

  Raises:
    status_lib.StatusNotOkError: Failure to initiate pipeline start or if
      execution is not inactive after waiting `timeout_secs`.
  """
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    context = context_lib.register_context_if_not_exists(
        mlmd_handle,
        context_type_name=_ORCHESTRATOR_RESERVED_ID,
        context_name=_orchestrator_context_name(pipeline_uid))

    executions = mlmd_handle.store.get_executions_by_context(context.id)
    if any(e for e in executions if execution_lib.is_execution_active(e)):
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.ALREADY_EXISTS,
            message=f'Pipeline with uid {pipeline_uid} already started.')

    execution = execution_lib.prepare_execution(
        mlmd_handle,
        _ORCHESTRATOR_EXECUTION_TYPE,
        metadata_store_pb2.Execution.NEW,
        exec_properties={
            _PIPELINE_IR:
            base64.b64encode(pipeline.SerializeToString()).decode('utf-8')
        })
    execution = execution_lib.put_execution(mlmd_handle, execution, [context])
    logging.info('Registered execution (id: %s) for the pipeline with uid: %s',
                 execution.id, pipeline_uid)
    return execution
Пример #7
0
    def testGetArtifactIdsForExecutionIdGroupedByEventType(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Register an input and output artifacts in MLMD.
            input_example = standard_artifacts.Examples()
            input_example.uri = 'example'
            input_example.type_id = common_utils.register_type_if_not_exist(
                m, input_example.artifact_type).id
            output_model = standard_artifacts.Model()
            output_model.uri = 'model'
            output_model.type_id = common_utils.register_type_if_not_exist(
                m, output_model.artifact_type).id
            [input_example.id, output_model.id] = m.store.put_artifacts(
                [input_example.mlmd_artifact, output_model.mlmd_artifact])
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                exec_properties={
                    'p1': 1,
                    'p2': '2'
                },
                state=metadata_store_pb2.Execution.COMPLETE)
            contexts = self._generate_contexts(m)
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts={'example': [input_example]},
                output_artifacts={'model': [output_model]})

            artifact_ids_by_event_type = (
                execution_lib.get_artifact_ids_by_event_type_for_execution_id(
                    m, execution.id))
            self.assertDictEqual(
                {
                    metadata_store_pb2.Event.INPUT: set([input_example.id]),
                    metadata_store_pb2.Event.OUTPUT: set([output_model.id]),
                }, artifact_ids_by_event_type)
Пример #8
0
 def testPrepareExecution(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         execution_type = metadata_store_pb2.ExecutionType()
         text_format.Parse(
             """
       name: 'my_execution'
       properties {
         key: 'p2'
         value: STRING
       }
       """, execution_type)
         result = execution_lib.prepare_execution(
             m,
             execution_type,
             exec_properties={
                 'p1': 1,
                 'p2': '2'
             },
             state=metadata_store_pb2.Execution.COMPLETE)
         self.assertProtoEquals(
             """
       type_id: 1
       last_known_state: COMPLETE
       properties {
         key: 'p2'
         value {
           string_value: '2'
         }
       }
       custom_properties {
         key: 'p1'
         value {
           int_value: 1
         }
       }
       """, result)
Пример #9
0
    def new(
        cls,
        mlmd_handle: metadata.Metadata,
        pipeline: pipeline_pb2.Pipeline,
        pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None,
    ) -> 'PipelineState':
        """Creates a `PipelineState` object for a new pipeline.

    No active pipeline with the same pipeline uid should exist for the call to
    be successful.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: IR of the pipeline.
      pipeline_run_metadata: Pipeline run metadata.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: If a pipeline with same UID already exists.
    """
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        context = context_lib.register_context_if_not_exists(
            mlmd_handle,
            context_type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))

        executions = mlmd_handle.store.get_executions_by_context(context.id)
        if any(e for e in executions if execution_lib.is_execution_active(e)):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=f'Pipeline with uid {pipeline_uid} already active.')

        exec_properties = {_PIPELINE_IR: _base64_encode(pipeline)}
        if pipeline_run_metadata:
            exec_properties[_PIPELINE_RUN_METADATA] = json_utils.dumps(
                pipeline_run_metadata)

        execution = execution_lib.prepare_execution(
            mlmd_handle,
            _ORCHESTRATOR_EXECUTION_TYPE,
            metadata_store_pb2.Execution.NEW,
            exec_properties=exec_properties)
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            data_types_utils.set_metadata_value(
                execution.custom_properties[_PIPELINE_RUN_ID],
                pipeline.runtime_spec.pipeline_run_id.field_value.string_value)
            # Set the node state to COMPLETE for any nodes that are marked to be
            # skipped in a partial pipeline run.
            node_states_dict = {}
            for node in get_all_pipeline_nodes(pipeline):
                if node.execution_options.HasField('skip'):
                    logging.info('Node %s is skipped in this partial run.',
                                 node.node_info.id)
                    node_states_dict[node.node_info.id] = NodeState(
                        state=NodeState.COMPLETE)
            if node_states_dict:
                _save_node_states_dict(execution, node_states_dict)

        execution = execution_lib.put_execution(mlmd_handle, execution,
                                                [context])
        record_state_change_time()

        return cls(mlmd_handle=mlmd_handle,
                   pipeline=pipeline,
                   execution_id=execution.id)
Пример #10
0
 def testPrepareExecution(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         execution_type = metadata_store_pb2.ExecutionType()
         text_format.Parse(
             """
       name: 'my_execution'
       properties {
         key: 'p2'
         value: STRING
       }
       """, execution_type)
         result = execution_lib.prepare_execution(
             m,
             execution_type,
             exec_properties={
                 'p1': 1,
                 'p2': '2',
                 'p3': True,
                 'p4': ['24', '56']
             },
             state=metadata_store_pb2.Execution.COMPLETE)
         result.ClearField('type_id')
         self.assertProtoEquals(
             """
       last_known_state: COMPLETE
       properties {
         key: 'p2'
         value {
           string_value: '2'
         }
       }
       custom_properties {
         key: 'p1'
         value {
           int_value: 1
         }
       }
       custom_properties {
         key: 'p3'
         value {
           string_value: 'true'
         }
       }
       custom_properties {
         key: '__schema__p3__'
         value {
           string_value: '{\\n  \\"value_type\\": {\\n    \\"boolean_type\\": {}\\n  }\\n}'
         }
       }
       custom_properties {
         key: 'p4'
         value {
           string_value: '["24", "56"]'
         }
       }
       custom_properties {
         key: '__schema__p4__'
         value {
           string_value: '{\\n  \\"value_type\\": {\\n    \\"list_type\\": {}\\n  }\\n}'
         }
       }
       """, result)
Пример #11
0
    def testGetArtifactsDict(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Create and shuffle a few artifacts. The shuffled order should be
            # retained in the output of `execution_lib.get_artifacts_dict`.
            input_artifact_keys = ('input1', 'input2', 'input3')
            input_artifacts_dict = collections.OrderedDict()
            for input_key in input_artifact_keys:
                input_examples = []
                for i in range(10):
                    input_example = standard_artifacts.Examples()
                    input_example.uri = f'{input_key}/example{i}'
                    input_example.type_id = common_utils.register_type_if_not_exist(
                        m, input_example.artifact_type).id
                    input_examples.append(input_example)
                random.shuffle(input_examples)
                input_artifacts_dict[input_key] = input_examples

            output_models = []
            for i in range(8):
                output_model = standard_artifacts.Model()
                output_model.uri = f'model{i}'
                output_model.type_id = common_utils.register_type_if_not_exist(
                    m, output_model.artifact_type).id
                output_models.append(output_model)
            random.shuffle(output_models)
            output_artifacts_dict = {'model': output_models}

            # Store input artifacts only. Outputs will be saved in put_execution().
            input_mlmd_artifacts = [
                a.mlmd_artifact
                for a in itertools.chain(*input_artifacts_dict.values())
            ]
            artifact_ids = m.store.put_artifacts(input_mlmd_artifacts)
            for artifact_id, mlmd_artifact in zip(artifact_ids,
                                                  input_mlmd_artifacts):
                mlmd_artifact.id = artifact_id

            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                state=metadata_store_pb2.Execution.RUNNING)
            contexts = self._generate_contexts(m)

            # Change the order of the OrderedDict to shuffle the order of input keys.
            input_artifacts_dict.move_to_end('input1')
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts=input_artifacts_dict,
                output_artifacts=output_artifacts_dict)

            # Verify that the same artifacts are returned in the correct order.
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, [metadata_store_pb2.Event.INPUT])
            self.assertEqual(set(input_artifact_keys),
                             set(artifacts_dict.keys()))
            for key in artifacts_dict:
                self.assertEqual([ex.uri for ex in input_artifacts_dict[key]],
                                 [a.uri for a in artifacts_dict[key]],
                                 f'for key={key}')
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, [metadata_store_pb2.Event.OUTPUT])
            self.assertEqual({'model'}, set(artifacts_dict.keys()))
            self.assertEqual([model.uri for model in output_models],
                             [a.uri for a in artifacts_dict['model']])