def testGetArtifactsDict(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Create and shuffle a few artifacts. The shuffled order should be # retained in the output of `execution_lib.get_artifacts_dict`. input_examples = [] for i in range(10): input_example = standard_artifacts.Examples() input_example.uri = 'example{}'.format(i) input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id input_examples.append(input_example) random.shuffle(input_examples) output_models = [] for i in range(8): output_model = standard_artifacts.Model() output_model.uri = 'model{}'.format(i) output_model.type_id = common_utils.register_type_if_not_exist( m, output_model.artifact_type).id output_models.append(output_model) random.shuffle(output_models) m.store.put_artifacts([ a.mlmd_artifact for a in itertools.chain(input_examples, output_models) ]) execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), state=metadata_store_pb2.Execution.RUNNING) contexts = self._generate_contexts(m) input_artifacts_dict = {'examples': input_examples} output_artifacts_dict = {'model': output_models} execution = execution_lib.put_execution( m, execution, contexts, input_artifacts=input_artifacts_dict, output_artifacts=output_artifacts_dict) # Verify that the same artifacts are returned in the correct order. artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, metadata_store_pb2.Event.INPUT) self.assertCountEqual(['examples'], list(artifacts_dict.keys())) self.assertEqual([ex.uri for ex in input_examples], [a.uri for a in artifacts_dict['examples']]) artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, metadata_store_pb2.Event.OUTPUT) self.assertCountEqual(['model'], list(artifacts_dict.keys())) self.assertEqual([model.uri for model in output_models], [a.uri for a in artifacts_dict['model']])
def get_all_node_artifacts( pipeline: pipeline_pb2.Pipeline, mlmd_handle: metadata.Metadata ) -> Dict[str, Dict[int, Dict[str, List[metadata_store_pb2.Artifact]]]]: """Returns all output artifacts of all pipeline nodes if present. Args: pipeline: Pipeline proto associated with a `PipelineState` object. mlmd_handle: Handle to MLMD db. Returns: Dict of node id to Dict of execution id to Dict of key to output artifact list. """ executions_dict = get_all_node_executions(pipeline, mlmd_handle) result = {} for node_id, executions in executions_dict.items(): node_artifacts = {} for execution in executions: execution_artifacts = {} for key, artifacts in execution_lib.get_artifacts_dict( mlmd_handle, execution.id, [ metadata_store_pb2.Event.OUTPUT, metadata_store_pb2.Event.DECLARED_OUTPUT ]).items(): execution_artifacts[key] = [ artifact.mlmd_artifact for artifact in artifacts ] node_artifacts[execution.id] = execution_artifacts result[node_id] = node_artifacts return result
def _generate_task_from_execution( metadata_handler: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, node: pipeline_pb2.PipelineNode, execution: metadata_store_pb2.Execution) -> task_lib.Task: """Generates `ExecNodeTask` given execution.""" contexts = metadata_handler.store.get_contexts_by_execution(execution.id) exec_properties = _extract_properties(execution) input_artifacts = execution_lib.get_artifacts_dict( metadata_handler, execution.id, metadata_store_pb2.Event.INPUT) outputs_resolver = outputs_utils.OutputsResolver(node, pipeline.pipeline_info, pipeline.runtime_spec, pipeline.execution_mode) return task_lib.ExecNodeTask( node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node), execution=execution, contexts=contexts, exec_properties=exec_properties, input_artifacts=input_artifacts, output_artifacts=outputs_resolver.generate_output_artifacts( execution.id), executor_output_uri=outputs_resolver.get_executor_output_uri( execution.id), stateful_working_dir=outputs_resolver.get_stateful_working_directory( execution.id))
def _cache_and_publish(self, existing_execution: metadata_store_pb2.Execution): """Updates MLMD.""" cached_execution_contexts = self._get_cached_execution_contexts( existing_execution) # Check if there are any previous attempts to cache and publish. prev_cache_executions = ( execution_lib.get_executions_associated_with_all_contexts( self._mlmd, contexts=cached_execution_contexts)) if not prev_cache_executions: new_execution = execution_publish_utils.register_execution( self._mlmd, execution_type=metadata_store_pb2.ExecutionType( id=existing_execution.type_id), contexts=cached_execution_contexts) else: if len(prev_cache_executions) > 1: logging.warning( 'More than one previous cache executions seen when attempting ' 'reuse_node_outputs: %s', prev_cache_executions) if (prev_cache_executions[-1].last_known_state == metadata_store_pb2.Execution.CACHED): return else: new_execution = prev_cache_executions[-1] output_artifacts = execution_lib.get_artifacts_dict( self._mlmd, existing_execution.id, event_types=list(event_lib.VALID_OUTPUT_EVENT_TYPES)) execution_publish_utils.publish_cached_execution( self._mlmd, contexts=cached_execution_contexts, execution_id=new_execution.id, output_artifacts=output_artifacts)
def testGetArtifactsDict(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Create and shuffle a few artifacts. The shuffled order should be # retained in the output of `execution_lib.get_artifacts_dict`. input_artifact_keys = ('input1', 'input2', 'input3') input_artifacts_dict = collections.OrderedDict() for input_key in input_artifact_keys: input_examples = [] for i in range(10): input_example = standard_artifacts.Examples() input_example.uri = f'{input_key}/example{i}' input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id input_examples.append(input_example) random.shuffle(input_examples) input_artifacts_dict[input_key] = input_examples output_models = [] for i in range(8): output_model = standard_artifacts.Model() output_model.uri = f'model{i}' output_model.type_id = common_utils.register_type_if_not_exist( m, output_model.artifact_type).id output_models.append(output_model) random.shuffle(output_models) output_artifacts_dict = {'model': output_models} # Store input artifacts only. Outputs will be saved in put_execution(). input_mlmd_artifacts = [ a.mlmd_artifact for a in itertools.chain(*input_artifacts_dict.values()) ] artifact_ids = m.store.put_artifacts(input_mlmd_artifacts) for artifact_id, mlmd_artifact in zip(artifact_ids, input_mlmd_artifacts): mlmd_artifact.id = artifact_id execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), state=metadata_store_pb2.Execution.RUNNING) contexts = self._generate_contexts(m) # Change the order of the OrderedDict to shuffle the order of input keys. input_artifacts_dict.move_to_end('input1') execution = execution_lib.put_execution( m, execution, contexts, input_artifacts=input_artifacts_dict, output_artifacts=output_artifacts_dict) # Verify that the same artifacts are returned in the correct order. artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, [metadata_store_pb2.Event.INPUT]) self.assertEqual(set(input_artifact_keys), set(artifacts_dict.keys())) for key in artifacts_dict: self.assertEqual([ex.uri for ex in input_artifacts_dict[key]], [a.uri for a in artifacts_dict[key]], f'for key={key}') artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, [metadata_store_pb2.Event.OUTPUT]) self.assertEqual({'model'}, set(artifacts_dict.keys())) self.assertEqual([model.uri for model in output_models], [a.uri for a in artifacts_dict['model']])