def with_stages(pipeline_proto, stages):

    # In case it was a generator that mutates components as it
    # produces outputs (as is the case with most transformations).
    stages = list(stages)

    new_proto = beam_runner_api_pb2.Pipeline()
    new_proto.CopyFrom(pipeline_proto)
    components = new_proto.components
    components.transforms.clear()

    parents = {
        child: parent
        for parent, proto in pipeline_proto.components.transforms.items()
        for child in proto.subtransforms
    }

    def add_parent(child, parent):
        if parent not in components.transforms:
            components.transforms[parent].CopyFrom(
                pipeline_proto.components.transforms[parent])
            del components.transforms[parent].subtransforms[:]
            if parent in parents:
                add_parent(parent, parents[parent])
        components.transforms[parent].subtransforms.append(child)

    for stage in stages:
        for transform in stage.transforms:
            id = unique_name(components.transforms, stage.name)
            components.transforms[id].CopyFrom(transform)
            if stage.parent:
                add_parent(id, stage.parent)

    return new_proto
Beispiel #2
0
    def to_runner_api(self,
                      return_context=False,
                      context=None,
                      use_fake_coders=False,
                      default_environment=None):
        """For internal use only; no backwards-compatibility guarantees."""
        from apache_beam.runners import pipeline_context
        from apache_beam.portability.api import beam_runner_api_pb2
        if context is None:
            context = pipeline_context.PipelineContext(
                use_fake_coders=use_fake_coders,
                default_environment=default_environment)
        elif default_environment is not None:
            raise ValueError(
                'Only one of context or default_environment may be specified.')

        # The RunnerAPI spec requires certain transforms and side-inputs to have KV
        # inputs (and corresponding outputs).
        # Currently we only upgrade to KV pairs.  If there is a need for more
        # general shapes, potential conflicts will have to be resolved.
        # We also only handle single-input, and (for fixing the output) single
        # output, which is sufficient.
        class ForceKvInputTypes(PipelineVisitor):
            def enter_composite_transform(self, transform_node):
                self.visit_transform(transform_node)

            def visit_transform(self, transform_node):
                if not transform_node.transform:
                    return
                if transform_node.transform.runner_api_requires_keyed_input():
                    pcoll = transform_node.inputs[0]
                    pcoll.element_type = typehints.coerce_to_kv_type(
                        pcoll.element_type, transform_node.full_label)
                    if len(transform_node.outputs) == 1:
                        # The runner often has expectations about the output types as well.
                        output, = transform_node.outputs.values()
                        output.element_type = transform_node.transform.infer_output_type(
                            pcoll.element_type)
                for side_input in transform_node.transform.side_inputs:
                    if side_input.requires_keyed_input():
                        side_input.pvalue.element_type = typehints.coerce_to_kv_type(
                            side_input.pvalue.element_type,
                            transform_node.full_label,
                            side_input_producer=side_input.pvalue.producer.
                            full_label)

        self.visit(ForceKvInputTypes())

        # Mutates context; placing inline would force dependence on
        # argument evaluation order.
        root_transform_id = context.transforms.get_id(self._root_transform())
        proto = beam_runner_api_pb2.Pipeline(
            root_transform_ids=[root_transform_id],
            components=context.to_runner_api())
        proto.components.transforms[root_transform_id].unique_name = (
            root_transform_id)
        if return_context:
            return proto, context
        else:
            return proto
    def test_end_to_end(self):

        job_service = local_job_service.LocalJobServicer()
        job_service.start_grpc_server()

        plan = TestJobServicePlan(job_service)

        _, message_stream, state_stream = plan.submit(
            beam_runner_api_pb2.Pipeline())

        state_results = list(state_stream)
        message_results = list(message_stream)

        expected_states = [
            beam_job_api_pb2.JobState.STOPPED,
            beam_job_api_pb2.JobState.STARTING,
            beam_job_api_pb2.JobState.RUNNING,
            beam_job_api_pb2.JobState.DONE,
        ]
        self.assertEqual([s.state for s in state_results], expected_states)

        self.assertEqual([
            s.state_response.state
            for s in message_results if s.HasField('state_response')
        ], expected_states)
    def test_error_messages_after_pipeline_failure(self):
        job_service = local_job_service.LocalJobServicer()
        job_service.start_grpc_server()

        plan = TestJobServicePlan(job_service)

        job_id, message_stream, state_stream = plan.submit(
            beam_runner_api_pb2.Pipeline(
                requirements=['unsupported_requirement']))

        message_results = list(message_stream)
        state_results = list(state_stream)

        expected_states = [
            beam_job_api_pb2.JobState.STOPPED,
            beam_job_api_pb2.JobState.STARTING,
            beam_job_api_pb2.JobState.RUNNING,
            beam_job_api_pb2.JobState.FAILED,
        ]
        self.assertEqual([s.state for s in state_results], expected_states)
        self.assertTrue(
            any('unsupported_requirement' in m.message_response.message_text
                for m in message_results), message_results)

        # Assert we still see the error message if we fetch error messages after
        # the job has completed.
        messages_again = list(
            plan.job_service.GetMessageStream(
                beam_job_api_pb2.JobMessagesRequest(job_id=job_id)))
        self.assertTrue(
            any('unsupported_requirement' in m.message_response.message_text
                for m in message_results), messages_again)
    def test_artifact_service_override(self):
        job_service = local_job_service.LocalJobServicer()
        port = job_service.start_grpc_server()

        test_artifact_endpoint = 'testartifactendpoint:4242'

        options = pipeline_options.PipelineOptions([
            '--job_endpoint',
            'localhost:%d' % port,
            '--artifact_endpoint',
            test_artifact_endpoint,
        ])
        runner = PortableRunner()
        job_service_handle = runner.create_job_service(options)

        with mock.patch.object(job_service_handle, 'stage') as mocked_stage:
            job_service_handle.submit(beam_runner_api_pb2.Pipeline())
            mocked_stage.assert_called_once_with(mock.ANY,
                                                 test_artifact_endpoint,
                                                 mock.ANY)

        # Confirm the artifact_endpoint is in the options protobuf
        options_proto = job_service_handle.get_pipeline_options()
        self.assertEqual(options_proto['beam:option:artifact_endpoint:v1'],
                         test_artifact_endpoint)
