Exemplo n.º 1
0
 def setUp(self):
     super().setUp()
     # Set up MLMD connection.
     pipeline_root = os.path.join(
         os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
         self.id())
     metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
     connection_config = metadata.sqlite_metadata_connection_config(
         metadata_path)
     connection_config.sqlite.SetInParent()
     self._mlmd_connection = metadata.Metadata(
         connection_config=connection_config)
     with self._mlmd_connection as m:
         self._execution = execution_publish_utils.register_execution(
             metadata_handler=m,
             execution_type=metadata_store_pb2.ExecutionType(
                 name='test_execution_type'),
             contexts=[],
             input_artifacts=[])
     # Set up gRPC stub.
     port = portpicker.pick_unused_port()
     self.sidecar = execution_watcher.ExecutionWatcher(
         port,
         mlmd_connection=self._mlmd_connection,
         execution=self._execution,
         creds=grpc.local_server_credentials())
     self.sidecar.start()
     self.stub = execution_watcher.generate_service_stub(
         self.sidecar.address, grpc.local_channel_credentials())
Exemplo n.º 2
0
    def test_put_events_get_events(self):
        store = _get_metadata_store()
        execution_type = metadata_store_pb2.ExecutionType()
        execution_type.name = self._get_test_type_name()
        execution_type_id = store.put_execution_type(execution_type)
        execution = metadata_store_pb2.Execution()
        execution.type_id = execution_type_id
        [execution_id] = store.put_executions([execution])
        artifact_type = metadata_store_pb2.ArtifactType()
        artifact_type.name = self._get_test_type_name()
        artifact_type_id = store.put_artifact_type(artifact_type)
        artifact = metadata_store_pb2.Artifact()
        artifact.type_id = artifact_type_id
        [artifact_id] = store.put_artifacts([artifact])

        event = metadata_store_pb2.Event()
        event.type = metadata_store_pb2.Event.DECLARED_OUTPUT
        event.artifact_id = artifact_id
        event.execution_id = execution_id
        store.put_events([event])
        [event_result] = store.get_events_by_artifact_ids([artifact_id])
        self.assertEqual(event_result.artifact_id, artifact_id)
        self.assertEqual(event_result.execution_id, execution_id)
        self.assertEqual(event_result.type,
                         metadata_store_pb2.Event.DECLARED_OUTPUT)

        [event_result_2] = store.get_events_by_execution_ids([execution_id])
        self.assertEqual(event_result_2.artifact_id, artifact_id)
        self.assertEqual(event_result_2.execution_id, execution_id)
        self.assertEqual(event_result_2.type,
                         metadata_store_pb2.Event.DECLARED_OUTPUT)
Exemplo n.º 3
0
 def _put_execution_type(self) -> int:
     exec_type = metadata_store_pb2.ExecutionType()
     exec_type.name = results._TRAINER
     exec_type.properties[results.RUN_ID_KEY] = metadata_store_pb2.STRING
     exec_type.properties[results._HPARAMS] = metadata_store_pb2.STRING
     exec_type.properties[results._COMPONENT_ID] = metadata_store_pb2.STRING
     return self.store.put_execution_type(exec_type)
Exemplo n.º 4
0
def connect_to_mlmd() -> metadata_store.MetadataStore:
    metadata_service_host = os.environ.get('METADATA_SERVICE_SERVICE_HOST',
                                           'metadata-service')
    metadata_service_port = int(
        os.environ.get('METADATA_SERVICE_SERVICE_PORT', 8080))

    mlmd_connection_config = metadata_store_pb2.MetadataStoreClientConfig(
        host=metadata_service_host,
        port=metadata_service_port,
    )

    # Checking the connection to the Metadata store.
    for _ in range(100):
        try:
            mlmd_store = metadata_store.MetadataStore(mlmd_connection_config)
            # All get requests fail when the DB is empty, so we have to use a put request.
            # TODO: Replace with _ = mlmd_store.get_context_types() when https://github.com/google/ml-metadata/issues/28 is fixed
            _ = mlmd_store.put_execution_type(
                metadata_store_pb2.ExecutionType(name="DummyExecutionType", ))
            return mlmd_store
        except Exception as e:
            print(
                'Failed to access the Metadata store. Exception: "{}"'.format(
                    str(e)),
                file=sys.stderr)
            sys.stderr.flush()
            sleep(1)

    raise RuntimeError('Could not connect to the Metadata store.')
Exemplo n.º 5
0
 def setUp(self):
   super(PlaceholderUtilsTest, self).setUp()
   examples = [standard_artifacts.Examples()]
   examples[0].uri = "/tmp"
   examples[0].split_names = artifact_utils.encode_split_names(
       ["train", "eval"])
   serving_spec = infra_validator_pb2.ServingSpec()
   serving_spec.tensorflow_serving.tags.extend(["latest", "1.15.0-gpu"])
   self._resolution_context = placeholder_utils.ResolutionContext(
       exec_info=data_types.ExecutionInfo(
           input_dict={
               "model": [standard_artifacts.Model()],
               "examples": examples,
           },
           output_dict={"blessing": [standard_artifacts.ModelBlessing()]},
           exec_properties={
               "proto_property":
                   json_format.MessageToJson(
                       message=serving_spec,
                       sort_keys=True,
                       preserving_proto_field_name=True,
                       indent=0)
           },
           execution_output_uri="test_executor_output_uri",
           stateful_working_dir="test_stateful_working_dir",
           pipeline_node=pipeline_pb2.PipelineNode(
               node_info=pipeline_pb2.NodeInfo(
                   type=metadata_store_pb2.ExecutionType(
                       name="infra_validator"))),
           pipeline_info=pipeline_pb2.PipelineInfo(id="test_pipeline_id")))
