Esempio n. 1
0
    def test_put_contexts_get_contexts(self):
        store = _get_metadata_store()
        context_type = _create_example_context_type()
        type_id = store.put_context_type(context_type)
        context_0 = metadata_store_pb2.Context()
        context_0.type_id = type_id
        context_0.name = "context0"
        context_0.properties["bar"].string_value = "Hello"
        context_1 = metadata_store_pb2.Context()
        context_1.name = "context1"
        context_1.type_id = type_id
        context_1.properties["foo"].int_value = -9

        [context_id_0,
         context_id_1] = store.put_contexts([context_0, context_1])

        context_result = store.get_contexts()
        self.assertLen(context_result, 2)
        self.assertEqual(context_result[0].id, context_id_0)
        self.assertEqual(context_result[0].name, "context0")
        self.assertEqual(context_result[0].properties["bar"].string_value,
                         "Hello")
        self.assertEqual(context_result[1].id, context_id_1)
        self.assertEqual(context_result[1].name, "context1")
        self.assertEqual(context_result[1].properties["foo"].int_value, -9)
Esempio n. 2
0
 def test_puts_contexts_duplicated_name_with_the_same_type(self):
   store = _get_metadata_store()
   with self.assertRaises(errors.AlreadyExistsError):
     context_type = _create_example_context_type()
     type_id = store.put_context_type(context_type)
     context_0 = metadata_store_pb2.Context()
     context_0.type_id = type_id
     context_0.name = "the_same_name"
     context_1 = metadata_store_pb2.Context()
     context_1.type_id = type_id
     context_1.name = "the_same_name"
     store.put_contexts([context_0, context_1])
Esempio n. 3
0
  def setUp(self):
    super().setUp()
    tm = test_mlmd.TestMLMD(context_type='pipeline')
    pipeline_ctx_id = tm.put_context(PIPELINE_NAME)
    tm.update_context_type(mlmd_analytics._IR_RUN_CONTEXT_NAME)
    run_context = metadata_store_pb2.Context(
        name=RUN_ID, create_time_since_epoch=5)
    run_context_2 = metadata_store_pb2.Context(
        name=RUN_ID_3, create_time_since_epoch=0)
    run_context_id = tm.put_context(context=run_context)
    run_context_id_2 = tm.put_context(context=run_context_2)
    tm.update_context_type(mlmd_analytics._IR_COMPONENT_NAME)
    self.component_context_id = tm.put_context(PIPELINE_NAME + '.' +
                                               COMPONENT_NAME)
    execution_id = tm.put_execution(RUN_ID,
                                    PIPELINE_NAME + '.' + COMPONENT_NAME)
    execution_id_2 = tm.put_execution(RUN_ID_3,
                                      PIPELINE_NAME + '.' + COMPONENT_NAME)

    tm.put_association(pipeline_ctx_id, execution_id)
    tm.put_association(run_context_id, execution_id)
    tm.put_association(self.component_context_id, execution_id)

    tm.put_association(pipeline_ctx_id, execution_id_2)
    tm.put_association(run_context_id_2, execution_id_2)

    self.input_artifact_id = tm.put_artifact({'name': INPUT_ARTIFACT_NAME})
    self.output_artifact_id = tm.put_artifact({'name': OUTPUT_ARTIFACT_NAME})
    tm.put_attribution(self.component_context_id, self.input_artifact_id)
    tm.put_attribution(self.component_context_id, self.output_artifact_id)
    tm.put_event(self.input_artifact_id, execution_id,
                 metadata_store_pb2.Event.INPUT)
    tm.put_event(self.output_artifact_id, execution_id,
                 metadata_store_pb2.Event.OUTPUT)

    # Attributed artifacts with unassociated executions. This is an anomoly
    # found in MLMD stores created by IR-Based orchestrators.
    execution_id_3 = tm.put_execution(RUN_ID_2, COMPONENT_NAME_2)
    input_artifact_id_2 = tm.put_artifact({'name': INPUT_ARTIFACT_NAME_2})
    output_artifact_id_2 = tm.put_artifact({'name': OUTPUT_ARTIFACT_NAME_2})
    tm.put_attribution(self.component_context_id, input_artifact_id_2)
    tm.put_attribution(self.component_context_id, output_artifact_id_2)
    tm.put_event(input_artifact_id_2, execution_id_3,
                 metadata_store_pb2.Event.INPUT)
    tm.put_event(output_artifact_id_2, execution_id_3,
                 metadata_store_pb2.Event.OUTPUT)

    self.analytics = mlmd_analytics.Analytics(store=tm.store)