def with_stages(pipeline_proto, stages):
    new_proto = beam_runner_api_pb2.Pipeline()
    new_proto.CopyFrom(pipeline_proto)
    components = new_proto.components
    components.transforms.clear()

    parents = {
        child: parent
        for parent, proto in pipeline_proto.components.transforms.items()
        for child in proto.subtransforms
    }

    def add_parent(child, parent):
        if parent not in components.transforms:
            components.transforms[parent].CopyFrom(
                pipeline_proto.components.transforms[parent])
            del components.transforms[parent].subtransforms[:]
            if parent in parents:
                add_parent(parent, parents[parent])
        components.transforms[parent].subtransforms.append(child)

    for stage in stages:
        for transform in stage.transforms:
            id = unique_name(components.transforms, stage.name)
            components.transforms[id].CopyFrom(transform)
            if stage.parent:
                add_parent(id, stage.parent)

    return new_proto
Beispiel #7
0
 def test_stage_resources(self):
     pipeline_options = PipelineOptions([
         '--temp_location', 'gs://test-location/temp', '--staging_location',
         'gs://test-location/staging', '--no_auth'
     ])
     pipeline = beam_runner_api_pb2.Pipeline(
         components=beam_runner_api_pb2.Components(
             environments={
                 'env1':
                 beam_runner_api_pb2.Environment(dependencies=[
                     beam_runner_api_pb2.ArtifactInformation(
                         type_urn=common_urns.artifact_types.FILE.urn,
                         type_payload=beam_runner_api_pb2.
                         ArtifactFilePayload(
                             path='/tmp/foo1').SerializeToString(),
                         role_urn=common_urns.artifact_roles.STAGING_TO.urn,
                         role_payload=beam_runner_api_pb2.
                         ArtifactStagingToRolePayload(
                             staged_name='foo1').SerializeToString()),
                     beam_runner_api_pb2.ArtifactInformation(
                         type_urn=common_urns.artifact_types.FILE.urn,
                         type_payload=beam_runner_api_pb2.
                         ArtifactFilePayload(
                             path='/tmp/bar1').SerializeToString(),
                         role_urn=common_urns.artifact_roles.STAGING_TO.urn,
                         role_payload=beam_runner_api_pb2.
                         ArtifactStagingToRolePayload(
                             staged_name='bar1').SerializeToString())
                 ]),
                 'env2':
                 beam_runner_api_pb2.Environment(dependencies=[
                     beam_runner_api_pb2.ArtifactInformation(
                         type_urn=common_urns.artifact_types.FILE.urn,
                         type_payload=beam_runner_api_pb2.
                         ArtifactFilePayload(
                             path='/tmp/foo2').SerializeToString(),
                         role_urn=common_urns.artifact_roles.STAGING_TO.urn,
                         role_payload=beam_runner_api_pb2.
                         ArtifactStagingToRolePayload(
                             staged_name='foo2').SerializeToString()),
                     beam_runner_api_pb2.ArtifactInformation(
                         type_urn=common_urns.artifact_types.FILE.urn,
                         type_payload=beam_runner_api_pb2.
                         ArtifactFilePayload(
                             path='/tmp/bar2').SerializeToString(),
                         role_urn=common_urns.artifact_roles.STAGING_TO.urn,
                         role_payload=beam_runner_api_pb2.
                         ArtifactStagingToRolePayload(
                             staged_name='bar2').SerializeToString())
                 ])
             }))
     client = apiclient.DataflowApplicationClient(pipeline_options)
     with mock.patch.object(apiclient._LegacyDataflowStager,
                            'stage_job_resources') as mock_stager:
         client._stage_resources(pipeline, pipeline_options)
     mock_stager.assert_called_once_with(
         [('/tmp/foo1', 'foo1'), ('/tmp/bar1', 'bar1'),
          ('/tmp/foo2', 'foo2'), ('/tmp/bar2', 'bar2')],
         staging_location='gs://test-location/staging')
