예제 #1
0
def test_enum_tuple_state():
    class MyStates(Enum):
        A = 0
        B = 1
    s1 = State(MyStates.B)
    s2 = State('MyStates.B')
    s3 = State(('hello', 'MyStates.B'))
    assert s3[1] == s2
    assert s3[1] == s1
예제 #2
0
def test_enum_tuple_state():
    class MyStates(Enum):
        A = 0
        B = 1

    s1 = State(MyStates.B)
    s2 = State("MyStates.B")
    s3 = State(("hello", "MyStates.B"))
    assert s3[1] == s2
    assert s3[1] == s1
예제 #3
0
 def set_transition_natex(self, source, target, speaker, natex):
     source, target = module_source_target(source, target)
     source = State(source)
     target = State(target)
     if isinstance(natex, str):
         if speaker == Speaker.USER:
             natex = NatexNLU(natex, macros=self._macros)
         else:
             natex = NatexNLG(natex, macros=self._macros)
     self._graph.arc_data(source, target, speaker)['natex'] = natex
예제 #4
0
 def add_state(self, state: Union[Enum, str, tuple], error_successor: Union[Union[Enum, str, tuple], None] =None, **settings):
     state = module_state(state)
     state = State(state)
     if self.has_state(state):
         raise ValueError('state {} already exists'.format(state))
     state_settings = Settings(user_multi_hop=False, system_multi_hop=False, switch=False, enter=None)
     state_settings.update(**settings)
     self._graph.add_node(state)
     self.update_state_settings(state, **state_settings)
     if error_successor is not None:
         error_successor = State(error_successor)
         self.set_error_successor(state, error_successor)
예제 #5
0
def test_enum_string_state():
    class MyStates(Enum):
        A = 0
        B = 1
    s1 = State(MyStates.A)
    s2 = State('MyStates.A')
    assert s1 == s2
    assert s1 == 'MyStates.A'
    assert s1 == MyStates.A
    assert s2 == s1
    assert s2 == 'MyStates.A'
    assert s2 == MyStates.A
예제 #6
0
 def error_successor(self, state):
     state = module_state(state)
     state = State(state)
     data = self._graph.data(state)
     if 'error' in data:
         return data['error']
     else:
         return None
예제 #7
0
 def update_state_settings(self, state, **settings):
     state = module_state(state)
     state = State(state)
     if 'settings' not in self._graph.data(state):
         self._graph.data(state)['settings'] = Settings()
     if 'global_nlu' in settings:
         self.add_global_nlu(state, settings['global_nlu'])
     if 'enter' in settings and isinstance(settings['enter'], str):
         settings['enter'] = NatexNLG(settings['enter'], macros=self._macros)
     self.state_settings(state).update(**settings)
예제 #8
0
 def add_global_nlu(self, state, nlu, score=0.5, post_nlu=None):
     state = module_state(state)
     state = State(state)
     if not self.has_state(state):
         self.add_state(state)
     if isinstance(state, tuple):
         state = ':'.join(state)
     if isinstance(nlu, list) or isinstance(nlu, set):
         nlu = '{' + ', '.join(nlu) + '}'
     if post_nlu is None:
         self._rules.add('{} ({})'.format(nlu, score), '#TRANSITION({}, {})'.format(state, score))
     else:
         self._rules.add('{} ({})'.format(nlu, score), '#TRANSITION({}, {}, {})'.format(state, score, post_nlu))
예제 #9
0
 def set_state(self, state: Union[Enum, str, tuple]):
     state = module_state(state)
     state = State(state)
     if self.speaker() == Speaker.SYSTEM:
         if '__state__' in self.vars():
             st_str = self.vars()['__state__'][1] if isinstance(self.vars()['__state__'],tuple) else self.vars()['__state__']
             if not st_str.startswith('_'):
                 self.vars()['__system_state__'] = self.vars()['__state__']
             if '__system_state__' not in self.vars():
                 self.vars()['__system_state__'] = 'None'
         else:
             self.vars()['__system_state__'] = 'None'
     self._vars['__state__'] = state
