Ejemplo n.º 1
0
def test_evaluation_store_serialise(entity_predictions, entity_targets):
    from rasa.shared.nlu.training_data.formats.readerwriter import TrainingDataWriter

    store = EvaluationStore(entity_predictions=entity_predictions,
                            entity_targets=entity_targets)

    targets, predictions = store.serialise()

    assert len(targets) == len(predictions)

    i_pred = 0
    i_target = 0
    for i, prediction in enumerate(predictions):
        target = targets[i]
        if prediction != "None" and target != "None":
            predicted = entity_predictions[i_pred]
            assert prediction == TrainingDataWriter.generate_entity(
                predicted.get("text"), predicted)
            assert predicted.get("start") == entity_targets[i_target].get(
                "start")
            assert predicted.get("end") == entity_targets[i_target].get("end")

        if prediction != "None":
            i_pred += 1
        if target != "None":
            i_target += 1
Ejemplo n.º 2
0
 def _generate_entity_training_data(entity: Dict[Text, Any]) -> Text:
     return TrainingDataWriter.generate_entity(entity.get("text"), entity)