Exemple #1
0
async def test_wait_time_between_pulls_without_interval(
        model_server: TestClient, monkeypatch: MonkeyPatch):
    monkeypatch.setattr("rasa.core.agent._schedule_model_pulling",
                        lambda *args: 1 / 0)  # will raise an exception

    model_endpoint_config = EndpointConfig.from_dict({
        "url":
        model_server.make_url("/model"),
        "wait_time_between_pulls":
        None
    })

    agent = Agent()
    # should not call _schedule_model_pulling, if it does, this will raise
    await rasa.core.agent.load_from_server(agent,
                                           model_server=model_endpoint_config)
Exemple #2
0
async def test_pika_connection_error(monkeypatch: MonkeyPatch):
    # patch PikaEventBroker to raise an AMQP connection error
    async def connect(self) -> None:
        raise aio_pika.exceptions.ProbableAuthenticationError("Oups")

    monkeypatch.setattr(PikaEventBroker, "connect", connect)
    cfg = EndpointConfig.from_dict({
        "type": "pika",
        "url": "localhost",
        "username": "******",
        "password": "******",
        "queues": ["queue-1"],
        "connection_attempts": 1,
        "retry_delay_in_seconds": 0,
    })
    with pytest.raises(ConnectionException):
        await EventBroker.create(cfg)
Exemple #3
0
def test_tracker_store_endpoint_config_loading(endpoints_path: Text):
    cfg = read_endpoint_config(endpoints_path, "tracker_store")

    assert cfg == EndpointConfig.from_dict(
        {
            "type": "redis",
            "url": "localhost",
            "port": 6379,
            "db": 0,
            "password": "******",
            "timeout": 30000,
            "use_ssl": True,
            "ssl_keyfile": "keyfile.key",
            "ssl_certfile": "certfile.crt",
            "ssl_ca_certs": "my-bundle.ca-bundle",
        }
    )
Exemple #4
0
async def test_agent_with_model_server_in_thread(
    model_server: TestClient, domain: Domain
):
    model_endpoint_config = EndpointConfig.from_dict(
        {"url": model_server.make_url("/model"), "wait_time_between_pulls": 2}
    )

    agent = Agent()
    agent = await rasa.core.agent.load_from_server(
        agent, model_server=model_endpoint_config
    )

    await asyncio.sleep(5)

    assert agent.fingerprint == "somehash"
    assert agent.domain.as_dict() == domain.as_dict()
    assert agent.processor.graph_runner

    assert model_server.app.number_of_model_requests == 1
    jobs.kill_scheduler()
Exemple #5
0
async def test_pull_model_with_invalid_domain(
    model_server: TestClient, monkeypatch: MonkeyPatch, caplog: LogCaptureFixture
):
    # mock `Domain.load()` as if the domain contains invalid YAML
    error_message = "domain is invalid"
    mock_load = Mock(side_effect=InvalidDomain(error_message))

    monkeypatch.setattr(Domain, "load", mock_load)
    model_endpoint_config = EndpointConfig.from_dict(
        {"url": model_server.make_url("/model"), "wait_time_between_pulls": None}
    )

    agent = Agent()
    await rasa.core.agent.load_from_server(agent, model_server=model_endpoint_config)

    # `Domain.load()` was called
    mock_load.assert_called_once()

    # error was logged
    assert error_message in caplog.text
Exemple #6
0
    async def load_model(request: Request):
        validate_request_body(request, "No path to model file defined in request_body.")

        model_path = request.json.get("model_file", None)
        model_server = request.json.get("model_server", None)
        remote_storage = request.json.get("remote_storage", None)
        if model_server:
            try:
                model_server = EndpointConfig.from_dict(model_server)
            except TypeError as e:
                logger.debug(traceback.format_exc())
                raise ErrorResponse(
                    400,
                    "BadRequest",
                    "Supplied 'model_server' is not valid. Error: {}".format(e),
                    {"parameter": "model_server", "in": "body"},
                )
        app.agent = await _load_agent(
            model_path, model_server, remote_storage, endpoints
        )

        logger.debug("Successfully loaded model '{}'.".format(model_path))
        return response.json(None, status=204)
Exemple #7
0
async def test_agent_with_model_server_in_thread(
    model_server: TestClient, moodbot_domain: Domain, moodbot_metadata: Any
):
    model_endpoint_config = EndpointConfig.from_dict(
        {"url": model_server.make_url("/model"), "wait_time_between_pulls": 2}
    )

    agent = Agent()
    agent = await rasa.core.agent.load_from_server(
        agent, model_server=model_endpoint_config
    )

    await asyncio.sleep(5)

    assert agent.fingerprint == "somehash"
    assert hash(agent.domain) == hash(moodbot_domain)

    agent_policies = {
        utils.module_path_from_instance(p) for p in agent.policy_ensemble.policies
    }
    moodbot_policies = set(moodbot_metadata["policy_names"])
    assert agent_policies == moodbot_policies
    assert model_server.app.number_of_model_requests == 1
    jobs.kill_scheduler()
def test_custom_token_name():
    test_data = {"url": "http://test", "token": "token", "token_name": "test_token"}

    actual = EndpointConfig.from_dict(test_data)

    assert actual.token_name == "test_token"
Exemple #9
0
def test_nlg_endpoint_config_loading():
    cfg = read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "nlg")

    assert cfg == EndpointConfig.from_dict(
        {"url": "http://localhost:5055/nlg"})
def test_tracker_store_connection_error(config: Dict, default_domain: Domain):
    store = EndpointConfig.from_dict(config)

    with pytest.raises(ConnectionException):
        TrackerStore.create(store, default_domain)
Exemple #11
0
 def from_credentials(
         cls, credentials: Optional[Dict[Text, Any]]) -> InputChannel:
     return cls(EndpointConfig.from_dict(credentials))
Exemple #12
0
 def from_credentials(cls, credentials):
     return cls(EndpointConfig.from_dict(credentials))
Exemple #13
0
def test_nlg_endpoint_config_loading(endpoints_path: Text):
    cfg = read_endpoint_config(endpoints_path, "nlg")

    assert cfg == EndpointConfig.from_dict(
        {"url": "http://localhost:5055/nlg"})
Exemple #14
0
import asyncio

import yaml
from rasa.core.agent import Agent
from rasa.shared.constants import DEFAULT_ENDPOINTS_PATH
from rasa.utils.endpoints import EndpointConfig

# 需要先训练好一个模型
with open(DEFAULT_ENDPOINTS_PATH) as fp:
    endpoint = EndpointConfig.from_dict(yaml.load(fp).get("action_endpoint"))

agent = Agent.load_local_model(
    model_path="models",
    action_endpoint=endpoint
)

print(asyncio.run(agent.handle_text("rasax")))
print(asyncio.run(agent.handle_text("你是机器人吗?")))
print(asyncio.run(agent.handle_text("成都天气好吗?")))