Beispiel #1
0
def test_tracker_store_endpoint_config_loading(endpoints_path: Text):
    cfg = read_endpoint_config(endpoints_path, "tracker_store")

    assert cfg == EndpointConfig.from_dict({
        "type":
        "redis",
        "url":
        "localhost",
        "port":
        6379,
        "db":
        0,
        "password":
        "******",
        "timeout":
        30000,
        "use_ssl":
        True,
        "ssl_keyfile":
        "keyfile.key",
        "ssl_certfile":
        "certfile.crt",
        "ssl_ca_certs":
        "my-bundle.ca-bundle",
    })
Beispiel #2
0
def test_db_url_with_query_from_endpoint_config():
    endpoint_config = """
    tracker_store:
      dialect: postgresql
      url: localhost
      port: 5123
      username: user
      password: pw
      login_db: login-db
      query:
        driver: my-driver
        another: query
    """

    with tempfile.NamedTemporaryFile("w+", suffix="_tmp_config_file.yml") as f:
        f.write(endpoint_config)
        f.flush()
        store_config = read_endpoint_config(f.name, "tracker_store")

    url = SQLTrackerStore.get_db_url(**store_config.kwargs)

    import itertools

    # order of query dictionary in yaml is random, test against both permutations
    connection_url = "postgresql://*****:*****@:5123/login-db?"
    assert any(
        str(url) == connection_url + "&".join(permutation)
        for permutation in (itertools.permutations(("another=query",
                                                    "driver=my-driver"))))
Beispiel #3
0
async def test_sql_broker_from_config():
    cfg = read_endpoint_config(
        "data/test_endpoints/event_brokers/sql_endpoint.yml", "event_broker")
    actual = await EventBroker.create(cfg)

    assert isinstance(actual, SQLEventBroker)
    assert actual.engine.name == "sqlite"
def test_db_url_with_query_from_endpoint_config(tmp_path: Path):
    endpoint_config = """
    tracker_store:
      dialect: postgresql
      url: localhost
      port: 5123
      username: user
      password: pw
      login_db: login-db
      query:
        driver: my-driver
        another: query
    """
    f = tmp_path / "tmp_config_file.yml"
    f.write_text(endpoint_config)
    store_config = read_endpoint_config(str(f), "tracker_store")

    url = SQLTrackerStore.get_db_url(**store_config.kwargs)

    import itertools

    # order of query dictionary in yaml is random, test against both permutations
    connection_url = "postgresql://*****:*****@:5123/login-db?"
    assert any(
        str(url) == connection_url + "&".join(permutation)
        for permutation in (
            itertools.permutations(("another=query", "driver=my-driver"))
        )
    )
Beispiel #5
0
def test_tracker_store_from_string(default_domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")

    tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)

    assert isinstance(tracker_store, ExampleTrackerStore)
