def setUp(self): super().setUp() # Set up MLMD connection. pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) with self._mlmd_connection as m: self._execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=metadata_store_pb2.ExecutionType( name='test_execution_type'), contexts=[], input_artifacts=[]) # Set up gRPC stub. port = portpicker.pick_unused_port() self.sidecar = execution_watcher.ExecutionWatcher( port, mlmd_connection=self._mlmd_connection, execution=self._execution, creds=grpc.local_server_credentials()) self.sidecar.start() self.stub = execution_watcher.generate_service_stub( self.sidecar.address, grpc.local_channel_credentials())
def test_put_events_get_events(self): store = _get_metadata_store() execution_type = metadata_store_pb2.ExecutionType() execution_type.name = self._get_test_type_name() execution_type_id = store.put_execution_type(execution_type) execution = metadata_store_pb2.Execution() execution.type_id = execution_type_id [execution_id] = store.put_executions([execution]) artifact_type = metadata_store_pb2.ArtifactType() artifact_type.name = self._get_test_type_name() artifact_type_id = store.put_artifact_type(artifact_type) artifact = metadata_store_pb2.Artifact() artifact.type_id = artifact_type_id [artifact_id] = store.put_artifacts([artifact]) event = metadata_store_pb2.Event() event.type = metadata_store_pb2.Event.DECLARED_OUTPUT event.artifact_id = artifact_id event.execution_id = execution_id store.put_events([event]) [event_result] = store.get_events_by_artifact_ids([artifact_id]) self.assertEqual(event_result.artifact_id, artifact_id) self.assertEqual(event_result.execution_id, execution_id) self.assertEqual(event_result.type, metadata_store_pb2.Event.DECLARED_OUTPUT) [event_result_2] = store.get_events_by_execution_ids([execution_id]) self.assertEqual(event_result_2.artifact_id, artifact_id) self.assertEqual(event_result_2.execution_id, execution_id) self.assertEqual(event_result_2.type, metadata_store_pb2.Event.DECLARED_OUTPUT)
def _put_execution_type(self) -> int: exec_type = metadata_store_pb2.ExecutionType() exec_type.name = results._TRAINER exec_type.properties[results.RUN_ID_KEY] = metadata_store_pb2.STRING exec_type.properties[results._HPARAMS] = metadata_store_pb2.STRING exec_type.properties[results._COMPONENT_ID] = metadata_store_pb2.STRING return self.store.put_execution_type(exec_type)
def connect_to_mlmd() -> metadata_store.MetadataStore: metadata_service_host = os.environ.get('METADATA_SERVICE_SERVICE_HOST', 'metadata-service') metadata_service_port = int( os.environ.get('METADATA_SERVICE_SERVICE_PORT', 8080)) mlmd_connection_config = metadata_store_pb2.MetadataStoreClientConfig( host=metadata_service_host, port=metadata_service_port, ) # Checking the connection to the Metadata store. for _ in range(100): try: mlmd_store = metadata_store.MetadataStore(mlmd_connection_config) # All get requests fail when the DB is empty, so we have to use a put request. # TODO: Replace with _ = mlmd_store.get_context_types() when https://github.com/google/ml-metadata/issues/28 is fixed _ = mlmd_store.put_execution_type( metadata_store_pb2.ExecutionType(name="DummyExecutionType", )) return mlmd_store except Exception as e: print( 'Failed to access the Metadata store. Exception: "{}"'.format( str(e)), file=sys.stderr) sys.stderr.flush() sleep(1) raise RuntimeError('Could not connect to the Metadata store.')
def setUp(self): super(PlaceholderUtilsTest, self).setUp() examples = [standard_artifacts.Examples()] examples[0].uri = "/tmp" examples[0].split_names = artifact_utils.encode_split_names( ["train", "eval"]) serving_spec = infra_validator_pb2.ServingSpec() serving_spec.tensorflow_serving.tags.extend(["latest", "1.15.0-gpu"]) self._resolution_context = placeholder_utils.ResolutionContext( exec_info=data_types.ExecutionInfo( input_dict={ "model": [standard_artifacts.Model()], "examples": examples, }, output_dict={"blessing": [standard_artifacts.ModelBlessing()]}, exec_properties={ "proto_property": json_format.MessageToJson( message=serving_spec, sort_keys=True, preserving_proto_field_name=True, indent=0) }, execution_output_uri="test_executor_output_uri", stateful_working_dir="test_stateful_working_dir", pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( type=metadata_store_pb2.ExecutionType( name="infra_validator"))), pipeline_info=pipeline_pb2.PipelineInfo(id="test_pipeline_id")))
def _prepare_execution_type(self, type_name: Text, exec_properties: Dict[Text, Any]) -> int: """Get a execution type. Use existing type if available.""" try: execution_type = self._store.get_execution_type(type_name) if execution_type is None: raise RuntimeError( 'Execution type is None for {}.'.format(type_name)) return execution_type.id except tf.errors.NotFoundError: execution_type = metadata_store_pb2.ExecutionType(name=type_name) execution_type.properties['state'] = metadata_store_pb2.STRING for k in exec_properties.keys(): execution_type.properties[k] = metadata_store_pb2.STRING # TODO(ruoyu): Find a better place / solution to the checksum logic. if 'module_file' in exec_properties: execution_type.properties[ 'checksum_md5'] = metadata_store_pb2.STRING execution_type.properties[ 'pipeline_name'] = metadata_store_pb2.STRING execution_type.properties[ 'pipeline_root'] = metadata_store_pb2.STRING execution_type.properties['run_id'] = metadata_store_pb2.STRING execution_type.properties[ 'component_id'] = metadata_store_pb2.STRING return self._store.put_execution_type(execution_type)
def __init__(self, phoenix_spec, study_name, study_owner, optimization_goal="minimize", optimization_metric="loss", connection_config=None): """Initializes a new MLMD connection instance. Args: phoenix_spec: PhoenixSpec proto. study_name: The name of the study. study_owner: The owner (username) of the study. optimization_goal: minimize or maximize (string). optimization_metric: what metric are we optimizing (string). connection_config: a metadata_store_pb2.ConnectionConfig() proto. If None, we fall back on the flags above. """ self._study_name = study_name self._study_owner = study_owner self._phoenix_spec = phoenix_spec self._optimization_goal = optimization_goal self._optimization_metric = optimization_metric self._connection_config = connection_config if not FLAGS.is_parsed(): logging.error( "Flags are not parsed. Using default in file mlmd database." " Please run main with absl.app.run(main) to fix this. " "If running in distributed mode, this means that the " "trainers are not sharing information between one another.") if self._connection_config is None: if FLAGS.is_parsed() and FLAGS.mlmd_default_sqllite_filename: self._connection_config = metadata_store_pb2.ConnectionConfig() self._connection_config.sqlite.filename_uri = ( FLAGS.mlmd_default_sqllite_filename) self._connection_config.sqlite.connection_mode = 3 elif FLAGS.is_parsed() and FLAGS.mlmd_socket: self._connection_config = metadata_store_pb2.ConnectionConfig() self._connection_config.mysql.socket = FLAGS.mlmd_socket self._connection_config.mysql.database = FLAGS.mlmd_database self._connection_config.mysql.user = FLAGS.mlmd_user self._connection_config.mysql.password = FLAGS.mlmd_password else: self._connection_config = metadata_store_pb2.ConnectionConfig() self._connection_config.sqlite.filename_uri = ( "/tmp/filedb-%d" % random.randint(0, 1000000)) self._connection_config.sqlite.connection_mode = 3 self._store = metadata_store.MetadataStore(self._connection_config) trial_type = metadata_store_pb2.ExecutionType() trial_type.name = "Trial" trial_type.properties["id"] = metadata_store_pb2.INT trial_type.properties["state"] = metadata_store_pb2.STRING trial_type.properties["serialized_data"] = metadata_store_pb2.STRING trial_type.properties["model_dir"] = metadata_store_pb2.STRING trial_type.properties["evaluation"] = metadata_store_pb2.STRING self._trial_type_id = self._store.put_execution_type(trial_type) self._trial_id_to_run_id = {}
def _write_test_execution(mlmd_handle): execution_type = metadata_store_pb2.ExecutionType(name='foo', version='bar') execution_type_id = mlmd_handle.store.put_execution_type(execution_type) [execution_id] = mlmd_handle.store.put_executions( [metadata_store_pb2.Execution(type_id=execution_type_id)]) [execution] = mlmd_handle.store.get_executions_by_id([execution_id]) return execution
def setUp(self): super(PlaceholderUtilsTest, self).setUp() examples = [standard_artifacts.Examples()] examples[0].uri = "/tmp" examples[0].split_names = artifact_utils.encode_split_names( ["train", "eval"]) self._serving_spec = infra_validator_pb2.ServingSpec() self._serving_spec.tensorflow_serving.tags.extend( ["latest", "1.15.0-gpu"]) self._resolution_context = placeholder_utils.ResolutionContext( exec_info=data_types.ExecutionInfo( input_dict={ "model": [standard_artifacts.Model()], "examples": examples, }, output_dict={"blessing": [standard_artifacts.ModelBlessing()]}, exec_properties={ "proto_property": proto_utils.proto_to_json(self._serving_spec) }, execution_output_uri="test_executor_output_uri", stateful_working_dir="test_stateful_working_dir", pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( type=metadata_store_pb2.ExecutionType( name="infra_validator"))), pipeline_info=pipeline_pb2.PipelineInfo( id="test_pipeline_id")), executor_spec=executable_spec_pb2.PythonClassExecutableSpec( class_path="test_class_path"), ) # Resolution context to simulate missing optional values. self._none_resolution_context = placeholder_utils.ResolutionContext( exec_info=data_types.ExecutionInfo( input_dict={}, output_dict={}, exec_properties={}, pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( type=metadata_store_pb2.ExecutionType( name="infra_validator"))), pipeline_info=pipeline_pb2.PipelineInfo( id="test_pipeline_id")), executor_spec=None, platform_config=None)
def test_put_execution_with_context(self): store = _get_metadata_store() execution_type = metadata_store_pb2.ExecutionType() execution_type.name = self._get_test_type_name() execution_type_id = store.put_execution_type(execution_type) execution = metadata_store_pb2.Execution() execution.type_id = execution_type_id artifact_type = metadata_store_pb2.ArtifactType() artifact_type.name = self._get_test_type_name() artifact_type_id = store.put_artifact_type(artifact_type) input_artifact = metadata_store_pb2.Artifact() input_artifact.type_id = artifact_type_id output_artifact = metadata_store_pb2.Artifact() output_artifact.type_id = artifact_type_id output_event = metadata_store_pb2.Event() output_event.type = metadata_store_pb2.Event.DECLARED_INPUT context_type = metadata_store_pb2.ContextType() context_type.name = self._get_test_type_name() context_type_id = store.put_context_type(context_type) context = metadata_store_pb2.Context() context.type_id = context_type_id context_name = self._get_test_type_name() context.name = context_name execution_id, artifact_ids, context_ids = store.put_execution( execution, [[input_artifact], [output_artifact, output_event]], [context]) # Test artifacts & events are correctly inserted. self.assertLen(artifact_ids, 2) events = store.get_events_by_execution_ids([execution_id]) self.assertLen(events, 1) # Test the context is correctly inserted. got_contexts = store.get_contexts_by_id(context_ids) self.assertLen(context_ids, 1) self.assertLen(got_contexts, 1) # Test the association link between execution and the context is correct. contexts_by_execution_id = store.get_contexts_by_execution( execution_id) self.assertLen(contexts_by_execution_id, 1) self.assertEqual(contexts_by_execution_id[0].name, context_name) self.assertEqual(contexts_by_execution_id[0].type_id, context_type_id) executions_by_context = store.get_executions_by_context(context_ids[0]) self.assertLen(executions_by_context, 1) # Test the attribution links between artifacts and the context are correct. contexts_by_artifact_id = store.get_contexts_by_artifact( artifact_ids[0]) self.assertLen(contexts_by_artifact_id, 1) self.assertEqual(contexts_by_artifact_id[0].name, context_name) self.assertEqual(contexts_by_artifact_id[0].type_id, context_type_id) artifacts_by_context = store.get_artifacts_by_context(context_ids[0]) self.assertLen(artifacts_by_context, 2)
def testPutExecutionGraph(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Prepares an input artifact. The artifact should be registered in MLMD # before the put_execution call. input_example = standard_artifacts.Examples() input_example.uri = 'example' input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id [input_example.id ] = m.store.put_artifacts([input_example.mlmd_artifact]) # Prepares an output artifact. output_model = standard_artifacts.Model() output_model.uri = 'model' execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), exec_properties={ 'p1': 1, 'p2': '2' }, state=metadata_store_pb2.Execution.COMPLETE) contexts = self._generate_contexts(m) execution = execution_lib.put_execution( m, execution, contexts, input_artifacts={'example': [input_example]}, output_artifacts={'model': [output_model]}) self.assertProtoPartiallyEquals( output_model.mlmd_artifact, m.store.get_artifacts_by_id([output_model.id])[0], ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) # Verifies edges between artifacts and execution. [input_event ] = m.store.get_events_by_artifact_ids([input_example.id]) self.assertEqual(input_event.execution_id, execution.id) self.assertEqual(input_event.type, metadata_store_pb2.Event.INPUT) [output_event ] = m.store.get_events_by_artifact_ids([output_model.id]) self.assertEqual(output_event.execution_id, execution.id) self.assertEqual(output_event.type, metadata_store_pb2.Event.OUTPUT) # Verifies edges connecting contexts and {artifacts, execution}. context_ids = [context.id for context in contexts] self.assertCountEqual([ c.id for c in m.store.get_contexts_by_artifact(input_example.id) ], context_ids) self.assertCountEqual([ c.id for c in m.store.get_contexts_by_artifact(output_model.id) ], context_ids) self.assertCountEqual([ c.id for c in m.store.get_contexts_by_execution(execution.id) ], context_ids)
def test_put_execution_type_get_execution_type(self): store = _get_metadata_store() execution_type = metadata_store_pb2.ExecutionType() execution_type.name = "test_type_1" execution_type.properties["foo"] = metadata_store_pb2.INT execution_type.properties["bar"] = metadata_store_pb2.STRING type_id = store.put_execution_type(execution_type) execution_type_result = store.get_execution_type("test_type_1") self.assertEqual(execution_type_result.id, type_id) self.assertEqual(execution_type_result.name, "test_type_1")
def get_or_create_execution_type(store, type_name, properties: dict = None) -> metadata_store_pb2.ExecutionType: try: execution_type = store.get_execution_type(type_name=type_name) return execution_type except: execution_type = metadata_store_pb2.ExecutionType( name=type_name, properties=properties, ) execution_type.id = store.put_execution_type(execution_type) # Returns ID return execution_type
def __init__(self, phoenix_spec, study_name, study_owner, optimization_goal="minimize", optimization_metric="loss", connection_config=None): """Initializes a new MLMD connection instance. Args: phoenix_spec: PhoenixSpec proto. study_name: The name of the study. study_owner: The owner (username) of the study. optimization_goal: minimize or maximize (string). optimization_metric: what metric are we optimizing (string). connection_config: a metadata_store_pb2.ConnectionConfig() proto. If None, we fall back on the flags above. """ self._study_name = study_name self._study_owner = study_owner self._phoenix_spec = phoenix_spec self._optimization_goal = optimization_goal self._optimization_metric = optimization_metric self._connection_config = connection_config if self._connection_config is None: if FLAGS.mlmd_default_sqllite_filename: self._connection_config = metadata_store_pb2.ConnectionConfig() self._connection_config.sqlite.filename_uri = ( FLAGS.mlmd_default_sqllite_filename) self._connection_config.sqlite.connection_mode = 3 elif FLAGS.mlmd_socket: self._connection_config = metadata_store_pb2.ConnectionConfig() self._connection_config.mysql.socket = FLAGS.mlmd_socket self._connection_config.mysql.database = FLAGS.mlmd_database self._connection_config.mysql.user = FLAGS.mlmd_user self._connection_config.mysql.password = FLAGS.mlmd_password else: self._connection_config = metadata_store_pb2.ConnectionConfig() self._connection_config.sqlite.filename_uri = ( "/tmp/filedb-%d" % random.randint(0, 1000000)) self._connection_config.sqlite.connection_mode = 3 self._store = metadata_store.MetadataStore(self._connection_config) trial_type = metadata_store_pb2.ExecutionType() trial_type.name = "Trial" trial_type.properties["id"] = metadata_store_pb2.INT trial_type.properties["state"] = metadata_store_pb2.STRING trial_type.properties["serialized_data"] = metadata_store_pb2.STRING trial_type.properties["model_dir"] = metadata_store_pb2.STRING trial_type.properties["evaluation"] = metadata_store_pb2.STRING self._trial_type_id = self._store.put_execution_type(trial_type) self._trial_id_to_run_id = {}
def test_put_execution_type_with_update_get_execution_type(self): store = _get_metadata_store() execution_type = metadata_store_pb2.ExecutionType() execution_type.name = "test_type" execution_type.properties["foo"] = metadata_store_pb2.DOUBLE type_id = store.put_execution_type(execution_type) want_execution_type = metadata_store_pb2.ExecutionType() want_execution_type.id = type_id want_execution_type.name = "test_type" want_execution_type.properties["foo"] = metadata_store_pb2.DOUBLE want_execution_type.properties["new_property"] = metadata_store_pb2.INT store.put_execution_type(want_execution_type, can_add_fields=True) got_execution_type = store.get_execution_type("test_type") self.assertEqual(got_execution_type.id, type_id) self.assertEqual(got_execution_type.name, "test_type") self.assertEqual(got_execution_type.properties["foo"], metadata_store_pb2.DOUBLE) self.assertEqual(got_execution_type.properties["new_property"], metadata_store_pb2.INT)
def testGetExecutionsAssociatedWithAllContexts(self): with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) self.assertLen(contexts, 2) # Create 2 executions and associate with one context each. execution1 = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), metadata_store_pb2.Execution.RUNNING) execution1 = execution_lib.put_execution(m, execution1, [contexts[0]]) execution2 = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), metadata_store_pb2.Execution.COMPLETE) execution2 = execution_lib.put_execution(m, execution2, [contexts[1]]) # Create another execution and associate with both contexts. execution3 = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), metadata_store_pb2.Execution.NEW) execution3 = execution_lib.put_execution(m, execution3, contexts) # Verify that the right executions are returned. with self.subTest(for_contexts=(0, )): executions = execution_lib.get_executions_associated_with_all_contexts( m, [contexts[0]]) self.assertCountEqual([execution1.id, execution3.id], [e.id for e in executions]) with self.subTest(for_contexts=(1, )): executions = execution_lib.get_executions_associated_with_all_contexts( m, [contexts[1]]) self.assertCountEqual([execution2.id, execution3.id], [e.id for e in executions]) with self.subTest(for_contexts=(0, 1)): executions = execution_lib.get_executions_associated_with_all_contexts( m, contexts) self.assertCountEqual([execution3.id], [e.id for e in executions])
def _set_up_test_execution_info(self, input_dict=None, output_dict=None, exec_properties=None): return data_types.ExecutionInfo( input_dict=input_dict or {}, output_dict=output_dict or {}, exec_properties=exec_properties or {}, execution_output_uri='/testing/executor/output/', stateful_working_dir='/testing/stateful/dir', pipeline_node=pipeline_pb2.PipelineNode( node_info=pipeline_pb2.NodeInfo( type=metadata_store_pb2.ExecutionType(name='Docker_executor'))), pipeline_info=pipeline_pb2.PipelineInfo(id='test_pipeline_id'))
def testGetArtifactsDict(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Create and shuffle a few artifacts. The shuffled order should be # retained in the output of `execution_lib.get_artifacts_dict`. input_examples = [] for i in range(10): input_example = standard_artifacts.Examples() input_example.uri = 'example{}'.format(i) input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id input_examples.append(input_example) random.shuffle(input_examples) output_models = [] for i in range(8): output_model = standard_artifacts.Model() output_model.uri = 'model{}'.format(i) output_model.type_id = common_utils.register_type_if_not_exist( m, output_model.artifact_type).id output_models.append(output_model) random.shuffle(output_models) m.store.put_artifacts([ a.mlmd_artifact for a in itertools.chain(input_examples, output_models) ]) execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), state=metadata_store_pb2.Execution.RUNNING) contexts = self._generate_contexts(m) input_artifacts_dict = {'examples': input_examples} output_artifacts_dict = {'model': output_models} execution = execution_lib.put_execution( m, execution, contexts, input_artifacts=input_artifacts_dict, output_artifacts=output_artifacts_dict) # Verify that the same artifacts are returned in the correct order. artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, metadata_store_pb2.Event.INPUT) self.assertCountEqual(['examples'], list(artifacts_dict.keys())) self.assertEqual([ex.uri for ex in input_examples], [a.uri for a in artifacts_dict['examples']]) artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, metadata_store_pb2.Event.OUTPUT) self.assertCountEqual(['model'], list(artifacts_dict.keys())) self.assertEqual([model.uri for model in output_models], [a.uri for a in artifacts_dict['model']])
def test_put_events_no_artifact_id(self): # No execution_id throws the same error type, so we just test this. store = _get_metadata_store() execution_type = metadata_store_pb2.ExecutionType() execution_type.name = "execution_type" execution_type_id = store.put_execution_type(execution_type) execution = metadata_store_pb2.Execution() execution.type_id = execution_type_id [execution_id] = store.put_executions([execution]) event = metadata_store_pb2.Event() event.type = metadata_store_pb2.Event.DECLARED_OUTPUT event.execution_id = execution_id with self.assertRaises(errors.InvalidArgumentError): store.put_events([event])
def test_put_executions_get_executions_by_id(self): store = _get_metadata_store() execution_type = metadata_store_pb2.ExecutionType() execution_type.name = "test_type_1" execution_type.properties["foo"] = metadata_store_pb2.INT execution_type.properties["bar"] = metadata_store_pb2.STRING type_id = store.put_execution_type(execution_type) execution = metadata_store_pb2.Execution() execution.type_id = type_id execution.properties["foo"].int_value = 3 execution.properties["bar"].string_value = "Hello" [execution_id] = store.put_executions([execution]) [execution_result] = store.get_executions_by_id([execution_id]) self.assertEqual(execution_result.properties["bar"].string_value, "Hello") self.assertEqual(execution_result.properties["foo"].int_value, 3)
def testGetCachedOutputArtifactsForNodesWithNoOuput(self): with metadata.Metadata(connection_config=self._connection_config) as m: cache_context = context_lib.register_context_if_not_exists( m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key') cached_output = cache_utils.get_cached_outputs(m, cache_context) # No succeed execution is associate with this context yet, so the cached # output is None. self.assertIsNone(cached_output) execution_one = execution_publish_utils.register_execution( m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context]) execution_publish_utils.publish_succeeded_execution( m, execution_one.id, [cache_context]) cached_output = cache_utils.get_cached_outputs(m, cache_context) # A succeed execution is associate with this context, so the cached # output is not None but an empty dict. self.assertIsNotNone(cached_output) self.assertEmpty(cached_output)
def test_update_execution_get_execution(self): store = _get_metadata_store() execution_type = metadata_store_pb2.ExecutionType() execution_type.name = "test_type_1" execution_type.properties["foo"] = metadata_store_pb2.INT execution_type.properties["bar"] = metadata_store_pb2.STRING type_id = store.put_execution_type(execution_type) execution = metadata_store_pb2.Execution() execution.type_id = type_id execution.properties["bar"].string_value = "Hello" [execution_id] = store.put_executions([execution]) execution_2 = metadata_store_pb2.Execution() execution_2.id = execution_id execution_2.type_id = type_id execution_2.properties["foo"].int_value = 12 execution_2.properties["bar"].string_value = "Goodbye" [execution_id_2] = store.put_executions([execution_2]) self.assertEqual(execution_id, execution_id_2) [execution_result] = store.get_executions_by_id([execution_id]) self.assertEqual(execution_result.properties["bar"].string_value, "Goodbye") self.assertEqual(execution_result.properties["foo"].int_value, 12)
def test_publish_execution(self): store = _get_metadata_store() execution_type = metadata_store_pb2.ExecutionType() execution_type.name = "execution_type" execution_type_id = store.put_execution_type(execution_type) execution = metadata_store_pb2.Execution() execution.type_id = execution_type_id artifact_type = metadata_store_pb2.ArtifactType() artifact_type.name = "artifact_type" artifact_type_id = store.put_artifact_type(artifact_type) input_artifact = metadata_store_pb2.Artifact() input_artifact.type_id = artifact_type_id output_artifact = metadata_store_pb2.Artifact() output_artifact.type_id = artifact_type_id output_event = metadata_store_pb2.Event() output_event.type = metadata_store_pb2.Event.DECLARED_INPUT execution_id, artifact_ids = store.put_execution( execution, [[input_artifact], [output_artifact, output_event]]) self.assertLen(artifact_ids, 2) events = store.get_events_by_execution_ids([execution_id]) self.assertLen(events, 1)
def _cache_and_publish(self, existing_execution: metadata_store_pb2.Execution): """Updates MLMD.""" cached_execution_contexts = self._get_cached_execution_contexts( existing_execution) # Check if there are any previous attempts to cache and publish. prev_cache_executions = ( execution_lib.get_executions_associated_with_all_contexts( self._mlmd, contexts=cached_execution_contexts)) if not prev_cache_executions: new_execution = execution_publish_utils.register_execution( self._mlmd, execution_type=metadata_store_pb2.ExecutionType( id=existing_execution.type_id), contexts=cached_execution_contexts) else: if len(prev_cache_executions) > 1: logging.warning( 'More than one previous cache executions seen when attempting ' 'reuse_node_outputs: %s', prev_cache_executions) if (prev_cache_executions[-1].last_known_state == metadata_store_pb2.Execution.CACHED): return else: new_execution = prev_cache_executions[-1] output_artifacts = execution_lib.get_artifacts_dict( self._mlmd, existing_execution.id, event_types=list(event_lib.VALID_OUTPUT_EVENT_TYPES)) execution_publish_utils.publish_cached_execution( self._mlmd, contexts=cached_execution_contexts, execution_id=new_execution.id, output_artifacts=output_artifacts)
def testGetArtifactIdsForExecutionIdGroupedByEventType(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Register an input and output artifacts in MLMD. input_example = standard_artifacts.Examples() input_example.uri = 'example' input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id output_model = standard_artifacts.Model() output_model.uri = 'model' output_model.type_id = common_utils.register_type_if_not_exist( m, output_model.artifact_type).id [input_example.id, output_model.id] = m.store.put_artifacts( [input_example.mlmd_artifact, output_model.mlmd_artifact]) execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), exec_properties={ 'p1': 1, 'p2': '2' }, state=metadata_store_pb2.Execution.COMPLETE) contexts = self._generate_contexts(m) execution = execution_lib.put_execution( m, execution, contexts, input_artifacts={'example': [input_example]}, output_artifacts={'model': [output_model]}) artifact_ids_by_event_type = ( execution_lib.get_artifact_ids_by_event_type_for_execution_id( m, execution.id)) self.assertDictEqual( { metadata_store_pb2.Event.INPUT: set([input_example.id]), metadata_store_pb2.Event.OUTPUT: set([output_model.id]), }, artifact_ids_by_event_type)
def test_put_events_with_paths_same_artifact(self): store = _get_metadata_store() execution_type = metadata_store_pb2.ExecutionType() execution_type.name = "execution_type" execution_type_id = store.put_execution_type(execution_type) execution_0 = metadata_store_pb2.Execution() execution_0.type_id = execution_type_id execution_1 = metadata_store_pb2.Execution() execution_1.type_id = execution_type_id [execution_id_0, execution_id_1] = store.put_executions([execution_0, execution_1]) artifact_type = metadata_store_pb2.ArtifactType() artifact_type.name = "artifact_type" artifact_type_id = store.put_artifact_type(artifact_type) artifact = metadata_store_pb2.Artifact() artifact.type_id = artifact_type_id [artifact_id] = store.put_artifacts([artifact]) event_0 = metadata_store_pb2.Event() event_0.type = metadata_store_pb2.Event.DECLARED_INPUT event_0.artifact_id = artifact_id event_0.execution_id = execution_id_0 event_0.path.steps.add().key = "ggg" event_1 = metadata_store_pb2.Event() event_1.type = metadata_store_pb2.Event.DECLARED_INPUT event_1.artifact_id = artifact_id event_1.execution_id = execution_id_1 event_1.path.steps.add().key = "fff" store.put_events([event_0, event_1]) [event_result_0, event_result_1] = store.get_events_by_artifact_ids([artifact_id]) self.assertLen(event_result_0.path.steps, 1) self.assertEqual(event_result_0.path.steps[0].key, "ggg") self.assertLen(event_result_1.path.steps, 1) self.assertEqual(event_result_1.path.steps[0].key, "fff")
def testPrepareExecution(self): with metadata.Metadata(connection_config=self._connection_config) as m: execution_type = metadata_store_pb2.ExecutionType() text_format.Parse( """ name: 'my_execution' properties { key: 'p2' value: STRING } """, execution_type) result = execution_lib.prepare_execution( m, execution_type, exec_properties={ 'p1': 1, 'p2': '2' }, state=metadata_store_pb2.Execution.COMPLETE) self.assertProtoEquals( """ type_id: 1 last_known_state: COMPLETE properties { key: 'p2' value { string_value: '2' } } custom_properties { key: 'p1' value { int_value: 1 } } """, result)
from tfx.proto.orchestration import pipeline_pb2 from tfx.utils import status as status_lib from ml_metadata.proto import metadata_store_pb2 _ORCHESTRATOR_RESERVED_ID = '__ORCHESTRATOR__' _PIPELINE_IR = 'pipeline_ir' _STOP_INITIATED = 'stop_initiated' _PIPELINE_RUN_ID = 'pipeline_run_id' _PIPELINE_STATUS_CODE = 'pipeline_status_code' _PIPELINE_STATUS_MSG = 'pipeline_status_msg' _NODE_STOP_INITIATED_PREFIX = 'node_stop_initiated_' _NODE_STATUS_CODE_PREFIX = 'node_status_code_' _NODE_STATUS_MSG_PREFIX = 'node_status_msg_' _ORCHESTRATOR_EXECUTION_TYPE = metadata_store_pb2.ExecutionType( name=_ORCHESTRATOR_RESERVED_ID, properties={_PIPELINE_IR: metadata_store_pb2.STRING}) _last_state_change_time_secs = -1.0 _state_change_time_lock = threading.Lock() def record_state_change_time() -> None: """Records current time at the point of function call as state change time. This function may be called after any operation that changes pipeline state or node execution state that requires further processing in the next iteration of the orchestration loop. As an optimization, the orchestration loop can elide wait period in between iterations when such state change is detected. """ global _last_state_change_time_secs
def _prepare_execution_type(self, type_name: Text, exec_properties: Dict[Text, Any]) -> int: """Gets execution type given execution type name and properties. Uses existing type if schema is superset of what is needed. Otherwise tries to register new execution type. Args: type_name: the name of the execution type exec_properties: the execution properties included by the execution Returns: execution type id Raises: ValueError if new execution type conflicts with existing schema in MLMD. """ try: existing_execution_type = self.store.get_execution_type(type_name) if existing_execution_type is None: raise RuntimeError('Execution type is None for %s.' % type_name) if all(k in existing_execution_type.properties for k in exec_properties.keys()): return existing_execution_type.id else: raise tf.errors.NotFoundError( None, None, 'No qualified execution type found.') except tf.errors.NotFoundError: execution_type = metadata_store_pb2.ExecutionType(name=type_name) execution_type.properties[ _EXECUTION_TYPE_KEY_STATE] = metadata_store_pb2.STRING # If exec_properties contains new entries, execution type schema will be # updated in MLMD. for k in exec_properties.keys(): assert k not in _EXECUTION_TYPE_RESERVED_KEYS, ( 'execution properties with reserved key %s') % k execution_type.properties[k] = metadata_store_pb2.STRING # TODO(ruoyu): Find a better place / solution to the checksum logic. if 'module_file' in exec_properties: execution_type.properties[ _EXECUTION_TYPE_KEY_CHECKSUM] = metadata_store_pb2.STRING execution_type.properties[ _EXECUTION_TYPE_KEY_PIPELINE_NAME] = metadata_store_pb2.STRING execution_type.properties[ _EXECUTION_TYPE_KEY_PIPELINE_ROOT] = metadata_store_pb2.STRING execution_type.properties[ _EXECUTION_TYPE_KEY_RUN_ID] = metadata_store_pb2.STRING execution_type.properties[ _EXECUTION_TYPE_KEY_COMPONENT_ID] = metadata_store_pb2.STRING try: execution_type_id = self.store.put_execution_type( execution_type=execution_type, can_add_fields=True) absl.logging.debug( 'Registering a new execution type with id %s.' % execution_type_id) return execution_type_id except tf.errors.AlreadyExistsError: warning_str = ( 'missing or modified key in exec_properties comparing with ' 'existing execution type with the same type name. Existing type: ' '%s, New type: %s') % (existing_execution_type, execution_type) absl.logging.warning(warning_str) raise ValueError(warning_str)
# Register artifact type to the Metadata Store schema_artifact_type_id = store.put_artifact_type(schema_artifact_type) print('Data artifact type:\n', data_artifact_type) print('Schema artifact type:\n', schema_artifact_type) print('Data artifact type ID:', data_artifact_type_id) print('Schema artifact type ID:', schema_artifact_type_id) # ## Register ExecutionType # # You will then create the execution types needed. For the simple setup, you will just declare one for the data validation component with a `state` property so you can record if the process is running or already completed. # In[5]: # Create ExecutionType for Data Validation component dv_execution_type = metadata_store_pb2.ExecutionType() dv_execution_type.name = 'Data Validation' dv_execution_type.properties['state'] = metadata_store_pb2.STRING # Register execution type to the Metadata Store dv_execution_type_id = store.put_execution_type(dv_execution_type) print('Data validation execution type:\n', dv_execution_type) print('Data validation execution type ID:', dv_execution_type_id) # ## Generate input artifact unit # # With the artifact types created, you can now create instances of those types. The cell below creates the artifact for the input dataset. This artifact is recorded in the metadata store through the `put_artifacts()` function. Again, it generates an `id` that can be used for reference. # In[6]: