Esempio n. 1
0
def get_pipeline_outputs(
        metadata_connection_config: Optional[
            metadata_store_pb2.ConnectionConfig], pipeline_name: str
) -> Dict[Text, Dict[Text, Dict[int, types.Artifact]]]:
    """Returns a dictionary of pipeline output artifacts for every component.

  Args:
    metadata_connection_config: connection configuration to MLMD.
    pipeline_name: Name of the pipeline.

  Returns:
    a dictionary of holding list of artifacts for a component id.
  """
    output_map = {}
    with metadata.Metadata(metadata_connection_config) as m:
        executions = pipeline_recorder_utils.get_latest_executions(
            m, pipeline_name)
        for execution in executions:
            component_id = pipeline_recorder_utils.get_component_id_from_execution(
                m, execution)
            output_dict = {}
            for event in m.store.get_events_by_execution_ids([execution.id]):
                if event.type == metadata_store_pb2.Event.OUTPUT:
                    artifacts = m.store.get_artifacts_by_id(
                        [event.artifact_id])
                    steps = event.path.steps
                    if not steps or not steps[0].HasField('key'):
                        raise ValueError(
                            'Artifact key is not recorded in the MLMD.')
                    key = steps[0].key
                    artifacts = m.store.get_artifacts_by_id(
                        [event.artifact_id])
                    if key not in output_dict:
                        output_dict[key] = {}
                    for pb_artifact in artifacts:
                        if len(steps) < 2 or not steps[1].HasField('index'):
                            raise ValueError(
                                'Artifact index is not recorded in the MLMD.')
                        artifact_index = steps[1].index
                        if artifact_index in output_dict[key]:
                            raise ValueError('Artifact already in output_dict')
                        [artifact_type] = m.store.get_artifact_types_by_id(
                            [pb_artifact.type_id])
                        artifact = artifact_utils.deserialize_artifact(
                            artifact_type, pb_artifact)
                        output_dict[key][artifact_index] = artifact
            output_map[component_id] = output_dict
    return output_map
  def testStubbedImdbPipelineBeam(self):
    pipeline_ir = compiler.Compiler().compile(self.imdb_pipeline)

    pipeline_mock.replace_executor_with_stub(pipeline_ir,
                                             self._recorded_output_dir, [])

    BeamDagRunner().run(pipeline_ir)

    self.assertTrue(fileio.exists(self._metadata_path))

    metadata_config = metadata.sqlite_metadata_connection_config(
        self._metadata_path)

    # Verify that recorded files are successfully copied to the output uris.
    with metadata.Metadata(metadata_config) as m:
      for execution in m.store.get_executions():
        component_id = pipeline_recorder_utils.get_component_id_from_execution(
            m, execution)
        if component_id.startswith('ResolverNode'):
          continue
        eid = [execution.id]
        events = m.store.get_events_by_execution_ids(eid)
        output_events = [
            x for x in events if x.type == metadata_store_pb2.Event.OUTPUT
        ]
        for event in output_events:
          steps = event.path.steps
          assert steps[0].HasField('key')
          name = steps[0].key
          artifacts = m.store.get_artifacts_by_id([event.artifact_id])
          for idx, artifact in enumerate(artifacts):
            self.assertDirectoryEqual(
                artifact.uri,
                os.path.join(self._recorded_output_dir, component_id, name,
                             str(idx)))

    # Calls verifier for pipeline output artifacts, excluding the resolver node.
    BeamDagRunner().run(self.imdb_pipeline)
    pipeline_outputs = executor_verifier_utils.get_pipeline_outputs(
        self.imdb_pipeline.metadata_connection_config,
        self._pipeline_name)

    verifier_map = {
        'model': self._verify_model,
        'model_run': self._verify_model,
        'examples': self._verify_examples,
        'schema': self._verify_schema,
        'anomalies': self._verify_anomalies,
        'evaluation': self._verify_evaluation,
        # A subdirectory of updated_analyzer_cache has changing name.
        'updated_analyzer_cache': self._veryify_root_dir,
    }

    # List of components to verify. ResolverNode is ignored because it
    # doesn't have an executor.
    verify_component_ids = [
        component.id
        for component in self.imdb_pipeline.components
        if not component.id.startswith('ResolverNode')
    ]

    for component_id in verify_component_ids:
      for key, artifact_dict in pipeline_outputs[component_id].items():
        for idx, artifact in artifact_dict.items():
          logging.info('Verifying %s', component_id)
          recorded_uri = os.path.join(self._recorded_output_dir, component_id,
                                      key, str(idx))
          verifier_map.get(key, self._verify_file_path)(artifact.uri,
                                                        recorded_uri)
    def testStubbedTaxiPipelineBeam(self):
        pipeline_ir = compiler.Compiler().compile(self.taxi_pipeline)

        logging.info('Replacing with test_data_dir:%s',
                     self._recorded_output_dir)
        pipeline_mock.replace_executor_with_stub(pipeline_ir,
                                                 self._recorded_output_dir, [])

        BeamDagRunner().run_with_ir(pipeline_ir)

        self.assertTrue(fileio.exists(self._metadata_path))

        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)

        # Verify that recorded files are successfully copied to the output uris.
        with metadata.Metadata(metadata_config) as m:
            artifacts = m.store.get_artifacts()
            artifact_count = len(artifacts)
            executions = m.store.get_executions()
            execution_count = len(executions)
            # Artifact count is greater by 7 due to extra artifacts produced by
            # Evaluator(blessing and evaluation), Trainer(model and model_run) and
            # Transform(example, graph, cache, pre_transform_statistics,
            # pre_transform_schema, post_transform_statistics, post_transform_schema,
            # post_transform_anomalies) minus Resolver which doesn't generate
            # new artifact.
            self.assertEqual(artifact_count, execution_count + 7)
            self.assertLen(self.taxi_pipeline.components, execution_count)

            for execution in executions:
                component_id = pipeline_recorder_utils.get_component_id_from_execution(
                    m, execution)
                if component_id.startswith('Resolver'):
                    continue
                eid = [execution.id]
                events = m.store.get_events_by_execution_ids(eid)
                output_events = [
                    x for x in events
                    if x.type == metadata_store_pb2.Event.OUTPUT
                ]
                for event in output_events:
                    steps = event.path.steps
                    self.assertTrue(steps[0].HasField('key'))
                    name = steps[0].key
                    artifacts = m.store.get_artifacts_by_id(
                        [event.artifact_id])
                    for idx, artifact in enumerate(artifacts):
                        self.assertDirectoryEqual(
                            artifact.uri,
                            os.path.join(self._recorded_output_dir,
                                         component_id, name, str(idx)))

        # Calls verifier for pipeline output artifacts, excluding the resolver node.
        BeamDagRunner().run(self.taxi_pipeline)
        pipeline_outputs = executor_verifier_utils.get_pipeline_outputs(
            self.taxi_pipeline.metadata_connection_config, self._pipeline_name)

        verifier_map = {
            'model': self._verify_model,
            'model_run': self._verify_model,
            'examples': self._verify_examples,
            'schema': self._verify_schema,
            'anomalies': self._verify_anomalies,
            'evaluation': self._verify_evaluation,
            # A subdirectory of updated_analyzer_cache has changing name.
            'updated_analyzer_cache': self._veryify_root_dir,
        }

        # List of components to verify. Resolver is ignored because it
        # doesn't have an executor.
        verify_component_ids = [
            component.id for component in self.taxi_pipeline.components
            if not component.id.startswith('Resolver')
        ]

        for component_id in verify_component_ids:
            logging.info('Verifying %s', component_id)
            for key, artifact_dict in pipeline_outputs[component_id].items():
                for idx, artifact in artifact_dict.items():
                    recorded_uri = os.path.join(self._recorded_output_dir,
                                                component_id, key, str(idx))
                    verifier_map.get(key, self._verify_file_path)(artifact.uri,
                                                                  recorded_uri)