예제 #1
0
def test_evaluation_script():
    actual, preds = collect_story_predictions(
        story_file="examples/concerts/data/stories.md",
        policy_model_path="examples/concerts/models/policy/init",
        nlu_model_path=None,
        max_stories=None,
        shuffle_stories=False)
    assert len(actual) == 14
    assert len(preds) == 14
예제 #2
0
def test_end_to_end_evaluation_script(tmpdir, default_agent):
    completed_trackers = evaluate._generate_trackers(END_TO_END_STORY_FILE,
                                                     default_agent,
                                                     use_e2e=True)

    evaluation_result, failed_stories, _ = collect_story_predictions(
        completed_trackers, default_agent, use_e2e=True)

    assert not evaluation_result.has_prediction_target_mismatch()
    assert len(failed_stories) == 0
예제 #3
0
def test_action_evaluation_script(tmpdir, default_agent):
    completed_trackers = evaluate._generate_trackers(DEFAULT_STORIES_FILE,
                                                     default_agent,
                                                     use_e2e=False)

    evaluation_result, failed_stories, _ = collect_story_predictions(
        completed_trackers, default_agent, use_e2e=False)

    assert not evaluation_result.has_prediction_target_mismatch()
    assert len(failed_stories) == 0
예제 #4
0
def test_evaluation_script(tmpdir, default_agent):
    completed_trackers = evaluate._generate_trackers(DEFAULT_STORIES_FILE,
                                                     default_agent)

    golds, predictions, failed_stories = collect_story_predictions(
        completed_trackers, default_agent)

    assert len(golds) == 14
    assert len(predictions) == 14
    assert len(failed_stories) == 0
예제 #5
0
def test_action_evaluation_script(tmpdir, default_agent):
    completed_trackers = evaluate._generate_trackers(DEFAULT_STORIES_FILE,
                                                     default_agent,
                                                     use_e2e=False)
    story_evaluation, num_stories = collect_story_predictions(
        completed_trackers, default_agent, use_e2e=False)

    assert not story_evaluation.evaluation_store. \
        has_prediction_target_mismatch()
    assert len(story_evaluation.failed_stories) == 0
    assert num_stories == 3
예제 #6
0
def test_evaluation_script(tmpdir, default_agent):
    model_path = os.path.join(tmpdir.strpath, "model")
    default_agent.persist(model_path)
    actual, preds = collect_story_predictions(
        resource_name=DEFAULT_STORIES_FILE,
        policy_model_path=model_path,
        nlu_model_path=None,
        max_stories=None,
        shuffle_stories=False)
    assert len(actual) == 14
    assert len(preds) == 14
예제 #7
0
def test_evaluation_script(tmpdir, default_agent):
    model_path = tmpdir.join("model").strpath
    default_agent.persist(model_path)

    actual, preds, failed_stories = collect_story_predictions(
        resource_name=DEFAULT_STORIES_FILE,
        policy_model_path=model_path,
        nlu_model_path=None,
        max_stories=None)
    assert len(actual) == 14
    assert len(preds) == 14
    assert len(failed_stories) == 0
예제 #8
0
def test_end_to_end_evaluation_script(tmpdir, default_agent):
    completed_trackers = evaluate._generate_trackers(END_TO_END_STORY_FILE,
                                                     default_agent,
                                                     use_e2e=True)

    story_evaluation, num_stories = collect_story_predictions(
        completed_trackers, default_agent, use_e2e=True)

    assert not story_evaluation.evaluation_store. \
        has_prediction_target_mismatch()
    assert len(story_evaluation.failed_stories) == 0
    assert num_stories == 2
예제 #9
0
def test_end_to_end_evaluation_script_unknown_entity(tmpdir, default_agent):
    completed_trackers = evaluate._generate_trackers(
        E2E_STORY_FILE_UNKNOWN_ENTITY, default_agent, use_e2e=True)

    story_evaluation, num_stories = collect_story_predictions(
        completed_trackers,
        default_agent,
        use_e2e=True)

    assert story_evaluation.evaluation_store. \
        has_prediction_target_mismatch()
    assert len(story_evaluation.failed_stories) == 1
    assert num_stories == 1
예제 #10
0
def run_evaluation(file_to_evaluate,
                   fail_on_prediction_errors=False,
                   max_stories=None,
                   use_e2e=False):

    _endpoints = AvailableEndpoints.read_endpoints(None)
    _interpreter = NaturalLanguageInterpreter.create(NLU_DIR)
    _agent = load_agent(CORE_DIR,
                        interpreter=_interpreter,
                        endpoints=_endpoints)

    completed_trackers = _generate_trackers(file_to_evaluate, _agent,
                                            max_stories, use_e2e)
    story_evaluation, _ = collect_story_predictions(completed_trackers, _agent,
                                                    fail_on_prediction_errors,
                                                    use_e2e)
    _failed_stories = story_evaluation.failed_stories

    _num_stories = len(completed_trackers)
    _file_result = FileResult(num_stories=_num_stories,
                              num_failed_stories=len(_failed_stories))

    file_message = "EVALUATING STORIES FOR FILE '{}':".format(file_to_evaluate)
    utils.print_color('\n' + '#' * 80, BOLD_COLOR)
    utils.print_color(file_message, BOLD_COLOR)

    files_results[file_to_evaluate] = _file_result

    if len(_failed_stories) == 0:
        success_message = "The stories have passed for file '{}'!!" \
                          .format(file_to_evaluate)
        utils.print_color('\n' + '=' * len(success_message), BLUE_COLOR)
        utils.print_color(success_message, BLUE_COLOR)
        utils.print_color('=' * len(success_message), BLUE_COLOR)
    else:
        for failed_story in _failed_stories:
            process_failed_story(failed_story.export_stories())
            story_name = re.search('## (.*)',
                                   failed_story.export_stories()).group(1)
            all_failed_stories.append(file_to_evaluate + ' - ' + story_name)

    utils.print_color('#' * 80 + '\n', BOLD_COLOR)