Esempio n. 1
0
 def setUp(self):
     super().setUp()
     with self.get_metadata() as m:
         common_utils.register_type_if_not_exist(
             m, metadata_store_pb2.ExecutionType(name='Transform'))
         common_utils.register_type_if_not_exist(
             m, metadata_store_pb2.ExecutionType(name='Trainer'))
Esempio n. 2
0
 def testRegisterTypeReuseExisting(self, metadata_type_class):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         type_with_two_properties = _create_type(metadata_type_class)
         result_one = common_utils.register_type_if_not_exist(
             m, type_with_two_properties)
         # Tries to register a type that shares the same name of the type
         # previously registered but with no properties. We expect the previously
         # registered type to be reused.
         type_without_properties = metadata_type_class()
         text_format.Parse("name: 'my_type'", type_without_properties)
         result_two = common_utils.register_type_if_not_exist(
             m, type_without_properties)
         self.assertProtoEquals(result_one, result_two)
Esempio n. 3
0
    def testGetArtifactsDict(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Create and shuffle a few artifacts. The shuffled order should be
            # retained in the output of `execution_lib.get_artifacts_dict`.
            input_examples = []
            for i in range(10):
                input_example = standard_artifacts.Examples()
                input_example.uri = 'example{}'.format(i)
                input_example.type_id = common_utils.register_type_if_not_exist(
                    m, input_example.artifact_type).id
                input_examples.append(input_example)
            random.shuffle(input_examples)
            output_models = []
            for i in range(8):
                output_model = standard_artifacts.Model()
                output_model.uri = 'model{}'.format(i)
                output_model.type_id = common_utils.register_type_if_not_exist(
                    m, output_model.artifact_type).id
                output_models.append(output_model)
            random.shuffle(output_models)
            m.store.put_artifacts([
                a.mlmd_artifact
                for a in itertools.chain(input_examples, output_models)
            ])
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                state=metadata_store_pb2.Execution.RUNNING)
            contexts = self._generate_contexts(m)
            input_artifacts_dict = {'examples': input_examples}
            output_artifacts_dict = {'model': output_models}
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts=input_artifacts_dict,
                output_artifacts=output_artifacts_dict)

            # Verify that the same artifacts are returned in the correct order.
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, metadata_store_pb2.Event.INPUT)
            self.assertCountEqual(['examples'], list(artifacts_dict.keys()))
            self.assertEqual([ex.uri for ex in input_examples],
                             [a.uri for a in artifacts_dict['examples']])
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, metadata_store_pb2.Event.OUTPUT)
            self.assertCountEqual(['model'], list(artifacts_dict.keys()))
            self.assertEqual([model.uri for model in output_models],
                             [a.uri for a in artifacts_dict['model']])
Esempio n. 4
0
def prepare_execution(
    metadata_handler: metadata.Metadata,
    execution_type: metadata_store_pb2.ExecutionType,
    state: metadata_store_pb2.Execution.State,
    exec_properties: Optional[Mapping[Text, types.Property]] = None,
) -> metadata_store_pb2.Execution:
    """Creates an execution proto based on the information provided.

  Args:
    metadata_handler: A handler to access MLMD store.
    execution_type: A metadata_pb2.ExecutionType message describing the type of
      the execution.
    state: The state of the execution.
    exec_properties: Execution properties that need to be attached.

  Returns:
    A metadata_store_pb2.Execution message.
  """
    execution = metadata_store_pb2.Execution()
    execution.last_known_state = state
    execution.type_id = common_utils.register_type_if_not_exist(
        metadata_handler, execution_type).id

    exec_properties = exec_properties or {}
    # For every execution property, put it in execution.properties if its key is
    # in execution type schema. Otherwise, put it in execution.custom_properties.
    for k, v in exec_properties.items():
        if (execution_type.properties.get(k) ==
                common_utils.get_metadata_value_type(v)):
            common_utils.set_metadata_value(execution.properties[k], v)
        else:
            common_utils.set_metadata_value(execution.custom_properties[k], v)
    logging.debug('Prepared EXECUTION:\n %s', execution)
    return execution
