def testCustomStubExecutor(self, mock_publisher):
        # verify whether custom stub executor substitution works
        mock_publisher.return_value.publish_execution.return_value = {}

        component_map = \
            {'_FakeComponent.FakeComponent': CustomStubExecutor}

        my_stub_launcher = \
            stub_component_launcher.get_stub_launcher_class(
                test_data_dir=self.record_dir,
                stubbed_component_ids=[],
                stubbed_component_map=component_map)

        launcher = my_stub_launcher.create(
            component=self.component,
            pipeline_info=self.pipeline_info,
            driver_args=self.driver_args,
            metadata_connection=self.metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        launcher.launch()

        output_path = self.component.outputs['output'].get()[0].uri
        generated_file = os.path.join(output_path, 'result.txt')
        self.assertTrue(tf.io.gfile.exists(generated_file))
        contents = io_utils.read_string_file(generated_file)
        self.assertEqual('custom component', contents)
    def testStubExecutor(self, mock_publisher):
        # verify whether base stub executor substitution works
        mock_publisher.return_value.publish_execution.return_value = {}

        record_file = os.path.join(self.record_dir, 'output', 'recorded.txt')
        io_utils.write_string_file(record_file, 'hello world')
        component_ids = ['_FakeComponent.FakeComponent']

        my_stub_launcher = \
            stub_component_launcher.get_stub_launcher_class(
                test_data_dir=self.record_dir,
                stubbed_component_ids=component_ids,
                stubbed_component_map={})

        launcher = my_stub_launcher.create(
            component=self.component,
            pipeline_info=self.pipeline_info,
            driver_args=self.driver_args,
            metadata_connection=self.metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        launcher.launch()

        output_path = self.component.outputs['output'].get()[0].uri
        copied_file = os.path.join(output_path, 'recorded.txt')
        self.assertTrue(tf.io.gfile.exists(copied_file))
        contents = io_utils.read_string_file(copied_file)
        self.assertEqual('hello world', contents)
    def testExecutor(self, mock_publisher):
        # verify whether original executors can run
        mock_publisher.return_value.publish_execution.return_value = {}

        io_utils.write_string_file(os.path.join(self.input_dir, 'result.txt'),
                                   'test')

        my_stub_launcher = \
            stub_component_launcher.get_stub_launcher_class(
                test_data_dir=self.record_dir,
                stubbed_component_ids=[],
                stubbed_component_map={})

        launcher = my_stub_launcher.create(
            component=self.component,
            pipeline_info=self.pipeline_info,
            driver_args=self.driver_args,
            metadata_connection=self.metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type,
            '.'.join([  # pylint: disable=protected-access
                test_utils._FakeComponent.__module__,  # pylint: disable=protected-access
                test_utils._FakeComponent.__name__  # pylint: disable=protected-access
            ]))
        launcher.launch()

        output_path = self.component.outputs['output'].get()[0].uri
        self.assertTrue(tf.io.gfile.exists(output_path))
        contents = io_utils.read_string_file(output_path)
        self.assertEqual('test', contents)
    def testTaxiPipelineBeam(self):
        # Runs the pipeline and record to self._recorded_output_dir
        record_taxi_pipeline = taxi_pipeline_beam._create_pipeline(  # pylint:disable=protected-access
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._recorded_mlmd_path,
            beam_pipeline_args=[])
        BeamDagRunner().run(record_taxi_pipeline)
        pipeline_recorder_utils.record_pipeline(
            output_dir=self._recorded_output_dir,
            metadata_db_uri=self._recorded_mlmd_path,
            host=None,
            port=None,
            pipeline_name=self._pipeline_name,
            run_id=None)

        # Run pipeline with stub executors.
        taxi_pipeline = taxi_pipeline_beam._create_pipeline(  # pylint:disable=protected-access
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._metadata_path,
            beam_pipeline_args=[])

        model_resolver_id = 'ResolverNode.latest_blessed_model_resolver'
        stubbed_component_ids = [
            component.id for component in taxi_pipeline.components
            if component.id != model_resolver_id
        ]

        stub_launcher = stub_component_launcher.get_stub_launcher_class(
            test_data_dir=self._recorded_output_dir,
            stubbed_component_ids=stubbed_component_ids,
            stubbed_component_map={})
        stub_pipeline_config = pipeline_config.PipelineConfig(
            supported_launcher_classes=[
                stub_launcher,
            ])
        BeamDagRunner(config=stub_pipeline_config).run(taxi_pipeline)

        self.assertTrue(tf.io.gfile.exists(self._metadata_path))

        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)

        # Verify that recorded files are successfully copied to the output uris.
        with metadata.Metadata(metadata_config) as m:
            artifacts = m.store.get_artifacts()
            artifact_count = len(artifacts)
            executions = m.store.get_executions()
            execution_count = len(executions)
            # Artifact count is greater by 3 due to extra artifacts produced by
            # Evaluator(blessing and evaluation), Trainer(model and model_run) and
            # Transform(example, graph, cache) minus Resolver which doesn't generate
            # new artifact.
            self.assertEqual(artifact_count, execution_count + 3)
            self.assertLen(taxi_pipeline.components, execution_count)

            for execution in executions:
                component_id = execution.properties[
                    metadata._EXECUTION_TYPE_KEY_COMPONENT_ID].string_value  # pylint: disable=protected-access
                if component_id == 'ResolverNode.latest_blessed_model_resolver':
                    continue
                eid = [execution.id]
                events = m.store.get_events_by_execution_ids(eid)
                output_events = [
                    x for x in events
                    if x.type == metadata_store_pb2.Event.OUTPUT
                ]
                for event in output_events:
                    steps = event.path.steps
                    self.assertTrue(steps[0].HasField('key'))
                    name = steps[0].key
                    artifacts = m.store.get_artifacts_by_id(
                        [event.artifact_id])
                    for idx, artifact in enumerate(artifacts):
                        self.assertDirectoryEqual(
                            artifact.uri,
                            os.path.join(self._recorded_output_dir,
                                         component_id, name, str(idx)))