Пример #1
0
def test_disamb_does_not_trigger_when_data_is_missing_in_parse_data():
    disambiguator = Disambiguator(
        disamb_rule=load_yaml('./tests/disambiguator/test_disambiguator4.yaml')
        ['disambiguation_policy'])

    parse_data = {}
    assert disambiguator.should_disambiguate(parse_data) is False
Пример #2
0
def test_buttons_without_fallback():
    disambiguator = Disambiguator(
        disamb_rule=load_yaml('./tests/disambiguator/test_disambiguator4.yaml')
        ['disambiguation_policy'])

    parse_data = {
        "intent_ranking": [{
            "name": "intentA",
            "confidence": 0.6
        }, {
            "name": "intentB",
            "confidence": 0.2
        }]
    }

    expected = {
        "text":
        "utter_disamb_text",
        "buttons": [{
            "title": "utter_disamb_intentA",
            "payload": "/intentA"
        }, {
            "title": "utter_disamb_intentB",
            "payload": "/intentB"
        }]
    }

    dispatcher = StubDispatcher()
    intents = disambiguator.get_intent_names(parse_data)
    assert ActionDisambiguate.get_disambiguation_message(
        dispatcher, disambiguator.disamb_rule,
        disambiguator.get_payloads(parse_data, intents), intents,
        tracker) == expected
Пример #3
0
 def update(self, rules_dict):
     self.rules_dict = rules_dict
     self.allowed_entities = rules_dict[
         "allowed_entities"] if rules_dict and "allowed_entities" in rules_dict else {}
     self.intent_substitutions = rules_dict[
         "intent_substitutions"] if rules_dict and "intent_substitutions" in rules_dict else []
     self.input_validation = InputValidator(
         rules_dict["input_validation"]
     ) if rules_dict and "input_validation" in rules_dict else []
     self.disambiguation_policy = Disambiguator(
         rules_dict.get("disambiguation_policy", None)
         if rules_dict else None,
         rules_dict.get("fallback_policy", None) if rules_dict else None)
Пример #4
0
def test_fallback_does_not_trigger_when_intent_is_null():
    disambiguator = Disambiguator(fallback_rule=load_yaml(
        './tests/disambiguator/test_disambiguator7.yaml')['fallback_policy'])

    parse_data = {
        "intent": {
            "name": "",
            "confidence": 0.0
        },
        "entities": [],
        "intent_ranking": []
    }

    assert disambiguator.should_fallback(parse_data) is False
Пример #5
0
def test_disamb_exclude_regex():
    disambiguator = Disambiguator(
        disamb_rule=load_yaml('./tests/disambiguator/test_disambiguator11.yaml'
                              )['disambiguation_policy'])

    parse_data = {
        "intent_ranking": [{
            "name": "chitchat.insults",
            "confidence": 0.3
        }, {
            "name": "intentA",
            "confidence": 0.2
        }, {
            "name": "chitchat.this_is_bad",
            "confidence": 0.2
        }, {
            "name": "basics.yes",
            "confidence": 0.15
        }, {
            "name": "intentB",
            "confidence": 0.15
        }],
        "entities": [{
            "entity": "entity1",
            "value": "value1"
        }]
    }

    expected = {
        "text":
        "utter_disamb_text",
        "buttons": [{
            "title": "utter_disamb_intentA",
            "payload": "/intentA{\"entity1\": \"value1\"}"
        }, {
            "title": "utter_disamb_intentB",
            "payload": "/intentB{\"entity1\": \"value1\"}"
        }, {
            "title": "utter_fallback",
            "payload": "/fallback"
        }]
    }

    dispatcher = StubDispatcher()
    intents = disambiguator.get_intent_names(parse_data)
    assert ActionDisambiguate.get_disambiguation_message(
        dispatcher, disambiguator.disamb_rule,
        disambiguator.get_payloads(parse_data, intents), intents,
        tracker) == expected
Пример #6
0
def test_fallback_trigger2():
    disambiguator = Disambiguator(fallback_rule=load_yaml(
        './tests/disambiguator/test_disambiguator6.yaml')['fallback_policy'])

    parse_data = {
        "intent_ranking": [{
            "name": "intentA",
            "confidence": 0.6
        }, {
            "name": "intentB",
            "confidence": 0.2
        }]
    }
    result = disambiguator.should_fallback(parse_data)
    assert result is False
Пример #7
0
def test_disamb_does_not_trigger_when_intent_is_null():
    disambiguator = Disambiguator(
        disamb_rule=load_yaml('./tests/disambiguator/test_disambiguator4.yaml')
        ['disambiguation_policy'])

    parse_data = {
        "intent": {
            "name": None,
            "confidence": 0.0
        },
        "entities": [],
        "intent_ranking": []
    }

    assert disambiguator.should_disambiguate(parse_data) is False
Пример #8
0
def test_fallback_buttons_with_fallback():
    disambiguator = Disambiguator(fallback_rule=load_yaml(
        './tests/disambiguator/test_disambiguator6.yaml')['fallback_policy'])

    parse_data = {
        "intent_ranking": [{
            "name": "intentA",
            "confidence": 0.49
        }, {
            "name": "intentB",
            "confidence": 0.3
        }]
    }

    expected = {
        "text":
        "utter_fallback_intro",
        "buttons": [{
            "title": "utter_fallback_yes",
            "payload": "/fallback"
        }, {
            "title": "utter_fallback_no",
            "payload": "/restart"
        }]
    }

    dispatcher = StubDispatcher()
    assert ActionFallback.get_fallback_message(dispatcher,
                                               disambiguator.fallback_rule,
                                               tracker) == expected
Пример #9
0
def test_disamb_trigger1():
    disambiguator = Disambiguator(
        disamb_rule=load_yaml('./tests/disambiguator/test_disambiguator2.yaml')
        ['disambiguation_policy'])

    parse_data = {
        "intent_ranking": [{
            "name": "intentA",
            "confidence": 0.6
        }, {
            "name": "intentB",
            "confidence": 0.5
        }]
    }
    result = disambiguator.should_disambiguate(parse_data)
    assert result is True
Пример #10
0
def test_disamb_and_fallback_trigger_none():
    rules = load_yaml('./tests/disambiguator/test_disambiguator9.yaml')
    disambiguator = Disambiguator(disamb_rule=rules['disambiguation_policy'],
                                  fallback_rule=rules['fallback_policy'])

    parse_data = {
        "intent_ranking": [{
            "name": "intentA",
            "confidence": 0.51
        }, {
            "name": "intentB",
            "confidence": 0.25
        }]
    }
    result = disambiguator.should_fallback(
        parse_data) or disambiguator.should_disambiguate(parse_data)
    assert result is False
Пример #11
0
def test_disamb_and_fallback_trigger_both():
    rules = load_yaml('./tests/disambiguator/test_disambiguator9.yaml')
    disambiguator = Disambiguator(disamb_rule=rules['disambiguation_policy'],
                                  fallback_rule=rules['fallback_policy'])

    parse_data = {
        "intent_ranking": [{
            "name": "intentA",
            "confidence": 0.49
        }, {
            "name": "intentB",
            "confidence": 0.4
        }]
    }
    # fallback has precedence and short-circuits disamb in the actual program flow as per rules.py
    result = disambiguator.should_fallback(
        parse_data) and disambiguator.should_disambiguate(parse_data)
    assert result is True
