Пример #1
0
def test_agent_log_level_responds_to_config(cloud_api):
    with set_temporary_config({
            "cloud.agent.auth_token":
            "TEST_TOKEN",
            "cloud.agent.level":
            "DEBUG",
            "cloud.agent.agent_address":
            "http://localhost:8000",
    }):
        agent = Agent()
        assert agent.logger.level == 10
        assert agent.agent_address == "http://localhost:8000"
Пример #2
0
def test_agent_start_max_polls_count(monkeypatch, runner_token, cloud_api):
    on_shutdown = MagicMock()
    monkeypatch.setattr("prefect.agent.agent.Agent.on_shutdown", on_shutdown)

    agent_process = MagicMock()
    monkeypatch.setattr("prefect.agent.agent.Agent.agent_process",
                        agent_process)

    agent_connect = MagicMock(return_value="id")
    monkeypatch.setattr("prefect.agent.agent.Agent.agent_connect",
                        agent_connect)

    heartbeat = MagicMock()
    monkeypatch.setattr("prefect.agent.agent.Agent.heartbeat", heartbeat)

    agent = Agent(max_polls=2)
    agent.start()

    assert on_shutdown.call_count == 1
    assert agent_process.call_count == 2
    assert heartbeat.call_count == 2
Пример #3
0
def test_agent_process(monkeypatch, runner_token, cloud_api):
    gql_return = MagicMock(
        return_value=MagicMock(
            data=MagicMock(
                set_flow_run_state=None,
                set_task_run_state=None,
                get_runs_in_queue=MagicMock(flow_run_ids=["id"]),
                flow_run=[
                    GraphQLResult(
                        {
                            "id": "id",
                            "serialized_state": Scheduled().serialize(),
                            "version": 1,
                            "task_runs": [
                                GraphQLResult(
                                    {
                                        "id": "id",
                                        "version": 1,
                                        "serialized_state": Scheduled().serialize(),
                                    }
                                )
                            ],
                        }
                    )
                ],
            )
        )
    )
    client = MagicMock()
    client.return_value.graphql = gql_return
    monkeypatch.setattr("prefect.agent.agent.Client", client)

    executor = MagicMock()
    future_mock = MagicMock()
    executor.submit = MagicMock(return_value=future_mock)

    agent = Agent()
    assert agent.agent_process(executor)
    assert executor.submit.called
    assert future_mock.add_done_callback.called
Пример #4
0
def test_agent_api_health_check(cloud_api):
    requests = pytest.importorskip("requests")

    with socket.socket() as sock:
        sock.bind(("", 0))
        port = sock.getsockname()[1]

    agent = Agent(agent_address=f"http://127.0.0.1:{port}", max_polls=1)

    agent._start_agent_api_server()

    # May take a sec for the api server to startup
    for attempt in range(5):
        try:
            resp = requests.get(f"http://127.0.0.1:{port}/api/health")
            break
        except Exception:
            time.sleep(0.1)
    else:
        assert False, "Failed to connect to health check"

    assert resp.status_code == 200

    agent._stop_agent_api_server()
    assert not agent._api_server_thread.is_alive()
Пример #5
0
def test_deploy_flow_run_sleeps_until_start_time(monkeypatch, cloud_api):
    gql_return = MagicMock(
        return_value=MagicMock(data=MagicMock(write_run_logs=MagicMock(success=True)))
    )
    client = MagicMock()
    client.return_value.write_run_logs = gql_return
    monkeypatch.setattr("prefect.agent.agent.Client", MagicMock(return_value=client))
    sleep = MagicMock()
    monkeypatch.setattr("time.sleep", sleep)

    dt = pendulum.now()
    agent = Agent()
    agent.deploy_flow = MagicMock()
    agent._deploy_flow_run(
        flow_run=GraphQLResult(
            {
                "id": "id",
                "serialized_state": Scheduled(
                    start_time=dt.add(seconds=10)
                ).serialize(),
                "scheduled_start_time": str(dt),
                "version": 1,
                "task_runs": [
                    GraphQLResult(
                        {
                            "id": "id",
                            "version": 1,
                            "serialized_state": Scheduled(
                                start_time=dt.add(seconds=10)
                            ).serialize(),
                        }
                    )
                ],
            }
        )
    )

    sleep_time = sleep.call_args[0][0]
    assert 10 >= sleep_time > 9
    agent.deploy_flow.assert_called_once()
