Ejemplo n.º 1
0
def test_agent_with_model_server(tmpdir, zipped_moodbot_model,
                                 moodbot_domain, moodbot_metadata):
    fingerprint = 'somehash'
    model_endpoint_config = EndpointConfig.from_dict(
            {"url": 'http://server.com/model/default_core@latest'}
    )

    # mock a response that returns a zipped model
    with io.open(zipped_moodbot_model, 'rb') as f:
        responses.add(responses.GET,
                      model_endpoint_config.url,
                      headers={"ETag": fingerprint},
                      body=f.read(),
                      content_type='application/zip',
                      stream=True)
    agent = rasa_core.agent.load_from_server(
            model_server=model_endpoint_config)
    assert agent.fingerprint == fingerprint

    assert agent.domain.as_dict() == moodbot_domain.as_dict()

    agent_policies = {utils.module_path_from_instance(p)
                      for p in agent.policy_ensemble.policies}
    moodbot_policies = set(moodbot_metadata["policy_names"])
    assert agent_policies == moodbot_policies
Ejemplo n.º 2
0
async def test_agent_with_model_server_in_thread(model_server, tmpdir,
                                                 zipped_moodbot_model,
                                                 moodbot_domain,
                                                 moodbot_metadata):
    model_endpoint_config = EndpointConfig.from_dict({
        "url":
        model_server.make_url('/model'),
        "wait_time_between_pulls":
        2
    })

    agent = Agent()
    agent = await rasa_core.agent.load_from_server(
        agent, model_server=model_endpoint_config)

    await asyncio.sleep(3)

    assert agent.fingerprint == "somehash"

    assert agent.domain.as_dict() == moodbot_domain.as_dict()

    agent_policies = {
        utils.module_path_from_instance(p)
        for p in agent.policy_ensemble.policies
    }
    moodbot_policies = set(moodbot_metadata["policy_names"])
    assert agent_policies == moodbot_policies
    assert model_server.app.number_of_model_requests == 1
    jobs.kill_scheduler()
Ejemplo n.º 3
0
def test_nlg(http_nlg, default_agent_path):
    sender = str(uuid.uuid1())

    nlg_endpoint = EndpointConfig.from_dict({"url": http_nlg})
    agent = Agent.load(default_agent_path, None, generator=nlg_endpoint)

    response = agent.handle_message("/greet", sender_id=sender)
    assert len(response) == 1
    assert response[0] == {"text": "Hey there!", "recipient_id": sender}
Ejemplo n.º 4
0
def test_tracker_store_endpoint_config_loading():
    cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")

    assert cfg == EndpointConfig.from_dict({
        "type": "redis",
        "url": "localhost",
        "port": 6379,
        "db": 0,
        "password": "******",
        "timeout": 30000
    })
Ejemplo n.º 5
0
def test_nlg(http_nlg, default_agent_path):
    sender = str(uuid.uuid1())

    nlg_endpoint = EndpointConfig.from_dict({
        "url": http_nlg
    })
    agent = Agent.load(default_agent_path, None,
                       generator=nlg_endpoint)

    response = agent.handle_message("/greet", sender_id=sender)
    assert len(response) == 1
    assert response[0] == {"text": "Hey there!", "recipient_id": sender}
Ejemplo n.º 6
0
def test_wait_time_between_pulls_with_not_number(monkeypatch):
    monkeypatch.setattr("rasa_core.agent.start_model_pulling_in_worker",
                        lambda *args: 1 / 0)  # raises an exception
    monkeypatch.setattr("rasa_core.agent._update_model_from_server",
                        lambda *args: True)

    model_endpoint_config = EndpointConfig.from_dict({
        "url":
        'http://server.com/model/default_core@latest',
        "wait_time_between_pulls":
        "None"
    })

    rasa_core.agent. \
        load_from_server(model_server=model_endpoint_config)
Ejemplo n.º 7
0
def test_agent_with_model_server(tmpdir, zipped_moodbot_model):
    fingerprint = 'somehash'
    model_endpoint_config = EndpointConfig.from_dict(
        {"url": 'http://server.com/model/default_core@latest'}
    )

    # mock a response that returns a zipped model
    with io.open(zipped_moodbot_model, 'rb') as f:
        responses.add(responses.GET,
                      model_endpoint_config.url,
                      headers={"ETag": fingerprint},
                      body=f.read(),
                      content_type='application/zip',
                      stream=True)
    agent = rasa_core.agent.load_from_server(model_server=model_endpoint_config)
    assert agent.fingerprint == fingerprint
Ejemplo n.º 8
0
def test_wait_time_between_pulls_str(monkeypatch):
    from future.utils import raise_

    monkeypatch.setattr("rasa_core.agent.start_model_pulling_in_worker",
                        lambda *args: True)
    monkeypatch.setattr("rasa_core.agent._update_model_from_server",
                        lambda *args: raise_(Exception()))

    model_endpoint_config = EndpointConfig.from_dict({
        "url":
        'http://server.com/model/default_core@latest',
        "wait_time_between_pulls":
        "10"
    })

    rasa_core.agent. \
        load_from_server(model_server=model_endpoint_config)
Ejemplo n.º 9
0
async def test_wait_time_between_pulls_without_interval(
        model_server, monkeypatch):

    monkeypatch.setattr("rasa_core.agent.schedule_model_pulling",
                        lambda *args: 1 / 0)  # will raise an exception

    model_endpoint_config = EndpointConfig.from_dict({
        "url":
        model_server.make_url('/model'),
        "wait_time_between_pulls":
        None
    })

    agent = Agent()
    # schould not call schedule_model_pulling, if it does, this will raise
    await rasa_core.agent.load_from_server(agent,
                                           model_server=model_endpoint_config)
    jobs.kill_scheduler()
Ejemplo n.º 10
0
def test_nlg_endpoint_config_loading():
    cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "nlg")

    assert cfg == EndpointConfig.from_dict(
        {"url": "http://localhost:5055/nlg"})
Ejemplo n.º 11
0
 def from_credentials(cls, credentials):
     return cls(EndpointConfig.from_dict(credentials))
Ejemplo n.º 12
0
def test_nlg_endpoint_config_loading():
    cfg = utils.read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "nlg")

    assert cfg == EndpointConfig.from_dict({
        "url": "http://localhost:5055/nlg"
    })