def test_revert_action_event(default_domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 intent = {"name": "greet", "confidence": 1.0} tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) tracker.update(UserUttered("/greet", intent, [])) tracker.update(ActionExecuted("my_action")) tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) # Expecting count of 4: # +3 executed actions # +1 final state assert tracker.latest_action_name == ACTION_LISTEN_NAME assert len(list(tracker.generate_all_prior_trackers())) == 4 tracker.update(ActionReverted()) # Expecting count of 3: # +3 executed actions # +1 final state # -1 reverted action assert tracker.latest_action_name == "my_action" assert len(list(tracker.generate_all_prior_trackers())) == 3 dialogue = tracker.as_dialogue() recovered = DialogueStateTracker("default", default_domain.slots) recovered.recreate_from_dialogue(dialogue) assert recovered.current_state() == tracker.current_state() assert tracker.latest_action_name == "my_action" assert len(list(tracker.generate_all_prior_trackers())) == 3
def test_restart_event(default_domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 intent = {"name": "greet", "confidence": 1.0} tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) tracker.update(UserUttered("/greet", intent, [])) tracker.update(ActionExecuted("my_action")) tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) assert len(tracker.events) == 4 assert tracker.latest_message.text == "/greet" assert len(list(tracker.generate_all_prior_trackers())) == 4 tracker.update(Restarted()) assert len(tracker.events) == 5 assert tracker.followup_action is not None assert tracker.followup_action == ACTION_LISTEN_NAME assert tracker.latest_message.text is None assert len(list(tracker.generate_all_prior_trackers())) == 1 dialogue = tracker.as_dialogue() recovered = DialogueStateTracker("default", default_domain.slots) recovered.recreate_from_dialogue(dialogue) assert recovered.current_state() == tracker.current_state() assert len(recovered.events) == 5 assert recovered.latest_message.text is None assert len(list(recovered.generate_all_prior_trackers())) == 1
def tracker_from_dialogue_file(filename: Text, domain: Domain = None): dialogue = read_dialogue_file(filename) if not domain: domain = Domain.load(DEFAULT_DOMAIN_PATH) tracker = DialogueStateTracker(dialogue.name, domain.slots) tracker.recreate_from_dialogue(dialogue) return tracker
def test_memorise_with_nlu(self, trained_policy, default_domain): filename = "data/test_dialogues/default.json" dialogue = read_dialogue_file(filename) tracker = DialogueStateTracker(dialogue.name, default_domain.slots) tracker.recreate_from_dialogue(dialogue) states = trained_policy.featurizer.prediction_states([tracker], default_domain)[0] recalled = trained_policy.recall(states, tracker, default_domain) assert recalled is not None
def tracker_from_dialogue_file( filename: Text, domain: Optional[Domain] = None) -> DialogueStateTracker: dialogue = read_dialogue_file(filename) if not domain: domain = Domain.load(DEFAULT_DOMAIN_PATH_WITH_SLOTS) tracker = DialogueStateTracker(dialogue.name, domain.slots) tracker.recreate_from_dialogue(dialogue) return tracker
def test_tracker_duplicate(): filename = "data/test_dialogues/moodbot.json" dialogue = read_dialogue_file(filename) tracker = DialogueStateTracker(dialogue.name, domain.slots) tracker.recreate_from_dialogue(dialogue) num_actions = len( [event for event in dialogue.events if isinstance(event, ActionExecuted)] ) # There is always one duplicated tracker more than we have actions, # as the tracker also gets duplicated for the # action that would be next (but isn't part of the operations) assert len(list(tracker.generate_all_prior_trackers())) == num_actions + 1
def test_tracker_duplicate(): filename = "{}/data/test_dialogues/moodbot.json".format(PRJ_DIR) dialogue = read_dialogue_file(filename) tracker = DialogueStateTracker(dialogue.name, domain.slots) tracker.recreate_from_dialogue(dialogue) num_actions = len( [event for event in dialogue.events if isinstance(event, ActionExecuted)] ) events = [event for event in dialogue.events if isinstance(event, ActionExecuted)] viz_events(dialogue.events) # print(type(events[0]).__name__) # exit() # viz_events(tracker.events) # viz_tracker(tracker, v_domain=True) # There is always one duplicated tracker more than we have actions, # as the tracker also gets duplicated for the # action that would be next (but isn't part of the operations) assert len(list(tracker.generate_all_prior_trackers())) == num_actions + 1 # print(list(tracker.generate_all_prior_trackers())) viz_trackers(list(tracker.generate_all_prior_trackers()))
def test_revert_user_utterance_event(default_domain): tracker = DialogueStateTracker("default", default_domain.slots) # the retrieved tracker should be empty assert len(tracker.events) == 0 intent1 = {"name": "greet", "confidence": 1.0} tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) tracker.update(UserUttered("/greet", intent1, [])) tracker.update(ActionExecuted("my_action_1")) tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) intent2 = {"name": "goodbye", "confidence": 1.0} tracker.update(UserUttered("/goodbye", intent2, [])) tracker.update(ActionExecuted("my_action_2")) tracker.update(ActionExecuted(ACTION_LISTEN_NAME)) # Expecting count of 6: # +5 executed actions # +1 final state assert tracker.latest_action_name == ACTION_LISTEN_NAME assert len(list(tracker.generate_all_prior_trackers())) == 6 tracker.update(UserUtteranceReverted()) # Expecting count of 3: # +5 executed actions # +1 final state # -2 rewound actions associated with the /goodbye # -1 rewound action from the listen right before /goodbye assert tracker.latest_action_name == "my_action_1" assert len(list(tracker.generate_all_prior_trackers())) == 3 dialogue = tracker.as_dialogue() recovered = DialogueStateTracker("default", default_domain.slots) recovered.recreate_from_dialogue(dialogue) assert recovered.current_state() == tracker.current_state() assert tracker.latest_action_name == "my_action_1" assert len(list(tracker.generate_all_prior_trackers())) == 3