Пример #6
0
def test_get_ready_flow_runs_ignores_currently_submitting_runs(
        monkeypatch, cloud_api):
    gql_return = MagicMock(return_value=MagicMock(data=MagicMock(
        get_runs_in_queue=MagicMock(flow_run_ids=["id1", "id2"]),
        flow_run=[
            GraphQLResult({
                "id": "id",
                "scheduled_start_time": str(pendulum.now())
            })
        ],
    )))
    client = MagicMock()
    client.return_value.graphql = gql_return
    monkeypatch.setattr("prefect.agent.agent.Client", client)

    agent = Agent()
    agent.submitting_flow_runs.add("id2")
    agent._get_ready_flow_runs()

    assert len(gql_return.call_args_list) == 2
    assert ('id: { _in: ["id1"] }'
            in list(gql_return.call_args_list[1][0][0]["query"].keys())[0])
Пример #7
0
def test_update_states_passes_no_task_runs(monkeypatch, runner_token):
    gql_return = MagicMock(
        return_value=MagicMock(
            data=MagicMock(set_flow_run_state=None, set_task_run_state=None)
        )
    )
    client = MagicMock()
    client.return_value.graphql = gql_return
    monkeypatch.setattr("prefect.agent.agent.Client", client)

    agent = Agent()
    assert not agent.update_state(
        flow_run=GraphQLResult(
            {
                "id": "id",
                "serialized_state": Scheduled().serialize(),
                "version": 1,
                "task_runs": [],
            }
        ),
        deployment_info="test",
    )
Пример #8
0
def test_query_flow_runs_ignores_currently_submitting_runs(monkeypatch, runner_token):
    gql_return = MagicMock(
        return_value=MagicMock(
            data=MagicMock(
                get_runs_in_queue=MagicMock(flow_run_ids=["id1", "id2"]),
                flow_run=[{"id1": "id1"}],
            )
        )
    )
    client = MagicMock()
    client.return_value.graphql = gql_return
    monkeypatch.setattr("prefect.agent.agent.Client", client)

    agent = Agent()
    agent.submitting_flow_runs.add("id2")
    agent.query_flow_runs()

    assert len(gql_return.call_args_list) == 2
    assert (
        'id: { _in: ["id1"] }'
        in list(gql_return.call_args_list[1][0][0]["query"].keys())[0]
    )
Пример #9
0
def test_mark_failed(monkeypatch, runner_token, cloud_api):
    gql_return = MagicMock(
        return_value=MagicMock(
            data=MagicMock(set_flow_run_state=None, set_task_run_state=None)
        )
    )
    client = MagicMock()
    client.return_value.graphql = gql_return
    monkeypatch.setattr("prefect.agent.agent.Client", client)

    agent = Agent()
    assert not agent.mark_failed(
        flow_run=GraphQLResult(
            {
                "id": "id",
                "serialized_state": Scheduled().serialize(),
                "version": 1,
                "task_runs": [],
            }
        ),
        exc=Exception(),
    )
Пример #10
0
def test_query_flow_runs_does_not_use_submitting_flow_runs_directly(
        monkeypatch, runner_token, caplog, cloud_api):
    gql_return = MagicMock(return_value=MagicMock(data=MagicMock(
        get_runs_in_queue=MagicMock(flow_run_ids=["already-submitted-id"]),
        flow_run=[{
            "id": "id"
        }],
    )))
    client = MagicMock()
    client.return_value.graphql = gql_return
    monkeypatch.setattr("prefect.agent.agent.Client", client)

    agent = Agent()
    agent.logger.setLevel(logging.DEBUG)
    copy_mock = MagicMock(return_value=set(["already-submitted-id"]))
    agent.submitting_flow_runs = MagicMock(copy=copy_mock)

    flow_runs = agent.query_flow_runs()

    assert flow_runs == []
    assert "1 already submitting: ['already-submitted-id']" in caplog.text
    copy_mock.assert_called_once_with()
Пример #11
0
def test_mark_flow_as_submitted(monkeypatch, cloud_api, with_task_runs):
    agent = Agent()
    agent.client = MagicMock()
    agent._mark_flow_as_submitted(
        flow_run=GraphQLResult(
            {
                "id": "id",
                "serialized_state": Scheduled().serialize(),
                "version": 1,
                "task_runs": (
                    [
                        GraphQLResult(
                            {
                                "id": "task-id",
                                "version": 1,
                                "serialized_state": Scheduled().serialize(),
                            }
                        )
                    ]
                    if with_task_runs
                    else []
                ),
            }
        )
    )

    agent.client.set_flow_run_state.assert_called_once_with(
        flow_run_id="id", version=1, state=Submitted(message="Submitted for execution")
    )

    if with_task_runs:
        agent.client.set_task_run_state.assert_called_once_with(
            task_run_id="task-id",
            version=1,
            state=Submitted(message="Submitted for execution"),
        )
    else:
        agent.client.set_task_run_state.assert_not_called()