예제 #10
0
 def add_system_transition(self, source: Union[Enum, str, tuple], target: Union[Enum, str, tuple],
                           natex_nlg: Union[str, NatexNLG, List[str]], **settings):
     source, target = module_source_target(source, target)
     source = State(source)
     target = State(target)
     if self.has_transition(source, target, Speaker.SYSTEM):
         raise ValueError('system transition {} -> {} already exists'.format(source, target))
     natex_nlg = NatexNLG(natex_nlg, macros=self._macros)
     if not self.has_state(source):
         self.add_state(source)
     if not self.has_state(target):
         self.add_state(target)
     self._graph.add_arc(source, target, Speaker.SYSTEM)
     self.set_transition_natex(source, target, Speaker.SYSTEM, natex_nlg)
     transition_settings = Settings(score=1.0)
     transition_settings.update(**settings)
     self.set_transition_settings(source, target, Speaker.SYSTEM, transition_settings)
     if self._all_multi_hop:
         self.update_state_settings(source, system_multi_hop=True)
     if target in self._prepends:
         prepend = self._prepends[target]
         natex = self.transition_natex(source, target, Speaker.SYSTEM)
         self.set_transition_natex(source, target, Speaker.SYSTEM, prepend + natex)
예제 #11
0
 def transitions(self, source_state, speaker=None):
     """
     get (source, target, speaker) transition tuples for the entire state machine
     (default) or that lead out from a given source_state
     :param source_state: optionally, filter returned transitions by source state
     :param speaker: optionally, filter returned transitions by speaker
     :return: a generator over (source, target, speaker) 3-tuples
     """
     source_state = module_state(source_state)
     source_state = State(source_state)
     if speaker is None:
         yield from self._graph.arcs_out(source_state)
     elif self._graph.has_arc_label(source_state, speaker):
         yield from self._graph.arcs_out(source_state, label=speaker)
     else:
         return
예제 #12
0
 def has_transition(self, source, target, speaker):
     source, target = module_source_target(source, target)
     source = State(source)
     target = State(target)
     return self._graph.has_arc(source, target, speaker)
예제 #13
0
 def set_error_successor(self, state, error_successor):
     state, error_successor = module_source_target(state, error_successor)
     state = State(state)
     error_successor = State(error_successor)
     self._graph.data(state)['error'] = error_successor