Beispiel #8
0
 def to_runner_api(self):
     """For internal use only; no backwards-compatibility guarantees."""
     from apache_beam.runners import pipeline_context
     from apache_beam.portability.api import beam_runner_api_pb2
     context = pipeline_context.PipelineContext()
     # Mutates context; placing inline would force dependence on
     # argument evaluation order.
     root_transform_id = context.transforms.get_id(self._root_transform())
     proto = beam_runner_api_pb2.Pipeline(
         root_transform_ids=[root_transform_id],
         components=context.to_runner_api())
     return proto
Beispiel #9
0
    def test_end_to_end(self):

        job_service = local_job_service.LocalJobServicer()
        job_service.start_grpc_server()

        # this logic is taken roughly from PortableRunner.run_pipeline()

        # Prepare the job.
        prepare_response = job_service.Prepare(
            beam_job_api_pb2.PrepareJobRequest(
                job_name='job', pipeline=beam_runner_api_pb2.Pipeline()))
        channel = grpc.insecure_channel(
            prepare_response.artifact_staging_endpoint.url)
        retrieval_token = beam_artifact_api_pb2_grpc.ArtifactStagingServiceStub(
            channel).CommitManifest(
                beam_artifact_api_pb2.CommitManifestRequest(
                    staging_session_token=prepare_response.
                    staging_session_token,
                    manifest=beam_artifact_api_pb2.Manifest())).retrieval_token
        channel.close()

        state_stream = job_service.GetStateStream(
            beam_job_api_pb2.GetJobStateRequest(
                job_id=prepare_response.preparation_id))
        # If there's an error, we don't always get it until we try to read.
        # Fortunately, there's always an immediate current state published.
        # state_results.append(next(state_stream))
        state_stream = increment_iter(state_stream)

        message_stream = job_service.GetMessageStream(
            beam_job_api_pb2.JobMessagesRequest(
                job_id=prepare_response.preparation_id))

        job_service.Run(
            beam_job_api_pb2.RunJobRequest(
                preparation_id=prepare_response.preparation_id,
                retrieval_token=retrieval_token))

        state_results = list(state_stream)
        message_results = list(message_stream)

        expected_states = [
            beam_job_api_pb2.JobState.STOPPED,
            beam_job_api_pb2.JobState.STARTING,
            beam_job_api_pb2.JobState.RUNNING,
            beam_job_api_pb2.JobState.DONE,
        ]
        self.assertEqual([s.state for s in state_results], expected_states)

        self.assertEqual([s.state_response.state for s in message_results],
                         expected_states)