Exemplo n.º 6
0
    def _prepare_execution_type(self, type_name: Text,
                                exec_properties: Dict[Text, Any]) -> int:
        """Get a execution type. Use existing type if available."""
        try:
            execution_type = self._store.get_execution_type(type_name)
            if execution_type is None:
                raise RuntimeError(
                    'Execution type is None for {}.'.format(type_name))
            return execution_type.id
        except tf.errors.NotFoundError:
            execution_type = metadata_store_pb2.ExecutionType(name=type_name)
            execution_type.properties['state'] = metadata_store_pb2.STRING
            for k in exec_properties.keys():
                execution_type.properties[k] = metadata_store_pb2.STRING
            # TODO(ruoyu): Find a better place / solution to the checksum logic.
            if 'module_file' in exec_properties:
                execution_type.properties[
                    'checksum_md5'] = metadata_store_pb2.STRING
            execution_type.properties[
                'pipeline_name'] = metadata_store_pb2.STRING
            execution_type.properties[
                'pipeline_root'] = metadata_store_pb2.STRING
            execution_type.properties['run_id'] = metadata_store_pb2.STRING
            execution_type.properties[
                'component_id'] = metadata_store_pb2.STRING

            return self._store.put_execution_type(execution_type)
Exemplo n.º 7
0
    def __init__(self,
                 phoenix_spec,
                 study_name,
                 study_owner,
                 optimization_goal="minimize",
                 optimization_metric="loss",
                 connection_config=None):
        """Initializes a new MLMD connection instance.

    Args:
      phoenix_spec: PhoenixSpec proto.
      study_name: The name of the study.
      study_owner: The owner (username) of the study.
      optimization_goal: minimize or maximize (string).
      optimization_metric: what metric are we optimizing (string).
      connection_config: a metadata_store_pb2.ConnectionConfig() proto. If None,
        we fall back on the flags above.
    """
        self._study_name = study_name
        self._study_owner = study_owner
        self._phoenix_spec = phoenix_spec
        self._optimization_goal = optimization_goal
        self._optimization_metric = optimization_metric

        self._connection_config = connection_config
        if not FLAGS.is_parsed():
            logging.error(
                "Flags are not parsed. Using default in file mlmd database."
                " Please run main with absl.app.run(main) to fix this. "
                "If running in distributed mode, this means that the "
                "trainers are not sharing information between one another.")
        if self._connection_config is None:
            if FLAGS.is_parsed() and FLAGS.mlmd_default_sqllite_filename:
                self._connection_config = metadata_store_pb2.ConnectionConfig()
                self._connection_config.sqlite.filename_uri = (
                    FLAGS.mlmd_default_sqllite_filename)
                self._connection_config.sqlite.connection_mode = 3
            elif FLAGS.is_parsed() and FLAGS.mlmd_socket:
                self._connection_config = metadata_store_pb2.ConnectionConfig()
                self._connection_config.mysql.socket = FLAGS.mlmd_socket
                self._connection_config.mysql.database = FLAGS.mlmd_database
                self._connection_config.mysql.user = FLAGS.mlmd_user
                self._connection_config.mysql.password = FLAGS.mlmd_password
            else:
                self._connection_config = metadata_store_pb2.ConnectionConfig()
                self._connection_config.sqlite.filename_uri = (
                    "/tmp/filedb-%d" % random.randint(0, 1000000))
                self._connection_config.sqlite.connection_mode = 3
        self._store = metadata_store.MetadataStore(self._connection_config)

        trial_type = metadata_store_pb2.ExecutionType()
        trial_type.name = "Trial"
        trial_type.properties["id"] = metadata_store_pb2.INT
        trial_type.properties["state"] = metadata_store_pb2.STRING
        trial_type.properties["serialized_data"] = metadata_store_pb2.STRING
        trial_type.properties["model_dir"] = metadata_store_pb2.STRING
        trial_type.properties["evaluation"] = metadata_store_pb2.STRING
        self._trial_type_id = self._store.put_execution_type(trial_type)
        self._trial_id_to_run_id = {}
Exemplo n.º 8
0
def _write_test_execution(mlmd_handle):
    execution_type = metadata_store_pb2.ExecutionType(name='foo',
                                                      version='bar')
    execution_type_id = mlmd_handle.store.put_execution_type(execution_type)
    [execution_id] = mlmd_handle.store.put_executions(
        [metadata_store_pb2.Execution(type_id=execution_type_id)])
    [execution] = mlmd_handle.store.get_executions_by_id([execution_id])
    return execution