Пример #12
0
def test_get_ready_flow_runs(monkeypatch, cloud_api):
    dt = pendulum.now()
    gql_return = MagicMock(return_value=MagicMock(data=MagicMock(
        get_runs_in_queue=MagicMock(flow_run_ids=["id"]),
        flow_run=[
            GraphQLResult({
                "id": "id",
                "scheduled_start_time": str(dt)
            })
        ],
    )))
    client = MagicMock()
    client.return_value.graphql = gql_return
    monkeypatch.setattr("prefect.agent.agent.Client", client)

    agent = Agent()
    flow_runs = agent._get_ready_flow_runs()
    assert flow_runs == [
        GraphQLResult({
            "id": "id",
            "scheduled_start_time": str(dt)
        })
    ]
Пример #13
0
def test_agent_poke_api(monkeypatch, runner_token, cloud_api):
    import threading

    requests = pytest.importorskip("requests")

    def _poke_agent(agent_address):
        # May take a sec for the api server to startup
        for attempt in range(5):
            try:
                resp = requests.get(f"{agent_address}/api/health")
                break
            except Exception:
                time.sleep(0.1)
        else:
            assert False, "Failed to connect to health check"

        assert resp.status_code == 200
        # Agent API is now available. Poke agent to start processing.
        requests.get(f"{agent_address}/api/poke")

    agent_process = MagicMock()
    monkeypatch.setattr("prefect.agent.agent.Agent.agent_process",
                        agent_process)

    agent_connect = MagicMock(return_value="id")
    monkeypatch.setattr("prefect.agent.agent.Agent.agent_connect",
                        agent_connect)

    heartbeat = MagicMock()
    monkeypatch.setattr("prefect.agent.agent.Agent.heartbeat", heartbeat)

    with socket.socket() as sock:
        sock.bind(("", 0))
        port = sock.getsockname()[1]

    agent_address = f"http://127.0.0.1:{port}"

    # Poke agent in separate thread as main thread is blocked by main agent
    # process waiting for loop interval to complete.
    poke_agent_thread = threading.Thread(target=_poke_agent,
                                         args=(agent_address, ))
    poke_agent_thread.start()

    agent_start_time = time.time()
    agent = Agent(agent_address=agent_address, max_polls=1)
    # Override loop interval to 5 seconds.
    agent.start(_loop_intervals={0: 5.0})
    agent_stop_time = time.time()

    agent.cleanup()

    assert agent_stop_time - agent_start_time < 5.0

    assert not agent._api_server_thread.is_alive()
    assert heartbeat.call_count == 1
    assert agent_process.call_count == 1
    assert agent_connect.call_count == 1
Пример #14
0
def test_deploy_flow_run_logs_flow_run_exceptions(monkeypatch, caplog,
                                                  cloud_api):
    gql_return = MagicMock(return_value=MagicMock(data=MagicMock(
        write_run_logs=MagicMock(success=True))))
    client = MagicMock()
    client.return_value.write_run_logs = gql_return
    monkeypatch.setattr("prefect.agent.agent.Client",
                        MagicMock(return_value=client))

    agent = Agent()
    agent.deploy_flow = MagicMock(side_effect=Exception("Error Here"))
    agent._deploy_flow_run(flow_run=GraphQLResult({
        "id":
        "id",
        "serialized_state":
        Scheduled().serialize(),
        "scheduled_start_time":
        str(pendulum.now()),
        "version":
        1,
        "task_runs": [
            GraphQLResult({
                "id": "id",
                "version": 1,
                "serialized_state": Scheduled().serialize(),
            })
        ],
    }))

    assert client.write_run_logs.called
    client.write_run_logs.assert_called_with([
        dict(flow_run_id="id",
             level="ERROR",
             message="Error Here",
             name="agent")
    ])
    assert "Encountered exception while deploying flow run id" in caplog.text
Пример #15
0
def test_query_flow_runs_ordered_by_start_time(monkeypatch, cloud_api):

    dt1, dt2 = pendulum.now(), pendulum.now().add(hours=1)
    gql_return = MagicMock(
        return_value=MagicMock(
            data=MagicMock(
                get_runs_in_queue=MagicMock(flow_run_ids=["id"]),
                flow_run=[
                    GraphQLResult({"id": "id2", "scheduled_start_time": str(dt2)}),
                    GraphQLResult({"id": "id", "scheduled_start_time": str(dt1)}),
                ],
            )
        )
    )
    client = MagicMock()
    client.return_value.graphql = gql_return
    monkeypatch.setattr("prefect.agent.agent.Client", client)

    agent = Agent()
    flow_runs = agent.query_flow_runs()
    assert flow_runs == [
        GraphQLResult({"id": "id", "scheduled_start_time": str(dt1)}),
        GraphQLResult({"id": "id2", "scheduled_start_time": str(dt2)}),
    ]
Пример #16
0
def test_agent_start_max_polls(cloud_api, max_polls):
    agent = Agent(max_polls=max_polls)
    # Mock the backend API to avoid immediate failure
    agent._setup_api_connection = MagicMock(return_value="id")
    # Mock the deployment func to count calls
    agent._submit_deploy_flow_run_jobs = MagicMock()

    agent.start()

    agent._submit_deploy_flow_run_jobs.call_count == max_polls