Esempio n. 4
0
  def testPreExecutionCached(self, mock_verify_input_artifacts_fn):
    self._mock_metadata.search_artifacts.return_value = list(
        self._input_dict['input_string'].get())
    self._mock_metadata.get_artifacts_by_info.side_effect = list(
        self._input_dict['input_data'].get()) + list(
            self._input_dict['input_string'].get())
    self._mock_metadata.register_run_context_if_not_exists.side_effect = [
        metadata_store_pb2.Context()
    ]
    self._mock_metadata.register_execution.side_effect = [self._execution]
    self._mock_metadata.get_cached_outputs.side_effect = [
        self._output_artifacts
    ]

    driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
    execution_decision = driver.pre_execution(
        input_dict=self._input_dict,
        output_dict=self._output_dict,
        exec_properties=self._exec_properties,
        driver_args=self._driver_args,
        pipeline_info=self._pipeline_info,
        component_info=self._component_info)
    self.assertTrue(execution_decision.use_cached_results)
    self.assertEqual(execution_decision.execution_id, self._execution_id)
    self.assertCountEqual(execution_decision.exec_properties,
                          self._exec_properties)
    self.assertCountEqual(execution_decision.output_dict,
                          self._output_artifacts)
Esempio n. 5
0
    def test_put_contexts_get_context_by_type_and_name(self):
        # Prepare test data.
        store = _get_metadata_store()
        context_type = _create_example_context_type(self._get_test_type_name())
        type_id = store.put_context_type(context_type)
        context = metadata_store_pb2.Context()
        context.type_id = type_id
        context.name = self._get_test_type_name()
        [context_id] = store.put_contexts([context])

        # Test Context found case.
        got_context = store.get_context_by_type_and_name(
            context_type.name, context.name)
        self.assertEqual(got_context.id, context_id)
        self.assertEqual(got_context.type_id, type_id)
        self.assertEqual(got_context.name, context.name)

        # Test Context not found cases.
        empty_context = store.get_context_by_type_and_name(
            "random_name", context.name)
        self.assertEqual(empty_context, None)
        empty_context = store.get_context_by_type_and_name(
            context_type.name, "random_name")
        self.assertEqual(empty_context, None)
        empty_context = store.get_context_by_type_and_name(
            "random_name", "random_name")
        self.assertEqual(empty_context, None)
Esempio n. 6
0
    def _register_run_context(self,
                              pipeline_info: data_types.PipelineInfo) -> int:
        """Create a new context in metadata for current pipeline run.

    Args:
      pipeline_info: pipeline information for current run.

    Returns:
      context id of the new context.
    """
        try:
            context_type = self._store.get_context_type(_CONTEXT_TYPE_RUN)
            assert context_type, 'Context type is None for %s.' % (
                _CONTEXT_TYPE_RUN)
            context_type_id = context_type.id
        except tf.errors.NotFoundError:
            context_type = metadata_store_pb2.ContextType(
                name=_CONTEXT_TYPE_RUN)
            context_type.properties[
                'pipeline_name'] = metadata_store_pb2.STRING
            context_type.properties['run_id'] = metadata_store_pb2.STRING
            # TODO(b/139485894): add DAG as properties.
            context_type_id = self._store.put_context_type(context_type)

        context = metadata_store_pb2.Context(
            type_id=context_type_id, name=pipeline_info.run_context_name)
        context.properties[
            'pipeline_name'].string_value = pipeline_info.pipeline_name
        context.properties['run_id'].string_value = pipeline_info.run_id
        [context_id] = self._store.put_contexts([context])

        return context_id
