def test_create_session_malformed_json(dag, mocker):
    op = LivySessionOperator(statements=[],
                             task_id="test_create_session_malformed_json",
                             dag=dag)
    http_response = mock_http_calls(201, content=b'{"id":{}')
    mocker.patch.object(HttpHook, "get_conn", return_value=http_response)
    with raises(AirflowBadRequest) as bre:
        op.create_session()
    print(f"\n\nImitated malformed JSON response when creating a session, "
          f"got the expected exception:\n<{bre.value}>")
def test_create_session_get_id(dag, mocker):
    op = LivySessionOperator(
        statements=[],
        task_id="test_create_session_get_id",
        dag=dag,
    )
    http_response = mock_http_calls(201, content=b'{"id": 456}')
    mocker.patch.object(HttpHook, "get_conn", return_value=http_response)
    op.create_session()
    assert op.session_id == 456
示例#3
0
def test_run_session_successfully(dag, mocker):
    st1 = LivySessionOperator.Statement(kind="spark", code="x = 1;")
    st2 = LivySessionOperator.Statement(kind="pyspark", code="print 'hi';")
    statements = [st1, st2]
    op = LivySessionOperator(
        statements=statements,
        spill_logs=False,
        task_id="test_run_session_successfully",
        dag=dag,
    )
    spill_logs_spy = mocker.spy(op, "spill_session_logs")
    submit_statement_spy = mocker.spy(op, "submit_statement")
    mock_livy_session_responses(mocker)
    op.execute({})

    submit_statement_spy.call_count = len(statements)

    # spill_logs is False and session completed successfully, so we don't expect logs.
    spill_logs_spy.assert_not_called()
    op.spill_logs = True
    op.execute({})

    submit_statement_spy.call_count = len(statements) * 2
    # We set spill_logs to True this time, therefore expecting logs.
    spill_logs_spy.assert_called_once()
示例#4
0
def test_run_session_logs_greater_than_page_size(dag, mocker):
    op = LivySessionOperator(
        statements=[],
        spill_logs=True,
        task_id="test_run_session_logs_greater_than_page_size",
        dag=dag,
    )
    fetch_log_page_spy = mocker.spy(op, "fetch_log_page")
    mock_livy_session_responses(mocker, log_lines=321)
    op.execute({})
    assert fetch_log_page_spy.call_count == 4
def test_create_session_string_id(dag, mocker):
    op = LivySessionOperator(statements=[],
                             task_id="test_create_session_string_id",
                             dag=dag)
    http_response = mock_http_calls(
        201, content=b'{"id":"unexpectedly, a string!"}')
    mocker.patch.object(HttpHook, "get_conn", return_value=http_response)
    with raises(AirflowException) as ae:
        op.create_session()
    print(f"\n\nImitated server returning a string for a session ID, "
          f"got the expected exception:\n<{ae.value}>")
示例#6
0
def test_run_session_logs_one_page_size(dag, mocker):
    op = LivySessionOperator(
        statements=[],
        spill_logs=True,
        task_id="test_run_session_logs_one_page_size",
        dag=dag,
    )
    fetch_log_page_spy = mocker.spy(op, "fetch_log_page")
    mock_livy_session_responses(mocker, log_lines=100)
    op.execute({})
    fetch_log_page_spy.assert_called_once()
def test_create_session_bad_response_codes(dag, mocker, code):
    op = LivySessionOperator(statements=[],
                             task_id="test_create_session_bad_response_codes",
                             dag=dag)
    http_response = mock_http_calls(code,
                                    content=b"Error content",
                                    reason="Good reason")
    mocker.patch.object(HttpHook, "get_conn", return_value=http_response)
    with raises(AirflowException) as ae:
        op.create_session()
    print(f"\n\nImitated the {code} error response when creating a session, "
          f"got the expected exception:\n<{ae.value}>")
示例#8
0
def test_run_session_logs_missing_attrs_in_json(dag, mocker):
    op = LivySessionOperator(
        statements=[],
        spill_logs=True,
        task_id="test_run_session_logs_missing_attrs_in_json",
        dag=dag,
    )
    mock_livy_session_responses(mocker,
                                log_override_response='{"id": 1, "from": 2}')
    with raises(AirflowException) as ae:
        op.execute({})
    print(f"\n\nImitated missing attributes when calling /logs , "
          f"got the expected exception:\n<{ae.value}>")
