Пример #1
0
def test_get_next_action_probabilities_passes_interpreter_to_policies(
    monkeypatch: MonkeyPatch, ):
    policy = TEDPolicy()
    test_interpreter = Mock()

    def predict_action_probabilities(
        tracker: DialogueStateTracker,
        domain: Domain,
        interpreter: NaturalLanguageInterpreter,
        **kwargs,
    ) -> List[float]:
        assert interpreter == test_interpreter
        return [1, 0]

    policy.predict_action_probabilities = predict_action_probabilities
    ensemble = SimplePolicyEnsemble(policies=[policy])

    domain = Domain.empty()

    processor = MessageProcessor(test_interpreter, ensemble, domain,
                                 InMemoryTrackerStore(domain), Mock())

    # This should not raise
    processor._get_next_action_probabilities(
        DialogueStateTracker.from_events("lala",
                                         [ActionExecuted(ACTION_LISTEN_NAME)]))
Пример #2
0
def test_get_next_action_probabilities_pass_policy_predictions_without_interpreter_arg(
    predict_function: Callable,
):
    policy = TEDPolicy()

    policy.predict_action_probabilities = predict_function

    ensemble = SimplePolicyEnsemble(policies=[policy])
    interpreter = Mock()
    domain = Domain.empty()

    processor = MessageProcessor(
        interpreter,
        ensemble,
        domain,
        InMemoryTrackerStore(domain),
        InMemoryLockStore(),
        Mock(),
    )

    with pytest.warns(DeprecationWarning):
        processor._get_next_action_probabilities(
            DialogueStateTracker.from_events(
                "lala", [ActionExecuted(ACTION_LISTEN_NAME)]
            )
        )
Пример #3
0
def test_predict_next_action_raises_limit_reached_exception(domain: Domain):
    interpreter = RegexInterpreter()
    ensemble = SimplePolicyEnsemble(
        policies=[RulePolicy(), MemoizationPolicy()])
    tracker_store = InMemoryTrackerStore(domain)
    lock_store = InMemoryLockStore()

    processor = MessageProcessor(
        interpreter,
        ensemble,
        domain,
        tracker_store,
        lock_store,
        TemplatedNaturalLanguageGenerator(domain.responses),
        max_number_of_predictions=1,
    )

    tracker = DialogueStateTracker.from_events(
        "test",
        evts=[
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered("Hi!"),
            ActionExecuted("test_action"),
        ],
    )
    tracker.set_latest_action({"action_name": "test_action"})

    with pytest.raises(ActionLimitReached):
        processor.predict_next_action(tracker)
Пример #4
0
def _in_training_data_fraction(action_list):
    """Given a list of action items, returns the fraction of actions

    that were predicted using one of the Memoization policies."""
    from rasa.core.policies import SimplePolicyEnsemble

    in_training_data = [
        a["action"] for a in action_list
        if not SimplePolicyEnsemble.is_not_memo_policy(a["policy"])
    ]

    return len(in_training_data) / len(action_list)
Пример #5
0
    async def train(self):
        """Train the engine.
        """
        nltk.download('punkt')
        lang = self.config['language']
        if not os.path.exists('data/' + self.config['skill-id']):
            _LOGGER.info("Starting Skill training.")
            _LOGGER.info("Generating stories.")
            data, domain_data, stories = await GenerateStories.run(
                self.config['skill-id'], self.config['language'], self.asm)
            training_data = TrainingData(training_examples=data)
            nlu_config = RasaNLUModelConfig({
                "language": lang,
                "pipeline": self.config['pipeline'],
                "data": None
            })

            trainer = Trainer(nlu_config, None, True)
            _LOGGER.info("Training Arcus NLU")
            trainer.train(training_data)
            trainer.persist("data/" + self.config['skill-id'], None, 'nlu')

            # Rasa core
            domain = Domain.from_dict(domain_data)

            reader = StoryFileReader(domain, RegexInterpreter(), None, False)
            story_steps = await reader.process_lines(stories)
            graph = StoryGraph(story_steps)

            g = TrainingDataGenerator(
                graph,
                domain,
                remove_duplicates=True,
                unique_last_num_states=None,
                augmentation_factor=20,
                tracker_limit=None,
                use_story_concatenation=True,
                debug_plots=False,
            )

            training_trackers = g.generate()
            policy_list = SimplePolicyEnsemble.from_dict(
                {"policies": self.config['policies']})
            policy_ensemble = SimplePolicyEnsemble(policy_list)

            _LOGGER.info("Training Arcus Core")
            policy_ensemble.train(training_trackers, domain)
            policy_ensemble.persist(
                "data/" + self.config['skill-id'] + "/core", False)
            domain.persist("data/" + self.config['skill-id'] + "/core/model")
            domain.persist_specification("data/" + self.config['skill-id'] +
                                         "/core")