def pipeline_from_stages(
    pipeline_proto, stages, known_runner_urns, partial):

  # In case it was a generator that mutates components as it
  # produces outputs (as is the case with most transformations).
  stages = list(stages)

  new_proto = beam_runner_api_pb2.Pipeline()
  new_proto.CopyFrom(pipeline_proto)
  components = new_proto.components
  components.transforms.clear()

  roots = set()
  parents = {
      child: parent
      for parent, proto in pipeline_proto.components.transforms.items()
      for child in proto.subtransforms
  }

  def add_parent(child, parent):
    if parent is None:
      roots.add(child)
    else:
      if parent not in components.transforms:
        components.transforms[parent].CopyFrom(
            pipeline_proto.components.transforms[parent])
        del components.transforms[parent].subtransforms[:]
        add_parent(parent, parents.get(parent))
      components.transforms[parent].subtransforms.append(child)

  all_consumers = collections.defaultdict(set)
  for stage in stages:
    for transform in stage.transforms:
      for pcoll in transform.inputs.values():
        all_consumers[pcoll].add(id(transform))

  for stage in stages:
    if partial:
      transform = only_element(stage.transforms)
    else:
      transform = stage.executable_stage_transform(
          known_runner_urns, all_consumers, components)
    transform_id = unique_name(components.transforms, stage.name)
    components.transforms[transform_id].CopyFrom(transform)
    add_parent(transform_id, stage.parent)

  del new_proto.root_transform_ids[:]
  new_proto.root_transform_ids.extend(roots)

  return new_proto
Beispiel #11
0
  def prune_subgraph_for(self, pipeline, required_transform_ids):
    # Create the pipeline_proto to read all the components from. It will later
    # create a new pipeline proto from the cut out components.
    pipeline_proto, context = pipeline.to_runner_api(
        return_context=True, use_fake_coders=False)

    # Get all the root transforms. The caching transforms will be subtransforms
    # of one of these roots.
    roots = [root for root in pipeline_proto.root_transform_ids]

    (t, p) = self._required_components(
        pipeline_proto,
        roots + required_transform_ids,
        set(),
        follow_outputs=True,
        follow_inputs=True)

    def set_proto_map(proto_map, new_value):
      proto_map.clear()
      for key, value in new_value.items():
        proto_map[key].CopyFrom(value)

    # Copy the transforms into the new pipeline.
    pipeline_to_execute = beam_runner_api_pb2.Pipeline()
    pipeline_to_execute.root_transform_ids[:] = roots
    set_proto_map(pipeline_to_execute.components.transforms, t)
    set_proto_map(pipeline_to_execute.components.pcollections, p)
    set_proto_map(
        pipeline_to_execute.components.coders, context.to_runner_api().coders)
    set_proto_map(
        pipeline_to_execute.components.windowing_strategies,
        context.to_runner_api().windowing_strategies)

    # Cut out all subtransforms in the root that aren't the required transforms.
    for root_id in roots:
      root = pipeline_to_execute.components.transforms[root_id]
      root.subtransforms[:] = [
          transform_id for transform_id in root.subtransforms
          if transform_id in pipeline_to_execute.components.transforms
      ]

    return pipeline_to_execute