Пример #17
0
def test_agent_fails_no_runner_token(monkeypatch, cloud_api):
    post = MagicMock(
        return_value=MagicMock(
            json=MagicMock(
                return_value=dict(
                    data=dict(auth_info=MagicMock(api_token_scope="USER"))
                )
            )
        )
    )
    session = MagicMock()
    session.return_value.post = post
    monkeypatch.setattr("requests.Session", session)

    with pytest.raises(AuthorizationError):
        agent = Agent().start()
Пример #18
0
def test_agent_fails_no_runner_token(monkeypatch, cloud_api):
    post = MagicMock(
        return_value=MagicMock(
            json=MagicMock(
                return_value=dict(
                    data=dict(auth_info=MagicMock(api_token_scope="USER"))
                )
            )
        )
    )
    session = MagicMock()
    session.return_value.post = post
    monkeypatch.setattr("requests.Session", session)

    with pytest.raises(RuntimeError, match="Error while contacting API") as err:
        Agent().start()
    assert isinstance(err.value.__cause__, AuthorizationError)
Пример #19
0
def test_setup_api_connection_runs_test_query(test_query_succeeds, cloud_api):
    agent = Agent()

    # Ignore the token check and registration
    agent._verify_token = MagicMock()
    agent._register_agent = MagicMock()

    if test_query_succeeds:
        # Create a successful test query
        agent.client.graphql = MagicMock(return_value="Hello")

    with nullcontext() if test_query_succeeds else pytest.raises(Exception):
        agent._setup_api_connection()
Пример #20
0
def test_catch_errors_in_heartbeat_thread(monkeypatch, cloud_api, caplog):
    """Check that errors in the heartbeat thread are caught, logged, and the thread keeps going"""
    monkeypatch.setattr(
        "prefect.agent.agent.Agent._submit_deploy_flow_run_jobs", MagicMock()
    )
    monkeypatch.setattr(
        "prefect.agent.agent.Agent._setup_api_connection", MagicMock(return_value="id")
    )

    heartbeat = MagicMock(side_effect=ValueError)
    monkeypatch.setattr("prefect.agent.agent.Agent.heartbeat", heartbeat)
    agent = Agent(max_polls=2)

    # Ignore registration
    agent._register_agent = MagicMock()

    agent.heartbeat_period = 0.1
    agent.start()

    assert heartbeat.call_count > 1
    assert any("Error in agent heartbeat" in m for m in caplog.messages)
Пример #21
0
def test_agent_labels(runner_token, cloud_api):
    with set_temporary_config({"cloud.agent.auth_token": "TEST_TOKEN"}):
        agent = Agent(labels=["test", "2"])
        assert agent.labels == ["test", "2"]
Пример #22
0
def test_agent_max_polls(runner_token, cloud_api):
    with set_temporary_config({"cloud.agent.auth_token": "TEST_TOKEN"}):
        agent = Agent(max_polls=10)
        assert agent.max_polls == 10
Пример #23
0
def test_agent_env_vars(runner_token, cloud_api):
    with set_temporary_config({"cloud.agent.auth_token": "TEST_TOKEN"}):
        agent = Agent(env_vars=dict(AUTH_THING="foo"))
        assert agent.env_vars == dict(AUTH_THING="foo")
Пример #24
0
def test_agent_log_level(runner_token, cloud_api):
    with set_temporary_config({"cloud.agent.auth_token": "TEST_TOKEN"}):
        agent = Agent()
        assert agent.logger.level == 20
Пример #25
0
def test_on_flow_run_deploy_attempt_removes_id(monkeypatch, runner_token, cloud_api):
    agent = Agent()
    agent.submitting_flow_runs.add("id")
    agent.on_flow_run_deploy_attempt(None, "id")
    assert len(agent.submitting_flow_runs) == 0
Пример #26
0
def test_heartbeat_passes_base_agent(runner_token, cloud_api):
    agent = Agent()
    assert not agent.heartbeat()
Пример #27
0
def test_deploy_flows_passes_base_agent(runner_token, cloud_api):
    agent = Agent()
    with pytest.raises(NotImplementedError):
        agent.deploy_flow(None)
Пример #28
0
def test_multiple_agent_init_doesnt_duplicate_logs(runner_token, cloud_api):
    a, b, c = Agent(), Agent(), Agent()
    assert len(c.logger.handlers) == 1
Пример #29
0
def test_agent_init(runner_token, cloud_api):
    agent = Agent()
    assert agent
Пример #30
0
def test_agent_fails_no_auth_token(cloud_api):
    with pytest.raises(AuthorizationError):
        agent = Agent().start()