Esempio n. 7
0
  def test_put_duplicated_attributions_and_empty_associations(self):
    store = _get_metadata_store()
    context_type = _create_example_context_type()
    context_type_id = store.put_context_type(context_type)
    want_context = metadata_store_pb2.Context()
    want_context.type_id = context_type_id
    want_context.name = "context"
    [context_id] = store.put_contexts([want_context])
    want_context.id = context_id

    artifact_type = _create_example_artifact_type()
    artifact_type_id = store.put_artifact_type(artifact_type)
    want_artifact = metadata_store_pb2.Artifact()
    want_artifact.type_id = artifact_type_id
    want_artifact.uri = "testuri"
    [artifact_id] = store.put_artifacts([want_artifact])
    want_artifact.id = artifact_id

    attribution = metadata_store_pb2.Attribution()
    attribution.artifact_id = want_artifact.id
    attribution.context_id = want_context.id
    store.put_attributions_and_associations([attribution, attribution], [])

    got_contexts = store.get_contexts_by_artifact(want_artifact.id)
    self.assertLen(got_contexts, 1)
    self.assertEqual(got_contexts[0].id, want_context.id)
    self.assertEqual(got_contexts[0].name, want_context.name)
    got_arifacts = store.get_artifacts_by_context(want_context.id)
    self.assertLen(got_arifacts, 1)
    self.assertEqual(got_arifacts[0].uri, want_artifact.uri)
    self.assertEmpty(store.get_executions_by_context(want_context.id))
Esempio n. 8
0
    def _get_context_id(self, reuse_workspace_if_exists):
        ctx = self._get_existing_context()
        if ctx is not None:
            if reuse_workspace_if_exists:
                return ctx.id
            else:
                raise ValueError(
                    'Workspace name {} already exists with id {}. You can initialize workspace with reuse_workspace_if_exists=True if want to reuse it'
                    .format(self.name, ctx.id))
        # Create new context type or get the existing type id.
        ctx_type = mlpb.ContextType(name=self.CONTEXT_TYPE_NAME,
                                    properties={
                                        "description": mlpb.STRING,
                                        "labels": mlpb.STRING
                                    })
        ctx_type_id = _retry(lambda: self.store.put_context_type(ctx_type))

        # Add new context for workspace.
        prop = {}
        if self.description is not None:
            prop["description"] = mlpb.Value(string_value=self.description)
        if self.labels is not None:
            prop["labels"] = mlpb.Value(string_value=json.dumps(self.labels))
        ctx = mlpb.Context(
            type_id=ctx_type_id,
            name=self.name,
            properties=prop,
        )
        ctx_id = _retry(lambda: self.store.put_contexts([ctx])[0])
        return ctx_id
Esempio n. 9
0
    def testPreExecutionNewExecution(self, mock_verify_input_artifacts_fn):
        self._mock_metadata.get_artifacts_by_info.side_effect = list(
            self._input_dict['input_data'].get())
        self._mock_metadata.register_execution.side_effect = [self._execution]
        self._mock_metadata.get_cached_outputs.side_effect = [None]
        self._mock_metadata.register_run_context_if_not_exists.side_effect = [
            metadata_store_pb2.Context()
        ]

        driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
        execution_decision = driver.pre_execution(
            input_dict=self._input_dict,
            output_dict=self._output_dict,
            exec_properties=self._exec_properties,
            driver_args=self._driver_args,
            pipeline_info=self._pipeline_info,
            component_info=self._component_info)
        self.assertFalse(execution_decision.use_cached_results)
        self.assertEqual(execution_decision.execution_id, self._execution_id)
        self.assertCountEqual(execution_decision.exec_properties,
                              self._exec_properties)
        self.assertEqual(
            execution_decision.output_dict['output_data'][0].uri,
            os.path.join(self._pipeline_info.pipeline_root,
                         self._component_info.component_id, 'output_data',
                         str(self._execution_id)))