Beispiel #6
0
def test_tracker_store_deprecated_url_argument_from_string(domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "tests.core.test_tracker_stores.URLExampleTrackerStore"

    with pytest.raises(Exception):
        TrackerStore.create(store_config, domain)
Beispiel #7
0
def test_file_broker_from_config():
    cfg = read_endpoint_config(
        "data/test_endpoints/event_brokers/file_endpoint.yml", "event_broker")
    actual = EventBroker.create(cfg)

    assert isinstance(actual, FileEventBroker)
    assert actual.path == "rasa_event.log"
Beispiel #8
0
def test_file_broker_from_config():
    cfg = read_endpoint_config(
        "data/test_endpoints/event_brokers/file_endpoint.yml", "event_broker")
    actual = broker_utils.from_endpoint_config(cfg)

    assert isinstance(actual, FileProducer)
    assert actual.path == "rasa_event.log"
Beispiel #9
0
def test_sql_broker_from_config():
    cfg = read_endpoint_config(
        "data/test_endpoints/event_brokers/sql_endpoint.yml", "event_broker")
    actual = broker_utils.from_endpoint_config(cfg)

    assert isinstance(actual, SQLProducer)
    assert actual.engine.name == "sqlite"
Beispiel #10
0
async def test_kafka_broker_from_config():
    endpoints_path = (
        "data/test_endpoints/event_brokers/kafka_sasl_plaintext_endpoint.yml")
    cfg = read_endpoint_config(endpoints_path, "event_broker")

    actual = await KafkaEventBroker.from_endpoint_config(cfg)

    expected = KafkaEventBroker(
        "localhost",
        sasl_username="******",
        sasl_password="******",
        sasl_mechanism="PLAIN",
        topic="topic",
        partition_by_sender=True,
        security_protocol="SASL_PLAINTEXT",
        convert_intent_id_to_string=True,
    )

    assert actual.url == expected.url
    assert actual.sasl_username == expected.sasl_username
    assert actual.sasl_password == expected.sasl_password
    assert actual.sasl_mechanism == expected.sasl_mechanism
    assert actual.topic == expected.topic
    assert actual.partition_by_sender == expected.partition_by_sender
    assert actual.convert_intent_id_to_string == expected.convert_intent_id_to_string
Beispiel #11
0
def test_tracker_store_from_invalid_module(default_domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "a.module.which.cannot.be.found"

    tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)

    assert isinstance(tracker_store, InMemoryTrackerStore)
Beispiel #12
0
def test_kafka_broker_security_protocols(file: Text, exception: Exception):
    endpoints_path = f"data/test_endpoints/event_brokers/{file}"
    cfg = read_endpoint_config(endpoints_path, "event_broker")

    actual = KafkaEventBroker.from_endpoint_config(cfg)
    with pytest.raises(exception):
        # noinspection PyProtectedMember
        actual._create_producer()
Beispiel #13
0
def test_tracker_store_from_invalid_string(default_domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "any string"

    tracker_store = TrackerStore.find_tracker_store(default_domain, store_config)

    assert isinstance(tracker_store, InMemoryTrackerStore)
Beispiel #14
0
def test_find_tracker_store(default_domain: Domain, monkeypatch: MonkeyPatch):
    store = read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")
    mock = Mock(side_effect=Exception("ignore this"))
    monkeypatch.setattr(rasa.core.tracker_store, "RedisTrackerStore", mock)

    assert isinstance(
        InMemoryTrackerStore(domain),
        type(TrackerStore.find_tracker_store(default_domain, store)),
    )
Beispiel #15
0
def test_pika_broker_from_config():
    cfg = read_endpoint_config(
        "data/test_endpoints/event_brokers/pika_endpoint.yml", "event_broker")
    actual = broker.from_endpoint_config(cfg)

    assert isinstance(actual, PikaProducer)
    assert actual.host == "localhost"
    assert actual.credentials.username == "username"
    assert actual.queue == "queue"
Beispiel #16
0
def test_tracker_store_from_invalid_module(domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "a.module.which.cannot.be.found"

    with pytest.warns(UserWarning):
        tracker_store = TrackerStore.create(store_config, domain)

    assert isinstance(tracker_store, InMemoryTrackerStore)
Beispiel #17
0
def test_pika_broker_from_config():
    cfg = read_endpoint_config(
        "data/test_endpoints/event_brokers/pika_endpoint.yml", "event_broker")
    actual = EventBroker.create(cfg)

    assert isinstance(actual, PikaEventBroker)
    assert actual.host == "localhost"
    assert actual.username == "username"
    assert actual.queues == ["queue-1"]
def test_tracker_store_deprecated_url_argument_from_string(default_domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "tests.core.test_tracker_stores.URLExampleTrackerStore"

    with pytest.warns(DeprecationWarning):
        tracker_store = TrackerStore.create(store_config, default_domain)

    assert isinstance(tracker_store, URLExampleTrackerStore)
def test_tracker_store_from_invalid_string(default_domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "any string"

    with pytest.warns(UserWarning):
        tracker_store = TrackerStore.create(store_config, default_domain)

    assert isinstance(tracker_store, InMemoryTrackerStore)
Beispiel #20
0
def test_dashbot_config():
    cfg = read_endpoint_config(
        os.path.join(os.path.dirname(__file__), "data/rasa_endpoints.yml"),
        "event_broker")
    actual = EventBroker.create(cfg)

    assert isinstance(actual, rasa)
    assert actual.proxies['http'] == 'http://10.10.1.10:3128'
    assert actual.proxies['https'] == 'http://10.10.1.10:1080'
    assert actual.apiKey == 'here'
def test_tracker_store_with_host_argument_from_string(default_domain: Domain):
    endpoints_path = "data/test_endpoints/custom_tracker_endpoints.yml"
    store_config = read_endpoint_config(endpoints_path, "tracker_store")
    store_config.type = "tests.core.test_tracker_stores.HostExampleTrackerStore"

    with pytest.warns(None) as record:
        tracker_store = TrackerStore.create(store_config, default_domain)

    assert len(record) == 0

    assert isinstance(tracker_store, HostExampleTrackerStore)
Beispiel #22
0
def test_raise_connection_exception_redis_tracker_store_creation(
        domain: Domain, monkeypatch: MonkeyPatch, endpoints_path: Text):
    store = read_endpoint_config(endpoints_path, "tracker_store")
    monkeypatch.setattr(
        rasa.core.tracker_store,
        "RedisTrackerStore",
        Mock(side_effect=ConnectionError()),
    )

    with pytest.raises(ConnectionException):
        TrackerStore.create(store, domain)
Beispiel #23
0
def test_tracker_store_endpoint_config_loading(endpoints_path: Text):
    cfg = read_endpoint_config(endpoints_path, "tracker_store")

    assert cfg == EndpointConfig.from_dict({
        "type": "redis",
        "url": "localhost",
        "port": 6379,
        "db": 0,
        "password": "******",
        "timeout": 30000,
    })
Beispiel #24
0
def test_tracker_store_endpoint_config_loading():
    cfg = read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")

    assert cfg == EndpointConfig.from_dict({
        "type": "redis",
        "url": "localhost",
        "port": 6379,
        "db": 0,
        "password": "******",
        "timeout": 30000,
    })
def test_create_tracker_store_from_endpoint_config(default_domain: Domain):
    store = read_endpoint_config(DEFAULT_ENDPOINTS_FILE, "tracker_store")
    tracker_store = RedisTrackerStore(
        domain=default_domain,
        host="localhost",
        port=6379,
        db=0,
        password="******",
        record_exp=3000,
    )

    assert isinstance(tracker_store, type(TrackerStore.create(store, default_domain)))
Beispiel #26
0
def test_pika_broker_from_config(monkeypatch: MonkeyPatch):
    # patch PikaEventBroker so it doesn't try to connect to RabbitMQ on init
    monkeypatch.setattr(PikaEventBroker, "_connect", lambda _: None)

    cfg = read_endpoint_config(
        "data/test_endpoints/event_brokers/pika_endpoint.yml", "event_broker")
    actual = EventBroker.create(cfg)

    assert isinstance(actual, PikaEventBroker)
    assert actual.host == "localhost"
    assert actual.username == "username"
    assert actual.queues == ["queue-1"]
Beispiel #27
0
def test_create_tracker_store_from_endpoint_config(domain: Domain,
                                                   endpoints_path: Text):
    store = read_endpoint_config(endpoints_path, "tracker_store")
    tracker_store = RedisTrackerStore(
        domain=domain,
        host="localhost",
        port=6379,
        db=0,
        password="******",
        record_exp=3000,
    )

    assert isinstance(tracker_store, type(TrackerStore.create(store, domain)))
Beispiel #28
0
    def read_endpoints(cls, endpoint_file: Text) -> "AvailableEndpoints":
        nlg = read_endpoint_config(endpoint_file, endpoint_type="nlg")
        nlu = read_endpoint_config(endpoint_file, endpoint_type="nlu")
        action = read_endpoint_config(endpoint_file, endpoint_type="action_endpoint")
        model = read_endpoint_config(endpoint_file, endpoint_type="models")
        tracker_store = read_endpoint_config(
            endpoint_file, endpoint_type="tracker_store"
        )
        lock_store = read_endpoint_config(endpoint_file, endpoint_type="lock_store")
        event_broker = read_endpoint_config(endpoint_file, endpoint_type="event_broker")

        return cls(nlg, nlu, action, model, tracker_store, lock_store, event_broker)
Beispiel #29
0
def test_sql_broker_logs_to_sql_db():
    cfg = read_endpoint_config(
        "data/test_endpoints/event_brokers/sql_endpoint.yml", "event_broker")
    actual = broker_utils.from_endpoint_config(cfg)

    for e in TEST_EVENTS:
        actual.publish(e.as_dict())

    events_types = [
        json.loads(event.data)["event"]
        for event in actual.session.query(actual.SQLBrokerEvent).all()
    ]

    assert events_types == ["user", "slot", "restart"]
Beispiel #30
0
def test_file_broker_from_config(tmp_path: Path):
    # backslashes need to be encoded (windows...) otherwise we run into unicode issues
    path = str(tmp_path / "rasa_test_event.log").replace("\\", "\\\\")
    endpoint_config = textwrap.dedent(f"""
        event_broker:
          path: "{path}"
          type: "file"
    """)
    rasa.utils.io.write_text_file(endpoint_config, tmp_path / "endpoint.yml")

    cfg = read_endpoint_config(str(tmp_path / "endpoint.yml"), "event_broker")
    actual = EventBroker.create(cfg)

    assert isinstance(actual, FileEventBroker)
    assert actual.path.endswith("rasa_test_event.log")