예제 #14
0
 def __init__(self, initial_state: Union[Enum, str, tuple], initial_speaker = Speaker.SYSTEM,
              macros: Dict[str, Macro] =None, kb: Union[KnowledgeBase, str, List[str]] =None,
              default_system_state=None, end_state='__end__', all_multi_hop=True, wordnet=False):
     self._graph = GraphDatabase()
     self._initial_state = State(initial_state)
     self._potential_transition = None
     self._initial_speaker = initial_speaker
     self._speaker = self._initial_speaker
     self._vars = HashableDict()
     self._transitions = []
     self._update_transitions = []
     self.vars()['__state__'] = self._initial_state
     self.set_state(self._initial_state)
     self._gates = defaultdict(list)
     self._prepends = {}
     self._var_dependencies = defaultdict(set)
     self._error_transitioned = False
     self._default_state = default_system_state
     self._end_state = end_state
     self._goals = {}
     self._all_multi_hop = all_multi_hop
     self._composite_dialogue_flow = None
     self._namespace = None
     self.vars()['__stack__'] = []
     self.vars()['__system_state__'] = 'None' if initial_speaker == Speaker.USER else self._initial_state
     if kb is None:
         self._kb = KnowledgeBase()
     elif isinstance(kb, str):
         self._kb = KnowledgeBase()
         self._kb.load_json_file(kb)
     elif isinstance(kb, list):
         self._kb = KnowledgeBase()
         for filename in kb:
             self._kb.load_json_file(filename)
     else:
         self._kb = kb
     onte = ONTE(self._kb)
     kbe = KBE(self._kb)
     goal_exit_macro = GoalExit(self)
     self._macros = {
         'WN': WN(wordnet),
         'ONT': onte, 'ONTE': onte,
         'ONTUL': ONTUL(self._kb),
         'KBQ': kbe, 'KBE': kbe,
         'ONTN': ONTN(self._kb),
         'EXP': EXP(self._kb),
         'ONT_NEG': ONT_NEG(self._kb),
         'FPP': FirstPersonPronoun(self._kb),
         'TPP': ThirdPersonPronoun(self._kb),
         'PSP': PossessivePronoun(self._kb),
         'GATE': Gate(self),
         'TRANSITION': Transition(self),
         'GOAL': GoalPursuit(goal_exit_macro, self),
         'GCOM': GoalCompletion(self),
         'GEXT': goal_exit_macro,
         'GSRET': SetGoalReturnPoint(),
         'GRET': GoalReturn(self),
         'GCLR': ClearGoalStack(),
         'VT': VirtualTransitions(self),
         'CE': CanEnter(self),
         'EXTR': ExtractList(self._kb)
     }
     self._macros.update(macros_common_dict)
     self._macros.update(natex_macros_common)
     if macros:
         self._macros.update(macros)
     self._rules = UpdateRules(vars=self._vars, macros=self._macros)
     self.add_state(end_state)
     self._vars['__user_utterance__'] = None
예제 #15
0
 def has_state(self, state):
     state = module_state(state)
     state = State(state)
     return self._graph.has_node(state)
예제 #16
0
 def remove_transition(self, source, target, speaker):
     source, target = module_source_target(source, target)
     source = State(source)
     target = State(target)
     MapMultidigraph.remove_arc(self.graph(), source, target, speaker)