Esempio n. 10
0
    def _prepare_context(
        self,
        context_type_name: Text,
        context_name: Text,
        properties: Optional[Dict[Text, Union[int, float, Text]]] = None
    ) -> metadata_store_pb2.Context:
        """Prepares a context proto."""
        # TODO(ruoyu): Centralize the type definition / mapping along with Artifact
        # property types.
        properties = properties or {}
        property_type_mapping = {
            int: metadata_store_pb2.INT,
            six.binary_type: metadata_store_pb2.STRING,
            six.text_type: metadata_store_pb2.STRING,
            float: metadata_store_pb2.DOUBLE
        }
        context_type_id = self._register_context_type_if_not_exist(
            context_type_name,
            dict((k, property_type_mapping[type(v)])
                 for k, v in properties.items()))

        context = metadata_store_pb2.Context(type_id=context_type_id,
                                             name=context_name)
        for k, v in properties.items():
            if isinstance(v, int):
                context.properties[k].int_value = v
            elif isinstance(v, six.string_types):
                context.properties[k].string_value = v
            elif isinstance(v, float):
                context.properties[k].double_value = v
            else:
                raise RuntimeError('Unexpected property type: %s' % type(v))
        return context
Esempio n. 11
0
 def test_puts_contexts_empty_name(self):
   store = _get_metadata_store()
   with self.assertRaises(errors.InvalidArgumentError):
     context_type = _create_example_context_type()
     type_id = store.put_context_type(context_type)
     context_0 = metadata_store_pb2.Context()
     context_0.type_id = type_id
     store.put_contexts([context_0])
Esempio n. 12
0
  def test_put_contexts_get_contexts_by_type(self):
    store = _get_metadata_store()
    context_type = _create_example_context_type()
    type_id = store.put_context_type(context_type)
    context_type_2 = _create_example_context_type_2()
    type_id_2 = store.put_context_type(context_type_2)
    context_0 = metadata_store_pb2.Context()
    context_0.type_id = type_id
    context_0.name = "context_name"
    context_1 = metadata_store_pb2.Context()
    context_1.type_id = type_id_2
    context_1.name = "context_name"

    [_, context_id_1] = store.put_contexts([context_0, context_1])
    context_result = store.get_contexts_by_type(context_type_2.name)
    self.assertLen(context_result, 1)
    self.assertEqual(context_result[0].id, context_id_1)
Esempio n. 13
0
    def test_put_execution_with_context(self):
        store = _get_metadata_store()
        execution_type = metadata_store_pb2.ExecutionType()
        execution_type.name = self._get_test_type_name()
        execution_type_id = store.put_execution_type(execution_type)
        execution = metadata_store_pb2.Execution()
        execution.type_id = execution_type_id

        artifact_type = metadata_store_pb2.ArtifactType()
        artifact_type.name = self._get_test_type_name()
        artifact_type_id = store.put_artifact_type(artifact_type)
        input_artifact = metadata_store_pb2.Artifact()
        input_artifact.type_id = artifact_type_id
        output_artifact = metadata_store_pb2.Artifact()
        output_artifact.type_id = artifact_type_id
        output_event = metadata_store_pb2.Event()
        output_event.type = metadata_store_pb2.Event.DECLARED_INPUT

        context_type = metadata_store_pb2.ContextType()
        context_type.name = self._get_test_type_name()
        context_type_id = store.put_context_type(context_type)
        context = metadata_store_pb2.Context()
        context.type_id = context_type_id
        context_name = self._get_test_type_name()
        context.name = context_name

        execution_id, artifact_ids, context_ids = store.put_execution(
            execution, [[input_artifact], [output_artifact, output_event]],
            [context])

        # Test artifacts & events are correctly inserted.
        self.assertLen(artifact_ids, 2)
        events = store.get_events_by_execution_ids([execution_id])
        self.assertLen(events, 1)

        # Test the context is correctly inserted.
        got_contexts = store.get_contexts_by_id(context_ids)
        self.assertLen(context_ids, 1)
        self.assertLen(got_contexts, 1)

        # Test the association link between execution and the context is correct.
        contexts_by_execution_id = store.get_contexts_by_execution(
            execution_id)
        self.assertLen(contexts_by_execution_id, 1)
        self.assertEqual(contexts_by_execution_id[0].name, context_name)
        self.assertEqual(contexts_by_execution_id[0].type_id, context_type_id)
        executions_by_context = store.get_executions_by_context(context_ids[0])
        self.assertLen(executions_by_context, 1)

        # Test the attribution links between artifacts and the context are correct.
        contexts_by_artifact_id = store.get_contexts_by_artifact(
            artifact_ids[0])
        self.assertLen(contexts_by_artifact_id, 1)
        self.assertEqual(contexts_by_artifact_id[0].name, context_name)
        self.assertEqual(contexts_by_artifact_id[0].type_id, context_type_id)
        artifacts_by_context = store.get_artifacts_by_context(context_ids[0])
        self.assertLen(artifacts_by_context, 2)