Exemplo n.º 9
0
 def setUp(self):
     super(PlaceholderUtilsTest, self).setUp()
     examples = [standard_artifacts.Examples()]
     examples[0].uri = "/tmp"
     examples[0].split_names = artifact_utils.encode_split_names(
         ["train", "eval"])
     self._serving_spec = infra_validator_pb2.ServingSpec()
     self._serving_spec.tensorflow_serving.tags.extend(
         ["latest", "1.15.0-gpu"])
     self._resolution_context = placeholder_utils.ResolutionContext(
         exec_info=data_types.ExecutionInfo(
             input_dict={
                 "model": [standard_artifacts.Model()],
                 "examples": examples,
             },
             output_dict={"blessing": [standard_artifacts.ModelBlessing()]},
             exec_properties={
                 "proto_property":
                 proto_utils.proto_to_json(self._serving_spec)
             },
             execution_output_uri="test_executor_output_uri",
             stateful_working_dir="test_stateful_working_dir",
             pipeline_node=pipeline_pb2.PipelineNode(
                 node_info=pipeline_pb2.NodeInfo(
                     type=metadata_store_pb2.ExecutionType(
                         name="infra_validator"))),
             pipeline_info=pipeline_pb2.PipelineInfo(
                 id="test_pipeline_id")),
         executor_spec=executable_spec_pb2.PythonClassExecutableSpec(
             class_path="test_class_path"),
     )
     # Resolution context to simulate missing optional values.
     self._none_resolution_context = placeholder_utils.ResolutionContext(
         exec_info=data_types.ExecutionInfo(
             input_dict={},
             output_dict={},
             exec_properties={},
             pipeline_node=pipeline_pb2.PipelineNode(
                 node_info=pipeline_pb2.NodeInfo(
                     type=metadata_store_pb2.ExecutionType(
                         name="infra_validator"))),
             pipeline_info=pipeline_pb2.PipelineInfo(
                 id="test_pipeline_id")),
         executor_spec=None,
         platform_config=None)
Exemplo n.º 10
0
    def test_put_execution_with_context(self):
        store = _get_metadata_store()
        execution_type = metadata_store_pb2.ExecutionType()
        execution_type.name = self._get_test_type_name()
        execution_type_id = store.put_execution_type(execution_type)
        execution = metadata_store_pb2.Execution()
        execution.type_id = execution_type_id

        artifact_type = metadata_store_pb2.ArtifactType()
        artifact_type.name = self._get_test_type_name()
        artifact_type_id = store.put_artifact_type(artifact_type)
        input_artifact = metadata_store_pb2.Artifact()
        input_artifact.type_id = artifact_type_id
        output_artifact = metadata_store_pb2.Artifact()
        output_artifact.type_id = artifact_type_id
        output_event = metadata_store_pb2.Event()
        output_event.type = metadata_store_pb2.Event.DECLARED_INPUT

        context_type = metadata_store_pb2.ContextType()
        context_type.name = self._get_test_type_name()
        context_type_id = store.put_context_type(context_type)
        context = metadata_store_pb2.Context()
        context.type_id = context_type_id
        context_name = self._get_test_type_name()
        context.name = context_name

        execution_id, artifact_ids, context_ids = store.put_execution(
            execution, [[input_artifact], [output_artifact, output_event]],
            [context])

        # Test artifacts & events are correctly inserted.
        self.assertLen(artifact_ids, 2)
        events = store.get_events_by_execution_ids([execution_id])
        self.assertLen(events, 1)

        # Test the context is correctly inserted.
        got_contexts = store.get_contexts_by_id(context_ids)
        self.assertLen(context_ids, 1)
        self.assertLen(got_contexts, 1)

        # Test the association link between execution and the context is correct.
        contexts_by_execution_id = store.get_contexts_by_execution(
            execution_id)
        self.assertLen(contexts_by_execution_id, 1)
        self.assertEqual(contexts_by_execution_id[0].name, context_name)
        self.assertEqual(contexts_by_execution_id[0].type_id, context_type_id)
        executions_by_context = store.get_executions_by_context(context_ids[0])
        self.assertLen(executions_by_context, 1)

        # Test the attribution links between artifacts and the context are correct.
        contexts_by_artifact_id = store.get_contexts_by_artifact(
            artifact_ids[0])
        self.assertLen(contexts_by_artifact_id, 1)
        self.assertEqual(contexts_by_artifact_id[0].name, context_name)
        self.assertEqual(contexts_by_artifact_id[0].type_id, context_type_id)
        artifacts_by_context = store.get_artifacts_by_context(context_ids[0])
        self.assertLen(artifacts_by_context, 2)
