Ejemplo n.º 1
0
 def _expand_request(self, action_name: str):
     """ Expand request_*slot* action """
     act = SysAct()
     act.type = SysActionType.Request
     req_slot = self._get_slotnames_from_actionname(action_name)[0]
     act.add_value(req_slot)
     return act
Ejemplo n.º 2
0
    def _expand_hello(self):
        """ Call this function when a dialog begins """

        hello_action = SysAct()
        hello_action.type = SysActionType.Welcome
        self.last_sys_act = hello_action
        self.logger.dialog_turn("system action > " + str(hello_action))
        return {'sys_act': hello_action}
Ejemplo n.º 3
0
 def _expand_confirm(self, action_name: str, beliefstate: BeliefState):
     """ Expand confirm_*slot* action """
     act = SysAct()
     act.type = SysActionType.Confirm
     conf_slot = self._get_slotnames_from_actionname(action_name)[0]
     candidates = beliefstate.get_most_probable_inf_beliefs(
         consider_NONE=False, threshold=0.0, max_results=1)
     conf_value = candidates[conf_slot]
     act.add_value(conf_slot, conf_value)
     return act
Ejemplo n.º 4
0
 def _expand_select(self, action_name: str, beliefstate: BeliefState):
     """ Expand select_*slot* action """
     act = SysAct()
     act.type = SysActionType.Select
     sel_slot = self._get_slotnames_from_actionname(action_name)[0]
     most_likely_choice = beliefstate.get_most_probable_sysreq_beliefs(
         consider_NONE=False, threshold=0.0, max_results=2)
     first_value = most_likely_choice[sel_slot][0]
     second_value = most_likely_choice[sel_slot][1]
     act.add_value(sel_slot, first_value)
     act.add_value(sel_slot, second_value)
     return act
Ejemplo n.º 5
0
    def forward(self,
                dialog_graph,
                user_acts: List[UserAct] = None,
                beliefstate: BeliefState = None,
                sim_goal: Goal = None,
                writer: SummaryWriter = None,
                **kwargs) -> dict(sys_act=SysAct):
        super(DQNPolicy, self).forward(dialog_graph,
                                       user_acts=user_acts,
                                       beliefstate=beliefstate,
                                       sim_goal=sim_goal,
                                       writer=writer)
        self.num_dialogs = dialog_graph.num_dialogs % self.train_dialogs
        if dialog_graph.num_dialogs == 0 and self.target_model is not None:
            # start with same weights for target and online net when a new epoch begins
            self.target_model.load_state_dict(self.model.state_dict())
        if dialog_graph.num_turns == 0:
            # first turn of dialog: say hello & don't record
            return self._expand_hello()

        if dialog_graph.num_turns > MAX_TURNS:
            # reached turn limit -> terminate dialog
            bye_action = SysAct()
            bye_action.type = SysActionType.Bye
            self.last_sys_act = bye_action
            #self.end_dialog(sim_goal)
            logger.dialog_turn("system action > " + str(bye_action))
            return {'sys_act': bye_action}

        # intermediate or closing turn
        #logger.dialog_turn("   last informed venue" + beliefstate['system']['lastInformedPrimKeyVal'])
        state_vector = self.beliefstate_dict_to_vector(beliefstate)
        next_action_idx = -1

        # check if user ended dialog
        if user_acts is not None:
            for user_act in user_acts:
                if user_act.type == UserActionType.Bye:
                    # user terminated current dialog -> say bye
                    next_action_idx = self.action_idx(SysActionType.Bye.value)
        if next_action_idx == -1:
            # dialog continues
            next_action_idx = self.select_action_eps_greedy(state_vector)

        self.turn_end(beliefstate, state_vector, next_action_idx)

        # if next_action_idx == self.action_idx(SysActionType.Bye.value):
        #     # system ended current dialog
        #     self.end_dialog(sim_goal)
        # logger.dialog_turn("Belief State: " + str(list(filter(lambda x: x[0] in ['area', 'pricerange'], beliefstate['beliefs'].items()))))
        return {'sys_act': self.last_sys_act}
Ejemplo n.º 6
0
    def _expand_confirm(self, action_name: str, beliefstate: BeliefState):
        """ Expand confirm_*slot* action """
        act = SysAct()
        act.type = SysActionType.Confirm
        conf_slot = self._get_slotnames_from_actionname(action_name)[0]
        candidates = beliefstate.get_most_probable_inf_beliefs(
            consider_NONE=False, threshold=0.0, max_results=1)

        # If the slot isn't in the beliefstate choose a random value from the ontology
        conf_value = candidates[
            conf_slot] if conf_slot in candidates else random.choice(
                self.domain.get_possible_values(conf_slot))

        act.add_value(conf_slot, conf_value)
        return act
