示例#1
0
文件: test.py 项目: spawn08/rasa
async def _get_e2e_entity_evaluation_result(
    processor: "MessageProcessor",
    tracker: DialogueStateTracker,
    prediction: PolicyPrediction,
) -> Optional[EntityEvaluationResult]:
    previous_event = tracker.events[-1]

    if isinstance(previous_event, SlotSet):
        # UserUttered events with entities can be followed by SlotSet events
        # if slots are defined in the domain
        previous_event = tracker.get_last_event_for(
            (UserUttered, ActionExecuted))

    if isinstance(previous_event, UserUttered):
        entities_predicted_by_policies = [
            entity for prediction_event in prediction.events
            if isinstance(prediction_event, EntitiesAdded)
            for entity in prediction_event.entities
        ]
        entity_targets = previous_event.entities
        if entity_targets or entities_predicted_by_policies:
            text = previous_event.text
            if text:
                parsed_message = await processor.parse_message(
                    UserMessage(text=text))
                if parsed_message:
                    tokens = [
                        Token(text[start:end], start, end)
                        for start, end in parsed_message.get(
                            TOKENS_NAMES[TEXT], [])
                    ]
                    return EntityEvaluationResult(
                        entity_targets, entities_predicted_by_policies, tokens,
                        text)
    return None
示例#2
0
def _get_e2e_entity_evaluation_result(
    processor: "MessageProcessor",
    tracker: DialogueStateTracker,
    prediction: PolicyPrediction,
) -> Optional[EntityEvaluationResult]:
    previous_event = tracker.events[-1]
    if isinstance(previous_event, UserUttered):
        entities_predicted_by_policies = [
            entity for prediction_event in prediction.events
            if isinstance(prediction_event, EntitiesAdded)
            for entity in prediction_event.entities
        ]
        entity_targets = previous_event.entities
        if entity_targets or entities_predicted_by_policies:
            text = previous_event.text
            parsed_message = processor.interpreter.featurize_message(
                Message(data={TEXT: text}))
            tokens = parsed_message.get(TOKENS_NAMES[TEXT])
            return EntityEvaluationResult(entity_targets,
                                          entities_predicted_by_policies,
                                          tokens, text)
示例#3
0
        "start": 42,
        "end": 56,
        "value": "Alexanderplatz",
        "entity": "location",
        "extractor": "EntityExtractorA",
    },
    {
        "start": 42,
        "end": 64,
        "value": "Alexanderplatz tonight",
        "entity": "movie",
        "extractor": "EntityExtractorB",
    },
]

EN_entity_result = EntityEvaluationResult(EN_targets, EN_predicted, EN_tokens)

EN_entity_result_no_tokens = EntityEvaluationResult(EN_targets, EN_predicted,
                                                    [])


def test_token_entity_intersection():
    # included
    intsec = determine_intersection(CH_correct_segmentation[1],
                                    CH_correct_entity)
    assert intsec == len(CH_correct_segmentation[1].text)

    # completely outside
    intsec = determine_intersection(CH_correct_segmentation[2],
                                    CH_correct_entity)
    assert intsec == 0
示例#4
0
        "start": 42,
        "end": 56,
        "value": "Alexanderplatz",
        "entity": "location",
        "extractor": "EntityExtractorA",
    },
    {
        "start": 42,
        "end": 64,
        "value": "Alexanderplatz tonight",
        "entity": "movie",
        "extractor": "EntityExtractorB",
    },
]

EN_entity_result = EntityEvaluationResult(
    EN_targets, EN_predicted, EN_tokens, " ".join([t.text for t in EN_tokens]))

EN_entity_result_no_tokens = EntityEvaluationResult(EN_targets, EN_predicted,
                                                    [], "")


def test_token_entity_intersection():
    # included
    intsec = determine_intersection(CH_correct_segmentation[1],
                                    CH_correct_entity)
    assert intsec == len(CH_correct_segmentation[1].text)

    # completely outside
    intsec = determine_intersection(CH_correct_segmentation[2],
                                    CH_correct_entity)
    assert intsec == 0