def spark_job(): return spark_uber_jar_job_server.SparkBeamJob( 'http://host:6066', '', '', '', '', '', pipeline_options.SparkRunnerOptions())
def test_end_to_end(self, http_mock): submission_id = "submission-id" worker_host_port = "workerhost:12345" worker_id = "worker-id" server_spark_version = "1.2.3" def spark_submission_status_response(state): return { 'json': { "action": "SubmissionStatusResponse", "driverState": state, "serverSparkVersion": server_spark_version, "submissionId": submission_id, "success": "true", "workerHostPort": worker_host_port, "workerId": worker_id } } with temp_name(suffix='fake.jar') as fake_jar: with zipfile.ZipFile(fake_jar, 'w') as zip: with zip.open('spark-version-info.properties', 'w') as fout: fout.write(b'version=4.5.6') options = pipeline_options.SparkRunnerOptions() options.spark_job_server_jar = fake_jar job_server = spark_uber_jar_job_server.SparkUberJarJobServer( 'http://host:6066', options) # Prepare the job. plan = TestJobServicePlan(job_server) # Prepare the job. prepare_response = plan.prepare(beam_runner_api_pb2.Pipeline()) retrieval_token = plan.stage( beam_runner_api_pb2.Pipeline(), prepare_response.artifact_staging_endpoint.url, prepare_response.staging_session_token) # Now actually run the job. http_mock.post( 'http://host:6066/v1/submissions/create', json={ "action": "CreateSubmissionResponse", "message": "Driver successfully submitted as submission-id", "serverSparkVersion": "1.2.3", "submissionId": "submission-id", "success": "true" }) job_server.Run( beam_job_api_pb2.RunJobRequest( preparation_id=prepare_response.preparation_id, retrieval_token=retrieval_token)) # Check the status until the job is "done" and get all error messages. http_mock.get( 'http://host:6066/v1/submissions/status/submission-id', [ spark_submission_status_response('RUNNING'), spark_submission_status_response('RUNNING'), { 'json': { "action": "SubmissionStatusResponse", "driverState": "ERROR", "message": "oops", "serverSparkVersion": "1.2.3", "submissionId": submission_id, "success": "true", "workerHostPort": worker_host_port, "workerId": worker_id } } ]) state_stream = job_server.GetStateStream( beam_job_api_pb2.GetJobStateRequest( job_id=prepare_response.preparation_id)) self.assertEqual([s.state for s in state_stream], [ beam_job_api_pb2.JobState.STOPPED, beam_job_api_pb2.JobState.RUNNING, beam_job_api_pb2.JobState.RUNNING, beam_job_api_pb2.JobState.FAILED ]) message_stream = job_server.GetMessageStream( beam_job_api_pb2.JobMessagesRequest( job_id=prepare_response.preparation_id)) def get_item(x): if x.HasField('message_response'): return x.message_response else: return x.state_response.state self.assertEqual([get_item(m) for m in message_stream], [ beam_job_api_pb2.JobState.STOPPED, beam_job_api_pb2.JobState.RUNNING, beam_job_api_pb2.JobMessage( message_id='message0', time='0', importance=beam_job_api_pb2.JobMessage.MessageImportance. JOB_MESSAGE_ERROR, message_text="oops"), beam_job_api_pb2.JobState.FAILED, ])