Beispiel #12
0
    def background_caching_pipeline_proto(self):
        """Returns the background caching pipeline.

    This method creates a background caching pipeline by: adding writes to cache
    from each unbounded source (done in the instrument method), and cutting out
    all components (transform, PCollections, coders, windowing strategies) that
    are not the unbounded sources or writes to cache (or subtransforms thereof).
    """
        # Create the pipeline_proto to read all the components from. It will later
        # create a new pipeline proto from the cut out components.
        pipeline_proto = self._background_caching_pipeline.to_runner_api(
            return_context=False, use_fake_coders=True)

        # Get all the sources we want to cache.
        sources = unbounded_sources(self._background_caching_pipeline)

        # Get all the root transforms. The caching transforms will be subtransforms
        # of one of these roots.
        roots = [root for root in pipeline_proto.root_transform_ids]

        # Get the transform IDs of the caching transforms. These caching operations
        # are added the the _background_caching_pipeline in the instrument() method.
        # It's added there so that multiple calls to this method won't add multiple
        # caching operations (idempotent).
        transforms = pipeline_proto.components.transforms
        caching_transform_ids = [
            t_id for root in roots for t_id in transforms[root].subtransforms
            if WRITE_CACHE in t_id
        ]

        # Get the IDs of the unbounded sources.
        required_transform_labels = [src.full_label for src in sources]
        unbounded_source_ids = [
            k for k, v in transforms.items()
            if v.unique_name in required_transform_labels
        ]

        # The required transforms are the tranforms that we want to cut out of
        # the pipeline_proto and insert into a new pipeline to return.
        required_transform_ids = (roots + caching_transform_ids +
                                  unbounded_source_ids)
        (t, p, c, w) = self._required_components(pipeline_proto,
                                                 required_transform_ids)

        def set_proto_map(proto_map, new_value):
            proto_map.clear()
            for key, value in new_value.items():
                proto_map[key].CopyFrom(value)

        # Copy the transforms into the new pipeline.
        pipeline_to_execute = beam_runner_api_pb2.Pipeline()
        pipeline_to_execute.root_transform_ids[:] = roots
        set_proto_map(pipeline_to_execute.components.transforms, t)
        set_proto_map(pipeline_to_execute.components.pcollections, p)
        set_proto_map(pipeline_to_execute.components.coders, c)
        set_proto_map(pipeline_to_execute.components.windowing_strategies, w)

        # Cut out all subtransforms in the root that aren't the required transforms.
        for root_id in roots:
            root = pipeline_to_execute.components.transforms[root_id]
            root.subtransforms[:] = [
                transform_id for transform_id in root.subtransforms
                if transform_id in pipeline_to_execute.components.transforms
            ]

        return pipeline_to_execute
    def test_end_to_end(self, http_mock):
        submission_id = "submission-id"
        worker_host_port = "workerhost:12345"
        worker_id = "worker-id"
        server_spark_version = "1.2.3"

        def spark_submission_status_response(state):
            return {
                'json': {
                    "action": "SubmissionStatusResponse",
                    "driverState": state,
                    "serverSparkVersion": server_spark_version,
                    "submissionId": submission_id,
                    "success": "true",
                    "workerHostPort": worker_host_port,
                    "workerId": worker_id
                }
            }

        with temp_name(suffix='fake.jar') as fake_jar:
            with zipfile.ZipFile(fake_jar, 'w') as zip:
                with zip.open('spark-version-info.properties', 'w') as fout:
                    fout.write(b'version=4.5.6')

            options = pipeline_options.SparkRunnerOptions()
            options.spark_job_server_jar = fake_jar
            job_server = spark_uber_jar_job_server.SparkUberJarJobServer(
                'http://host:6066', options)

            # Prepare the job.
            plan = TestJobServicePlan(job_server)

            # Prepare the job.
            prepare_response = plan.prepare(beam_runner_api_pb2.Pipeline())
            retrieval_token = plan.stage(
                beam_runner_api_pb2.Pipeline(),
                prepare_response.artifact_staging_endpoint.url,
                prepare_response.staging_session_token)

            # Now actually run the job.
            http_mock.post(
                'http://host:6066/v1/submissions/create',
                json={
                    "action": "CreateSubmissionResponse",
                    "message":
                    "Driver successfully submitted as submission-id",
                    "serverSparkVersion": "1.2.3",
                    "submissionId": "submission-id",
                    "success": "true"
                })
            job_server.Run(
                beam_job_api_pb2.RunJobRequest(
                    preparation_id=prepare_response.preparation_id,
                    retrieval_token=retrieval_token))

            # Check the status until the job is "done" and get all error messages.
            http_mock.get(
                'http://host:6066/v1/submissions/status/submission-id', [
                    spark_submission_status_response('RUNNING'),
                    spark_submission_status_response('RUNNING'), {
                        'json': {
                            "action": "SubmissionStatusResponse",
                            "driverState": "ERROR",
                            "message": "oops",
                            "serverSparkVersion": "1.2.3",
                            "submissionId": submission_id,
                            "success": "true",
                            "workerHostPort": worker_host_port,
                            "workerId": worker_id
                        }
                    }
                ])

            state_stream = job_server.GetStateStream(
                beam_job_api_pb2.GetJobStateRequest(
                    job_id=prepare_response.preparation_id))

            self.assertEqual([s.state for s in state_stream], [
                beam_job_api_pb2.JobState.STOPPED,
                beam_job_api_pb2.JobState.RUNNING,
                beam_job_api_pb2.JobState.RUNNING,
                beam_job_api_pb2.JobState.FAILED
            ])

            message_stream = job_server.GetMessageStream(
                beam_job_api_pb2.JobMessagesRequest(
                    job_id=prepare_response.preparation_id))

            def get_item(x):
                if x.HasField('message_response'):
                    return x.message_response
                else:
                    return x.state_response.state

            self.assertEqual([get_item(m) for m in message_stream], [
                beam_job_api_pb2.JobState.STOPPED,
                beam_job_api_pb2.JobState.RUNNING,
                beam_job_api_pb2.JobMessage(
                    message_id='message0',
                    time='0',
                    importance=beam_job_api_pb2.JobMessage.MessageImportance.
                    JOB_MESSAGE_ERROR,
                    message_text="oops"),
                beam_job_api_pb2.JobState.FAILED,
            ])
