def test_create_session_get_id(dag, mocker):
    op = LivySessionOperator(
        statements=[],
        task_id="test_create_session_get_id",
        dag=dag,
    )
    http_response = mock_http_calls(201, content=b'{"id": 456}')
    mocker.patch.object(HttpHook, "get_conn", return_value=http_response)
    op.create_session()
    assert op.session_id == 456
def test_create_session_malformed_json(dag, mocker):
    op = LivySessionOperator(statements=[],
                             task_id="test_create_session_malformed_json",
                             dag=dag)
    http_response = mock_http_calls(201, content=b'{"id":{}')
    mocker.patch.object(HttpHook, "get_conn", return_value=http_response)
    with raises(AirflowBadRequest) as bre:
        op.create_session()
    print(f"\n\nImitated malformed JSON response when creating a session, "
          f"got the expected exception:\n<{bre.value}>")
def test_create_session_string_id(dag, mocker):
    op = LivySessionOperator(statements=[],
                             task_id="test_create_session_string_id",
                             dag=dag)
    http_response = mock_http_calls(
        201, content=b'{"id":"unexpectedly, a string!"}')
    mocker.patch.object(HttpHook, "get_conn", return_value=http_response)
    with raises(AirflowException) as ae:
        op.create_session()
    print(f"\n\nImitated server returning a string for a session ID, "
          f"got the expected exception:\n<{ae.value}>")
def test_create_session_bad_response_codes(dag, mocker, code):
    op = LivySessionOperator(statements=[],
                             task_id="test_create_session_bad_response_codes",
                             dag=dag)
    http_response = mock_http_calls(code,
                                    content=b"Error content",
                                    reason="Good reason")
    mocker.patch.object(HttpHook, "get_conn", return_value=http_response)
    with raises(AirflowException) as ae:
        op.create_session()
    print(f"\n\nImitated the {code} error response when creating a session, "
          f"got the expected exception:\n<{ae.value}>")
def test_create_session_params(dag, mocker):
    heartbeat_timeout = 9
    session_start_timeout_sec = 11
    session_start_poll_period_sec = 22
    statemt_timeout_minutes = 33
    statemt_poll_period_sec = 44
    http_conn_id = "foo"
    spill_logs = True
    st1 = LivySessionOperator.Statement(kind="spark", code="x=1")
    st2 = LivySessionOperator.Statement(kind="sparkr", code="print hi")
    op = LivySessionOperator(
        statements=[st1, st2],
        kind="pyspark",
        proxy_user="******",
        jars=["jar1", "jar2"],
        py_files=["py_file1", "py_file2"],
        files=["file1", "file2"],
        driver_memory="driver_memory",
        driver_cores=1,
        executor_memory="executor_memory",
        executor_cores=2,
        num_executors=3,
        archives=["archive1", "archive2"],
        queue="queue",
        name="name",
        conf={
            "key1": "val1",
            "key2": 2
        },
        heartbeat_timeout=heartbeat_timeout,
        session_start_timeout_sec=session_start_timeout_sec,
        session_start_poll_period_sec=session_start_poll_period_sec,
        statemt_timeout_minutes=statemt_timeout_minutes,
        statemt_poll_period_sec=statemt_poll_period_sec,
        http_conn_id=http_conn_id,
        spill_logs=spill_logs,
        task_id="test_create_session_params",
        dag=dag,
    )
    mock_response = Response()
    mock_response._content = b'{"id": 1}'
    patched_hook = mocker.patch.object(HttpHook,
                                       "run",
                                       return_value=mock_response)

    op.create_session()

    assert op.statements[0] == st1
    assert op.statements[1] == st2
    assert op.heartbeat_timeout == heartbeat_timeout
    assert op.session_start_timeout_sec == session_start_timeout_sec
    assert op.session_start_poll_period_sec == session_start_poll_period_sec
    assert op.statemt_timeout_minutes == statemt_timeout_minutes
    assert op.statemt_poll_period_sec == statemt_poll_period_sec
    assert op.http_conn_id == http_conn_id
    assert op.spill_logs == spill_logs
    expected_json = json.loads("""{
          "kind": "pyspark",
          "proxyUser": "******",
          "jars": [
            "jar1",
            "jar2"
          ],
          "pyFiles": [
            "py_file1",
            "py_file2"
          ],
          "files": [
            "file1",
            "file2"
          ],
          "driverMemory": "driver_memory",
          "driverCores": 1,
          "executorMemory": "executor_memory",
          "executorCores": 2,
          "numExecutors": 3,
          "archives": [
            "archive1",
            "archive2"
          ],
          "queue": "queue",
          "name": "name",
          "conf": {
            "key1": "val1",
            "key2": 2
          },
          "heartbeatTimeoutInSecond": 9
        }""")
    actual_args, actual_kwargs = patched_hook._call_matcher(
        patched_hook.call_args)
    actual_json = find_json_in_args(actual_args, actual_kwargs)
    if actual_json is None:
        raise AssertionError(
            f"Can not find JSON in HttpHook args.\n"
            f"Args:\n{actual_args}\n"
            f"KWArgs (JSON should be under 'data' key):\n{actual_kwargs}")
    else:
        diff = DeepDiff(actual_json, expected_json, ignore_order=True)
        if diff:
            print(f"\nDifference:\n{json.dumps(diff, indent=2)}")
        assert not diff