示例#9
0
def test_run_session_logs_malformed_json(dag, mocker):
    op = LivySessionOperator(
        statements=[],
        spill_logs=True,
        task_id="test_run_session_logs_greater_than_page_size",
        dag=dag,
    )
    mock_livy_session_responses(mocker,
                                log_override_response='{"invalid":json]}')
    with raises(AirflowException) as ae:
        op.execute({})
    print(f"\n\nImitated malformed response when calling /logs , "
          f"got the expected exception:\n<{ae.value}>")
def test_allowed_statement_kinds():
    kind = "unknown"
    with raises(AirflowException) as ae:
        LivySessionOperator.Statement(kind=kind, code="a=5;")
    print(
        f"\n\nTried to create a statement with kind '{kind}', "
        f"got the expected exception:\n<{ae.value}>"
    )
示例#11
0
def test_run_session_error_during_status_probing(dag, mocker, code):
    op = LivySessionOperator(
        statements=[],
        spill_logs=True,
        task_id="test_run_session_error_during_status_probing",
        dag=dag,
    )
    spill_logs_spy = mocker.spy(op, "spill_session_logs")
    mock_livy_session_responses(
        mocker,
        mock_get_session=[
            MockedResponse(code, body=f"Response from server:{code}")
        ],
    )
    with raises(AirflowException) as ae:
        op.execute({})
    print(
        f"\n\nImitated {code} response from server during session creation probing , "
        f"got the expected exception:\n<{ae.value}>")
    # spill_logs=True, and Operator had the session_id by the time error occured.
    spill_logs_spy.assert_called_once()
    op.spill_logs = False
    with raises(AirflowException):
        op.execute({})
    # spill_logs=False, but error occured and Operator had the session_id.
    assert spill_logs_spy.call_count == 2
示例#12
0
def test_run_session_error_before_session_created(dag, mocker):
    op = LivySessionOperator(
        statements=[],
        spill_logs=True,
        task_id="test_run_session_error_before_session_created",
        dag=dag,
    )
    spill_logs_spy = mocker.spy(op, "spill_session_logs")
    mocker.patch.object(
        HttpHook,
        "get_connection",
        return_value=Connection(host="HOST", port=123),
    )
    with raises(requests.exceptions.ConnectionError) as ae:
        op.execute({})
    print(f"\n\nNo response from server was mocked, "
          f"got the expected exception:\n<{ae.value}>")
    # Even though we set spill_logs to True, Operator doesn't have a session_id yet.
    spill_logs_spy.assert_not_called()
def test_allowed_session_kinds(dag):
    kind = "unknown"
    with raises(AirflowException) as ae:
        LivySessionOperator(
            kind=kind, statements=[], task_id="test_allowed_session_kinds", dag=dag,
        )
    print(
        f"\n\nTried to create a session with kind '{kind}', "
        f"got the expected exception:\n<{ae.value}>"
    )
示例#14
0
def test_statement_repr():
    st = LivySessionOperator.Statement(kind="spark", code="a=5;\nprint('s');")
    assert (st.__str__() == """
{
  Statement, kind: spark
  code:
--------------------------------------------------------------------------------
a=5;
print('s');
--------------------------------------------------------------------------------
}""")
示例#15
0
def test_run_session_error_when_submitting_statement(dag, mocker, code):
    st1 = LivySessionOperator.Statement(kind="spark", code="x = 1;")
    st2 = LivySessionOperator.Statement(kind="pyspark", code="print 'hi';")
    op = LivySessionOperator(
        statements=[st1, st2],
        spill_logs=True,
        task_id="test_run_session_error_when_submitting_statement",
        dag=dag,
    )
    spill_logs_spy = mocker.spy(op, "spill_session_logs")
    submit_statement_spy = mocker.spy(op, "submit_statement")
    mock_livy_session_responses(
        mocker,
        mock_post_statement=[
            MockedResponse(200, json_body={"id": STATEMENT_ID}),
            MockedResponse(code, body=f"Response from server:{code}"),
            MockedResponse(200, json_body={"no id here": "haha"}),
        ],
    )
    with raises(AirflowException) as ae:
        op.execute({})
    print(
        f"\n\nImitated {code} response from server during second statement submission, "
        f"got the expected exception:\n<{ae.value}>")
    # spill_logs=True, and Operator had the session_id by the time error occured.
    spill_logs_spy.assert_called_once()
    assert submit_statement_spy.call_count == 2
    op.spill_logs = False
    with raises(AirflowException):
        op.execute({})
    print(
        f"\n\nImitated {code} response from server during first statement submission, "
        f"got the expected exception:\n<{ae.value}>")
    # spill_logs=False, but error occured and Operator had the session_id.
    assert spill_logs_spy.call_count == 2
    assert submit_statement_spy.call_count == 3
