Beispiel #1
0
def stores_to_be_tested():
    temp = tempfile.mkdtemp()
    return [
        MockRedisTrackerStore(domain),
        InMemoryTrackerStore(domain),
        SQLTrackerStore(domain, db=os.path.join(temp, "rasa.db")),
    ]
Beispiel #2
0
def test_session_scope_error(monkeypatch: MonkeyPatch, capsys: CaptureFixture,
                             default_domain: Domain):
    tracker_store = SQLTrackerStore(default_domain)
    tracker_store.sessionmaker = Mock()

    requested_schema = uuid.uuid4().hex

    # `ensure_schema_exists()` raises `ValueError`
    mocked_ensure_schema_exists = Mock(
        side_effect=ValueError(requested_schema))
    monkeypatch.setattr(
        rasa.core.tracker_store,
        "ensure_schema_exists",
        mocked_ensure_schema_exists,
    )

    # `SystemExit` is triggered by failing `ensure_schema_exists()`
    with pytest.raises(SystemExit):
        with tracker_store.session_scope() as _:
            pass

    # error message is printed
    assert (
        f"Requested PostgreSQL schema '{requested_schema}' was not found in the "
        f"database." in capsys.readouterr()[0])
Beispiel #3
0
async def test_fetch_events_within_time_range_with_session_events(
        tmp_path: Path):
    conversation_id = "test_fetch_events_within_time_range_with_sessions"

    tracker_store = SQLTrackerStore(
        dialect="sqlite",
        db=str(tmp_path / f"{uuid.uuid4().hex}.db"),
        domain=Domain.empty(),
    )

    events = [
        random_user_uttered_event(1),
        SessionStarted(2),
        ActionExecuted(timestamp=3, action_name=ACTION_SESSION_START_NAME),
        random_user_uttered_event(4),
    ]
    tracker = DialogueStateTracker.from_events(conversation_id, evts=events)
    await tracker_store.save(tracker)

    exporter = MockExporter(tracker_store=tracker_store)

    # noinspection PyProtectedMember
    fetched_events = await exporter._fetch_events_within_time_range()

    assert len(fetched_events) == len(events)
Beispiel #4
0
def test_sql_tracker_store_with_login_db_db_already_exists(
    postgres_login_db_connection: sa.engine.Connection,
):
    postgres_login_db_connection.execution_options(
        isolation_level="AUTOCOMMIT"
    ).execute(f"CREATE DATABASE {POSTGRES_TRACKER_STORE_DB}")

    tracker_store = SQLTrackerStore(
        dialect="postgresql",
        host=POSTGRES_HOST,
        port=POSTGRES_PORT,
        username=POSTGRES_USER,
        password=POSTGRES_PASSWORD,
        db=POSTGRES_TRACKER_STORE_DB,
        login_db=POSTGRES_LOGIN_DB,
    )

    matching_rows = (
        postgres_login_db_connection.execution_options(isolation_level="AUTOCOMMIT")
        .execute(
            sa.text(
                "SELECT 1 FROM pg_catalog.pg_database WHERE datname = :database_name"
            ),
            database_name=POSTGRES_TRACKER_STORE_DB,
        )
        .rowcount
    )
    assert matching_rows == 1
    tracker_store.engine.dispose()
Beispiel #5
0
def test_sql_additional_events(default_domain: Domain):
    tracker_store = SQLTrackerStore(default_domain)
    additional_events, tracker = create_tracker_with_partially_saved_events(
        tracker_store)

    # make sure only new events are returned
    with tracker_store.session_scope() as session:
        # noinspection PyProtectedMember
        assert (list(tracker_store._additional_events(
            session, tracker)) == additional_events)
def test_sql_additional_events_with_session_start(default_domain: Domain):
    sender = "test_sql_additional_events_with_session_start"
    tracker_store = SQLTrackerStore(default_domain)
    tracker = _saved_tracker_with_multiple_session_starts(tracker_store, sender)

    tracker.update(UserUttered("hi2"), default_domain)

    # make sure only new events are returned
    with tracker_store.session_scope() as session:
        # noinspection PyProtectedMember
        additional_events = list(tracker_store._additional_events(session, tracker))
        assert len(additional_events) == 1
        assert isinstance(additional_events[0], UserUttered)
Beispiel #7
0
def test_sql_tracker_store_with_login_db_race_condition(
    postgres_login_db_connection: sa.engine.Connection,
    caplog: LogCaptureFixture,
    monkeypatch: MonkeyPatch,
):
    original_execute = sa.engine.Connection.execute

    def mock_execute(self, *args, **kwargs):
        # this simulates a race condition
        if kwargs == {"database_name": POSTGRES_TRACKER_STORE_DB}:
            original_execute(
                self.execution_options(isolation_level="AUTOCOMMIT"),
                f"CREATE DATABASE {POSTGRES_TRACKER_STORE_DB}",
            )
            return Mock(rowcount=0)
        else:
            return original_execute(self, *args, **kwargs)

    with monkeypatch.context() as mp:
        mp.setattr(sa.engine.Connection, "execute", mock_execute)
        with caplog.at_level(logging.ERROR):
            tracker_store = SQLTrackerStore(
                dialect="postgresql",
                host=POSTGRES_HOST,
                port=POSTGRES_PORT,
                username=POSTGRES_USER,
                password=POSTGRES_PASSWORD,
                db=POSTGRES_TRACKER_STORE_DB,
                login_db=POSTGRES_LOGIN_DB,
            )

    # IntegrityError has been caught and we log the error
    assert any(
        [
            f"Could not create database '{POSTGRES_TRACKER_STORE_DB}'" in record.message
            for record in caplog.records
        ]
    )
    matching_rows = (
        postgres_login_db_connection.execution_options(isolation_level="AUTOCOMMIT")
        .execute(
            sa.text(
                "SELECT 1 FROM pg_catalog.pg_database WHERE datname = :database_name"
            ),
            database_name=POSTGRES_TRACKER_STORE_DB,
        )
        .rowcount
    )
    assert matching_rows == 1
    tracker_store.engine.dispose()
