예제 #1
0
def test_run_session_successfully(dag, mocker):
    st1 = LivySessionOperator.Statement(kind="spark", code="x = 1;")
    st2 = LivySessionOperator.Statement(kind="pyspark", code="print 'hi';")
    statements = [st1, st2]
    op = LivySessionOperator(
        statements=statements,
        spill_logs=False,
        task_id="test_run_session_successfully",
        dag=dag,
    )
    spill_logs_spy = mocker.spy(op, "spill_session_logs")
    submit_statement_spy = mocker.spy(op, "submit_statement")
    mock_livy_session_responses(mocker)
    op.execute({})

    submit_statement_spy.call_count = len(statements)

    # spill_logs is False and session completed successfully, so we don't expect logs.
    spill_logs_spy.assert_not_called()
    op.spill_logs = True
    op.execute({})

    submit_statement_spy.call_count = len(statements) * 2
    # We set spill_logs to True this time, therefore expecting logs.
    spill_logs_spy.assert_called_once()
예제 #2
0
def test_jinja(dag):
    st1 = LivySessionOperator.Statement(kind="spark",
                                        code="x=1+{{ custom_param }};")
    st2 = LivySessionOperator.Statement(
        kind="pyspark", code="print('{{run_id | replace(':', '-')}}')")
    op = LivySessionOperator(
        name="test_jinja_{{ run_id }}",
        statements=[st1, st2],
        task_id="test_jinja_session",
        dag=dag,
    )
    op.render_template_fields({"run_id": "hello:world", "custom_param": 3})
    assert op.name == "test_jinja_hello:world"
    assert op.statements[0].code == "x=1+3;"
    assert op.statements[1].code == "print('hello-world')"
def test_allowed_statement_kinds():
    kind = "unknown"
    with raises(AirflowException) as ae:
        LivySessionOperator.Statement(kind=kind, code="a=5;")
    print(
        f"\n\nTried to create a statement with kind '{kind}', "
        f"got the expected exception:\n<{ae.value}>"
    )
예제 #4
0
def test_statement_repr():
    st = LivySessionOperator.Statement(kind="spark", code="a=5;\nprint('s');")
    assert (st.__str__() == """
{
  Statement, kind: spark
  code:
--------------------------------------------------------------------------------
a=5;
print('s');
--------------------------------------------------------------------------------
}""")
예제 #5
0
def test_run_session_error_when_submitting_statement(dag, mocker, code):
    st1 = LivySessionOperator.Statement(kind="spark", code="x = 1;")
    st2 = LivySessionOperator.Statement(kind="pyspark", code="print 'hi';")
    op = LivySessionOperator(
        statements=[st1, st2],
        spill_logs=True,
        task_id="test_run_session_error_when_submitting_statement",
        dag=dag,
    )
    spill_logs_spy = mocker.spy(op, "spill_session_logs")
    submit_statement_spy = mocker.spy(op, "submit_statement")
    mock_livy_session_responses(
        mocker,
        mock_post_statement=[
            MockedResponse(200, json_body={"id": STATEMENT_ID}),
            MockedResponse(code, body=f"Response from server:{code}"),
            MockedResponse(200, json_body={"no id here": "haha"}),
        ],
    )
    with raises(AirflowException) as ae:
        op.execute({})
    print(
        f"\n\nImitated {code} response from server during second statement submission, "
        f"got the expected exception:\n<{ae.value}>")
    # spill_logs=True, and Operator had the session_id by the time error occured.
    spill_logs_spy.assert_called_once()
    assert submit_statement_spy.call_count == 2
    op.spill_logs = False
    with raises(AirflowException):
        op.execute({})
    print(
        f"\n\nImitated {code} response from server during first statement submission, "
        f"got the expected exception:\n<{ae.value}>")
    # spill_logs=False, but error occured and Operator had the session_id.
    assert spill_logs_spy.call_count == 2
    assert submit_statement_spy.call_count == 3
def test_create_session_params(dag, mocker):
    heartbeat_timeout = 9
    session_start_timeout_sec = 11
    session_start_poll_period_sec = 22
    statemt_timeout_minutes = 33
    statemt_poll_period_sec = 44
    http_conn_id = "foo"
    spill_logs = True
    st1 = LivySessionOperator.Statement(kind="spark", code="x=1")
    st2 = LivySessionOperator.Statement(kind="sparkr", code="print hi")
    op = LivySessionOperator(
        statements=[st1, st2],
        kind="pyspark",
        proxy_user="******",
        jars=["jar1", "jar2"],
        py_files=["py_file1", "py_file2"],
        files=["file1", "file2"],
        driver_memory="driver_memory",
        driver_cores=1,
        executor_memory="executor_memory",
        executor_cores=2,
        num_executors=3,
        archives=["archive1", "archive2"],
        queue="queue",
        name="name",
        conf={
            "key1": "val1",
            "key2": 2
        },
        heartbeat_timeout=heartbeat_timeout,
        session_start_timeout_sec=session_start_timeout_sec,
        session_start_poll_period_sec=session_start_poll_period_sec,
        statemt_timeout_minutes=statemt_timeout_minutes,
        statemt_poll_period_sec=statemt_poll_period_sec,
        http_conn_id=http_conn_id,
        spill_logs=spill_logs,
        task_id="test_create_session_params",
        dag=dag,
    )
    mock_response = Response()
    mock_response._content = b'{"id": 1}'
    patched_hook = mocker.patch.object(HttpHook,
                                       "run",
                                       return_value=mock_response)

    op.create_session()

    assert op.statements[0] == st1
    assert op.statements[1] == st2
    assert op.heartbeat_timeout == heartbeat_timeout
    assert op.session_start_timeout_sec == session_start_timeout_sec
    assert op.session_start_poll_period_sec == session_start_poll_period_sec
    assert op.statemt_timeout_minutes == statemt_timeout_minutes
    assert op.statemt_poll_period_sec == statemt_poll_period_sec
    assert op.http_conn_id == http_conn_id
    assert op.spill_logs == spill_logs
    expected_json = json.loads("""{
          "kind": "pyspark",
          "proxyUser": "******",
          "jars": [
            "jar1",
            "jar2"
          ],
          "pyFiles": [
            "py_file1",
            "py_file2"
          ],
          "files": [
            "file1",
            "file2"
          ],
          "driverMemory": "driver_memory",
          "driverCores": 1,
          "executorMemory": "executor_memory",
          "executorCores": 2,
          "numExecutors": 3,
          "archives": [
            "archive1",
            "archive2"
          ],
          "queue": "queue",
          "name": "name",
          "conf": {
            "key1": "val1",
            "key2": 2
          },
          "heartbeatTimeoutInSecond": 9
        }""")
    actual_args, actual_kwargs = patched_hook._call_matcher(
        patched_hook.call_args)
    actual_json = find_json_in_args(actual_args, actual_kwargs)
    if actual_json is None:
        raise AssertionError(
            f"Can not find JSON in HttpHook args.\n"
            f"Args:\n{actual_args}\n"
            f"KWArgs (JSON should be under 'data' key):\n{actual_kwargs}")
    else:
        diff = DeepDiff(actual_json, expected_json, ignore_order=True)
        if diff:
            print(f"\nDifference:\n{json.dumps(diff, indent=2)}")
        assert not diff