Exemplo n.º 11
0
    def testPutExecutionGraph(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Prepares an input artifact. The artifact should be registered in MLMD
            # before the put_execution call.
            input_example = standard_artifacts.Examples()
            input_example.uri = 'example'
            input_example.type_id = common_utils.register_type_if_not_exist(
                m, input_example.artifact_type).id
            [input_example.id
             ] = m.store.put_artifacts([input_example.mlmd_artifact])
            # Prepares an output artifact.
            output_model = standard_artifacts.Model()
            output_model.uri = 'model'
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                exec_properties={
                    'p1': 1,
                    'p2': '2'
                },
                state=metadata_store_pb2.Execution.COMPLETE)
            contexts = self._generate_contexts(m)
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts={'example': [input_example]},
                output_artifacts={'model': [output_model]})

            self.assertProtoPartiallyEquals(
                output_model.mlmd_artifact,
                m.store.get_artifacts_by_id([output_model.id])[0],
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
            # Verifies edges between artifacts and execution.
            [input_event
             ] = m.store.get_events_by_artifact_ids([input_example.id])
            self.assertEqual(input_event.execution_id, execution.id)
            self.assertEqual(input_event.type, metadata_store_pb2.Event.INPUT)
            [output_event
             ] = m.store.get_events_by_artifact_ids([output_model.id])
            self.assertEqual(output_event.execution_id, execution.id)
            self.assertEqual(output_event.type,
                             metadata_store_pb2.Event.OUTPUT)
            # Verifies edges connecting contexts and {artifacts, execution}.
            context_ids = [context.id for context in contexts]
            self.assertCountEqual([
                c.id
                for c in m.store.get_contexts_by_artifact(input_example.id)
            ], context_ids)
            self.assertCountEqual([
                c.id for c in m.store.get_contexts_by_artifact(output_model.id)
            ], context_ids)
            self.assertCountEqual([
                c.id for c in m.store.get_contexts_by_execution(execution.id)
            ], context_ids)
Exemplo n.º 12
0
 def test_put_execution_type_get_execution_type(self):
   store = _get_metadata_store()
   execution_type = metadata_store_pb2.ExecutionType()
   execution_type.name = "test_type_1"
   execution_type.properties["foo"] = metadata_store_pb2.INT
   execution_type.properties["bar"] = metadata_store_pb2.STRING
   type_id = store.put_execution_type(execution_type)
   execution_type_result = store.get_execution_type("test_type_1")
   self.assertEqual(execution_type_result.id, type_id)
   self.assertEqual(execution_type_result.name, "test_type_1")
Exemplo n.º 13
0
def get_or_create_execution_type(store, type_name, properties: dict = None) -> metadata_store_pb2.ExecutionType:
    try:
        execution_type = store.get_execution_type(type_name=type_name)
        return execution_type
    except:
        execution_type = metadata_store_pb2.ExecutionType(
            name=type_name,
            properties=properties,
        )
        execution_type.id = store.put_execution_type(execution_type) # Returns ID
        return execution_type
Exemplo n.º 14
0
    def __init__(self,
                 phoenix_spec,
                 study_name,
                 study_owner,
                 optimization_goal="minimize",
                 optimization_metric="loss",
                 connection_config=None):
        """Initializes a new MLMD connection instance.

    Args:
      phoenix_spec: PhoenixSpec proto.
      study_name: The name of the study.
      study_owner: The owner (username) of the study.
      optimization_goal: minimize or maximize (string).
      optimization_metric: what metric are we optimizing (string).
      connection_config: a metadata_store_pb2.ConnectionConfig() proto. If None,
        we fall back on the flags above.
    """
        self._study_name = study_name
        self._study_owner = study_owner
        self._phoenix_spec = phoenix_spec
        self._optimization_goal = optimization_goal
        self._optimization_metric = optimization_metric

        self._connection_config = connection_config
        if self._connection_config is None:
            if FLAGS.mlmd_default_sqllite_filename:
                self._connection_config = metadata_store_pb2.ConnectionConfig()
                self._connection_config.sqlite.filename_uri = (
                    FLAGS.mlmd_default_sqllite_filename)
                self._connection_config.sqlite.connection_mode = 3
            elif FLAGS.mlmd_socket:
                self._connection_config = metadata_store_pb2.ConnectionConfig()
                self._connection_config.mysql.socket = FLAGS.mlmd_socket
                self._connection_config.mysql.database = FLAGS.mlmd_database
                self._connection_config.mysql.user = FLAGS.mlmd_user
                self._connection_config.mysql.password = FLAGS.mlmd_password
            else:
                self._connection_config = metadata_store_pb2.ConnectionConfig()
                self._connection_config.sqlite.filename_uri = (
                    "/tmp/filedb-%d" % random.randint(0, 1000000))
                self._connection_config.sqlite.connection_mode = 3
        self._store = metadata_store.MetadataStore(self._connection_config)

        trial_type = metadata_store_pb2.ExecutionType()
        trial_type.name = "Trial"
        trial_type.properties["id"] = metadata_store_pb2.INT
        trial_type.properties["state"] = metadata_store_pb2.STRING
        trial_type.properties["serialized_data"] = metadata_store_pb2.STRING
        trial_type.properties["model_dir"] = metadata_store_pb2.STRING
        trial_type.properties["evaluation"] = metadata_store_pb2.STRING
        self._trial_type_id = self._store.put_execution_type(trial_type)
        self._trial_id_to_run_id = {}
Exemplo n.º 15
0
  def test_put_execution_type_with_update_get_execution_type(self):
    store = _get_metadata_store()
    execution_type = metadata_store_pb2.ExecutionType()
    execution_type.name = "test_type"
    execution_type.properties["foo"] = metadata_store_pb2.DOUBLE
    type_id = store.put_execution_type(execution_type)

    want_execution_type = metadata_store_pb2.ExecutionType()
    want_execution_type.id = type_id
    want_execution_type.name = "test_type"
    want_execution_type.properties["foo"] = metadata_store_pb2.DOUBLE
    want_execution_type.properties["new_property"] = metadata_store_pb2.INT
    store.put_execution_type(want_execution_type, can_add_fields=True)

    got_execution_type = store.get_execution_type("test_type")
    self.assertEqual(got_execution_type.id, type_id)
    self.assertEqual(got_execution_type.name, "test_type")
    self.assertEqual(got_execution_type.properties["foo"],
                     metadata_store_pb2.DOUBLE)
    self.assertEqual(got_execution_type.properties["new_property"],
                     metadata_store_pb2.INT)
Exemplo n.º 16
0
    def testGetExecutionsAssociatedWithAllContexts(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            self.assertLen(contexts, 2)

            # Create 2 executions and associate with one context each.
            execution1 = execution_lib.prepare_execution(
                m, metadata_store_pb2.ExecutionType(name='my_execution_type'),
                metadata_store_pb2.Execution.RUNNING)
            execution1 = execution_lib.put_execution(m, execution1,
                                                     [contexts[0]])
            execution2 = execution_lib.prepare_execution(
                m, metadata_store_pb2.ExecutionType(name='my_execution_type'),
                metadata_store_pb2.Execution.COMPLETE)
            execution2 = execution_lib.put_execution(m, execution2,
                                                     [contexts[1]])

            # Create another execution and associate with both contexts.
            execution3 = execution_lib.prepare_execution(
                m, metadata_store_pb2.ExecutionType(name='my_execution_type'),
                metadata_store_pb2.Execution.NEW)
            execution3 = execution_lib.put_execution(m, execution3, contexts)

            # Verify that the right executions are returned.
            with self.subTest(for_contexts=(0, )):
                executions = execution_lib.get_executions_associated_with_all_contexts(
                    m, [contexts[0]])
                self.assertCountEqual([execution1.id, execution3.id],
                                      [e.id for e in executions])
            with self.subTest(for_contexts=(1, )):
                executions = execution_lib.get_executions_associated_with_all_contexts(
                    m, [contexts[1]])
                self.assertCountEqual([execution2.id, execution3.id],
                                      [e.id for e in executions])
            with self.subTest(for_contexts=(0, 1)):
                executions = execution_lib.get_executions_associated_with_all_contexts(
                    m, contexts)
                self.assertCountEqual([execution3.id],
                                      [e.id for e in executions])
Exemplo n.º 17
0
 def _set_up_test_execution_info(self,
                                 input_dict=None,
                                 output_dict=None,
                                 exec_properties=None):
   return data_types.ExecutionInfo(
       input_dict=input_dict or {},
       output_dict=output_dict or {},
       exec_properties=exec_properties or {},
       execution_output_uri='/testing/executor/output/',
       stateful_working_dir='/testing/stateful/dir',
       pipeline_node=pipeline_pb2.PipelineNode(
           node_info=pipeline_pb2.NodeInfo(
               type=metadata_store_pb2.ExecutionType(name='Docker_executor'))),
       pipeline_info=pipeline_pb2.PipelineInfo(id='test_pipeline_id'))
Exemplo n.º 18
0
    def testGetArtifactsDict(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Create and shuffle a few artifacts. The shuffled order should be
            # retained in the output of `execution_lib.get_artifacts_dict`.
            input_examples = []
            for i in range(10):
                input_example = standard_artifacts.Examples()
                input_example.uri = 'example{}'.format(i)
                input_example.type_id = common_utils.register_type_if_not_exist(
                    m, input_example.artifact_type).id
                input_examples.append(input_example)
            random.shuffle(input_examples)
            output_models = []
            for i in range(8):
                output_model = standard_artifacts.Model()
                output_model.uri = 'model{}'.format(i)
                output_model.type_id = common_utils.register_type_if_not_exist(
                    m, output_model.artifact_type).id
                output_models.append(output_model)
            random.shuffle(output_models)
            m.store.put_artifacts([
                a.mlmd_artifact
                for a in itertools.chain(input_examples, output_models)
            ])
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                state=metadata_store_pb2.Execution.RUNNING)
            contexts = self._generate_contexts(m)
            input_artifacts_dict = {'examples': input_examples}
            output_artifacts_dict = {'model': output_models}
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts=input_artifacts_dict,
                output_artifacts=output_artifacts_dict)

            # Verify that the same artifacts are returned in the correct order.
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, metadata_store_pb2.Event.INPUT)
            self.assertCountEqual(['examples'], list(artifacts_dict.keys()))
            self.assertEqual([ex.uri for ex in input_examples],
                             [a.uri for a in artifacts_dict['examples']])
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, metadata_store_pb2.Event.OUTPUT)
            self.assertCountEqual(['model'], list(artifacts_dict.keys()))
            self.assertEqual([model.uri for model in output_models],
                             [a.uri for a in artifacts_dict['model']])