Ejemplo n.º 7
0
    def _expand_informbyalternatives(self, beliefstate: BeliefState):
        """ Expand inform_byalternatives action """
        act = SysAct()
        act.type = SysActionType.InformByAlternatives
        # get set of all previously informed primary key values
        informedPrimKeyValsSinceNone = set(
            self.sys_state['informedPrimKeyValsSinceNone'])
        candidates = beliefstate.get_most_probable_inf_beliefs(
            consider_NONE=True, threshold=0.7, max_results=1)
        filtered_slot_values = self._remove_dontcare_slots(candidates)
        # query db by constraints
        db_matches = self.domain.find_entities(candidates)
        if not db_matches:
            # no results found
            for slot in common.numpy.random.choice(
                    list(filtered_slot_values.keys()),
                    min(5, len(filtered_slot_values)),
                    replace=False):
                act.add_value(slot, filtered_slot_values[slot])
            act.add_value(self.primary_key, 'none')
            return act

        # don't inform about already informed entities
        # -> exclude primary key values from informedPrimKeyValsSinceNone
        for db_match in db_matches:
            if db_match[self.primary_key] not in informedPrimKeyValsSinceNone:
                for slot in common.numpy.random.choice(
                        list(filtered_slot_values.keys()),
                        min(4 - (len(act.slot_values)),
                            len(filtered_slot_values)),
                        replace=False):
                    act.add_value(slot, filtered_slot_values[slot])
                additional_constraints = {}
                for slot, value in candidates.items():
                    if len(act.slot_values) < 4 and value == 'dontcare':
                        additional_constraints[slot] = value
                db_match = self.domain.find_info_about_entity(
                    db_match[self.primary_key],
                    requested_slots=self.domain.get_informable_slots())[0]
                self._db_results_to_sysact(act, additional_constraints,
                                           db_match)

                return act

        # no alternatives found (that were not already mentioned)
        act.add_value(self.primary_key, 'none')
        return act
Ejemplo n.º 8
0
    def _expand_confreq(self, action_name: str, beliefstate: BeliefState):
        """ Expand confreq_*confirmslot*_*requestslot* action """
        act = SysAct()
        act.type = SysActionType.ConfirmRequest
        # first slot name is confirmation, second is request
        slots = self._get_slotnames_from_actionname(action_name)
        conf_slot = slots[0]
        req_slot = slots[1]

        # get value that needs confirmation
        candidates = beliefstate.get_most_probable_inf_beliefs(
            consider_NONE=False, threshold=0.0, max_results=1)
        conf_value = candidates[conf_slot]
        act.add_value(conf_slot, conf_value)
        # add request slot
        act.add_value(req_slot)
        return act
Ejemplo n.º 9
0
    def _expand_byconstraints(self, beliefstate: BeliefState):
        """ Create inform act with an entity from the database, if any matches
            could be found for the constraints, otherwise will return an inform
            act with primary key=none """
        act = SysAct()
        act.type = SysActionType.Inform

        # get constraints and query db by them
        constraints = beliefstate.get_most_probable_inf_beliefs(
            consider_NONE=True, threshold=0.7, max_results=1)

        db_matches = self.domain.find_entities(constraints,
                                               requested_slots=constraints)
        if not db_matches:
            # no matching entity found -> return inform with primary key=none
            # and other constraints
            filtered_slot_values = self._remove_dontcare_slots(constraints)
            filtered_slots = common.numpy.random.choice(
                list(filtered_slot_values.keys()),
                min(5, len(filtered_slot_values)),
                replace=False)
            for slot in filtered_slots:
                if not slot == 'name':
                    act.add_value(slot, filtered_slot_values[slot])
            act.add_value(self.primary_key, 'none')
        else:
            # match found -> return its name
            # if > 1 match and matches contain last informed entity,
            # stick to this
            match = [
                db_match for db_match in db_matches
                if db_match[self.primary_key] ==
                self.sys_state['lastInformedPrimKeyVal']
            ]
            if not match:
                # none matches last informed venue -> pick first result
                # match = db_matches[0]
                match = common.random.choice(db_matches)
            else:
                assert len(match) == 1
                match = match[0]
            # fill act with values from db
            self._db_results_to_sysact(act, constraints, match)

        return act