Esempio n. 14
0
 def _create_context_with_type(self, context_name: str, type_name: str,
                               property_types: dict = None,
                               properties: dict = None):
     context_type = self._get_or_create_context_type(
         type_name=type_name, properties=property_types)
     context = metadata_store_pb2.Context(name=context_name,
                                          type_id=context_type.id,
                                          properties=properties)
     context.id = self.store.put_contexts([context])[0]
     return context
Esempio n. 15
0
    def test_put_contexts_get_contexts(self):
        store = _get_metadata_store()
        context_type = _create_example_context_type(self._get_test_type_name())
        type_id = store.put_context_type(context_type)
        context_0 = metadata_store_pb2.Context()
        context_0.type_id = type_id
        context_0_name = self._get_test_type_name()
        context_0.name = context_0_name
        context_0.properties["bar"].string_value = "Hello"
        context_1 = metadata_store_pb2.Context()
        context_1_name = self._get_test_type_name()
        context_1.name = context_1_name
        context_1.type_id = type_id
        context_1.properties["foo"].int_value = -9

        existing_contexts_count = 0
        try:
            existing_contexts_count = len(store.get_contexts())
        except errors.NotFoundError:
            existing_contexts_count = 0
        [context_id_0,
         context_id_1] = store.put_contexts([context_0, context_1])

        context_result = store.get_contexts()
        new_contexts_count = len(context_result)
        context_result = [
            c for c in context_result
            if c.id == context_id_0 or c.id == context_id_1
        ]

        self.assertEqual(existing_contexts_count + 2, new_contexts_count)
        # Normalize the order of the results.
        if context_result[0].id == context_id_0:
            [context_result_0, context_result_1] = context_result
        else:
            [context_result_1, context_result_0] = context_result

        self.assertEqual(context_result_0.name, context_0_name)
        self.assertEqual(context_result_0.properties["bar"].string_value,
                         "Hello")
        self.assertEqual(context_result_1.name, context_1_name)
        self.assertEqual(context_result_1.properties["foo"].int_value, -9)
Esempio n. 16
0
    def _register_context_if_not_exist(
        self, context_type_name: Text, context_name: Text,
        properties: Dict[Text, Union[int, float, Text]]
    ) -> metadata_store_pb2.Context:
        """Registers a context if not exist, otherwise returns the existing one.

    Args:
      context_type_name: the name of the context type desired.
      context_name: the name of the context.
      properties: properties to set in the context.

    Returns:
      id of the desired context

    Raises:
      RuntimeError: when meeting unexpected property type.
    """
        # TODO(ruoyu): Centralize the type definition / mapping along with Artifact
        # property types.
        property_type_mapping = {
            int: metadata_store_pb2.INT,
            six.binary_type: metadata_store_pb2.STRING,
            six.text_type: metadata_store_pb2.STRING,
            float: metadata_store_pb2.DOUBLE
        }
        context_type_id = self._register_context_type_if_not_exist(
            context_type_name,
            dict((k, property_type_mapping[type(k)])
                 for k, v in properties.items()))

        context = metadata_store_pb2.Context(type_id=context_type_id,
                                             name=context_name)
        for k, v in properties.items():
            if isinstance(v, int):
                context.properties[k].int_value = v
            elif isinstance(v, six.string_types):
                context.properties[k].string_value = v
            elif isinstance(v, float):
                context.properties[k].double_value = v
            else:
                raise RuntimeError('Unexpected property type: %s' % type(v))
        try:
            [context_id] = self.store.put_contexts([context])
            context.id = context_id
        except tf.errors.AlreadyExistsError:
            absl.logging.debug('Run context %s already exists.', context_name)
            context = self.store.get_context_by_type_and_name(
                context_type_name, context_name)
            assert context is not None, 'Run context is missing for %s.' % (
                context_name)

        absl.logging.debug('ID of run context %s is %s.', context_name,
                           context.id)
        return context