Exemplo n.º 19
0
  def test_put_events_no_artifact_id(self):
    # No execution_id throws the same error type, so we just test this.
    store = _get_metadata_store()
    execution_type = metadata_store_pb2.ExecutionType()
    execution_type.name = "execution_type"
    execution_type_id = store.put_execution_type(execution_type)
    execution = metadata_store_pb2.Execution()
    execution.type_id = execution_type_id
    [execution_id] = store.put_executions([execution])

    event = metadata_store_pb2.Event()
    event.type = metadata_store_pb2.Event.DECLARED_OUTPUT
    event.execution_id = execution_id
    with self.assertRaises(errors.InvalidArgumentError):
      store.put_events([event])
Exemplo n.º 20
0
 def test_put_executions_get_executions_by_id(self):
   store = _get_metadata_store()
   execution_type = metadata_store_pb2.ExecutionType()
   execution_type.name = "test_type_1"
   execution_type.properties["foo"] = metadata_store_pb2.INT
   execution_type.properties["bar"] = metadata_store_pb2.STRING
   type_id = store.put_execution_type(execution_type)
   execution = metadata_store_pb2.Execution()
   execution.type_id = type_id
   execution.properties["foo"].int_value = 3
   execution.properties["bar"].string_value = "Hello"
   [execution_id] = store.put_executions([execution])
   [execution_result] = store.get_executions_by_id([execution_id])
   self.assertEqual(execution_result.properties["bar"].string_value, "Hello")
   self.assertEqual(execution_result.properties["foo"].int_value, 3)
