コード例 #1
0
def test_load_first_n(marker_trackerstore: TrackerStore):
    """Tests loading trackers using 'first_n' strategy."""
    loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_FIRST_N, 3)
    result = list(loader.load())

    assert len(result) == 3

    for item in result:
        assert marker_trackerstore.exists(item.sender_id)
コード例 #2
0
def test_load_all(marker_trackerstore: TrackerStore):
    """Tests loading trackers using 'all' strategy."""
    loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_ALL)
    result = list(loader.load())

    assert len(result) == len(list(marker_trackerstore.keys()))

    for item in result:
        assert marker_trackerstore.exists(item.sender_id)
コード例 #3
0
def test_load_sample(marker_trackerstore: TrackerStore):
    """Tests loading trackers using 'sample' strategy."""
    loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_SAMPLE_N, 3)
    result = list(loader.load())

    assert len(result) == 3

    senders = set()
    for item in result:
        assert marker_trackerstore.exists(item.sender_id)
        assert item.sender_id not in senders
        senders.add(item.sender_id)
コード例 #4
0
def test_load_sample_with_seed(marker_trackerstore: TrackerStore):
    """Tests loading trackers using 'sample' strategy with seed set."""
    loader = MarkerTrackerLoader(marker_trackerstore,
                                 STRATEGY_SAMPLE_N,
                                 3,
                                 seed=3)
    result = list(loader.load())
    expected_ids = ["1", "4", "3"]

    assert len(result) == 3

    for item, expected in zip(result, expected_ids):
        assert item.sender_id == expected
        assert marker_trackerstore.exists(item.sender_id)
コード例 #5
0
async def test_markers_cli_results_save_correctly(tmp_path: Path):
    domain = Domain.empty()
    store = InMemoryTrackerStore(domain)

    for i in range(5):
        tracker = DialogueStateTracker(str(i), None)
        tracker.update_with_events([SlotSet(str(j), "slot") for j in range(5)],
                                   domain)
        tracker.update(ActionExecuted(ACTION_SESSION_START_NAME))
        tracker.update(UserUttered("hello"))
        tracker.update_with_events(
            [SlotSet(str(5 + j), "slot") for j in range(5)], domain)
        await store.save(tracker)

    tracker_loader = MarkerTrackerLoader(store, "all")

    results_path = tmp_path / "results.csv"

    markers = OrMarker(markers=[
        SlotSetMarker("2", name="marker1"),
        SlotSetMarker("7", name="marker2")
    ])
    await markers.evaluate_trackers(tracker_loader.load(), results_path)

    with open(results_path, "r") as results:
        result_reader = csv.DictReader(results)
        senders = set()

        for row in result_reader:
            senders.add(row["sender_id"])
            if row["marker"] == "marker1":
                assert row["session_idx"] == "0"
                assert int(row["event_idx"]) >= 2
                assert row["num_preceding_user_turns"] == "0"

            if row["marker"] == "marker2":
                assert row["session_idx"] == "1"
                assert int(row["event_idx"]) >= 3
                assert row["num_preceding_user_turns"] == "1"

        assert len(senders) == 5
コード例 #6
0
def test_load_sessions(tmp_path):
    """Tests loading a tracker with multiple sessions."""
    domain = Domain.empty()
    store = SQLTrackerStore(domain, db=os.path.join(tmp_path, "temp.db"))
    tracker = DialogueStateTracker("test123", None)
    tracker.update_with_events(
        [
            UserUttered("0"),
            UserUttered("1"),
            SessionStarted(),
            UserUttered("2"),
            UserUttered("3"),
        ],
        domain,
    )
    store.save(tracker)

    loader = MarkerTrackerLoader(store, STRATEGY_ALL)
    result = list(loader.load())
    assert len(result) == 1  # contains only one tracker
    assert len(result[0].events) == len(tracker.events)
コード例 #7
0
def test_warn_count_exceeds_store(marker_trackerstore: TrackerStore):
    """Tests a warning is thrown when 'count' is larger than the number of trackers."""
    loader = MarkerTrackerLoader(marker_trackerstore, STRATEGY_SAMPLE_N, 6)
    with pytest.warns(UserWarning):
        # Need to force the generator to evaluate to produce the warning
        list(loader.load())