Esempio n. 5
0
    def testPutExecutionGraph(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Prepares an input artifact. The artifact should be registered in MLMD
            # before the put_execution call.
            input_example = standard_artifacts.Examples()
            input_example.uri = 'example'
            input_example.type_id = common_utils.register_type_if_not_exist(
                m, input_example.artifact_type).id
            [input_example.id
             ] = m.store.put_artifacts([input_example.mlmd_artifact])
            # Prepares an output artifact.
            output_model = standard_artifacts.Model()
            output_model.uri = 'model'
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                exec_properties={
                    'p1': 1,
                    'p2': '2'
                },
                state=metadata_store_pb2.Execution.COMPLETE)
            contexts = self._generate_contexts(m)
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts={'example': [input_example]},
                output_artifacts={'model': [output_model]})

            self.assertProtoPartiallyEquals(
                output_model.mlmd_artifact,
                m.store.get_artifacts_by_id([output_model.id])[0],
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
            # Verifies edges between artifacts and execution.
            [input_event
             ] = m.store.get_events_by_artifact_ids([input_example.id])
            self.assertEqual(input_event.execution_id, execution.id)
            self.assertEqual(input_event.type, metadata_store_pb2.Event.INPUT)
            [output_event
             ] = m.store.get_events_by_artifact_ids([output_model.id])
            self.assertEqual(output_event.execution_id, execution.id)
            self.assertEqual(output_event.type,
                             metadata_store_pb2.Event.OUTPUT)
            # Verifies edges connecting contexts and {artifacts, execution}.
            context_ids = [context.id for context in contexts]
            self.assertCountEqual([
                c.id
                for c in m.store.get_contexts_by_artifact(input_example.id)
            ], context_ids)
            self.assertCountEqual([
                c.id for c in m.store.get_contexts_by_artifact(output_model.id)
            ], context_ids)
            self.assertCountEqual([
                c.id for c in m.store.get_contexts_by_execution(execution.id)
            ], context_ids)
Esempio n. 6
0
 def testRegisterTypeModifiedKey(self, metadata_type_class):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         type_with_two_properties = _create_type(metadata_type_class)
         common_utils.register_type_if_not_exist(m,
                                                 type_with_two_properties)
         # Tries to register a type that shares the same name of the type
         # previously registered but with conflicting property types. We expect
         # this to fail.
         type_with_different_properties = metadata_type_class()
         text_format.Parse(
             """
       name: 'my_type'
       properties {
         key: 'p1'
         value: STRING  # This is different from the original registered type
       }
       """, type_with_different_properties)
         with self.assertRaisesRegex(RuntimeError,
                                     'Missing or modified key'):
             common_utils.register_type_if_not_exist(
                 m, type_with_different_properties)
Esempio n. 7
0
    def testGetArtifactIdsForExecutionIdGroupedByEventType(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Register an input and output artifacts in MLMD.
            input_example = standard_artifacts.Examples()
            input_example.uri = 'example'
            input_example.type_id = common_utils.register_type_if_not_exist(
                m, input_example.artifact_type).id
            output_model = standard_artifacts.Model()
            output_model.uri = 'model'
            output_model.type_id = common_utils.register_type_if_not_exist(
                m, output_model.artifact_type).id
            [input_example.id, output_model.id] = m.store.put_artifacts(
                [input_example.mlmd_artifact, output_model.mlmd_artifact])
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                exec_properties={
                    'p1': 1,
                    'p2': '2'
                },
                state=metadata_store_pb2.Execution.COMPLETE)
            contexts = self._generate_contexts(m)
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts={'example': [input_example]},
                output_artifacts={'model': [output_model]})

            artifact_ids_by_event_type = (
                execution_lib.get_artifact_ids_by_event_type_for_execution_id(
                    m, execution.id))
            self.assertDictEqual(
                {
                    metadata_store_pb2.Event.INPUT: set([input_example.id]),
                    metadata_store_pb2.Event.OUTPUT: set([output_model.id]),
                }, artifact_ids_by_event_type)