Ejemplo n.º 10
0
 def _expand_reqmore(self):
     """ Expand reqmore action """
     act = SysAct()
     act.type = SysActionType.RequestMore
     return act
Ejemplo n.º 11
0
 def _expand_bye(self):
     """ Expand bye action """
     act = SysAct()
     act.type = SysActionType.Bye
     return act
Ejemplo n.º 12
0
    def _expand_informbyname(self, beliefstate: BeliefState):
        """ Expand inform_byname action """
        act = SysAct()
        act.type = SysActionType.InformByName

        # get most probable entity primary key
        if self.primary_key in beliefstate['informs']:
            primkeyval = sorted(
                beliefstate["informs"][self.primary_key].items(),
                key=lambda kv: kv[1],
                reverse=True)[0][0]
        else:
            # try to use previously informed name instead
            primkeyval = self.sys_state['lastInformedPrimKeyVal']
            # TODO change behaviour from here, because primkeyval might be "**NONE**" and this might be an entity in the database
        # find db entry by primary key
        constraints = beliefstate.get_most_probable_inf_beliefs(
            consider_NONE=True, threshold=0.7, max_results=1)

        db_matches = self.domain.find_entities({
            **constraints, self.primary_key:
            primkeyval
        })
        # NOTE usually not needed to give all constraints (shouldn't make a difference)
        if not db_matches:
            # select random entity if none could be found
            primkeyvals = self.domain.get_possible_values(self.primary_key)
            primkeyval = common.random.choice(primkeyvals)
            db_matches = self.domain.find_entities(
                constraints, self.domain.get_requestable_slots())
            # use knowledge from current belief state

        if not db_matches:
            # no results found
            filtered_slot_values = self._remove_dontcare_slots(constraints)
            for slot in common.numpy.random.choice(
                    list(filtered_slot_values.keys()),
                    min(5, len(filtered_slot_values)),
                    replace=False):
                act.add_value(slot, filtered_slot_values[slot])
            act.add_value(self.primary_key, 'none')
            return act

        # select random match
        db_match = common.random.choice(db_matches)
        db_match = self.domain.find_info_about_entity(
            db_match[self.primary_key],
            requested_slots=self.domain.get_requestable_slots())[0]

        # get slots requested by user
        usr_requests = beliefstate.get_requested_slots()
        # remove primary key (to exlude from minimum number) since it is added anyway at the end
        if self.primary_key in usr_requests:
            usr_requests.remove(self.primary_key)
        if usr_requests:
            # add user requested values into system act using db result
            for req_slot in common.numpy.random.choice(usr_requests,
                                                       min(
                                                           4,
                                                           len(usr_requests)),
                                                       replace=False):
                if req_slot in db_match:
                    act.add_value(req_slot, db_match[req_slot])
                else:
                    act.add_value(req_slot, 'none')
        else:
            constraints = self._remove_dontcare_slots(constraints)
            if constraints:
                for inform_slot in common.numpy.random.choice(
                        list(constraints.keys()),
                        min(4, len(constraints)),
                        replace=False):
                    value = db_match[inform_slot]
                    act.add_value(inform_slot, value)
            else:
                # add random slot and value if no user request was detected
                usr_requestable_slots = set(
                    self.domain.get_requestable_slots())
                usr_requestable_slots.remove(self.primary_key)
                random_slot = common.random.choice(list(usr_requestable_slots))
                value = db_match[random_slot]
                act.add_value(random_slot, value)
        # ensure entity primary key is included
        if self.primary_key not in act.slot_values:
            act.add_value(self.primary_key, db_match[self.primary_key])

        return act