Пример #6
0
def is_predicted_event_in_training_data(policy: Optional[Text]):
    """Determine whether event predicted by `policy` was in the training data.

    A predicted event is considered to be in training data if it was predicted by
    the `MemoizationPolicy` or the `AugmentedMemoizationPolicy`.

    Args:
        policy: Policy of the predicted event.

    Returns:
        `True` if the event was not predicted, otherwise `True` if it was not
        predicted by a memo policy, else `False`.
    """
    if not policy:
        # event was not predicted by a policy
        return True

    return not SimplePolicyEnsemble.is_not_memo_policy(policy)
Пример #7
0
    async def get_response(self, request):
        """Train the engine.
        """
        if self.config.get('domain') is None:
            self.config.setdefault(
                'domain',
                Domain.from_file("data/" + self.config['skill-id'] +
                                 "/core/model"))
            self.config.setdefault(
                'tracker_store',
                ArcusTrackerStore(self.config.get('domain'), self.asm))

        domain = self.config.get('domain')
        tracker_store = self.config.get('tracker_store')
        nlg = NaturalLanguageGenerator.create(None, domain)
        policy_ensemble = SimplePolicyEnsemble.load("data/" +
                                                    self.config['skill-id'] +
                                                    "/core")
        interpreter = LocalNLUInterpreter(request)

        url = 'http://localhost:8080/api/v1/skill/generic_action'
        processor = MessageProcessor(interpreter,
                                     policy_ensemble,
                                     domain,
                                     tracker_store,
                                     nlg,
                                     action_endpoint=EndpointConfig(url),
                                     message_preprocessor=None)

        message_nlu = UserMessage(request['text'],
                                  None,
                                  request['user'],
                                  input_channel=request['channel'])

        result = await processor.handle_message(message_nlu)
        if result is not None and len(result) > 0:
            return {"text": result[0]['text']}
        else:
            _LOGGER.info(result)
            return {"text": "error"}
Пример #8
0
def test_predict_next_action_with_hidden_rules():
    rule_intent = "rule_intent"
    rule_action = "rule_action"
    story_intent = "story_intent"
    story_action = "story_action"
    rule_slot = "rule_slot"
    story_slot = "story_slot"
    domain = Domain.from_yaml(f"""
        version: "2.0"
        intents:
        - {rule_intent}
        - {story_intent}
        actions:
        - {rule_action}
        - {story_action}
        slots:
          {rule_slot}:
            type: text
          {story_slot}:
            type: text
        """)

    rule = TrackerWithCachedStates.from_events(
        "rule",
        domain=domain,
        slots=domain.slots,
        evts=[
            ActionExecuted(RULE_SNIPPET_ACTION_NAME),
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(intent={"name": rule_intent}),
            ActionExecuted(rule_action),
            SlotSet(rule_slot, rule_slot),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
        is_rule_tracker=True,
    )
    story = TrackerWithCachedStates.from_events(
        "story",
        domain=domain,
        slots=domain.slots,
        evts=[
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(intent={"name": story_intent}),
            ActionExecuted(story_action),
            SlotSet(story_slot, story_slot),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )
    interpreter = RegexInterpreter()
    ensemble = SimplePolicyEnsemble(
        policies=[RulePolicy(), MemoizationPolicy()])
    ensemble.train([rule, story], domain, interpreter)

    tracker_store = InMemoryTrackerStore(domain)
    lock_store = InMemoryLockStore()
    processor = MessageProcessor(
        interpreter,
        ensemble,
        domain,
        tracker_store,
        lock_store,
        TemplatedNaturalLanguageGenerator(domain.responses),
    )

    tracker = DialogueStateTracker.from_events(
        "casd",
        evts=[
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(intent={"name": rule_intent}),
        ],
        slots=domain.slots,
    )
    action, prediction = processor.predict_next_action(tracker)
    assert action._name == rule_action
    assert prediction.hide_rule_turn

    processor._log_action_on_tracker(tracker, action,
                                     [SlotSet(rule_slot, rule_slot)],
                                     prediction)

    action, prediction = processor.predict_next_action(tracker)
    assert isinstance(action, ActionListen)
    assert prediction.hide_rule_turn

    processor._log_action_on_tracker(tracker, action, None, prediction)

    tracker.events.append(UserUttered(intent={"name": story_intent}))

    # rules are hidden correctly if memo policy predicts next actions correctly
    action, prediction = processor.predict_next_action(tracker)
    assert action._name == story_action
    assert not prediction.hide_rule_turn

    processor._log_action_on_tracker(tracker, action,
                                     [SlotSet(story_slot, story_slot)],
                                     prediction)

    action, prediction = processor.predict_next_action(tracker)
    assert isinstance(action, ActionListen)
    assert not prediction.hide_rule_turn