Esempio n. 17
0
    def test_put_and_use_attributions_and_associations(self):
        store = _get_metadata_store()
        context_type = _create_example_context_type(self._get_test_type_name())
        context_type_id = store.put_context_type(context_type)
        want_context = metadata_store_pb2.Context()
        want_context.type_id = context_type_id
        want_context.name = self._get_test_type_name()
        [context_id] = store.put_contexts([want_context])
        want_context.id = context_id

        execution_type = _create_example_execution_type(
            self._get_test_type_name())
        execution_type_id = store.put_execution_type(execution_type)
        want_execution = metadata_store_pb2.Execution()
        want_execution.type_id = execution_type_id
        want_execution.properties["foo"].int_value = 3
        [execution_id] = store.put_executions([want_execution])
        want_execution.id = execution_id

        artifact_type = _create_example_artifact_type(
            self._get_test_type_name())
        artifact_type_id = store.put_artifact_type(artifact_type)
        want_artifact = metadata_store_pb2.Artifact()
        want_artifact.type_id = artifact_type_id
        want_artifact.uri = "testuri"
        [artifact_id] = store.put_artifacts([want_artifact])
        want_artifact.id = artifact_id

        # insert attribution and association and test querying the relationship
        attribution = metadata_store_pb2.Attribution()
        attribution.artifact_id = want_artifact.id
        attribution.context_id = want_context.id
        association = metadata_store_pb2.Association()
        association.execution_id = want_execution.id
        association.context_id = want_context.id
        store.put_attributions_and_associations([attribution], [association])

        # test querying the relationship
        got_contexts = store.get_contexts_by_artifact(want_artifact.id)
        self.assertLen(got_contexts, 1)
        self.assertEqual(got_contexts[0].id, want_context.id)
        self.assertEqual(got_contexts[0].name, want_context.name)
        got_arifacts = store.get_artifacts_by_context(want_context.id)
        self.assertLen(got_arifacts, 1)
        self.assertEqual(got_arifacts[0].uri, want_artifact.uri)
        got_executions = store.get_executions_by_context(want_context.id)
        self.assertLen(got_executions, 1)
        self.assertEqual(got_executions[0].properties["foo"],
                         want_execution.properties["foo"])
        got_contexts = store.get_contexts_by_execution(want_execution.id)
        self.assertLen(got_contexts, 1)
        self.assertEqual(got_contexts[0].id, want_context.id)
        self.assertEqual(got_contexts[0].name, want_context.name)
Esempio n. 18
0
  def test_update_context_get_context(self):
    store = _get_metadata_store()
    context_type = _create_example_context_type()
    type_id = store.put_context_type(context_type)
    context = metadata_store_pb2.Context()
    context.type_id = type_id
    context.name = "context1"
    context.properties["bar"].string_value = "Hello"
    [context_id] = store.put_contexts([context])

    context_2 = metadata_store_pb2.Context()
    context_2.id = context_id
    context_2.name = "context2"
    context_2.type_id = type_id
    context_2.properties["foo"].int_value = 12
    context_2.properties["bar"].string_value = "Goodbye"
    [context_id_2] = store.put_contexts([context_2])
    self.assertEqual(context_id, context_id_2)

    [context_result] = store.get_contexts_by_id([context_id])
    self.assertEqual(context_result.name, context_2.name)
    self.assertEqual(context_result.properties["bar"].string_value, "Goodbye")
    self.assertEqual(context_result.properties["foo"].int_value, 12)
