Example #1
0
def test_differently_ordered_entity_predictions_tolerated():
    """The order in which entities were extracted shouldn't matter.

    Let's have an utterance like this: "[Researcher](job_name) from [Germany](country)."
    and imagine we use different entity extractors for the two entities. Then, the order
    in which entities are extracted from the utterance depends on the order in which the
    extractors are listed in the NLU pipeline. However, the expected order is given by
    where the entities are found in the utterance, i.e. "Researcher" comes before
    "Germany". Hence, it's reasonable for the expected and extracted order to not match
    and it shouldn't be flagged as a prediction error.

    """
    entity1 = {
        ENTITY_ATTRIBUTE_TEXT: "Algeria and Albania",
        ENTITY_ATTRIBUTE_START: 0,
        ENTITY_ATTRIBUTE_END: 7,
        ENTITY_ATTRIBUTE_VALUE: "Algeria",
        ENTITY_ATTRIBUTE_TYPE: "country",
    }
    entity2 = {
        ENTITY_ATTRIBUTE_TEXT: "Algeria and Albania",
        ENTITY_ATTRIBUTE_START: 12,
        ENTITY_ATTRIBUTE_END: 19,
        ENTITY_ATTRIBUTE_VALUE: "Albania",
        ENTITY_ATTRIBUTE_TYPE: "country",
    }
    evaluation = EvaluationStore(entity_predictions=[entity1, entity2],
                                 entity_targets=[entity2, entity1])
    assert not evaluation.check_prediction_target_mismatch()
Example #2
0
def test_evaluation_store_serialise(entity_predictions, entity_targets):
    from rasa.shared.nlu.training_data.formats.readerwriter import TrainingDataWriter

    store = EvaluationStore(entity_predictions=entity_predictions,
                            entity_targets=entity_targets)

    targets, predictions = store.serialise()

    assert len(targets) == len(predictions)

    i_pred = 0
    i_target = 0
    for i, prediction in enumerate(predictions):
        target = targets[i]
        if prediction != "None" and target != "None":
            predicted = entity_predictions[i_pred]
            assert prediction == TrainingDataWriter.generate_entity(
                predicted.get("text"), predicted)
            assert predicted.get("start") == entity_targets[i_target].get(
                "start")
            assert predicted.get("end") == entity_targets[i_target].get("end")

        if prediction != "None":
            i_pred += 1
        if target != "None":
            i_target += 1
Example #3
0
def test_write_classification_errors():
    evaluation = EvaluationStore(
        action_predictions=["utter_goodbye"],
        action_targets=["utter_greet"],
        intent_predictions=["goodbye"],
        intent_targets=["greet"],
        entity_predictions=None,
        entity_targets=None,
    )
    events = [
        WronglyClassifiedUserUtterance(
            UserUttered("Hello", {"name": "goodbye"}), evaluation),
        WronglyPredictedAction("utter_greet", "", "utter_goodbye"),
    ]
    tracker = DialogueStateTracker.from_events("default", events)
    dump = YAMLStoryWriter().dumps(tracker.as_story().story_steps)
    assert (dump.strip() == textwrap.dedent("""
        version: "3.0"
        stories:
        - story: default
          steps:
          - intent: greet  # predicted: goodbye: Hello
          - action: utter_greet  # predicted: utter_goodbye

    """).strip())
Example #4
0
def test_duplicated_entity_predictions_tolerated():
    """Same entity extracted multiple times shouldn't be flagged as prediction error.

    This can happen when multiple entity extractors extract the same entity but a test
    story only lists the entity once. For completeness, the other case (entity listed
    twice in test story and extracted once) is also tested here because it should work
    the same way.
    """
    entity = {
        ENTITY_ATTRIBUTE_TEXT: "Algeria",
        ENTITY_ATTRIBUTE_START: 0,
        ENTITY_ATTRIBUTE_END: 7,
        ENTITY_ATTRIBUTE_VALUE: "Algeria",
        ENTITY_ATTRIBUTE_TYPE: "country",
    }
    evaluation_with_duplicated_prediction = EvaluationStore(
        entity_predictions=[entity, entity], entity_targets=[entity])
    assert not evaluation_with_duplicated_prediction.check_prediction_target_mismatch(
    )

    evaluation_with_duplicated_target = EvaluationStore(
        entity_predictions=[entity], entity_targets=[entity, entity])
    assert not evaluation_with_duplicated_target.check_prediction_target_mismatch(
    )
Example #5
0
                            "id": 12,
                            "confidence": 0.03813415765762329,
                            "intent_response_key": "chitchat/ask_weather",
                        },
                    ],
                },
            },
        },
    ),
    SessionStarted(),
    ActionExecuted(action_name="action_listen"),
    AgentUttered(),
    EndToEndUserUtterance(),
    WronglyClassifiedUserUtterance(
        event=UserUttered(),
        eval_store=EvaluationStore(intent_targets=["test"])),
    WronglyPredictedAction(
        action_name_prediction="test",
        action_name_target="demo",
        action_text_target="example",
    ),
    WarningPredictedAction(action_name_prediction="test"),
]


@pytest.mark.parametrize("event", tested_events)
def test_event_fingerprint_consistency(event: Event):
    f1 = event.fingerprint()

    event2 = copy.deepcopy(event)
    f2 = event2.fingerprint()
Example #6
0
                        {
                            "id": 12,
                            "confidence": 0.03813415765762329,
                            "intent_response_key": "chitchat/ask_weather",
                        },
                    ],
                },
            },
        },
    ),
    SessionStarted(),
    ActionExecuted(action_name="action_listen"),
    AgentUttered(),
    EndToEndUserUtterance(),
    WronglyClassifiedUserUtterance(
        event=UserUttered(), eval_store=EvaluationStore(intent_targets=["test"])
    ),
    WronglyPredictedAction(
        action_name_prediction="test",
        action_name_target="demo",
        action_text_target="example",
    ),
    WarningPredictedAction(action_name="action_listen", action_name_prediction="test"),
]


@pytest.mark.parametrize("event", tested_events)
def test_event_fingerprint_consistency(event: Event):
    f1 = event.fingerprint()

    event2 = copy.deepcopy(event)