Esempio n. 1
0
def test_form_wins_over_everything_else(ensemble: SimplePolicyEnsemble):
    form_name = "test-form"
    domain = f"""
    forms:
    - {form_name}
    """
    domain = Domain.from_yaml(domain)

    events = [
        ActiveLoop("test-form"),
        ActionExecuted(ACTION_LISTEN_NAME),
        utilities.user_uttered("test", 1),
    ]
    tracker = DialogueStateTracker.from_events("test", events, [])
    prediction = ensemble.probabilities_using_best_policy(
        tracker, domain, RegexInterpreter())

    next_action = rasa.core.actions.action.action_for_index(
        prediction.max_confidence_index, domain, None)

    index_of_form_policy = 0
    assert (prediction.policy_name ==
            f"policy_{index_of_form_policy}_{FormPolicy.__name__}")
    assert next_action.name() == form_name
Esempio n. 2
0
def test_policy_priority():
    domain = Domain.load("data/test_domains/default.yml")
    tracker = DialogueStateTracker.from_events("test", [UserUttered("hi")], [])

    priority_1 = ConstantPolicy(priority=1, predict_index=0)
    priority_2 = ConstantPolicy(priority=2, predict_index=1)

    policy_ensemble_0 = SimplePolicyEnsemble([priority_1, priority_2])
    policy_ensemble_1 = SimplePolicyEnsemble([priority_2, priority_1])

    priority_2_result = priority_2.predict_action_probabilities(
        tracker, domain)

    i = 1  # index of priority_2 in ensemble_0
    result, best_policy = policy_ensemble_0.probabilities_using_best_policy(
        tracker, domain)
    assert best_policy == "policy_{}_{}".format(i, type(priority_2).__name__)
    assert result.tolist() == priority_2_result

    i = 0  # index of priority_2 in ensemble_1
    result, best_policy = policy_ensemble_1.probabilities_using_best_policy(
        tracker, domain)
    assert best_policy == "policy_{}_{}".format(i, type(priority_2).__name__)
    assert result.tolist() == priority_2_result
Esempio n. 3
0
    max_confidence_index = result.index(max(result))
    next_action = domain.action_for_index(max_confidence_index, None)

    index_of_mapping_policy = 0
    assert best_policy == f"policy_{index_of_mapping_policy}_{MappingPolicy.__name__}"
    assert next_action.name() == ACTION_RESTART_NAME


@pytest.mark.parametrize(
    "ensemble",
    [
        SimplePolicyEnsemble(
            [
                FormPolicy(),
                ConstantPolicy(FORM_POLICY_PRIORITY - 1, 0),
                FallbackPolicy(),
            ]
        ),
        SimplePolicyEnsemble([FormPolicy(), MappingPolicy()]),
    ],
)
def test_form_wins_over_everything_else(ensemble: SimplePolicyEnsemble):
    form_name = "test-form"
    domain = f"""
    forms:
    - {form_name}
    """
    domain = Domain.from_yaml(domain)

    events = [
Esempio n. 4
0
def test_is_not_in_training_data(policy_name: Text,
                                 confidence: Optional[float],
                                 not_in_training_data: bool):
    assert (SimplePolicyEnsemble.is_not_in_training_data(
        policy_name, confidence) == not_in_training_data)