def _expand_request(self, action_name: str): """ Expand request_*slot* action """ act = SysAct() act.type = SysActionType.Request req_slot = self._get_slotnames_from_actionname(action_name)[0] act.add_value(req_slot) return act
def _expand_hello(self): """ Call this function when a dialog begins """ hello_action = SysAct() hello_action.type = SysActionType.Welcome self.last_sys_act = hello_action self.logger.dialog_turn("system action > " + str(hello_action)) return {'sys_act': hello_action}
def _expand_confirm(self, action_name: str, beliefstate: BeliefState): """ Expand confirm_*slot* action """ act = SysAct() act.type = SysActionType.Confirm conf_slot = self._get_slotnames_from_actionname(action_name)[0] candidates = beliefstate.get_most_probable_inf_beliefs( consider_NONE=False, threshold=0.0, max_results=1) conf_value = candidates[conf_slot] act.add_value(conf_slot, conf_value) return act
def _expand_select(self, action_name: str, beliefstate: BeliefState): """ Expand select_*slot* action """ act = SysAct() act.type = SysActionType.Select sel_slot = self._get_slotnames_from_actionname(action_name)[0] most_likely_choice = beliefstate.get_most_probable_sysreq_beliefs( consider_NONE=False, threshold=0.0, max_results=2) first_value = most_likely_choice[sel_slot][0] second_value = most_likely_choice[sel_slot][1] act.add_value(sel_slot, first_value) act.add_value(sel_slot, second_value) return act
def forward(self, dialog_graph, user_acts: List[UserAct] = None, beliefstate: BeliefState = None, sim_goal: Goal = None, writer: SummaryWriter = None, **kwargs) -> dict(sys_act=SysAct): super(DQNPolicy, self).forward(dialog_graph, user_acts=user_acts, beliefstate=beliefstate, sim_goal=sim_goal, writer=writer) self.num_dialogs = dialog_graph.num_dialogs % self.train_dialogs if dialog_graph.num_dialogs == 0 and self.target_model is not None: # start with same weights for target and online net when a new epoch begins self.target_model.load_state_dict(self.model.state_dict()) if dialog_graph.num_turns == 0: # first turn of dialog: say hello & don't record return self._expand_hello() if dialog_graph.num_turns > MAX_TURNS: # reached turn limit -> terminate dialog bye_action = SysAct() bye_action.type = SysActionType.Bye self.last_sys_act = bye_action #self.end_dialog(sim_goal) logger.dialog_turn("system action > " + str(bye_action)) return {'sys_act': bye_action} # intermediate or closing turn #logger.dialog_turn(" last informed venue" + beliefstate['system']['lastInformedPrimKeyVal']) state_vector = self.beliefstate_dict_to_vector(beliefstate) next_action_idx = -1 # check if user ended dialog if user_acts is not None: for user_act in user_acts: if user_act.type == UserActionType.Bye: # user terminated current dialog -> say bye next_action_idx = self.action_idx(SysActionType.Bye.value) if next_action_idx == -1: # dialog continues next_action_idx = self.select_action_eps_greedy(state_vector) self.turn_end(beliefstate, state_vector, next_action_idx) # if next_action_idx == self.action_idx(SysActionType.Bye.value): # # system ended current dialog # self.end_dialog(sim_goal) # logger.dialog_turn("Belief State: " + str(list(filter(lambda x: x[0] in ['area', 'pricerange'], beliefstate['beliefs'].items())))) return {'sys_act': self.last_sys_act}
def _expand_confirm(self, action_name: str, beliefstate: BeliefState): """ Expand confirm_*slot* action """ act = SysAct() act.type = SysActionType.Confirm conf_slot = self._get_slotnames_from_actionname(action_name)[0] candidates = beliefstate.get_most_probable_inf_beliefs( consider_NONE=False, threshold=0.0, max_results=1) # If the slot isn't in the beliefstate choose a random value from the ontology conf_value = candidates[ conf_slot] if conf_slot in candidates else random.choice( self.domain.get_possible_values(conf_slot)) act.add_value(conf_slot, conf_value) return act
def _expand_informbyalternatives(self, beliefstate: BeliefState): """ Expand inform_byalternatives action """ act = SysAct() act.type = SysActionType.InformByAlternatives # get set of all previously informed primary key values informedPrimKeyValsSinceNone = set( self.sys_state['informedPrimKeyValsSinceNone']) candidates = beliefstate.get_most_probable_inf_beliefs( consider_NONE=True, threshold=0.7, max_results=1) filtered_slot_values = self._remove_dontcare_slots(candidates) # query db by constraints db_matches = self.domain.find_entities(candidates) if not db_matches: # no results found for slot in common.numpy.random.choice( list(filtered_slot_values.keys()), min(5, len(filtered_slot_values)), replace=False): act.add_value(slot, filtered_slot_values[slot]) act.add_value(self.primary_key, 'none') return act # don't inform about already informed entities # -> exclude primary key values from informedPrimKeyValsSinceNone for db_match in db_matches: if db_match[self.primary_key] not in informedPrimKeyValsSinceNone: for slot in common.numpy.random.choice( list(filtered_slot_values.keys()), min(4 - (len(act.slot_values)), len(filtered_slot_values)), replace=False): act.add_value(slot, filtered_slot_values[slot]) additional_constraints = {} for slot, value in candidates.items(): if len(act.slot_values) < 4 and value == 'dontcare': additional_constraints[slot] = value db_match = self.domain.find_info_about_entity( db_match[self.primary_key], requested_slots=self.domain.get_informable_slots())[0] self._db_results_to_sysact(act, additional_constraints, db_match) return act # no alternatives found (that were not already mentioned) act.add_value(self.primary_key, 'none') return act
def _expand_confreq(self, action_name: str, beliefstate: BeliefState): """ Expand confreq_*confirmslot*_*requestslot* action """ act = SysAct() act.type = SysActionType.ConfirmRequest # first slot name is confirmation, second is request slots = self._get_slotnames_from_actionname(action_name) conf_slot = slots[0] req_slot = slots[1] # get value that needs confirmation candidates = beliefstate.get_most_probable_inf_beliefs( consider_NONE=False, threshold=0.0, max_results=1) conf_value = candidates[conf_slot] act.add_value(conf_slot, conf_value) # add request slot act.add_value(req_slot) return act
def _expand_byconstraints(self, beliefstate: BeliefState): """ Create inform act with an entity from the database, if any matches could be found for the constraints, otherwise will return an inform act with primary key=none """ act = SysAct() act.type = SysActionType.Inform # get constraints and query db by them constraints = beliefstate.get_most_probable_inf_beliefs( consider_NONE=True, threshold=0.7, max_results=1) db_matches = self.domain.find_entities(constraints, requested_slots=constraints) if not db_matches: # no matching entity found -> return inform with primary key=none # and other constraints filtered_slot_values = self._remove_dontcare_slots(constraints) filtered_slots = common.numpy.random.choice( list(filtered_slot_values.keys()), min(5, len(filtered_slot_values)), replace=False) for slot in filtered_slots: if not slot == 'name': act.add_value(slot, filtered_slot_values[slot]) act.add_value(self.primary_key, 'none') else: # match found -> return its name # if > 1 match and matches contain last informed entity, # stick to this match = [ db_match for db_match in db_matches if db_match[self.primary_key] == self.sys_state['lastInformedPrimKeyVal'] ] if not match: # none matches last informed venue -> pick first result # match = db_matches[0] match = common.random.choice(db_matches) else: assert len(match) == 1 match = match[0] # fill act with values from db self._db_results_to_sysact(act, constraints, match) return act
def _expand_reqmore(self): """ Expand reqmore action """ act = SysAct() act.type = SysActionType.RequestMore return act
def _expand_bye(self): """ Expand bye action """ act = SysAct() act.type = SysActionType.Bye return act
def _expand_informbyname(self, beliefstate: BeliefState): """ Expand inform_byname action """ act = SysAct() act.type = SysActionType.InformByName # get most probable entity primary key if self.primary_key in beliefstate['informs']: primkeyval = sorted( beliefstate["informs"][self.primary_key].items(), key=lambda kv: kv[1], reverse=True)[0][0] else: # try to use previously informed name instead primkeyval = self.sys_state['lastInformedPrimKeyVal'] # TODO change behaviour from here, because primkeyval might be "**NONE**" and this might be an entity in the database # find db entry by primary key constraints = beliefstate.get_most_probable_inf_beliefs( consider_NONE=True, threshold=0.7, max_results=1) db_matches = self.domain.find_entities({ **constraints, self.primary_key: primkeyval }) # NOTE usually not needed to give all constraints (shouldn't make a difference) if not db_matches: # select random entity if none could be found primkeyvals = self.domain.get_possible_values(self.primary_key) primkeyval = common.random.choice(primkeyvals) db_matches = self.domain.find_entities( constraints, self.domain.get_requestable_slots()) # use knowledge from current belief state if not db_matches: # no results found filtered_slot_values = self._remove_dontcare_slots(constraints) for slot in common.numpy.random.choice( list(filtered_slot_values.keys()), min(5, len(filtered_slot_values)), replace=False): act.add_value(slot, filtered_slot_values[slot]) act.add_value(self.primary_key, 'none') return act # select random match db_match = common.random.choice(db_matches) db_match = self.domain.find_info_about_entity( db_match[self.primary_key], requested_slots=self.domain.get_requestable_slots())[0] # get slots requested by user usr_requests = beliefstate.get_requested_slots() # remove primary key (to exlude from minimum number) since it is added anyway at the end if self.primary_key in usr_requests: usr_requests.remove(self.primary_key) if usr_requests: # add user requested values into system act using db result for req_slot in common.numpy.random.choice(usr_requests, min( 4, len(usr_requests)), replace=False): if req_slot in db_match: act.add_value(req_slot, db_match[req_slot]) else: act.add_value(req_slot, 'none') else: constraints = self._remove_dontcare_slots(constraints) if constraints: for inform_slot in common.numpy.random.choice( list(constraints.keys()), min(4, len(constraints)), replace=False): value = db_match[inform_slot] act.add_value(inform_slot, value) else: # add random slot and value if no user request was detected usr_requestable_slots = set( self.domain.get_requestable_slots()) usr_requestable_slots.remove(self.primary_key) random_slot = common.random.choice(list(usr_requestable_slots)) value = db_match[random_slot] act.add_value(random_slot, value) # ensure entity primary key is included if self.primary_key not in act.slot_values: act.add_value(self.primary_key, db_match[self.primary_key]) return act
def _expand_informbyalternatives(self, beliefstate: BeliefState): """ Expand inform_byalternatives action """ act = SysAct() act.type = SysActionType.InformByAlternatives # get set of all previously informed primary key values informedPrimKeyValsSinceNone = set( beliefstate['system']['informedPrimKeyValsSinceNone']) candidates = beliefstate.get_most_probable_inf_beliefs( consider_NONE=True, threshold=0.7, max_results=1) filtered_slot_values = self._remove_dontcare_slots(candidates) # query db by constraints db_matches = self.domain.find_entities(candidates) if not db_matches: # no results found for slot in common.numpy.random.choice( list(filtered_slot_values.keys()), min(5, len(filtered_slot_values)), replace=False): act.add_value(slot, filtered_slot_values[slot]) act.add_value(self.primary_key, 'none') return act # don't inform about already informed entities # -> exclude primary key values from informedPrimKeyValsSinceNone for db_match in db_matches: if db_match[self.primary_key] not in informedPrimKeyValsSinceNone: # new entity! # TODO move to _db_results_to_sysact method (including maximum inform slot-values) # get slots requested by user # usr_requests = beliefstate.get_requested_slots(threshold=0.0) # if self.primary_key in usr_requests: # usr_requests.remove(self.primary_key) # if usr_requests: # # add user requested values into system act using db result # for req_slot in common.numpy.random.choice(usr_requests, min(4, # len(usr_requests)), replace=False): # if req_slot in db_match: # act.add_value(req_slot, db_match[req_slot]) # else: # act.add_value(req_slot, 'none') # add slots for which the value != 'dontcare' for slot in common.numpy.random.choice( list(filtered_slot_values.keys()), min(4 - (len(act.slot_values)), len(filtered_slot_values)), replace=False): act.add_value(slot, filtered_slot_values[slot]) additional_constraints = {} for slot, value in candidates.items(): if len(act.slot_values) < 4 and value == 'dontcare': additional_constraints[slot] = value db_match = self.domain.find_info_about_entity( db_match[self.primary_key], requested_slots=self.domain.get_informable_slots())[0] self._db_results_to_sysact(act, additional_constraints, db_match) return act # no alternatives found (that were not already mentioned) act.add_value(self.primary_key, 'none') return act
def choose_sys_act(self, beliefstate: BeliefState = None ) -> dict(sys_act=SysAct): """ Determine the next system act based on the given beliefstate Args: beliefstate (BeliefState): beliefstate, contains all information the system knows about the environment (in this case the user) Returns: (dict): dictionary where the keys are "sys_act" representing the action chosen by the policy, and "sys_state" which contains additional informatino which might be needed by the NLU to disambiguate challenging utterances. """ self.num_dialogs = self.cumulative_train_dialogs % self.train_dialogs if self.cumulative_train_dialogs == 0 and self.target_model is not None: # start with same weights for target and online net when a new epoch begins self.target_model.load_state_dict(self.model.state_dict()) self.turns += 1 if self.turns == 1: # first turn of dialog: say hello & don't record out_dict = self._expand_hello() out_dict["sys_state"] = {"last_act": out_dict["sys_act"]} return out_dict if self.turns > self.max_turns: # reached turn limit -> terminate dialog bye_action = SysAct() bye_action.type = SysActionType.Bye self.last_sys_act = bye_action # self.end_dialog(sim_goal) if self.logger: self.logger.dialog_turn("system action > " + str(bye_action)) sys_state = {"last_act": bye_action} return {'sys_act': bye_action, "sys_state": sys_state} # intermediate or closing turn state_vector = self.beliefstate_dict_to_vector(beliefstate) next_action_idx = -1 # check if user ended dialog if UserActionType.Bye in beliefstate["user_acts"]: # user terminated current dialog -> say bye next_action_idx = self.action_idx(SysActionType.Bye.value) if next_action_idx == -1: # dialog continues next_action_idx = self.select_action_eps_greedy(state_vector) self.turn_end(beliefstate, state_vector, next_action_idx) # Update the sys_state if self.last_sys_act.type in [ SysActionType.InformByName, SysActionType.InformByAlternatives ]: values = self.last_sys_act.get_values( self.domain.get_primary_key()) if values: # belief_state['system']['lastInformedPrimKeyVal'] = values[0] self.sys_state['lastInformedPrimKeyVal'] = values[0] elif self.last_sys_act.type == SysActionType.Request: if len(list(self.last_sys_act.slot_values.keys())) > 0: self.sys_state['lastRequestSlot'] = list( self.last_sys_act.slot_values.keys())[0] self.sys_state["last_act"] = self.last_sys_act return {'sys_act': self.last_sys_act, "sys_state": self.sys_state}