示例#1
0
文件: test.py 项目: zhfneu/rasa_core
def collect_story_predictions(
        completed_trackers: List['DialogueStateTracker'],
        agent: 'Agent',
        fail_on_prediction_errors: bool = False,
        use_e2e: bool = False) -> Tuple[StoryEvalution, int]:
    """Test the stories from a file, running them through the stored model."""
    from rasa_nlu.test import get_evaluation_metrics
    from tqdm import tqdm

    story_eval_store = EvaluationStore()
    failed = []
    correct_dialogues = []
    num_stories = len(completed_trackers)

    logger.info("Evaluating {} stories\n" "Progress:".format(num_stories))

    action_list = []

    for tracker in tqdm(completed_trackers):
        tracker_results, predicted_tracker, tracker_actions = \
            _predict_tracker_actions(tracker, agent,
                                     fail_on_prediction_errors, use_e2e)

        story_eval_store.merge_store(tracker_results)

        action_list.extend(tracker_actions)

        if tracker_results.has_prediction_target_mismatch():
            # there is at least one wrong prediction
            failed.append(predicted_tracker)
            correct_dialogues.append(0)
        else:
            correct_dialogues.append(1)

    logger.info("Finished collecting predictions.")
    with warnings.catch_warnings():
        from sklearn.exceptions import UndefinedMetricWarning

        warnings.simplefilter("ignore", UndefinedMetricWarning)
        report, precision, f1, accuracy = get_evaluation_metrics(
            [1] * len(completed_trackers), correct_dialogues)

    in_training_data_fraction = _in_training_data_fraction(action_list)

    log_evaluation_table([1] * len(completed_trackers),
                         "END-TO-END" if use_e2e else "CONVERSATION",
                         report,
                         precision,
                         f1,
                         accuracy,
                         in_training_data_fraction,
                         include_report=False)

    return (StoryEvalution(
        evaluation_store=story_eval_store,
        failed_stories=failed,
        action_list=action_list,
        in_training_data_fraction=in_training_data_fraction), num_stories)
示例#2
0
文件: test.py 项目: zhfneu/rasa_core
async def test(stories: Text,
               agent: 'Agent',
               max_stories: Optional[int] = None,
               out_directory: Optional[Text] = None,
               fail_on_prediction_errors: bool = False,
               use_e2e: bool = False):
    """Run the evaluation of the stories, optionally plot the results."""
    from rasa_nlu.test import get_evaluation_metrics

    completed_trackers = await _generate_trackers(stories, agent, max_stories,
                                                  use_e2e)

    story_evaluation, _ = collect_story_predictions(completed_trackers, agent,
                                                    fail_on_prediction_errors,
                                                    use_e2e)

    evaluation_store = story_evaluation.evaluation_store

    with warnings.catch_warnings():
        from sklearn.exceptions import UndefinedMetricWarning

        warnings.simplefilter("ignore", UndefinedMetricWarning)
        report, precision, f1, accuracy = get_evaluation_metrics(
            evaluation_store.serialise_targets(),
            evaluation_store.serialise_predictions())

    if out_directory:
        plot_story_evaluation(evaluation_store.action_targets,
                              evaluation_store.action_predictions, report,
                              precision, f1, accuracy,
                              story_evaluation.in_training_data_fraction,
                              out_directory)

    log_failed_stories(story_evaluation.failed_stories, out_directory)

    return {
        "report": report,
        "precision": precision,
        "f1": f1,
        "accuracy": accuracy,
        "actions": story_evaluation.action_list,
        "in_training_data_fraction":
        story_evaluation.in_training_data_fraction,
        "is_end_to_end_evaluation": use_e2e
    }