Beispiel #8
0
def marker_sqlite_tracker(tmp_path: Path) -> Tuple[SQLTrackerStore, Text]:
    domain = Domain.empty()
    db_path = str(tmp_path / "rasa.db")
    tracker_store = SQLTrackerStore(dialect="sqlite", db=db_path)
    for i in range(5):
        tracker = DialogueStateTracker(str(i), None)
        tracker.update_with_events([SlotSet(str(j), "slot") for j in range(5)],
                                   domain)
        tracker.update(ActionExecuted(ACTION_SESSION_START_NAME))
        tracker.update(UserUttered("hello"))
        tracker.update_with_events(
            [SlotSet(str(5 + j), "slot") for j in range(5)], domain)
        tracker_store.save(tracker)

    return tracker_store, db_path
Beispiel #9
0
def test_sql_additional_events(
        domain: Domain, retrieve_events_from_previous_conversation_sessions):
    tracker_store = SQLTrackerStore(
        domain,
        retrieve_events_from_previous_conversation_sessions=
        retrieve_events_from_previous_conversation_sessions,
    )
    additional_events, tracker = create_tracker_with_partially_saved_events(
        tracker_store)

    # make sure only new events are returned
    with tracker_store.session_scope() as session:
        # noinspection PyProtectedMember
        assert (list(tracker_store._additional_events(
            session, tracker)) == additional_events)
def test_sql_tracker_store_logs_do_not_show_password(caplog: LogCaptureFixture):
    dialect = "postgresql"
    host = "localhost"
    port = 9901
    db = "some-database"
    username = "******"
    password = "******"

    with caplog.at_level(logging.DEBUG):
        _ = SQLTrackerStore(None, dialect, host, port, db, username, password)

    # the URL in the logs does not contain the password
    assert password not in caplog.text

    # instead the password is displayed as '***'
    assert f"postgresql://{username}:***@{host}:{port}/{db}" in caplog.text
def test_load_sessions(tmp_path):
    """Tests loading a tracker with multiple sessions."""
    domain = Domain.empty()
    store = SQLTrackerStore(domain, db=os.path.join(tmp_path, "temp.db"))
    tracker = DialogueStateTracker("test123", None)
    tracker.update_with_events(
        [
            UserUttered("0"),
            UserUttered("1"),
            SessionStarted(),
            UserUttered("2"),
            UserUttered("3"),
        ],
        domain,
    )
    store.save(tracker)

    loader = MarkerTrackerLoader(store, STRATEGY_ALL)
    result = list(loader.load())
    assert len(result) == 1  # contains only one tracker
    assert len(result[0].events) == len(tracker.events)
def test_tracker_store_deprecated_session_retrieval_kwarg():
    tracker_store = SQLTrackerStore(
        Domain.empty(),
        retrieve_events_from_previous_conversation_sessions=True)

    conversation_id = uuid.uuid4().hex
    tracker = DialogueStateTracker.from_events(
        conversation_id,
        [
            ActionExecuted(ACTION_SESSION_START_NAME),
            SessionStarted(),
            UserUttered("hi"),
        ],
    )

    mocked_retrieve_full_tracker = Mock()
    tracker_store.retrieve_full_tracker = mocked_retrieve_full_tracker

    tracker_store.save(tracker)

    _ = tracker_store.retrieve(conversation_id)

    mocked_retrieve_full_tracker.assert_called_once()
def test_login_db_with_no_postgresql(tmp_path: Path):
    with pytest.warns(UserWarning):
        SQLTrackerStore(db=str(tmp_path / "rasa.db"), login_db="other")
Beispiel #14
0
def _get_rasa_x_tracker_store() -> SQLTrackerStore:
    return SQLTrackerStore(Domain.empty(), db="tracker.db")
Beispiel #15
0
def test_sql_tracker_store_with_token_serialisation(
    domain: Domain, response_selector_agent: Agent
):
    tracker_store = SQLTrackerStore(domain, **{"host": "sqlite:///"})
    prepare_token_serialisation(tracker_store, response_selector_agent, "sql")
Beispiel #16
0
def _get_rasa_x_tracker_store(endpoints_file: Optional[Text]) -> TrackerStore:
    if endpoints_file and os.path.exists(endpoints_file):
        return _get_tracker_store_from_endpoints_config(endpoints_file)
    else:
        return SQLTrackerStore(Domain.empty(), db="tracker.db")