示例#16
0
def test_jinja(dag):
    st1 = LivySessionOperator.Statement(kind="spark",
                                        code="x=1+{{ custom_param }};")
    st2 = LivySessionOperator.Statement(
        kind="pyspark", code="print('{{run_id | replace(':', '-')}}')")
    op = LivySessionOperator(
        name="test_jinja_{{ run_id }}",
        statements=[st1, st2],
        task_id="test_jinja_session",
        dag=dag,
    )
    op.render_template_fields({"run_id": "hello:world", "custom_param": 3})
    assert op.name == "test_jinja_hello:world"
    assert op.statements[0].code == "x=1+3;"
    assert op.statements[1].code == "print('hello-world')"
def test_create_session_params(dag, mocker):
    heartbeat_timeout = 9
    session_start_timeout_sec = 11
    session_start_poll_period_sec = 22
    statemt_timeout_minutes = 33
    statemt_poll_period_sec = 44
    http_conn_id = "foo"
    spill_logs = True
    st1 = LivySessionOperator.Statement(kind="spark", code="x=1")
    st2 = LivySessionOperator.Statement(kind="sparkr", code="print hi")
    op = LivySessionOperator(
        statements=[st1, st2],
        kind="pyspark",
        proxy_user="******",
        jars=["jar1", "jar2"],
        py_files=["py_file1", "py_file2"],
        files=["file1", "file2"],
        driver_memory="driver_memory",
        driver_cores=1,
        executor_memory="executor_memory",
        executor_cores=2,
        num_executors=3,
        archives=["archive1", "archive2"],
        queue="queue",
        name="name",
        conf={
            "key1": "val1",
            "key2": 2
        },
        heartbeat_timeout=heartbeat_timeout,
        session_start_timeout_sec=session_start_timeout_sec,
        session_start_poll_period_sec=session_start_poll_period_sec,
        statemt_timeout_minutes=statemt_timeout_minutes,
        statemt_poll_period_sec=statemt_poll_period_sec,
        http_conn_id=http_conn_id,
        spill_logs=spill_logs,
        task_id="test_create_session_params",
        dag=dag,
    )
    mock_response = Response()
    mock_response._content = b'{"id": 1}'
    patched_hook = mocker.patch.object(HttpHook,
                                       "run",
                                       return_value=mock_response)

    op.create_session()

    assert op.statements[0] == st1
    assert op.statements[1] == st2
    assert op.heartbeat_timeout == heartbeat_timeout
    assert op.session_start_timeout_sec == session_start_timeout_sec
    assert op.session_start_poll_period_sec == session_start_poll_period_sec
    assert op.statemt_timeout_minutes == statemt_timeout_minutes
    assert op.statemt_poll_period_sec == statemt_poll_period_sec
    assert op.http_conn_id == http_conn_id
    assert op.spill_logs == spill_logs
    expected_json = json.loads("""{
          "kind": "pyspark",
          "proxyUser": "******",
          "jars": [
            "jar1",
            "jar2"
          ],
          "pyFiles": [
            "py_file1",
            "py_file2"
          ],
          "files": [
            "file1",
            "file2"
          ],
          "driverMemory": "driver_memory",
          "driverCores": 1,
          "executorMemory": "executor_memory",
          "executorCores": 2,
          "numExecutors": 3,
          "archives": [
            "archive1",
            "archive2"
          ],
          "queue": "queue",
          "name": "name",
          "conf": {
            "key1": "val1",
            "key2": 2
          },
          "heartbeatTimeoutInSecond": 9
        }""")
    actual_args, actual_kwargs = patched_hook._call_matcher(
        patched_hook.call_args)
    actual_json = find_json_in_args(actual_args, actual_kwargs)
    if actual_json is None:
        raise AssertionError(
            f"Can not find JSON in HttpHook args.\n"
            f"Args:\n{actual_args}\n"
            f"KWArgs (JSON should be under 'data' key):\n{actual_kwargs}")
    else:
        diff = DeepDiff(actual_json, expected_json, ignore_order=True)
        if diff:
            print(f"\nDifference:\n{json.dumps(diff, indent=2)}")
        assert not diff