예제 #17
0
 def user_transition(self, natural_language: str, state: Union[Enum, str, tuple], debugging=False):
     """
     :param state:
     :param natural_language:
     :param debugging:
     :return: the successor state representing the highest score user transition
              that matches natural_language, or None if none match
     """
     if '__gate__' in self._vars:
         del self._vars['__gate__']
     if '__user_utterance__' in self.vars() and self.vars()['__user_utterance__'] is not None:
         natural_language = self.vars()['__user_utterance__']
     else:
         natural_language = ''.join([c.lower() for c in natural_language if c.isalpha() or c == ' '])
     state = module_state(state)
     self._error_transitioned = False
     ti = time()
     if state is None:
         state = self.state()
     else:
         state = State(state)
     transition_options = []
     transition_items = []
     for transition in self.transitions(state, Speaker.USER):
         natex = self.transition_natex(*transition)
         score = self.transition_settings(*transition).score
         transition_items.append((natex, transition, score))
     while self._transitions:
         natex, transition, score = self._transitions.pop()
         transition_items.append((natex, transition, score))
     ngrams = Ngrams(natural_language, n=10)
     for natex, transition, score in transition_items:
         self._potential_transition = transition
         if not self.is_module() and isinstance(transition[1], tuple):
             continue
         t1 = time()
         if debugging:
             print('Evaluating transition {}'.format(transition[:2]))
         vars = HashableDict(self._vars)
         try:
             match = natex.match(natural_language, vars, self._macros, ngrams, debugging)
         except Exception as e:
             print()
             print('Transition {}: {} failed'.format(str(transition), natex))
             traceback.print_exc(file=sys.stdout)
             print()
             match = None
         source, target, speaker = transition
         if '__source__' in vars:
             source = State(module_state(vars['__source__']))
             del vars['__source__']
         if '__target__' in vars:
             target = State(module_state(vars['__target__']))
             del vars['__target__']
         transition = source, target, speaker
         if self.is_module() and isinstance(target, tuple):
             enter_natex = self.composite_dialogue_flow().state_settings(*target).enter
         else:
             enter_natex = self.state_settings(target).enter
         enter_natex_pass = True
         if enter_natex is not None:
             try:
                 enter_natex_pass = enter_natex.generate(vars=vars, macros=self._macros, debugging=debugging)
             except Exception as e:
                 print()
                 print(e)
                 print('Enter Natex {}: {} failed'.format(str(target), enter_natex))
                 print()
                 enter_natex_pass = None
         if match and enter_natex_pass is not None:
             if debugging:
                 print('Transition {} matched "{}"'.format(transition[:2], natural_language))
             if '__score__' in vars:
                 score = vars['__score__']
                 del vars['__score__']
             gate_closed = False
             gate_var_config = None
             gate_target_id = None
             if '__gate__' in vars:
                 gate_var_config = vars['__gate__']
                 gate_target_id = (self.namespace(), target) if (
                             not isinstance(target, tuple) and self.is_module()) else target
                 for vc in self.gates()[gate_target_id]:
                     if gate_var_config == vc:
                         gate_closed = True
                 del vars['__gate__']
             if not gate_closed:
                 transition_options.append((score, natex, transition, vars, gate_var_config, gate_target_id))
         t2 = time()
         if debugging:
             print('Transition {} evaluated in {:.5f}'.format(transition, t2-t1))
         while self._transitions:
             natex, transition, score = self._transitions.pop()
             transition_items.append((natex, transition, score))
     self._transitions.clear()
     if transition_options:
         if debugging:
             print('Transition options: ------------')
             for option in transition_options:
                 print('{} {}: {}'.format(option[0], option[2][1], option[1]))
             print('--------------------------------')
         score, natex, transition, vars, gate_var_config, gate_target_id = random_max(transition_options, key=lambda x: x[0])
         if gate_var_config is not None:
             self.gates()[gate_target_id].append(gate_var_config)
         if debugging:
             updates = {}
             for k, v in vars.items():
                 if k not in self._vars or v != self._vars[k]:
                     updates[k] = v
             if updates:
                 print('Updating vars:')
                 for k, v in updates.items():
                     if k in self._vars:
                         print('  {} = {} -> {}'.format(k, self._vars[k], v))
                     else:
                         print('  {} = None -> {}'.format(k, v))
         self.update_vars(vars)
         next_state = transition[1]
         if debugging:
             print('User transition in {:.5f}'.format(time() - ti))
             print('Transitioning {} -> {}'.format(self.state(), next_state))
         return next_state
     else:
         self._error_transitioned = True
         next_state = self.error_successor(self.state())
         if debugging:
             print('User transition in {:.5f}'.format(time() - ti))
             print('Error transition {} -> {}'.format(self.state(), next_state))
         return next_state
예제 #18
0
 def incoming_transitions(self, target_state):
     target_state = module_state(target_state)
     target_state = State(target_state)
     yield from self._graph.arcs_in(target_state)
예제 #19
0
 def update_transition_settings(self, source, target, speaker, **settings):
     source, target = module_source_target(source, target)
     source = State(source)
     target = State(target)
     self.transition_settings(source, target, speaker).update(**settings)
예제 #20
0
 def set_transition_settings(self, source, target, speaker, settings):
     source, target = module_source_target(source, target)
     source = State(source)
     target = State(target)
     self._graph.arc_data(source, target, speaker)['settings'] = settings
예제 #21
0
 def transition_settings(self, source: Union[Enum, str, tuple], target: Union[Enum, str, tuple], speaker: Enum):
     source, target = module_source_target(source, target)
     source = State(source)
     target = State(target)
     return self._graph.arc_data(source, target, speaker)['settings']
예제 #22
0
 def state_settings(self, state):
     state = module_state(state)
     state = State(state)
     return self._graph.data(state)['settings']