Beispiel #14
0
    def _analyze_pipeline(self):
        """Analyzes the pipeline and sets the variables that can be queried.

    This function construct Pipeline proto to execute by
      1. Start from target PCollections and recursively insert the producing
         PTransforms of those PCollections, where the producing PTransforms are
         either ReadCache or PTransforms in the original pipeline.
      2. Append WriteCache PTransforsm in the pipeline.

    After running this function, the following variables will be set:
      self._pipeline_proto_to_execute
      self._top_level_referenced_pcoll_ids
      self._top_level_required_transforms
      self._caches_used
      self._read_cache_ids
      self._write_cache_ids
    """
        # We filter PTransforms to be executed bottom-up from these PCollections.
        desired_pcollections = self._desired_pcollections(self._pipeline_info)

        required_transforms = collections.OrderedDict()
        top_level_required_transforms = collections.OrderedDict()

        for pcoll_id in desired_pcollections:
            # TODO(qinyeli): Collections consumed by no-output transforms.
            self._insert_producing_transforms(pcoll_id, required_transforms,
                                              top_level_required_transforms)

        top_level_referenced_pcoll_ids = self._referenced_pcoll_ids(
            top_level_required_transforms)

        for pcoll_id in self._pipeline_info.all_pcollections():
            if not pcoll_id in top_level_referenced_pcoll_ids:
                continue

            if (pcoll_id in desired_pcollections
                    and not pcoll_id in self._caches_used):
                self._insert_caching_transforms(pcoll_id, required_transforms,
                                                top_level_required_transforms)

            if not self._cache_manager.exists(
                    'sample', self._pipeline_info.cache_label(pcoll_id)):
                self._insert_caching_transforms(pcoll_id,
                                                required_transforms,
                                                top_level_required_transforms,
                                                sample=True)

        required_transforms['_root'] = beam_runner_api_pb2.PTransform(
            subtransforms=list(top_level_required_transforms))

        referenced_pcoll_ids = self._referenced_pcoll_ids(required_transforms)
        referenced_pcollections = {}
        for pcoll_id in referenced_pcoll_ids:
            obj = self._context.pcollections.get_by_id(pcoll_id)
            proto = self._context.pcollections.get_proto(obj)
            referenced_pcollections[pcoll_id] = proto

        pipeline_to_execute = beam_runner_api_pb2.Pipeline()
        pipeline_to_execute.root_transform_ids[:] = ['_root']
        set_proto_map(pipeline_to_execute.components.transforms,
                      required_transforms)
        set_proto_map(pipeline_to_execute.components.pcollections,
                      referenced_pcollections)
        set_proto_map(pipeline_to_execute.components.coders,
                      self._context.to_runner_api().coders)
        set_proto_map(pipeline_to_execute.components.windowing_strategies,
                      self._context.to_runner_api().windowing_strategies)

        self._pipeline_proto_to_execute = pipeline_to_execute
        self._top_level_referenced_pcoll_ids = top_level_referenced_pcoll_ids
        self._top_level_required_transforms = top_level_required_transforms