Esempio n. 19
0
 def test_put_contexts_get_contexts_by_id(self):
   store = _get_metadata_store()
   context_type = _create_example_context_type()
   type_id = store.put_context_type(context_type)
   context = metadata_store_pb2.Context()
   context.type_id = type_id
   context.name = "context1"
   context.properties["foo"].int_value = 3
   context.custom_properties["abc"].string_value = "s"
   [context_id] = store.put_contexts([context])
   [context_result] = store.get_contexts_by_id([context_id])
   self.assertEqual(context_result.name, context.name)
   self.assertEqual(context_result.properties["foo"].int_value,
                    context.properties["foo"].int_value)
   self.assertEqual(context_result.custom_properties["abc"].string_value,
                    context.custom_properties["abc"].string_value)
Esempio n. 20
0
def create_context_with_type(
    store,
    context_name: str,
    type_name: str,
    properties: dict = None,
    type_properties: dict = None,
) -> metadata_store_pb2.Context:
    # ! Context_name must be unique
    context_type = get_or_create_context_type(
        store=store,
        type_name=type_name,
        properties=type_properties,
    )
    context = metadata_store_pb2.Context(
        name=context_name,
        type_id=context_type.id,
        properties=properties,
    )
    context.id = store.put_contexts([context])[0]
    return context
Esempio n. 21
0
def _generate_context_proto(
        metadata_handler: metadata.Metadata,
        context_spec: pipeline_pb2.ContextSpec) -> metadata_store_pb2.Context:
    """Generates metadata_pb2.Context based on the ContextSpec message.

  Args:
    metadata_handler: A handler to access MLMD store.
    context_spec: A pipeline_pb2.ContextSpec message that instructs registering
      of a context.

  Returns:
    A metadata_store_pb2.Context message.

  Raises:
    RuntimeError: When actual property type does not match provided metadata
      type schema.
  """
    context_type = common_utils.register_type_if_not_exist(
        metadata_handler, context_spec.type)
    context_name = common_utils.get_value(context_spec.name)
    assert isinstance(context_name, Text), 'context name should be string.'
    context = metadata_store_pb2.Context(type_id=context_type.id,
                                         name=context_name)
    for k, v in context_spec.properties.items():
        if k in context_type.properties:
            actual_property_type = common_utils.get_metadata_value_type(v)
            if context_type.properties.get(k) == actual_property_type:
                common_utils.set_metadata_value(context.properties[k], v)
            else:
                raise RuntimeError(
                    'Property type %s different from provided metadata type property type %s for key %s'
                    %
                    (actual_property_type, context_type.properties.get(k), k))
        else:
            common_utils.set_metadata_value(context.custom_properties[k], v)
    return context
Esempio n. 22
0
# In[14]:

# Create a ContextType
expt_context_type = metadata_store_pb2.ContextType()
expt_context_type.name = 'Experiment'
expt_context_type.properties['note'] = metadata_store_pb2.STRING

# Register context type to the Metadata Store
expt_context_type_id = store.put_context_type(expt_context_type)

# Similarly, you can create an instance of this context type and use the `put_contexts()` method to register to the store.

# In[15]:

# Generate the context
expt_context = metadata_store_pb2.Context()
expt_context.type_id = expt_context_type_id
# Give the experiment a name
expt_context.name = 'Demo'
expt_context.properties['note'].string_value = 'Walkthrough of metadata'

# Submit context to the Metadata Store
expt_context_id = store.put_contexts([expt_context])[0]

print('Experiment Context type:\n', expt_context_type)
print('Experiment Context type ID: ', expt_context_type_id)

print('Experiment Context:\n', expt_context)
print('Experiment Context ID: ', expt_context_id)

# ## Generate attribution and association relationships