예제 #23
0
 def system_transition(self, state: Union[Enum, str, tuple], debugging=False):
     """
     :param state:
     :param debugging:
     :return: a <state, response> tuple representing the successor state and response
     """
     if '__gate__' in self._vars:
         del self._vars['__gate__']
     state = module_state(state)
     ti = time()
     if state is None:
         state = self.state()
     else:
         state = State(state)
     transition_options = []
     transitions = list(self.transitions(state, Speaker.SYSTEM))
     transition_items = []
     for transition in transitions:
         natex = self.transition_natex(*transition)
         score = self.transition_settings(*transition).score
         transition_items.append((natex, transition, score))
     while self._transitions:
         natex, transition, score = self._transitions.pop()
         transition_items.append((natex, transition, score))
     while self._update_transitions:
         natex, transition, score = self._update_transitions.pop()
         transition_items.append((natex, transition, score))
     for natex, transition, score in transition_items:
         t1 = time()
         transition_transition_enter = None
         vars = HashableDict(self._vars)
         self._potential_transition = transition # MOVED, todo
         try:
             generation = natex.generate(vars=vars, macros=self._macros, debugging=debugging)
         except Exception as e:
             print()
             print('Transition {}: {} failed'.format(str(transition), natex))
             traceback.print_exc(file=sys.stdout)
             print()
             generation = None
         source, target, speaker = transition
         if '__source__' in vars:
             source = State(module_state(vars['__source__']))
             del vars['__source__']
         if '__target__' in vars:
             target = State(module_state(vars['__target__']))
             del vars['__target__']
         transition = source, target, speaker
         # if not self.is_module() and isinstance(target, tuple):
         #     continue
         if '->' in transition[1]:
             _src, _tar = target.split('->')[0], target.split('->')[1]
             _tar = State(module_state(_tar))
             transition = (_src, _tar, speaker)
             try:
                 appended_generation = self.transition_natex(*transition).generate(vars=vars, macros=self._macros, debugging=debugging)
                 if appended_generation is None:
                     generation = None
                 else:
                     generation = generation + ' ' + appended_generation
             except Exception as e:
                 print()
                 print('Transition {}: {} failed'.format(str(transition), natex))
                 traceback.print_exc(file=sys.stdout)
                 print()
                 generation = None
         elif isinstance(transition[1], tuple) and '->' in transition[1][1]:
             namespace = transition[1][0]
             source, target = (namespace, target[1].split('->')[0]), target[1].split('->')[1]
             target = State(module_state(target))
             transition_transition_enter = source
             transition = (source, target, speaker)
             try:
                 appended_generation = self.composite_dialogue_flow().transition_natex(
                     namespace, *transition).generate(vars=vars, macros=self._macros, debugging=debugging)
                 if generation is None or appended_generation is None:
                     generation = None
                 else:
                     generation = generation + ' ' + appended_generation
             except Exception as e:
                 print()
                 print('Transition {}: {} failed'.format(str(transition), natex))
                 traceback.print_exc(file=sys.stdout)
                 print()
                 generation = None
         source, target, speaker = transition
         if '__source__' in vars:
             source = State(module_state(vars['__source__']))
             del vars['__source__']
         if '__target__' in vars:
             target = State(module_state(vars['__target__']))
             del vars['__target__']
         transition = source, target, speaker
         enter_natex_pass = True
         transition_transition_enter_vars = vars
         if transition_transition_enter is not None:
             if self.is_module() and isinstance(transition_transition_enter, tuple):
                 enter_natex = self.composite_dialogue_flow().state_settings(*transition_transition_enter).enter
             else:
                 enter_natex = self.state_settings(transition_transition_enter).enter
             if enter_natex is not None:
                 try:
                     enter_natex_pass = enter_natex.generate(vars=transition_transition_enter_vars, macros=self._macros, debugging=debugging)
                 except Exception as e:
                     print()
                     print(e)
                     print('Enter Natex {}: {} failed'.format(str(transition_transition_enter), enter_natex))
                     print()
                     enter_natex_pass = None
         if enter_natex_pass:
             if self.is_module() and isinstance(target, tuple):
                 enter_natex = self.composite_dialogue_flow().state_settings(*target).enter
             else:
                 enter_natex = self.state_settings(target).enter
             if enter_natex is not None:
                 try:
                     enter_natex_pass = enter_natex.generate(vars=vars, macros=self._macros, debugging=debugging)
                 except Exception as e:
                     print()
                     print(e)
                     print('Enter Natex {}: {} failed'.format(str(target), enter_natex))
                     print()
                     enter_natex_pass = None
         if generation is not None and enter_natex_pass is not None:
             if '__score__' in vars:
                 score = vars['__score__']
                 del vars['__score__']
             gate_closed = False
             gate_var_config = None
             gate_target_id = None
             if '__gate__' in vars:
                 gate_var_config = vars['__gate__']
                 gate_target_id = (self.namespace(), target) if (not isinstance(target, tuple) and self.is_module()) else target
                 for vc in self.gates()[gate_target_id]:
                     if gate_var_config == vc:
                         gate_closed = True
                 del vars['__gate__']
             tt_gate_var_config = None
             tt_gate_target_id = None
             if transition_transition_enter is not None and '__gate__' in transition_transition_enter_vars:
                 tt_gate_var_config = transition_transition_enter_vars['__gate__']
                 tt_gate_target_id = (self.namespace(), transition_transition_enter) if \
                     (not isinstance(transition_transition_enter, tuple) and self.is_module()) else transition_transition_enter
                 for vc in self.gates()[tt_gate_target_id]:
                     if tt_gate_var_config == vc:
                         gate_closed = True
                 del transition_transition_enter_vars['__gate__']
             transition_transition_enter_vars.update(vars)
             vars = transition_transition_enter_vars
             if not gate_closed:
                 transition_options.append((score, natex, generation, transition, vars, gate_var_config, gate_target_id, tt_gate_var_config, tt_gate_target_id))
         t2 = time()
         if debugging:
             print('Transition {} evaluated in {:.5f}'.format(transition, t2-t1))
         while self._transitions:
             natex, transition, score = self._transitions.pop()
             transition_items.append((natex, transition, score))
     self._transitions.clear()
     if transition_options:
         if debugging:
             print('Transition options: ------------')
             for option in transition_options:
                 print('{} {}: {}'.format(option[0], option[3][1], option[1]))
             print('--------------------------------')
         score, natex, response, transition, vars, gate_var_config, gate_target_id, tt_gate_var_config, tt_gate_target_id =\
             random_max(transition_options, key=lambda x: x[0])
         if gate_var_config is not None:
             self.gates()[gate_target_id].append(gate_var_config)
         if tt_gate_var_config is not None:
             self.gates()[tt_gate_target_id].append(tt_gate_var_config)
         if debugging:
             updates = {}
             for k, v in vars.items():
                 if k not in self._vars or v != self._vars[k]:
                     updates[k] = v
             if updates:
                 print('Updating vars:')
                 for k, v in updates.items():
                     if k in self._vars:
                         print('  {} = {} -> {}'.format(k, self._vars[k], v))
                     else:
                         print('  {} = None -> {}'.format(k, v))
         self.update_vars(vars)
         next_state = transition[1]
         if debugging:
             tf = time()
             print('System transition in {:.5f}'.format(tf-ti))
             print('Transitioning {} -> {}'.format(self.state(), next_state))
         if '__response_prefix__' in self.vars() and self.vars()['__response_prefix__'] != 'None':
             response = self.vars()['__response_prefix__'] + ' ' + response
             self.vars()['__response_prefix__'] = 'None'
         return response, next_state
     else:
         if self._default_state is not None:
             self.set_state(self._default_state)
             if debugging:
                 print('No valid system transitions found, going to default state...')
             return self.system_transition(self.state(), debugging=debugging)
         raise AssertionError('dialogue flow system transition found no valid options from state {}'.format(state))