Example #1
0
async def test_agent_with_model_server_in_thread(model_server: TestClient,
                                                 moodbot_domain: Domain,
                                                 moodbot_metadata: Any):
    model_endpoint_config = EndpointConfig.from_dict({
        "url":
        model_server.make_url("/model"),
        "wait_time_between_pulls":
        2
    })

    agent = Agent()
    agent = await rasa.core.agent.load_from_server(
        agent, model_server=model_endpoint_config)

    await asyncio.sleep(5)

    assert agent.fingerprint == "somehash"
    assert hash(agent.domain) == hash(moodbot_domain)

    agent_policies = {
        rasa.shared.utils.common.module_path_from_instance(p)
        for p in agent.policy_ensemble.policies
    }
    moodbot_policies = set(moodbot_metadata["policy_names"])
    assert agent_policies == moodbot_policies
    assert model_server.app.number_of_model_requests == 1
    jobs.kill_scheduler()
Example #2
0
async def test_agent_with_model_server_in_thread(
        model_server: TestClient, default_domain: Domain,
        unpacked_trained_rasa_model: Text):
    model_endpoint_config = EndpointConfig.from_dict({
        "url":
        model_server.make_url("/model"),
        "wait_time_between_pulls":
        2
    })

    agent = Agent()
    agent = await rasa.core.agent.load_from_server(
        agent, model_server=model_endpoint_config)

    await asyncio.sleep(5)

    assert agent.fingerprint == "somehash"
    assert agent.domain.as_dict() == default_domain.as_dict()

    expected_policies = PolicyEnsemble.load_metadata(
        str(Path(unpacked_trained_rasa_model, "core")))["policy_names"]

    agent_policies = {
        rasa.shared.utils.common.module_path_from_instance(p)
        for p in agent.policy_ensemble.policies
    }
    assert agent_policies == set(expected_policies)
    assert model_server.app.number_of_model_requests == 1
    jobs.kill_scheduler()
Example #3
0
async def test_agent_with_model_server_in_thread(model_server, tmpdir,
                                                 zipped_moodbot_model,
                                                 moodbot_domain,
                                                 moodbot_metadata):
    model_endpoint_config = EndpointConfig.from_dict({
        "url":
        model_server.make_url('/model'),
        "wait_time_between_pulls":
        2
    })

    agent = Agent()
    agent = await rasa.core.agent.load_from_server(
        agent, model_server=model_endpoint_config)

    await asyncio.sleep(3)

    assert agent.fingerprint == "somehash"

    assert agent.domain.as_dict() == moodbot_domain.as_dict()

    agent_policies = {
        utils.module_path_from_instance(p)
        for p in agent.policy_ensemble.policies
    }
    moodbot_policies = set(moodbot_metadata["policy_names"])
    assert agent_policies == moodbot_policies
    assert model_server.app.number_of_model_requests == 1
    jobs.kill_scheduler()
Example #4
0
async def wait_until_all_jobs_were_executed(
    timeout_after_seconds: Optional[float] = None, ) -> None:
    total_seconds = 0.0
    while len((await jobs.scheduler()).get_jobs()) > 0 and (
            not timeout_after_seconds
            or total_seconds < timeout_after_seconds):
        await asyncio.sleep(0.1)
        total_seconds += 0.1

    if total_seconds >= timeout_after_seconds:
        jobs.kill_scheduler()
        raise TimeoutError
Example #5
0
async def test_wait_time_between_pulls_without_interval(model_server, monkeypatch):

    monkeypatch.setattr(
        "rasa.core.agent.schedule_model_pulling", lambda *args: 1 / 0
    )  # will raise an exception

    model_endpoint_config = EndpointConfig.from_dict(
        {"url": model_server.make_url("/model"), "wait_time_between_pulls": None}
    )

    agent = Agent()
    # schould not call schedule_model_pulling, if it does, this will raise
    await rasa.core.agent.load_from_server(agent, model_server=model_endpoint_config)
    jobs.kill_scheduler()
Example #6
0
async def test_agent_with_model_server_in_thread(
    model_server: TestClient, domain: Domain
):
    model_endpoint_config = EndpointConfig.from_dict(
        {"url": model_server.make_url("/model"), "wait_time_between_pulls": 2}
    )

    agent = Agent()
    agent = await rasa.core.agent.load_from_server(
        agent, model_server=model_endpoint_config
    )

    await asyncio.sleep(5)

    assert agent.fingerprint == "somehash"
    assert agent.domain.as_dict() == domain.as_dict()
    assert agent.processor.graph_runner

    assert model_server.app.number_of_model_requests == 1
    jobs.kill_scheduler()