Beispiel #15
0
    def test_end_to_end(self, http_mock):
        with temp_name(suffix='fake.jar') as fake_jar:
            # Create the jar file with some trivial contents.
            with zipfile.ZipFile(fake_jar, 'w') as zip:
                with zip.open('FakeClass.class', 'w') as fout:
                    fout.write(b'[original_contents]')

            job_server = flink_uber_jar_job_server.FlinkUberJarJobServer(
                'http://flink', fake_jar)

            # Prepare the job.
            prepare_response = job_server.Prepare(
                beam_job_api_pb2.PrepareJobRequest(
                    job_name='job', pipeline=beam_runner_api_pb2.Pipeline()))
            channel = grpc.insecure_channel(
                prepare_response.artifact_staging_endpoint.url)
            retrieval_token = beam_artifact_api_pb2_grpc.ArtifactStagingServiceStub(
                channel).CommitManifest(
                    beam_artifact_api_pb2.CommitManifestRequest(
                        staging_session_token=prepare_response.
                        staging_session_token,
                        manifest=beam_artifact_api_pb2.Manifest())
                ).retrieval_token
            channel.close()

            # Now actually run the job.
            http_mock.post('http://flink/v1/jars/upload',
                           json={'filename': '/path/to/jar/nonce'})
            http_mock.post('http://flink/v1/jars/nonce/run',
                           json={'jobid': 'some_job_id'})
            job_server.Run(
                beam_job_api_pb2.RunJobRequest(
                    preparation_id=prepare_response.preparation_id,
                    retrieval_token=retrieval_token))

            # Check the status until the job is "done" and get all error messages.
            http_mock.get('http://flink/v1/jobs/some_job_id/execution-result',
                          [{
                              'json': {
                                  'status': {
                                      'id': 'IN_PROGRESS'
                                  }
                              }
                          }, {
                              'json': {
                                  'status': {
                                      'id': 'IN_PROGRESS'
                                  }
                              }
                          }, {
                              'json': {
                                  'status': {
                                      'id': 'COMPLETED'
                                  }
                              }
                          }])
            http_mock.get('http://flink/v1/jobs/some_job_id',
                          json={'state': 'FINISHED'})
            http_mock.delete('http://flink/v1/jars/nonce')

            state_stream = job_server.GetStateStream(
                beam_job_api_pb2.GetJobStateRequest(
                    job_id=prepare_response.preparation_id))
            self.assertEqual([s.state for s in state_stream], [
                beam_job_api_pb2.JobState.RUNNING,
                beam_job_api_pb2.JobState.DONE
            ])

            http_mock.get('http://flink/v1/jobs/some_job_id/exceptions',
                          json={
                              'all-exceptions': [{
                                  'exception': 'exc_text',
                                  'timestamp': 0
                              }]
                          })
            message_stream = job_server.GetMessageStream(
                beam_job_api_pb2.JobMessagesRequest(
                    job_id=prepare_response.preparation_id))
            self.assertEqual([m for m in message_stream], [
                beam_job_api_pb2.JobMessagesResponse(
                    message_response=beam_job_api_pb2.JobMessage(
                        message_id='message0',
                        time='0',
                        importance=beam_job_api_pb2.JobMessage.
                        MessageImportance.JOB_MESSAGE_ERROR,
                        message_text='exc_text')),
                beam_job_api_pb2.JobMessagesResponse(
                    state_response=beam_job_api_pb2.GetJobStateResponse(
                        state=beam_job_api_pb2.JobState.DONE)),
            ])
    def test_end_to_end(self, http_mock):
        with temp_name(suffix='fake.jar') as fake_jar:
            # Create the jar file with some trivial contents.
            with zipfile.ZipFile(fake_jar, 'w') as zip:
                with zip.open('FakeClass.class', 'w') as fout:
                    fout.write(b'[original_contents]')

            options = pipeline_options.FlinkRunnerOptions()
            options.flink_job_server_jar = fake_jar
            job_server = flink_uber_jar_job_server.FlinkUberJarJobServer(
                'http://flink', options)

            plan = TestJobServicePlan(job_server)

            # Prepare the job.
            prepare_response = plan.prepare(beam_runner_api_pb2.Pipeline())
            plan.stage(beam_runner_api_pb2.Pipeline(),
                       prepare_response.artifact_staging_endpoint.url,
                       prepare_response.staging_session_token)

            # Now actually run the job.
            http_mock.post('http://flink/v1/jars/upload',
                           json={'filename': '/path/to/jar/nonce'})
            http_mock.post('http://flink/v1/jars/nonce/run',
                           json={'jobid': 'some_job_id'})

            _, message_stream, state_stream = plan.run(
                prepare_response.preparation_id)

            # Check the status until the job is "done" and get all error messages.
            http_mock.get('http://flink/v1/jobs/some_job_id/execution-result',
                          [{
                              'json': {
                                  'status': {
                                      'id': 'IN_PROGRESS'
                                  }
                              }
                          }, {
                              'json': {
                                  'status': {
                                      'id': 'IN_PROGRESS'
                                  }
                              }
                          }, {
                              'json': {
                                  'status': {
                                      'id': 'COMPLETED'
                                  }
                              }
                          }])
            http_mock.get('http://flink/v1/jobs/some_job_id',
                          json={'state': 'FINISHED'})
            http_mock.delete('http://flink/v1/jars/nonce')

            self.assertEqual([s.state for s in state_stream], [
                beam_job_api_pb2.JobState.STOPPED,
                beam_job_api_pb2.JobState.RUNNING,
                beam_job_api_pb2.JobState.RUNNING,
                beam_job_api_pb2.JobState.DONE
            ])

            http_mock.get('http://flink/v1/jobs/some_job_id/exceptions',
                          json={
                              'all-exceptions': [{
                                  'exception': 'exc_text',
                                  'timestamp': 0
                              }]
                          })

            def get_item(x):
                if x.HasField('message_response'):
                    return x.message_response
                else:
                    return x.state_response.state

            self.assertEqual([get_item(m) for m in message_stream], [
                beam_job_api_pb2.JobState.STOPPED,
                beam_job_api_pb2.JobState.RUNNING,
                beam_job_api_pb2.JobMessage(
                    message_id='message0',
                    time='0',
                    importance=beam_job_api_pb2.JobMessage.MessageImportance.
                    JOB_MESSAGE_ERROR,
                    message_text='exc_text'),
                beam_job_api_pb2.JobState.DONE,
            ])