def test_create_session_malformed_json(dag, mocker): op = LivySessionOperator(statements=[], task_id="test_create_session_malformed_json", dag=dag) http_response = mock_http_calls(201, content=b'{"id":{}') mocker.patch.object(HttpHook, "get_conn", return_value=http_response) with raises(AirflowBadRequest) as bre: op.create_session() print(f"\n\nImitated malformed JSON response when creating a session, " f"got the expected exception:\n<{bre.value}>")
def test_create_session_get_id(dag, mocker): op = LivySessionOperator( statements=[], task_id="test_create_session_get_id", dag=dag, ) http_response = mock_http_calls(201, content=b'{"id": 456}') mocker.patch.object(HttpHook, "get_conn", return_value=http_response) op.create_session() assert op.session_id == 456
def test_run_session_successfully(dag, mocker): st1 = LivySessionOperator.Statement(kind="spark", code="x = 1;") st2 = LivySessionOperator.Statement(kind="pyspark", code="print 'hi';") statements = [st1, st2] op = LivySessionOperator( statements=statements, spill_logs=False, task_id="test_run_session_successfully", dag=dag, ) spill_logs_spy = mocker.spy(op, "spill_session_logs") submit_statement_spy = mocker.spy(op, "submit_statement") mock_livy_session_responses(mocker) op.execute({}) submit_statement_spy.call_count = len(statements) # spill_logs is False and session completed successfully, so we don't expect logs. spill_logs_spy.assert_not_called() op.spill_logs = True op.execute({}) submit_statement_spy.call_count = len(statements) * 2 # We set spill_logs to True this time, therefore expecting logs. spill_logs_spy.assert_called_once()
def test_run_session_logs_greater_than_page_size(dag, mocker): op = LivySessionOperator( statements=[], spill_logs=True, task_id="test_run_session_logs_greater_than_page_size", dag=dag, ) fetch_log_page_spy = mocker.spy(op, "fetch_log_page") mock_livy_session_responses(mocker, log_lines=321) op.execute({}) assert fetch_log_page_spy.call_count == 4
def test_create_session_string_id(dag, mocker): op = LivySessionOperator(statements=[], task_id="test_create_session_string_id", dag=dag) http_response = mock_http_calls( 201, content=b'{"id":"unexpectedly, a string!"}') mocker.patch.object(HttpHook, "get_conn", return_value=http_response) with raises(AirflowException) as ae: op.create_session() print(f"\n\nImitated server returning a string for a session ID, " f"got the expected exception:\n<{ae.value}>")
def test_run_session_logs_one_page_size(dag, mocker): op = LivySessionOperator( statements=[], spill_logs=True, task_id="test_run_session_logs_one_page_size", dag=dag, ) fetch_log_page_spy = mocker.spy(op, "fetch_log_page") mock_livy_session_responses(mocker, log_lines=100) op.execute({}) fetch_log_page_spy.assert_called_once()
def test_create_session_bad_response_codes(dag, mocker, code): op = LivySessionOperator(statements=[], task_id="test_create_session_bad_response_codes", dag=dag) http_response = mock_http_calls(code, content=b"Error content", reason="Good reason") mocker.patch.object(HttpHook, "get_conn", return_value=http_response) with raises(AirflowException) as ae: op.create_session() print(f"\n\nImitated the {code} error response when creating a session, " f"got the expected exception:\n<{ae.value}>")
def test_run_session_logs_missing_attrs_in_json(dag, mocker): op = LivySessionOperator( statements=[], spill_logs=True, task_id="test_run_session_logs_missing_attrs_in_json", dag=dag, ) mock_livy_session_responses(mocker, log_override_response='{"id": 1, "from": 2}') with raises(AirflowException) as ae: op.execute({}) print(f"\n\nImitated missing attributes when calling /logs , " f"got the expected exception:\n<{ae.value}>")
def test_run_session_logs_malformed_json(dag, mocker): op = LivySessionOperator( statements=[], spill_logs=True, task_id="test_run_session_logs_greater_than_page_size", dag=dag, ) mock_livy_session_responses(mocker, log_override_response='{"invalid":json]}') with raises(AirflowException) as ae: op.execute({}) print(f"\n\nImitated malformed response when calling /logs , " f"got the expected exception:\n<{ae.value}>")
def test_allowed_statement_kinds(): kind = "unknown" with raises(AirflowException) as ae: LivySessionOperator.Statement(kind=kind, code="a=5;") print( f"\n\nTried to create a statement with kind '{kind}', " f"got the expected exception:\n<{ae.value}>" )
def test_run_session_error_during_status_probing(dag, mocker, code): op = LivySessionOperator( statements=[], spill_logs=True, task_id="test_run_session_error_during_status_probing", dag=dag, ) spill_logs_spy = mocker.spy(op, "spill_session_logs") mock_livy_session_responses( mocker, mock_get_session=[ MockedResponse(code, body=f"Response from server:{code}") ], ) with raises(AirflowException) as ae: op.execute({}) print( f"\n\nImitated {code} response from server during session creation probing , " f"got the expected exception:\n<{ae.value}>") # spill_logs=True, and Operator had the session_id by the time error occured. spill_logs_spy.assert_called_once() op.spill_logs = False with raises(AirflowException): op.execute({}) # spill_logs=False, but error occured and Operator had the session_id. assert spill_logs_spy.call_count == 2
def test_run_session_error_before_session_created(dag, mocker): op = LivySessionOperator( statements=[], spill_logs=True, task_id="test_run_session_error_before_session_created", dag=dag, ) spill_logs_spy = mocker.spy(op, "spill_session_logs") mocker.patch.object( HttpHook, "get_connection", return_value=Connection(host="HOST", port=123), ) with raises(requests.exceptions.ConnectionError) as ae: op.execute({}) print(f"\n\nNo response from server was mocked, " f"got the expected exception:\n<{ae.value}>") # Even though we set spill_logs to True, Operator doesn't have a session_id yet. spill_logs_spy.assert_not_called()
def test_allowed_session_kinds(dag): kind = "unknown" with raises(AirflowException) as ae: LivySessionOperator( kind=kind, statements=[], task_id="test_allowed_session_kinds", dag=dag, ) print( f"\n\nTried to create a session with kind '{kind}', " f"got the expected exception:\n<{ae.value}>" )
def test_statement_repr(): st = LivySessionOperator.Statement(kind="spark", code="a=5;\nprint('s');") assert (st.__str__() == """ { Statement, kind: spark code: -------------------------------------------------------------------------------- a=5; print('s'); -------------------------------------------------------------------------------- }""")
def test_run_session_error_when_submitting_statement(dag, mocker, code): st1 = LivySessionOperator.Statement(kind="spark", code="x = 1;") st2 = LivySessionOperator.Statement(kind="pyspark", code="print 'hi';") op = LivySessionOperator( statements=[st1, st2], spill_logs=True, task_id="test_run_session_error_when_submitting_statement", dag=dag, ) spill_logs_spy = mocker.spy(op, "spill_session_logs") submit_statement_spy = mocker.spy(op, "submit_statement") mock_livy_session_responses( mocker, mock_post_statement=[ MockedResponse(200, json_body={"id": STATEMENT_ID}), MockedResponse(code, body=f"Response from server:{code}"), MockedResponse(200, json_body={"no id here": "haha"}), ], ) with raises(AirflowException) as ae: op.execute({}) print( f"\n\nImitated {code} response from server during second statement submission, " f"got the expected exception:\n<{ae.value}>") # spill_logs=True, and Operator had the session_id by the time error occured. spill_logs_spy.assert_called_once() assert submit_statement_spy.call_count == 2 op.spill_logs = False with raises(AirflowException): op.execute({}) print( f"\n\nImitated {code} response from server during first statement submission, " f"got the expected exception:\n<{ae.value}>") # spill_logs=False, but error occured and Operator had the session_id. assert spill_logs_spy.call_count == 2 assert submit_statement_spy.call_count == 3
def test_jinja(dag): st1 = LivySessionOperator.Statement(kind="spark", code="x=1+{{ custom_param }};") st2 = LivySessionOperator.Statement( kind="pyspark", code="print('{{run_id | replace(':', '-')}}')") op = LivySessionOperator( name="test_jinja_{{ run_id }}", statements=[st1, st2], task_id="test_jinja_session", dag=dag, ) op.render_template_fields({"run_id": "hello:world", "custom_param": 3}) assert op.name == "test_jinja_hello:world" assert op.statements[0].code == "x=1+3;" assert op.statements[1].code == "print('hello-world')"
def test_create_session_params(dag, mocker): heartbeat_timeout = 9 session_start_timeout_sec = 11 session_start_poll_period_sec = 22 statemt_timeout_minutes = 33 statemt_poll_period_sec = 44 http_conn_id = "foo" spill_logs = True st1 = LivySessionOperator.Statement(kind="spark", code="x=1") st2 = LivySessionOperator.Statement(kind="sparkr", code="print hi") op = LivySessionOperator( statements=[st1, st2], kind="pyspark", proxy_user="******", jars=["jar1", "jar2"], py_files=["py_file1", "py_file2"], files=["file1", "file2"], driver_memory="driver_memory", driver_cores=1, executor_memory="executor_memory", executor_cores=2, num_executors=3, archives=["archive1", "archive2"], queue="queue", name="name", conf={ "key1": "val1", "key2": 2 }, heartbeat_timeout=heartbeat_timeout, session_start_timeout_sec=session_start_timeout_sec, session_start_poll_period_sec=session_start_poll_period_sec, statemt_timeout_minutes=statemt_timeout_minutes, statemt_poll_period_sec=statemt_poll_period_sec, http_conn_id=http_conn_id, spill_logs=spill_logs, task_id="test_create_session_params", dag=dag, ) mock_response = Response() mock_response._content = b'{"id": 1}' patched_hook = mocker.patch.object(HttpHook, "run", return_value=mock_response) op.create_session() assert op.statements[0] == st1 assert op.statements[1] == st2 assert op.heartbeat_timeout == heartbeat_timeout assert op.session_start_timeout_sec == session_start_timeout_sec assert op.session_start_poll_period_sec == session_start_poll_period_sec assert op.statemt_timeout_minutes == statemt_timeout_minutes assert op.statemt_poll_period_sec == statemt_poll_period_sec assert op.http_conn_id == http_conn_id assert op.spill_logs == spill_logs expected_json = json.loads("""{ "kind": "pyspark", "proxyUser": "******", "jars": [ "jar1", "jar2" ], "pyFiles": [ "py_file1", "py_file2" ], "files": [ "file1", "file2" ], "driverMemory": "driver_memory", "driverCores": 1, "executorMemory": "executor_memory", "executorCores": 2, "numExecutors": 3, "archives": [ "archive1", "archive2" ], "queue": "queue", "name": "name", "conf": { "key1": "val1", "key2": 2 }, "heartbeatTimeoutInSecond": 9 }""") actual_args, actual_kwargs = patched_hook._call_matcher( patched_hook.call_args) actual_json = find_json_in_args(actual_args, actual_kwargs) if actual_json is None: raise AssertionError( f"Can not find JSON in HttpHook args.\n" f"Args:\n{actual_args}\n" f"KWArgs (JSON should be under 'data' key):\n{actual_kwargs}") else: diff = DeepDiff(actual_json, expected_json, ignore_order=True) if diff: print(f"\nDifference:\n{json.dumps(diff, indent=2)}") assert not diff