Ejemplo n.º 13
0
    def _expand_informbyalternatives(self, beliefstate: BeliefState):
        """ Expand inform_byalternatives action """
        act = SysAct()
        act.type = SysActionType.InformByAlternatives
        # get set of all previously informed primary key values
        informedPrimKeyValsSinceNone = set(
            beliefstate['system']['informedPrimKeyValsSinceNone'])
        candidates = beliefstate.get_most_probable_inf_beliefs(
            consider_NONE=True, threshold=0.7, max_results=1)
        filtered_slot_values = self._remove_dontcare_slots(candidates)
        # query db by constraints
        db_matches = self.domain.find_entities(candidates)
        if not db_matches:
            # no results found
            for slot in common.numpy.random.choice(
                    list(filtered_slot_values.keys()),
                    min(5, len(filtered_slot_values)),
                    replace=False):
                act.add_value(slot, filtered_slot_values[slot])
            act.add_value(self.primary_key, 'none')
            return act

        # don't inform about already informed entities
        # -> exclude primary key values from informedPrimKeyValsSinceNone
        for db_match in db_matches:
            if db_match[self.primary_key] not in informedPrimKeyValsSinceNone:
                # new entity!
                # TODO move to _db_results_to_sysact method (including maximum inform slot-values)
                # get slots requested by user
                # usr_requests = beliefstate.get_requested_slots(threshold=0.0)
                # if self.primary_key in usr_requests:
                #     usr_requests.remove(self.primary_key)
                # if usr_requests:
                #     # add user requested values into system act using db result
                #     for req_slot in common.numpy.random.choice(usr_requests, min(4,
                #                                                 len(usr_requests)), replace=False):
                #         if req_slot in db_match:
                #             act.add_value(req_slot, db_match[req_slot])
                #         else:
                #             act.add_value(req_slot, 'none')
                # add slots for which the value != 'dontcare'
                for slot in common.numpy.random.choice(
                        list(filtered_slot_values.keys()),
                        min(4 - (len(act.slot_values)),
                            len(filtered_slot_values)),
                        replace=False):
                    act.add_value(slot, filtered_slot_values[slot])
                additional_constraints = {}
                for slot, value in candidates.items():
                    if len(act.slot_values) < 4 and value == 'dontcare':
                        additional_constraints[slot] = value
                db_match = self.domain.find_info_about_entity(
                    db_match[self.primary_key],
                    requested_slots=self.domain.get_informable_slots())[0]
                self._db_results_to_sysact(act, additional_constraints,
                                           db_match)

                return act

        # no alternatives found (that were not already mentioned)
        act.add_value(self.primary_key, 'none')
        return act
Ejemplo n.º 14
0
    def choose_sys_act(self,
                       beliefstate: BeliefState = None
                       ) -> dict(sys_act=SysAct):
        """
            Determine the next system act based on the given beliefstate

            Args:
                beliefstate (BeliefState): beliefstate, contains all information the system knows
                                           about the environment (in this case the user)

            Returns:
                (dict): dictionary where the keys are "sys_act" representing the action chosen by
                        the policy, and "sys_state" which contains additional informatino which might
                        be needed by the NLU to disambiguate challenging utterances.
        """

        self.num_dialogs = self.cumulative_train_dialogs % self.train_dialogs
        if self.cumulative_train_dialogs == 0 and self.target_model is not None:
            # start with same weights for target and online net when a new epoch begins
            self.target_model.load_state_dict(self.model.state_dict())
        self.turns += 1
        if self.turns == 1:
            # first turn of dialog: say hello & don't record
            out_dict = self._expand_hello()
            out_dict["sys_state"] = {"last_act": out_dict["sys_act"]}
            return out_dict

        if self.turns > self.max_turns:
            # reached turn limit -> terminate dialog
            bye_action = SysAct()
            bye_action.type = SysActionType.Bye
            self.last_sys_act = bye_action
            # self.end_dialog(sim_goal)
            if self.logger:
                self.logger.dialog_turn("system action > " + str(bye_action))
            sys_state = {"last_act": bye_action}
            return {'sys_act': bye_action, "sys_state": sys_state}

        # intermediate or closing turn
        state_vector = self.beliefstate_dict_to_vector(beliefstate)
        next_action_idx = -1

        # check if user ended dialog
        if UserActionType.Bye in beliefstate["user_acts"]:
            # user terminated current dialog -> say bye
            next_action_idx = self.action_idx(SysActionType.Bye.value)
        if next_action_idx == -1:
            # dialog continues
            next_action_idx = self.select_action_eps_greedy(state_vector)

        self.turn_end(beliefstate, state_vector, next_action_idx)

        # Update the sys_state
        if self.last_sys_act.type in [
                SysActionType.InformByName, SysActionType.InformByAlternatives
        ]:
            values = self.last_sys_act.get_values(
                self.domain.get_primary_key())
            if values:
                # belief_state['system']['lastInformedPrimKeyVal'] = values[0]
                self.sys_state['lastInformedPrimKeyVal'] = values[0]
        elif self.last_sys_act.type == SysActionType.Request:
            if len(list(self.last_sys_act.slot_values.keys())) > 0:
                self.sys_state['lastRequestSlot'] = list(
                    self.last_sys_act.slot_values.keys())[0]

        self.sys_state["last_act"] = self.last_sys_act
        return {'sys_act': self.last_sys_act, "sys_state": self.sys_state}