Exemplo n.º 21
0
 def testGetCachedOutputArtifactsForNodesWithNoOuput(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         cache_context = context_lib.register_context_if_not_exists(
             m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key')
         cached_output = cache_utils.get_cached_outputs(m, cache_context)
         # No succeed execution is associate with this context yet, so the cached
         # output is None.
         self.assertIsNone(cached_output)
         execution_one = execution_publish_utils.register_execution(
             m, metadata_store_pb2.ExecutionType(name='my_type'),
             [cache_context])
         execution_publish_utils.publish_succeeded_execution(
             m, execution_one.id, [cache_context])
         cached_output = cache_utils.get_cached_outputs(m, cache_context)
         # A succeed execution is associate with this context, so the cached
         # output is not None but an empty dict.
         self.assertIsNotNone(cached_output)
         self.assertEmpty(cached_output)
Exemplo n.º 22
0
  def test_update_execution_get_execution(self):
    store = _get_metadata_store()
    execution_type = metadata_store_pb2.ExecutionType()
    execution_type.name = "test_type_1"
    execution_type.properties["foo"] = metadata_store_pb2.INT
    execution_type.properties["bar"] = metadata_store_pb2.STRING
    type_id = store.put_execution_type(execution_type)
    execution = metadata_store_pb2.Execution()
    execution.type_id = type_id
    execution.properties["bar"].string_value = "Hello"

    [execution_id] = store.put_executions([execution])
    execution_2 = metadata_store_pb2.Execution()
    execution_2.id = execution_id
    execution_2.type_id = type_id
    execution_2.properties["foo"].int_value = 12
    execution_2.properties["bar"].string_value = "Goodbye"
    [execution_id_2] = store.put_executions([execution_2])
    self.assertEqual(execution_id, execution_id_2)

    [execution_result] = store.get_executions_by_id([execution_id])
    self.assertEqual(execution_result.properties["bar"].string_value, "Goodbye")
    self.assertEqual(execution_result.properties["foo"].int_value, 12)
Exemplo n.º 23
0
  def test_publish_execution(self):
    store = _get_metadata_store()
    execution_type = metadata_store_pb2.ExecutionType()
    execution_type.name = "execution_type"
    execution_type_id = store.put_execution_type(execution_type)
    execution = metadata_store_pb2.Execution()
    execution.type_id = execution_type_id

    artifact_type = metadata_store_pb2.ArtifactType()
    artifact_type.name = "artifact_type"
    artifact_type_id = store.put_artifact_type(artifact_type)
    input_artifact = metadata_store_pb2.Artifact()
    input_artifact.type_id = artifact_type_id
    output_artifact = metadata_store_pb2.Artifact()
    output_artifact.type_id = artifact_type_id
    output_event = metadata_store_pb2.Event()
    output_event.type = metadata_store_pb2.Event.DECLARED_INPUT

    execution_id, artifact_ids = store.put_execution(
        execution, [[input_artifact], [output_artifact, output_event]])
    self.assertLen(artifact_ids, 2)
    events = store.get_events_by_execution_ids([execution_id])
    self.assertLen(events, 1)
Exemplo n.º 24
0
    def _cache_and_publish(self,
                           existing_execution: metadata_store_pb2.Execution):
        """Updates MLMD."""
        cached_execution_contexts = self._get_cached_execution_contexts(
            existing_execution)
        # Check if there are any previous attempts to cache and publish.
        prev_cache_executions = (
            execution_lib.get_executions_associated_with_all_contexts(
                self._mlmd, contexts=cached_execution_contexts))
        if not prev_cache_executions:
            new_execution = execution_publish_utils.register_execution(
                self._mlmd,
                execution_type=metadata_store_pb2.ExecutionType(
                    id=existing_execution.type_id),
                contexts=cached_execution_contexts)
        else:
            if len(prev_cache_executions) > 1:
                logging.warning(
                    'More than one previous cache executions seen when attempting '
                    'reuse_node_outputs: %s', prev_cache_executions)

            if (prev_cache_executions[-1].last_known_state ==
                    metadata_store_pb2.Execution.CACHED):
                return
            else:
                new_execution = prev_cache_executions[-1]

        output_artifacts = execution_lib.get_artifacts_dict(
            self._mlmd,
            existing_execution.id,
            event_types=list(event_lib.VALID_OUTPUT_EVENT_TYPES))

        execution_publish_utils.publish_cached_execution(
            self._mlmd,
            contexts=cached_execution_contexts,
            execution_id=new_execution.id,
            output_artifacts=output_artifacts)
Exemplo n.º 25
0
    def testGetArtifactIdsForExecutionIdGroupedByEventType(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Register an input and output artifacts in MLMD.
            input_example = standard_artifacts.Examples()
            input_example.uri = 'example'
            input_example.type_id = common_utils.register_type_if_not_exist(
                m, input_example.artifact_type).id
            output_model = standard_artifacts.Model()
            output_model.uri = 'model'
            output_model.type_id = common_utils.register_type_if_not_exist(
                m, output_model.artifact_type).id
            [input_example.id, output_model.id] = m.store.put_artifacts(
                [input_example.mlmd_artifact, output_model.mlmd_artifact])
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                exec_properties={
                    'p1': 1,
                    'p2': '2'
                },
                state=metadata_store_pb2.Execution.COMPLETE)
            contexts = self._generate_contexts(m)
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts={'example': [input_example]},
                output_artifacts={'model': [output_model]})

            artifact_ids_by_event_type = (
                execution_lib.get_artifact_ids_by_event_type_for_execution_id(
                    m, execution.id))
            self.assertDictEqual(
                {
                    metadata_store_pb2.Event.INPUT: set([input_example.id]),
                    metadata_store_pb2.Event.OUTPUT: set([output_model.id]),
                }, artifact_ids_by_event_type)
Exemplo n.º 26
0
  def test_put_events_with_paths_same_artifact(self):
    store = _get_metadata_store()
    execution_type = metadata_store_pb2.ExecutionType()
    execution_type.name = "execution_type"
    execution_type_id = store.put_execution_type(execution_type)
    execution_0 = metadata_store_pb2.Execution()
    execution_0.type_id = execution_type_id
    execution_1 = metadata_store_pb2.Execution()
    execution_1.type_id = execution_type_id
    [execution_id_0,
     execution_id_1] = store.put_executions([execution_0, execution_1])
    artifact_type = metadata_store_pb2.ArtifactType()
    artifact_type.name = "artifact_type"
    artifact_type_id = store.put_artifact_type(artifact_type)
    artifact = metadata_store_pb2.Artifact()
    artifact.type_id = artifact_type_id
    [artifact_id] = store.put_artifacts([artifact])

    event_0 = metadata_store_pb2.Event()
    event_0.type = metadata_store_pb2.Event.DECLARED_INPUT
    event_0.artifact_id = artifact_id
    event_0.execution_id = execution_id_0
    event_0.path.steps.add().key = "ggg"

    event_1 = metadata_store_pb2.Event()
    event_1.type = metadata_store_pb2.Event.DECLARED_INPUT
    event_1.artifact_id = artifact_id
    event_1.execution_id = execution_id_1
    event_1.path.steps.add().key = "fff"

    store.put_events([event_0, event_1])
    [event_result_0,
     event_result_1] = store.get_events_by_artifact_ids([artifact_id])
    self.assertLen(event_result_0.path.steps, 1)
    self.assertEqual(event_result_0.path.steps[0].key, "ggg")
    self.assertLen(event_result_1.path.steps, 1)
    self.assertEqual(event_result_1.path.steps[0].key, "fff")
Exemplo n.º 27
0
 def testPrepareExecution(self):
     with metadata.Metadata(connection_config=self._connection_config) as m:
         execution_type = metadata_store_pb2.ExecutionType()
         text_format.Parse(
             """
       name: 'my_execution'
       properties {
         key: 'p2'
         value: STRING
       }
       """, execution_type)
         result = execution_lib.prepare_execution(
             m,
             execution_type,
             exec_properties={
                 'p1': 1,
                 'p2': '2'
             },
             state=metadata_store_pb2.Execution.COMPLETE)
         self.assertProtoEquals(
             """
       type_id: 1
       last_known_state: COMPLETE
       properties {
         key: 'p2'
         value {
           string_value: '2'
         }
       }
       custom_properties {
         key: 'p1'
         value {
           int_value: 1
         }
       }
       """, result)
Exemplo n.º 28
0
from tfx.proto.orchestration import pipeline_pb2
from tfx.utils import status as status_lib

from ml_metadata.proto import metadata_store_pb2

_ORCHESTRATOR_RESERVED_ID = '__ORCHESTRATOR__'
_PIPELINE_IR = 'pipeline_ir'
_STOP_INITIATED = 'stop_initiated'
_PIPELINE_RUN_ID = 'pipeline_run_id'
_PIPELINE_STATUS_CODE = 'pipeline_status_code'
_PIPELINE_STATUS_MSG = 'pipeline_status_msg'
_NODE_STOP_INITIATED_PREFIX = 'node_stop_initiated_'
_NODE_STATUS_CODE_PREFIX = 'node_status_code_'
_NODE_STATUS_MSG_PREFIX = 'node_status_msg_'
_ORCHESTRATOR_EXECUTION_TYPE = metadata_store_pb2.ExecutionType(
    name=_ORCHESTRATOR_RESERVED_ID,
    properties={_PIPELINE_IR: metadata_store_pb2.STRING})

_last_state_change_time_secs = -1.0
_state_change_time_lock = threading.Lock()


def record_state_change_time() -> None:
    """Records current time at the point of function call as state change time.

  This function may be called after any operation that changes pipeline state or
  node execution state that requires further processing in the next iteration of
  the orchestration loop. As an optimization, the orchestration loop can elide
  wait period in between iterations when such state change is detected.
  """
    global _last_state_change_time_secs
Exemplo n.º 29
0
    def _prepare_execution_type(self, type_name: Text,
                                exec_properties: Dict[Text, Any]) -> int:
        """Gets execution type given execution type name and properties.

    Uses existing type if schema is superset of what is needed. Otherwise tries
    to register new execution type.

    Args:
      type_name: the name of the execution type
      exec_properties: the execution properties included by the execution

    Returns:
      execution type id
    Raises:
      ValueError if new execution type conflicts with existing schema in MLMD.
    """
        try:
            existing_execution_type = self.store.get_execution_type(type_name)
            if existing_execution_type is None:
                raise RuntimeError('Execution type is None for %s.' %
                                   type_name)
            if all(k in existing_execution_type.properties
                   for k in exec_properties.keys()):
                return existing_execution_type.id
            else:
                raise tf.errors.NotFoundError(
                    None, None, 'No qualified execution type found.')
        except tf.errors.NotFoundError:
            execution_type = metadata_store_pb2.ExecutionType(name=type_name)
            execution_type.properties[
                _EXECUTION_TYPE_KEY_STATE] = metadata_store_pb2.STRING
            # If exec_properties contains new entries, execution type schema will be
            # updated in MLMD.
            for k in exec_properties.keys():
                assert k not in _EXECUTION_TYPE_RESERVED_KEYS, (
                    'execution properties with reserved key %s') % k
                execution_type.properties[k] = metadata_store_pb2.STRING
            # TODO(ruoyu): Find a better place / solution to the checksum logic.
            if 'module_file' in exec_properties:
                execution_type.properties[
                    _EXECUTION_TYPE_KEY_CHECKSUM] = metadata_store_pb2.STRING
            execution_type.properties[
                _EXECUTION_TYPE_KEY_PIPELINE_NAME] = metadata_store_pb2.STRING
            execution_type.properties[
                _EXECUTION_TYPE_KEY_PIPELINE_ROOT] = metadata_store_pb2.STRING
            execution_type.properties[
                _EXECUTION_TYPE_KEY_RUN_ID] = metadata_store_pb2.STRING
            execution_type.properties[
                _EXECUTION_TYPE_KEY_COMPONENT_ID] = metadata_store_pb2.STRING

            try:
                execution_type_id = self.store.put_execution_type(
                    execution_type=execution_type, can_add_fields=True)
                absl.logging.debug(
                    'Registering a new execution type with id %s.' %
                    execution_type_id)
                return execution_type_id
            except tf.errors.AlreadyExistsError:
                warning_str = (
                    'missing or modified key in exec_properties comparing with '
                    'existing execution type with the same type name. Existing type: '
                    '%s, New type: %s') % (existing_execution_type,
                                           execution_type)
                absl.logging.warning(warning_str)
                raise ValueError(warning_str)
Exemplo n.º 30
0
# Register artifact type to the Metadata Store
schema_artifact_type_id = store.put_artifact_type(schema_artifact_type)

print('Data artifact type:\n', data_artifact_type)
print('Schema artifact type:\n', schema_artifact_type)
print('Data artifact type ID:', data_artifact_type_id)
print('Schema artifact type ID:', schema_artifact_type_id)

# ## Register ExecutionType
#
# You will then create the execution types needed. For the simple setup, you will just declare one for the data validation component with a `state` property so you can record if the process is running or already completed.

# In[5]:

# Create ExecutionType for Data Validation component
dv_execution_type = metadata_store_pb2.ExecutionType()
dv_execution_type.name = 'Data Validation'
dv_execution_type.properties['state'] = metadata_store_pb2.STRING

# Register execution type to the Metadata Store
dv_execution_type_id = store.put_execution_type(dv_execution_type)

print('Data validation execution type:\n', dv_execution_type)
print('Data validation execution type ID:', dv_execution_type_id)

# ## Generate input artifact unit
#
# With the artifact types created, you can now create instances of those types. The cell below creates the artifact for the input dataset. This artifact is recorded in the metadata store through the `put_artifacts()` function. Again, it generates an `id` that can be used for reference.

# In[6]: