def predict_action_probabilities(self, tracker: DialogueStateTracker, domain: Domain) -> List[float]: """Predicts the corresponding form action if there is an active form""" result = [0.0] * domain.num_actions if tracker.active_form.get('name'): logger.debug("There is an active form '{}'" "".format(tracker.active_form['name'])) if tracker.latest_action_name == ACTION_LISTEN_NAME: # predict form action after user utterance if tracker.active_form.get('rejected'): if self.state_is_unhappy(tracker, domain): tracker.update(FormValidation(False)) return result idx = domain.index_for_action(tracker.active_form['name']) result[idx] = FORM_SCORE elif tracker.latest_action_name == tracker.active_form.get('name'): # predict action_listen after form action idx = domain.index_for_action(ACTION_LISTEN_NAME) result[idx] = FORM_SCORE else: logger.debug("There is no active form") return result
def run(self, dispatcher, tracker: DialogueStateTracker, domain: TemplateDomain): ents = { e["entity"]: e["value"] for e in tracker.latest_message.entities } # add new entities to memory ActionFalloutSlots._update_memory_with_new(ents) # "forgetting procedure" to_remove = ActionFalloutSlots._countdown_memory(tracker) # for debugging # print("-- memory: " + str(ActionFalloutSlots._memory)) # print("-- intent: " + str(tracker.latest_message.intent["name"])) # print("-- stuff to pop: " + str(to_remove)) # print("-- what slots we have: " + str(extract_non_empty_slots(tracker))) # this action is always followed by listen tracker.trigger_follow_up_action( domain.action_map[ACTION_LISTEN_NAME][1]) # clean slots return [SlotSet("matches", None)] + to_remove
def test_missing_classes_filled_correctly( self, default_domain, trackers, tracker, featurizer): # Pretend that a couple of classes are missing and check that # those classes are predicted as 0, while the other class # probabilities are predicted normally. policy = self.create_policy(featurizer=featurizer, cv=None) classes = [1, 3] new_trackers = [] for tr in trackers: new_tracker = DialogueStateTracker(UserMessage.DEFAULT_SENDER_ID, default_domain.slots) for e in tr.applied_events(): if isinstance(e, ActionExecuted): new_action = default_domain.action_for_index( np.random.choice(classes)).name() new_tracker.update(ActionExecuted(new_action)) else: new_tracker.update(e) new_trackers.append(new_tracker) policy.train(new_trackers, domain=default_domain) predicted_probabilities = policy.predict_action_probabilities( tracker, default_domain) assert len(predicted_probabilities) == default_domain.num_actions assert np.allclose(sum(predicted_probabilities), 1.0) for i, prob in enumerate(predicted_probabilities): if i in classes: assert prob >= 0.0 else: assert prob == 0.0
def tracker_from_dialogue_file(filename: Text, domain: Domain = None): dialogue = read_dialogue_file(filename) if not domain: domain = Domain.load(DEFAULT_DOMAIN_PATH) tracker = DialogueStateTracker(dialogue.name, domain.slots) tracker.recreate_from_dialogue(dialogue) return tracker
def run(self, dispatcher, tracker: DialogueStateTracker, domain): anz_kurs = tracker.get_slot("anz_kuerse") change_freq = tracker.get_slot("change_freq") eingabe_kanal = "online" if int(anz_kurs) > 50: if change_freq in ["täglich", "wöchentlich"]: eingabe_kanal = "xml" return [SlotSet(key="empfohlenes_kanal", value=eingabe_kanal)]
def tracker_from_dialogue_file(filename, domain=None): dialogue = read_dialogue_file(filename) if domain is not None: domain = domain else: domain = Domain.load(DEFAULT_DOMAIN_PATH) tracker = DialogueStateTracker(dialogue.name, domain.slots) tracker.recreate_from_dialogue(dialogue) return tracker
def tracker_from_dialogue_file(filename, domain=None): dialogue = read_dialogue_file(filename) if domain is not None: domain = domain else: domain = TemplateDomain.load(DEFAULT_DOMAIN_PATH) tracker = DialogueStateTracker(dialogue.name, domain.slots) tracker.recreate_from_dialogue(dialogue) return tracker
def log_bot_utterances_on_tracker(tracker: DialogueStateTracker, dispatcher: Dispatcher) -> None: if dispatcher.latest_bot_messages: for m in dispatcher.latest_bot_messages: bot_utterance = BotUttered(text=m.text, data=m.data) logger.debug("Bot utterance '{}'".format(bot_utterance)) tracker.update(bot_utterance) dispatcher.latest_bot_messages = []
def test_memorise_with_nlu(self, trained_policy, default_domain): filename = "data/test_dialogues/nlu_dialogue.json" dialogue = read_dialogue_file(filename) tracker = DialogueStateTracker(dialogue.name, default_domain.slots) tracker.recreate_from_dialogue(dialogue) states = trained_policy.featurizer.prediction_states([tracker], default_domain)[0] recalled = trained_policy.recall(states, tracker, default_domain) assert recalled is not None
def test_memorise_with_nlu(self, trained_policy, default_domain): filename = "data/test_dialogues/nlu_dialogue.json" dialogue = read_dialogue_file(filename) tracker = DialogueStateTracker(dialogue.name, default_domain.slots) tracker.recreate_from_dialogue(dialogue) states = trained_policy.featurizer.prediction_states([tracker], default_domain)[0] recalled = trained_policy.recall(states, tracker, default_domain) assert recalled is not None
def run(self, dispatcher, tracker: DialogueStateTracker, domain): kanal = tracker.get_slot("empfohlenes_kanal") auf_ok = tracker.get_slot("aufwand_ok") verfahren = "redaktion" if auf_ok == "ja": verfahren = kanal elif auf_ok == "nein": if kanal == "xml": verfahren = "extern" return [SlotSet(key="empfohlenes_verfahren", value=verfahren)]
def get_user_tracker(user_id): red = redis.StrictRedis() dialogue = _pickle.loads(red.get(user_id)) domain = OntologyDomain.load(env_config['domain']) tracker = DialogueStateTracker(sender_id=dialogue.name, sys_goals=domain.sys_goals, usr_slots=domain.usr_slots, methods=domain.methods, processes=domain.processes) tracker.recreate_from_dialogue(dialogue) return tracker
def test_tracker_duplicate(): filename = "data/test_dialogues/inform_no_change.json" dialogue = read_dialogue_file(filename) tracker = DialogueStateTracker(dialogue.name, domain.slots) tracker.recreate_from_dialogue(dialogue) num_actions = len([event for event in dialogue.events if isinstance(event, ActionExecuted)]) # There is always one duplicated tracker more than we have actions, # as the tracker also gets duplicated for the # action that would be next (but isn't part of the operations) assert len(list(tracker.generate_all_prior_trackers())) == num_actions + 1
def test_tracker_duplicate(): filename = "data/test_dialogues/inform_no_change.json" dialogue = read_dialogue_file(filename) tracker = DialogueStateTracker(dialogue.name, domain.slots) tracker.recreate_from_dialogue(dialogue) num_actions = len([event for event in dialogue.events if isinstance(event, ActionExecuted)]) # There is always one duplicated tracker more than we have actions, # as the tracker also gets duplicated for the # action that would be next (but isn't part of the operations) assert len(list(tracker.generate_all_prior_trackers())) == num_actions + 1
def tracker_from_dialogue_file(filename, domain=None): dialogue = read_dialogue_file(filename) dialogue_topics = set( [Topic(t.topic) for t in dialogue.events if isinstance(t, TopicSet)]) if domain is not None: domain = domain else: domain = TemplateDomain.load(DEFAULT_DOMAIN_PATH) domain.topics.extend(dialogue_topics) tracker = DialogueStateTracker(dialogue.name, domain.slots, domain.topics, domain.default_topic) tracker.recreate_from_dialogue(dialogue) return tracker
def predict_action_probabilities(self, tracker: DialogueStateTracker, domain: Domain) -> List[float]: """Predicts the next action if NLU confidence is low. """ if (USER_INTENT_AFFIRM not in domain.intents or USER_INTENT_DENY not in domain.intents): raise InvalidDomain('The intents {} and {} must be present in the ' 'domain file to use this policy.'.format( USER_INTENT_AFFIRM, USER_INTENT_AFFIRM)) nlu_data = tracker.latest_message.parse_data nlu_confidence = nlu_data["intent"].get("confidence", 1.0) last_intent_name = nlu_data['intent'].get('name', None) should_fallback = self.should_fallback(nlu_confidence, tracker.latest_action_name) user_rephrased = has_user_rephrased(tracker) if self._is_user_input_expected(tracker): result = confidence_scores_for(ACTION_LISTEN_NAME, FALLBACK_SCORE, domain) elif _has_user_denied(last_intent_name, tracker): logger.debug("User '{}' denied suggested intents.".format( tracker.sender_id)) result = self._results_for_user_denied(tracker, domain) elif user_rephrased and should_fallback: logger.debug("Ambiguous rephrasing of user '{}' " "for intent '{}'".format(tracker.sender_id, last_intent_name)) result = confidence_scores_for(ACTION_DEFAULT_ASK_AFFIRMATION_NAME, FALLBACK_SCORE, domain) elif has_user_affirmed(last_intent_name, tracker) or user_rephrased: logger.debug("User '{}' affirmed intent by affirmation or " "rephrasing.".format(tracker.sender_id)) result = confidence_scores_for(ACTION_REVERT_FALLBACK_EVENTS_NAME, FALLBACK_SCORE, domain) elif tracker.last_executed_action_has( ACTION_DEFAULT_ASK_AFFIRMATION_NAME): if not should_fallback: logger.debug("User '{}' rephrased intent '{}' instead " "of affirming.".format(tracker.sender_id, last_intent_name)) result = confidence_scores_for( ACTION_REVERT_FALLBACK_EVENTS_NAME, FALLBACK_SCORE, domain) else: result = confidence_scores_for(self.fallback_action_name, FALLBACK_SCORE, domain) elif should_fallback: logger.debug("User '{}' has to affirm intent '{}'.".format( tracker.sender_id, last_intent_name)) result = confidence_scores_for(ACTION_DEFAULT_ASK_AFFIRMATION_NAME, FALLBACK_SCORE, domain) else: result = self.fallback_scores(domain, self.core_threshold) return result
def predict_action_probabilities(self, tracker: DialogueStateTracker, domain: Domain) -> List[float]: """Predicts the assigned action. If the current intent is assigned to an action that action will be predicted with the highest probability of all policies. If it is not the policy will predict zero for every action.""" prediction = [0.0] * domain.num_actions intent = tracker.latest_message.intent.get('name') action = domain.intent_properties.get(intent, {}).get('triggers') if tracker.latest_action_name == ACTION_LISTEN_NAME: if action: idx = domain.index_for_action(action) if idx is None: logger.warning("MappingPolicy tried to predict unkown " "action '{}'.".format(action)) else: prediction[idx] = 1 elif intent == USER_INTENT_RESTART: idx = domain.index_for_action(ACTION_RESTART_NAME) prediction[idx] = 1 elif intent == USER_INTENT_BACK: idx = domain.index_for_action(ACTION_BACK_NAME) prediction[idx] = 1 elif tracker.latest_action_name == action and action is not None: latest_action = tracker.get_last_event_for(ActionExecuted) assert latest_action.action_name == action if latest_action.policy == type(self).__name__: # this ensures that we only predict listen, if we predicted # the mapped action idx = domain.index_for_action(ACTION_LISTEN_NAME) prediction[idx] = 1 return prediction
def retrieve(self, sender_id): stored = self.conversations.find_one({"sender_id": sender_id}) # look for conversations which have used an `int` sender_id in the past # and update them. if stored is None and sender_id.isdigit(): from pymongo import ReturnDocument stored = self.conversations.find_one_and_update( {"sender_id": int(sender_id)}, {"$set": { "sender_id": str(sender_id) }}, return_document=ReturnDocument.AFTER) if stored is not None: if self.domain: return DialogueStateTracker.from_dict(sender_id, stored.get("events"), self.domain.slots) else: logger.warning("Can't recreate tracker from mongo storage " "because no domain is set. Returning `None` " "instead.") return None else: return None
async def continue_training(request: Request): epochs = request.raw_args.get("epochs", 30) batch_size = request.raw_args.get("batch_size", 5) request_params = request.json sender_id = UserMessage.DEFAULT_SENDER_ID try: tracker = DialogueStateTracker.from_dict(sender_id, request_params, app.agent.domain.slots) except Exception as e: raise ErrorResponse(400, "InvalidParameter", "Supplied events are not valid. {}".format(e), {"parameter": "", "in": "body"}) try: # Fetches the appropriate bot response in a json format app.agent.continue_training([tracker], epochs=epochs, batch_size=batch_size) return response.text('', 204) except Exception as e: logger.exception("Caught an exception during prediction.") raise ErrorResponse(500, "TrainingException", "Server failure. Error: {}".format(e))
def continue_training(): request.headers.get("Accept") epochs = request.args.get("epochs", 30) batch_size = request.args.get("batch_size", 5) request_params = request.get_json(force=True) sender_id = UserMessage.DEFAULT_SENDER_ID try: tracker = DialogueStateTracker.from_dict(sender_id, request_params, agent.domain.slots) except Exception as e: return error(400, "InvalidParameter", "Supplied events are not valid. {}".format(e), { "parameter": "", "in": "body" }) try: # Fetches the appropriate bot response in a json format agent.continue_training([tracker], epochs=epochs, batch_size=batch_size) return '', 204 except Exception as e: logger.exception("Caught an exception during prediction.") return error(500, "TrainingException", "Server failure. Error: {}".format(e))
def _is_in_training_data(tracker, agent, fail_on_prediction_errors=False): processor = agent.create_processor() events = list(tracker.events) partial_tracker = DialogueStateTracker.from_events(tracker.sender_id, events[:1], agent.domain.slots) in_training_data = True test_in_training_data = True for event in events[1:]: if isinstance(event, ActionExecuted): _, policy = \ _collect_action_executed_predictions( processor, partial_tracker, event, fail_on_prediction_errors ) if (test_in_training_data and policy is not None and SimplePolicyEnsemble.is_not_memo_policy(policy)): in_training_data = False test_in_training_data = False return in_training_data
def _predict_tracker_actions(tracker, agent, fail_on_prediction_errors=False): processor = agent.create_processor() golds = [] predictions = [] events = list(tracker.events) partial_tracker = DialogueStateTracker.from_events(tracker.sender_id, events[:1], agent.domain.slots) for event in events[1:]: if isinstance(event, ActionExecuted): action, _, _ = processor.predict_next_action(partial_tracker) predicted = action.name() gold = event.action_name predictions.append(predicted) golds.append(gold) if predicted != gold: partial_tracker.update(WronglyPredictedAction(gold, predicted)) if fail_on_prediction_errors: raise ValueError( "Model predicted a wrong action. Failed Story: " "\n\n{}".format(partial_tracker.export_stories())) else: partial_tracker.update(event) else: partial_tracker.update(event) return golds, predictions, partial_tracker
def test_restart(default_dispatcher, default_domain): tracker = DialogueStateTracker("default", default_domain.slots, default_domain.topics, default_domain.default_topic) events = ActionRestart().run(default_dispatcher, tracker, default_domain) assert events == [Restarted()]
def _predict_tracker_actions(tracker, agent, fail_on_prediction_errors=False, use_e2e=False): processor = agent.create_processor() tracker_eval_store = EvaluationStore() events = list(tracker.events) partial_tracker = DialogueStateTracker.from_events(tracker.sender_id, events[:1], agent.domain.slots) for event in events[1:]: if isinstance(event, ActionExecuted): action_executed_result = \ _collect_action_executed_predictions( processor, partial_tracker, event, fail_on_prediction_errors ) tracker_eval_store.merge_store(action_executed_result) elif use_e2e and isinstance(event, UserUttered): user_uttered_result = \ _collect_user_uttered_predictions( event, partial_tracker, fail_on_prediction_errors) tracker_eval_store.merge_store(user_uttered_result) else: partial_tracker.update(event) return tracker_eval_store, partial_tracker
def tracker_predict(): """ Given a list of events, predicts the next action""" sender_id = UserMessage.DEFAULT_SENDER_ID request_params = request.get_json(force=True) verbosity = event_verbosity_parameter(EventVerbosity.AFTER_RESTART) try: tracker = DialogueStateTracker.from_dict(sender_id, request_params, agent.domain.slots) except Exception as e: return error(400, "InvalidParameter", "Supplied events are not valid. {}".format(e), { "parameter": "", "in": "body" }) policy_ensemble = agent.policy_ensemble probabilities, policy = \ policy_ensemble.probabilities_using_best_policy(tracker, agent.domain) scores = [{ "action": a, "score": p } for a, p in zip(agent.domain.action_names, probabilities)] return jsonify({ "scores": scores, "policy": policy, "tracker": tracker.current_state(verbosity) })
def test_remote_action_logs_events(default_dispatcher_collecting, default_domain): tracker = DialogueStateTracker("default", default_domain.slots) endpoint = EndpointConfig("https://abc.defg/webhooks/actions") remote_action = action.RemoteAction("my_action", endpoint) response = { "events": [ {"event": "slot", "value": "rasa", "name": "name"}], "responses": [{"text": "test text", "buttons": [{"title": "cheap", "payload": "cheap"}]}, {"template": "utter_greet"}]} httpretty.register_uri( httpretty.POST, 'https://abc.defg/webhooks/actions', body=json.dumps(response)) httpretty.enable() events = remote_action.run(default_dispatcher_collecting, tracker, default_domain) httpretty.disable() assert (httpretty.latest_requests[-1].path == "/webhooks/actions") b = httpretty.latest_requests[-1].body.decode("utf-8") assert json.loads(b) == { 'domain': default_domain.as_dict(), 'next_action': 'my_action', 'sender_id': 'default', 'tracker': { 'latest_message': { 'entities': [], 'intent': {}, 'text': None }, 'sender_id': 'default', 'paused': False, 'followup_action': 'action_listen', 'latest_event_time': None, 'slots': {'name': None}, 'events': [], 'latest_input_channel': None } } assert events == [SlotSet("name", "rasa")] channel = default_dispatcher_collecting.output_channel assert channel.messages == [ {"text": "test text", "recipient_id": "my-sender", "buttons": [{"title": "cheap", "payload": "cheap"}]}, {"text": "hey there None!", "recipient_id": "my-sender"}]
def _is_reminder_still_valid(tracker: DialogueStateTracker, reminder_event: ReminderScheduled) -> bool: """Check if the conversation has been restarted after reminder.""" for e in reversed(tracker.applied_events()): if MessageProcessor._is_reminder(e, reminder_event.name): return True return False # not found in applied events --> has been restarted
def test_persist_and_load(self, trained_policy, default_domain, tmpdir): trained_policy.persist(tmpdir.strpath) loaded = trained_policy.__class__.load(tmpdir.strpath, trained_policy.featurizer, trained_policy.max_history) stories = extract_stories_from_file( DEFAULT_STORIES_FILE, default_domain) for story in stories: tracker = DialogueStateTracker("default", default_domain.slots) dialogue = story.as_dialogue("default", default_domain) tracker.recreate_from_dialogue(dialogue) predicted_probabilities = loaded.predict_action_probabilities( tracker, default_domain) actual_probabilities = trained_policy.predict_action_probabilities( tracker, default_domain) assert predicted_probabilities == actual_probabilities
def init_tracker(self, sender_id): if self.domain: return DialogueStateTracker( sender_id, self.domain.slots, max_event_history=self.max_event_history) else: return None
def test_prediction_on_empty_tracker(self, trained_policy, default_domain): tracker = DialogueStateTracker(UserMessage.DEFAULT_SENDER_ID, default_domain.slots) probabilities = trained_policy.predict_action_probabilities( tracker, default_domain) assert len(probabilities) == default_domain.num_actions assert max(probabilities) <= 1.0 assert min(probabilities) >= 0.0
def generate(self, template_name: Text, tracker: DialogueStateTracker, output_channel: Text, **kwargs: Any) -> Optional[Dict[Text, Any]]: """Generate a response for the requested template.""" filled_slots = tracker.current_slot_values() return self.generate_from_slots(template_name, filled_slots, output_channel, **kwargs)
def load_tracker_from_json(tracker_dump, domain): # type: (Text, Agent) -> DialogueStateTracker """Read the json dump from the file and instantiate a tracker it.""" tracker_json = json.loads(utils.read_file(tracker_dump)) sender_id = tracker_json.get("sender_id", UserMessage.DEFAULT_SENDER_ID) return DialogueStateTracker.from_dict(sender_id, tracker_json.get("events", []), domain)
def replace_events(sender_id): """Use a list of events to set a conversations tracker to a state.""" request_params = request.get_json(force=True) tracker = DialogueStateTracker.from_dict(sender_id, request_params, agent.domain.slots) # will override an existing tracker with the same id! agent.tracker_store.save(tracker) return jsonify(tracker.current_state(EventVerbosity.AFTER_RESTART))
def test_action(): domain = Domain.load('domain.yml') nlg = TemplatedNaturalLanguageGenerator(domain.templates) dispatcher = Dispatcher("my-sender", CollectingOutputChannel(), nlg) uid = str(uuid.uuid1()) tracker = DialogueStateTracker(uid, domain.slots) # print ("dispatcher,uid,tracker ===", dispatcher, uid, tracker) action = QuoraSearch() action.run(dispatcher, tracker, domain)
def load_tracker_from_json(tracker_dump: Text, domain: Domain) -> DialogueStateTracker: """Read the json dump from the file and instantiate a tracker it.""" tracker_json = json.loads(utils.read_file(tracker_dump)) sender_id = tracker_json.get("sender_id", UserMessage.DEFAULT_SENDER_ID) return DialogueStateTracker.from_dict(sender_id, tracker_json.get("events", []), domain.slots)
def replace_events(sender_id): """Use a list of events to set a conversations tracker to a state.""" request_params = request.get_json(force=True) tracker = DialogueStateTracker.from_dict(sender_id, request_params, agent.domain.slots) # will override an existing tracker with the same id! agent.tracker_store.save(tracker) return jsonify(tracker.current_state(should_include_events=True))
def generate_response(nlg_call, domain): kwargs = nlg_call.get("arguments", {}) template = nlg_call.get("template") sender_id = nlg_call.get("tracker", {}).get("sender_id") events = nlg_call.get("tracker", {}).get("events") tracker = DialogueStateTracker.from_dict( sender_id, events, domain.slots) channel_name = nlg_call.get("channel") return TemplatedNaturalLanguageGenerator(domain.templates).generate( template, tracker, channel_name, **kwargs)
def retrieve(self, sender_id): stored = self.conversations.find_one({"sender_id": sender_id}) if stored is not None: if self.domain: return DialogueStateTracker.from_dict(sender_id, stored.get("events"), self.domain.slots) else: logger.warning("Can't recreate tracker from mongo storage " "because no domain is set. Returning `None` " "instead.") return None else: return None
def tracker(self, sender_id, # type: Text domain, # type: Domain only_events_after_latest_restart=False, # type: bool include_events=True, # type: bool until=None # type: Optional[int] ): """Retrieve and recreate a tracker fetched from the remote instance.""" tracker_json = self.tracker_json( sender_id, only_events_after_latest_restart, include_events, until) tracker = DialogueStateTracker.from_dict( sender_id, tracker_json.get("events", []), domain) return tracker
def tracker_predict(): """ Given a list of events, predicts the next action""" sender_id = UserMessage.DEFAULT_SENDER_ID request_params = request.get_json(force=True) for param in request_params: if param.get('event', None) is None: return Response( """Invalid list of events provided.""", status=400) tracker = DialogueStateTracker.from_dict(sender_id, request_params, agent.domain.slots) policy_ensemble = agent.policy_ensemble probabilities = policy_ensemble.probabilities_using_best_policy(tracker, agent.domain) probability_dict = {agent.domain.action_for_index(idx, agent.action_endpoint).name(): probability for idx, probability in enumerate(probabilities)} return jsonify(probability_dict)
def test_tracker_entity_retrieval(default_domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 assert list(tracker.get_latest_entity_values("entity_name")) == [] intent = {"name": "greet", "confidence": 1.0} tracker.update(UserUttered("/greet", intent, [{ "start": 1, "end": 5, "value": "greet", "entity": "entity_name", "extractor": "manual" }])) assert list(tracker.get_latest_entity_values("entity_name")) == ["greet"] assert list(tracker.get_latest_entity_values("unknown")) == []
def continue_training(): request.headers.get("Accept") epochs = request.args.get("epochs", 30) batch_size = request.args.get("batch_size", 5) request_params = request.get_json(force=True) tracker = DialogueStateTracker.from_dict(UserMessage.DEFAULT_SENDER_ID, request_params, agent.domain.slots) try: # Fetches the appropriate bot response in a json format agent.continue_training([tracker], epochs=epochs, batch_size=batch_size) return '', 204 except Exception as e: logger.exception("Caught an exception during prediction.") return Response(jsonify(error="Server failure. Error: {}" "".format(e)), status=500, content_type="application/json")
def test_traveling_back_in_time(default_domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 intent = {"name": "greet", "confidence": 1.0} tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) tracker.update(UserUttered("/greet", intent, [])) import time time.sleep(1) time_for_timemachine = time.time() time.sleep(1) tracker.update(ActionExecuted("my_action")) tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) # Expecting count of 4: # +3 executed actions # +1 final state assert tracker.latest_action_name == ACTION_LISTEN_NAME assert len(tracker.events) == 4 assert len(list(tracker.generate_all_prior_trackers())) == 4 tracker = tracker.travel_back_in_time(time_for_timemachine) # Expecting count of 2: # +1 executed actions # +1 final state assert tracker.latest_action_name == ACTION_LISTEN_NAME assert len(tracker.events) == 2 assert len(list(tracker.generate_all_prior_trackers())) == 2
def test_revert_action_event(default_domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 intent = {"name": "greet", "confidence": 1.0} tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) tracker.update(UserUttered("/greet", intent, [])) tracker.update(ActionExecuted("my_action")) tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) # Expecting count of 4: # +3 executed actions # +1 final state assert tracker.latest_action_name == ACTION_LISTEN_NAME assert len(list(tracker.generate_all_prior_trackers())) == 4 tracker.update(ActionReverted()) # Expecting count of 3: # +3 executed actions # +1 final state # -1 reverted action assert tracker.latest_action_name == "my_action" assert len(list(tracker.generate_all_prior_trackers())) == 3 dialogue = tracker.as_dialogue() recovered = DialogueStateTracker("default", default_domain.slots) recovered.recreate_from_dialogue(dialogue) assert recovered.current_state() == tracker.current_state() assert tracker.latest_action_name == "my_action" assert len(list(tracker.generate_all_prior_trackers())) == 3
def test_revert_user_utterance_event(default_domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 intent1 = {"name": "greet", "confidence": 1.0} tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) tracker.update(UserUttered("/greet", intent1, [])) tracker.update(ActionExecuted("my_action_1")) tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) intent2 = {"name": "goodbye", "confidence": 1.0} tracker.update(UserUttered("/goodbye", intent2, [])) tracker.update(ActionExecuted("my_action_2")) tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) # Expecting count of 6: # +5 executed actions # +1 final state assert tracker.latest_action_name == ACTION_LISTEN_NAME assert len(list(tracker.generate_all_prior_trackers())) == 6 tracker.update(UserUtteranceReverted()) # Expecting count of 3: # +5 executed actions # +1 final state # -2 rewound actions associated with the /goodbye # -1 rewound action from the listen right before /goodbye assert tracker.latest_action_name == "my_action_1" assert len(list(tracker.generate_all_prior_trackers())) == 3 dialogue = tracker.as_dialogue() recovered = DialogueStateTracker("default", default_domain.slots) recovered.recreate_from_dialogue(dialogue) assert recovered.current_state() == tracker.current_state() assert tracker.latest_action_name == "my_action_1" assert len(list(tracker.generate_all_prior_trackers())) == 3
def test_restart_event(default_domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 intent = {"name": "greet", "confidence": 1.0} tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) tracker.update(UserUttered("/greet", intent, [])) tracker.update(ActionExecuted("my_action")) tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) assert len(tracker.events) == 4 assert tracker.latest_message.text == "/greet" assert len(list(tracker.generate_all_prior_trackers())) == 4 tracker.update(Restarted()) assert len(tracker.events) == 5 assert tracker.follow_up_action is not None assert tracker.follow_up_action.name() == ACTION_LISTEN_NAME assert tracker.latest_message.text is None assert len(list(tracker.generate_all_prior_trackers())) == 1 dialogue = tracker.as_dialogue() recovered = DialogueStateTracker("default", default_domain.slots) recovered.recreate_from_dialogue(dialogue) assert recovered.current_state() == tracker.current_state() assert len(recovered.events) == 5 assert tracker.follow_up_action is not None assert tracker.follow_up_action.name() == ACTION_LISTEN_NAME assert recovered.latest_message.text is None assert len(list(recovered.generate_all_prior_trackers())) == 1