Пример #12
0
class Rules(object):
    def __init__(self, rules_dict):
        self.rules_dict = {}
        self.input_validation = None
        self.allowed_entities = None
        self.intent_substitutions = None
        self.disambiguation_policy = None
        self.actions_to_ignore = ['action_listen', 'action_invalid_utterance']
        self.update(rules_dict)

    def update(self, rules_dict):
        self.rules_dict = rules_dict
        self.allowed_entities = rules_dict[
            "allowed_entities"] if rules_dict and "allowed_entities" in rules_dict else {}
        self.intent_substitutions = rules_dict[
            "intent_substitutions"] if rules_dict and "intent_substitutions" in rules_dict else []
        self.input_validation = InputValidator(
            rules_dict["input_validation"]
        ) if rules_dict and "input_validation" in rules_dict else []
        self.disambiguation_policy = Disambiguator(
            rules_dict.get("disambiguation_policy", None)
            if rules_dict else None,
            rules_dict.get("fallback_policy", None) if rules_dict else None)

    def get(self):
        return self.rules_dict

    def interrupts(self, dispatcher, parse_data, tracker, run_action):
        parse_data['original_data'] = copy.deepcopy(parse_data)

        # fallback has precedence
        if self.disambiguation_policy.fallback(parse_data, tracker, dispatcher, run_action) or \
                self.disambiguation_policy.disambiguate(parse_data, tracker, dispatcher, run_action):
            return True

        self.run_swap_intent_rules(parse_data, tracker)

        self.filter_entities(parse_data)

        if self.input_validation:
            error_template = self.input_validation.get_error(
                parse_data, tracker)
            if error_template is not None:
                self._utter_error_and_roll_back(dispatcher, tracker,
                                                error_template, run_action)
                return True

        if {
                key: val
                for key, val in parse_data.items() if key != 'original_data'
        } == parse_data['original_data']:
            # Nothing has changed
            del parse_data['original_data']

    @staticmethod
    def _utter_error_and_roll_back(dispatcher, tracker, template, run_action):
        action = ActionInvalidUtterance(template)
        run_action(action, tracker, dispatcher)

    def filter_entities(self, parse_data):

        if parse_data['intent']['name'] in self.allowed_entities.keys():
            filtered = list(
                filter(
                    lambda ent: ent['entity'] in self.allowed_entities[
                        parse_data['intent']['name']], parse_data['entities']))
        else:
            filtered = parse_data['entities']

        if len(filtered) < len(parse_data['entities']):
            # logging first
            logger.warning("entity(ies) were removed from parse stories")
            parse_data['entities'] = filtered

    def run_swap_intent_rules(self, parse_data, tracker):
        # # don't do anything if no intent is present
        # if parse_data["intent"]["name"] is None or parse_data["intent"]["name"] == "":
        #     return

        previous_action = self._get_previous_action(tracker)

        for rule in self.intent_substitutions:
            if Rules._swap_intent(parse_data, previous_action, rule):
                break

    @staticmethod
    def _swap_intent(parse_data, previous_action, rule):
        # don't do anything if no intent is present
        # if parse_data["intent"]["name"] is None or parse_data["intent"]["name"] == "":
        #     return

        # for an after rule
        if previous_action and 'after' in rule and re.match(
                rule['after'], previous_action):
            return Rules._swap_intent_after(parse_data, rule)

        # for a general substitution
        elif 'after' not in rule:
            if rule['intent'] is None or parse_data['intent']['name'] is None:
                return
            if (rule['intent'] is None and parse_data['intent']['name'] is None) \
                    or (re.match(rule['intent'], parse_data['intent']['name'])):
                return Rules.swap_intent_with(parse_data, rule)

    @staticmethod
    def _swap_intent_after(parse_data, rule):
        rule['unless'] = rule['unless'] if 'unless' in rule else []
        if parse_data['intent']['name'] is None or parse_data['intent'][
                'name'] not in rule['unless']:
            logger.debug("intent '{}' was replaced with '{}'".format(
                parse_data['intent']['name'], rule['intent']))
            parse_data['intent']['name'] = rule['intent']
            parse_data['intent']['confidence'] = 1.0
            parse_data.pop('intent_ranking', None)
            return True

    @staticmethod
    def swap_intent_with(parse_data, rule):
        def format(text, parse_data):
            return text.format(intent=parse_data["intent"]["name"])

        pd_copy = copy.deepcopy(parse_data)
        parse_data['intent']['name'] = rule['with']
        parse_data['intent_ranking'] = [{
            "name": rule['with'],
            "confidence": 1.0
        }]
        logger.debug("intent '{}' was replaced with '{}'".format(
            parse_data['intent']['name'], rule['with']))
        if 'entities' in rule and 'add' in rule["entities"]:
            for entity in rule["entities"]["add"]:
                if 'entities' not in parse_data:
                    parse_data['entities'] = []
                parse_data['entities'].append({
                    "entity":
                    format(entity["name"], pd_copy),
                    "value":
                    format(entity["value"], pd_copy)
                })
        return True

    def _get_previous_action(self, tracker):
        action_listen_found = False
        for i in range(len(tracker.events) - 1, -1, -1):
            if i == 0:
                return None
            if type(tracker.events[i]) is ActionExecuted \
                    and action_listen_found is False \
                    and tracker.events[i].action_name not in self.actions_to_ignore:
                return tracker.events[i].action_name

        return None

    @staticmethod
    def _load_yaml(rules_file):
        with io.open(rules_file, 'r', encoding='utf-8') as stream:
            try:
                return yaml.load(stream)
            except yaml.YAMLError as exc:
                raise ValueError(exc)

    @classmethod
    def load_from_remote(cls, endpoint):
        try:
            logger.debug("Requesting rules from server {}..."
                         "".format(endpoint.url))
            response = endpoint.request(method='get')

            if response.status_code in [204, 304]:
                logger.debug(
                    "Model server returned {} status code, indicating "
                    "that no new rules are available.".format(
                        response.status_code))
                return None
            elif response.status_code == 404:
                logger.warning(
                    "Tried to fetch rules from server but got a 404 response")
                return None
            elif response.status_code != 200:
                logger.warning(
                    "Tried to fetch rules from server, but server response "
                    "status code is {}. We'll retry later..."
                    "".format(response.status_code))
            else:
                rules = response.json()
                return Rules(rules)

        except RequestException as e:
            logger.warning(
                "Tried to fetch rules from server, but couldn't reach "
                "server. We'll retry later... Error: {}."
                "".format(e))

    @classmethod
    def load_from_file(cls, rules_file):
        try:
            logger.debug('Loading rules from the {} file'.format(rules_file))
            return Rules(Rules._load_yaml(rules_file))
        except Exception as e:
            raise e
Пример #13
0
def test_disamb_and_fallback_trigger_wrong_format():
    with pytest.raises(SchemaError) as e:
        rules = load_yaml('./tests/disambiguator/test_disambiguator8.yaml')
        disambiguator = Disambiguator(
            disamb_rule=rules['disambiguation_policy'],
            fallback_rule=rules['fallback_policy'])
Пример #14
0
def test_fallback_does_not_trigger_when_data_is_missing_in_parse_data():
    disambiguator = Disambiguator(fallback_rule=load_yaml(
        './tests/disambiguator/test_disambiguator7.yaml')['fallback_policy'])

    parse_data = {}
    assert disambiguator.should_fallback(parse_data) is False
Пример #15
0
def test_fallback_trigger_wrong_format():
    with pytest.raises(SchemaError) as e:
        Disambiguator(fallback_rule=load_yaml(
            './tests/disambiguator/test_disambiguator5.yaml')
                      ['fallback_policy'])
Пример #16
0
def test_disamb_trigger_wrong_format():
    with pytest.raises(SchemaError) as e:
        Disambiguator(disamb_rule=load_yaml(
            './tests/disambiguator/test_disambiguator3.yaml')
                      ['disambiguation_policy'])