Esempio n. 8
0
 def testRegisterType(self, metadata_type_class):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         type_with_two_properties = _create_type(metadata_type_class)
         result_one = common_utils.register_type_if_not_exist(
             m, type_with_two_properties)
         self.assertProtoEquals(
             """
       id: 1
       name: 'my_type'
       properties {
         key: 'p1'
         value: INT
       }
       properties {
         key: 'p2'
         value: STRING
       }
       """, result_one)
Esempio n. 9
0
def prepare_execution(
    metadata_handler: metadata.Metadata,
    execution_type: metadata_store_pb2.ExecutionType,
    state: metadata_store_pb2.Execution.State,
    exec_properties: Optional[Mapping[str, types.ExecPropertyTypes]] = None,
) -> metadata_store_pb2.Execution:
    """Creates an execution proto based on the information provided.

  Args:
    metadata_handler: A handler to access MLMD store.
    execution_type: A metadata_pb2.ExecutionType message describing the type of
      the execution.
    state: The state of the execution.
    exec_properties: Execution properties that need to be attached.

  Returns:
    A metadata_store_pb2.Execution message.
  """
    execution = metadata_store_pb2.Execution()
    execution.last_known_state = state
    execution.type_id = common_utils.register_type_if_not_exist(
        metadata_handler, execution_type).id

    exec_properties = exec_properties or {}
    # For every execution property, put it in execution.properties if its key is
    # in execution type schema. Otherwise, put it in execution.custom_properties.
    for k, v in exec_properties.items():
        value = pipeline_pb2.Value()
        value = data_types_utils.set_parameter_value(value, v)

        if value.HasField('schema'):
            # Stores schema in custom_properties for non-primitive types to allow
            # parsing in later stages.
            data_types_utils.set_metadata_value(
                execution.custom_properties[get_schema_key(k)],
                proto_utils.proto_to_json(value.schema))

        if (execution_type.properties.get(k) ==
                data_types_utils.get_metadata_value_type(v)):
            execution.properties[k].CopyFrom(value.field_value)
        else:
            execution.custom_properties[k].CopyFrom(value.field_value)
    logging.debug('Prepared EXECUTION:\n %s', execution)
    return execution
Esempio n. 10
0
def _create_artifact_and_event_pairs(
    metadata_handler: metadata.Metadata,
    artifact_dict: MutableMapping[Text, Sequence[types.Artifact]],
    event_type: metadata_store_pb2.Event.Type,
) -> List[Tuple[metadata_store_pb2.Artifact, metadata_store_pb2.Event]]:
    """Creates a list of [Artifact, Event] tuples.

  The result of this function will be used in a MLMD put_execution() call.

  Args:
    metadata_handler: A handler to access MLMD store.
    artifact_dict: The source of artifacts to work on. For each artifact in the
      dict, creates a tuple for that. Note that all artifacts of the same key in
      the artifact_dict are expected to share the same artifact type.
    event_type: The event type of the event to be attached to the artifact

  Returns:
    A list of [Artifact, Event] tuples
  """
    result = []
    for key, artifact_list in artifact_dict.items():
        artifact_type = None
        for index, artifact in enumerate(artifact_list):
            # TODO(b/153904840): If artifact id is present, skip putting the artifact
            # into the pair when MLMD API is ready.
            event = event_lib.generate_event(event_type=event_type,
                                             key=key,
                                             index=index)
            # Reuses already registered type in the same list whenever possible as
            # the artifacts in the same list share the same artifact type.
            if artifact_type:
                assert artifact_type.name == artifact.artifact_type.name, (
                    'Artifacts under the same key should share the same artifact type.'
                )
            artifact_type = common_utils.register_type_if_not_exist(
                metadata_handler, artifact.artifact_type)
            artifact.set_mlmd_artifact_type(artifact_type)
            result.append((artifact.mlmd_artifact, event))
    return result
Esempio n. 11
0
def _generate_context_proto(
        metadata_handler: metadata.Metadata,
        context_spec: pipeline_pb2.ContextSpec) -> metadata_store_pb2.Context:
    """Generates metadata_pb2.Context based on the ContextSpec message.

  Args:
    metadata_handler: A handler to access MLMD store.
    context_spec: A pipeline_pb2.ContextSpec message that instructs registering
      of a context.

  Returns:
    A metadata_store_pb2.Context message.

  Raises:
    RuntimeError: When actual property type does not match provided metadata
      type schema.
  """
    context_type = common_utils.register_type_if_not_exist(
        metadata_handler, context_spec.type)
    context_name = common_utils.get_value(context_spec.name)
    assert isinstance(context_name, Text), 'context name should be string.'
    context = metadata_store_pb2.Context(type_id=context_type.id,
                                         name=context_name)
    for k, v in context_spec.properties.items():
        if k in context_type.properties:
            actual_property_type = common_utils.get_metadata_value_type(v)
            if context_type.properties.get(k) == actual_property_type:
                common_utils.set_metadata_value(context.properties[k], v)
            else:
                raise RuntimeError(
                    'Property type %s different from provided metadata type property type %s for key %s'
                    %
                    (actual_property_type, context_type.properties.get(k), k))
        else:
            common_utils.set_metadata_value(context.custom_properties[k], v)
    return context
Esempio n. 12
0
    def testGetArtifactsDict(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Create and shuffle a few artifacts. The shuffled order should be
            # retained in the output of `execution_lib.get_artifacts_dict`.
            input_artifact_keys = ('input1', 'input2', 'input3')
            input_artifacts_dict = collections.OrderedDict()
            for input_key in input_artifact_keys:
                input_examples = []
                for i in range(10):
                    input_example = standard_artifacts.Examples()
                    input_example.uri = f'{input_key}/example{i}'
                    input_example.type_id = common_utils.register_type_if_not_exist(
                        m, input_example.artifact_type).id
                    input_examples.append(input_example)
                random.shuffle(input_examples)
                input_artifacts_dict[input_key] = input_examples

            output_models = []
            for i in range(8):
                output_model = standard_artifacts.Model()
                output_model.uri = f'model{i}'
                output_model.type_id = common_utils.register_type_if_not_exist(
                    m, output_model.artifact_type).id
                output_models.append(output_model)
            random.shuffle(output_models)
            output_artifacts_dict = {'model': output_models}

            # Store input artifacts only. Outputs will be saved in put_execution().
            input_mlmd_artifacts = [
                a.mlmd_artifact
                for a in itertools.chain(*input_artifacts_dict.values())
            ]
            artifact_ids = m.store.put_artifacts(input_mlmd_artifacts)
            for artifact_id, mlmd_artifact in zip(artifact_ids,
                                                  input_mlmd_artifacts):
                mlmd_artifact.id = artifact_id

            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                state=metadata_store_pb2.Execution.RUNNING)
            contexts = self._generate_contexts(m)

            # Change the order of the OrderedDict to shuffle the order of input keys.
            input_artifacts_dict.move_to_end('input1')
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts=input_artifacts_dict,
                output_artifacts=output_artifacts_dict)

            # Verify that the same artifacts are returned in the correct order.
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, [metadata_store_pb2.Event.INPUT])
            self.assertEqual(set(input_artifact_keys),
                             set(artifacts_dict.keys()))
            for key in artifacts_dict:
                self.assertEqual([ex.uri for ex in input_artifacts_dict[key]],
                                 [a.uri for a in artifacts_dict[key]],
                                 f'for key={key}')
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, [metadata_store_pb2.Event.OUTPUT])
            self.assertEqual({'model'}, set(artifacts_dict.keys()))
            self.assertEqual([model.uri for model in output_models],
                             [a.uri for a in artifacts_dict['model']])