コード例 #1
0
ファイル: bst.py プロジェクト: zhiyin121/adviser
    def dialog_start(self):
        """
            Restets the belief state so it is ready for a new dialog

            Returns:
                (dict): a dictionary with a single entry where the key is 'beliefstate'and
                        the value is a new BeliefState object
        """
        # initialize belief state
        self.bs = BeliefState(self.domain)
コード例 #2
0
ファイル: bst.py プロジェクト: wendywtchang/adviser
    def forward(self,
                dialog_graph,
                user_acts: List[UserAct] = None,
                beliefstate: BeliefState = None,
                sys_act: List[SysAct] = None,
                **kwargs) -> dict(beliefstate=BeliefState):
        """
            Function for updating the current dialog belief state (which tracks the system's
            knowledge about what has been said in the dialog) based on the user actions generated
            from the user's utterances

            Args:
                dialog_graph (DialogSystem): the graph to which the policy belongs
                belief_state (BeliefState): this should be None
                user_acts (list): a list of UserAct objects mapped from the user's last utterance
                sys_act (SysAct): this should be None

            Returns:
                (dict): a dictionary with the key "beliefstate" and the value the updated
                        BeliefState object

        """
        beliefstate.start_new_turn()

        if user_acts is None:
            # this check is required in case the BST is the first called module
            # e.g. usersimulation on semantic level:
            #   dialog acts as outputs -> no NLU
            return {'beliefstate': beliefstate}

        self._reset_informs(acts=user_acts, beliefstate=beliefstate)
        self._update_methods(beliefstate, user_acts)

        # TODO user acts should include probabilities and beliefstate should
        # update probabilities instead of always choosing 1.0

        # important to set these to zero since we don't want stale discourseAct
        for act in beliefstate['beliefs']['discourseAct']:
            beliefstate['beliefs']['discourseAct'][act] = 0.0
        beliefstate['beliefs']['discourseAct']['none'] = 1.0

        self.request_slots = {}

        self._handle_user_acts(beliefstate, user_acts, sys_act)

        beliefstate.update_num_dbmatches()
        return {'beliefstate': beliefstate}
コード例 #3
0
 def _expand_confirm(self, action_name: str, beliefstate: BeliefState):
     """ Expand confirm_*slot* action """
     act = SysAct()
     act.type = SysActionType.Confirm
     conf_slot = self._get_slotnames_from_actionname(action_name)[0]
     candidates = beliefstate.get_most_probable_inf_beliefs(
         consider_NONE=False, threshold=0.0, max_results=1)
     conf_value = candidates[conf_slot]
     act.add_value(conf_slot, conf_value)
     return act
コード例 #4
0
ファイル: bst.py プロジェクト: wendywtchang/adviser
    def start_dialog(self, **kwargs):
        """
            Restets the belief state so it is ready for a new dialog

            Returns:
                (dict): a dictionary with a single entry where the key is 'beliefstate'and
                        the value is a new BeliefState object
        """
        # initialize belief state
        self.inform_scores = {}
        self.request_slots = {}
        return {'beliefstate': BeliefState(self.domain)}
コード例 #5
0
 def _expand_select(self, action_name: str, beliefstate: BeliefState):
     """ Expand select_*slot* action """
     act = SysAct()
     act.type = SysActionType.Select
     sel_slot = self._get_slotnames_from_actionname(action_name)[0]
     most_likely_choice = beliefstate.get_most_probable_sysreq_beliefs(
         consider_NONE=False, threshold=0.0, max_results=2)
     first_value = most_likely_choice[sel_slot][0]
     second_value = most_likely_choice[sel_slot][1]
     act.add_value(sel_slot, first_value)
     act.add_value(sel_slot, second_value)
     return act
コード例 #6
0
ファイル: policy_rl.py プロジェクト: zhiyin121/adviser
    def _expand_confirm(self, action_name: str, beliefstate: BeliefState):
        """ Expand confirm_*slot* action """
        act = SysAct()
        act.type = SysActionType.Confirm
        conf_slot = self._get_slotnames_from_actionname(action_name)[0]
        candidates = beliefstate.get_most_probable_inf_beliefs(
            consider_NONE=False, threshold=0.0, max_results=1)

        # If the slot isn't in the beliefstate choose a random value from the ontology
        conf_value = candidates[
            conf_slot] if conf_slot in candidates else random.choice(
                self.domain.get_possible_values(conf_slot))

        act.add_value(conf_slot, conf_value)
        return act
コード例 #7
0
ファイル: policy_rl.py プロジェクト: zhiyin121/adviser
    def _expand_informbyalternatives(self, beliefstate: BeliefState):
        """ Expand inform_byalternatives action """
        act = SysAct()
        act.type = SysActionType.InformByAlternatives
        # get set of all previously informed primary key values
        informedPrimKeyValsSinceNone = set(
            self.sys_state['informedPrimKeyValsSinceNone'])
        candidates = beliefstate.get_most_probable_inf_beliefs(
            consider_NONE=True, threshold=0.7, max_results=1)
        filtered_slot_values = self._remove_dontcare_slots(candidates)
        # query db by constraints
        db_matches = self.domain.find_entities(candidates)
        if not db_matches:
            # no results found
            for slot in common.numpy.random.choice(
                    list(filtered_slot_values.keys()),
                    min(5, len(filtered_slot_values)),
                    replace=False):
                act.add_value(slot, filtered_slot_values[slot])
            act.add_value(self.primary_key, 'none')
            return act

        # don't inform about already informed entities
        # -> exclude primary key values from informedPrimKeyValsSinceNone
        for db_match in db_matches:
            if db_match[self.primary_key] not in informedPrimKeyValsSinceNone:
                for slot in common.numpy.random.choice(
                        list(filtered_slot_values.keys()),
                        min(4 - (len(act.slot_values)),
                            len(filtered_slot_values)),
                        replace=False):
                    act.add_value(slot, filtered_slot_values[slot])
                additional_constraints = {}
                for slot, value in candidates.items():
                    if len(act.slot_values) < 4 and value == 'dontcare':
                        additional_constraints[slot] = value
                db_match = self.domain.find_info_about_entity(
                    db_match[self.primary_key],
                    requested_slots=self.domain.get_informable_slots())[0]
                self._db_results_to_sysact(act, additional_constraints,
                                           db_match)

                return act

        # no alternatives found (that were not already mentioned)
        act.add_value(self.primary_key, 'none')
        return act
コード例 #8
0
ファイル: policy_rl.py プロジェクト: zhiyin121/adviser
    def _expand_confreq(self, action_name: str, beliefstate: BeliefState):
        """ Expand confreq_*confirmslot*_*requestslot* action """
        act = SysAct()
        act.type = SysActionType.ConfirmRequest
        # first slot name is confirmation, second is request
        slots = self._get_slotnames_from_actionname(action_name)
        conf_slot = slots[0]
        req_slot = slots[1]

        # get value that needs confirmation
        candidates = beliefstate.get_most_probable_inf_beliefs(
            consider_NONE=False, threshold=0.0, max_results=1)
        conf_value = candidates[conf_slot]
        act.add_value(conf_slot, conf_value)
        # add request slot
        act.add_value(req_slot)
        return act
コード例 #9
0
ファイル: policy_rl.py プロジェクト: zhiyin121/adviser
    def _expand_byconstraints(self, beliefstate: BeliefState):
        """ Create inform act with an entity from the database, if any matches
            could be found for the constraints, otherwise will return an inform
            act with primary key=none """
        act = SysAct()
        act.type = SysActionType.Inform

        # get constraints and query db by them
        constraints = beliefstate.get_most_probable_inf_beliefs(
            consider_NONE=True, threshold=0.7, max_results=1)

        db_matches = self.domain.find_entities(constraints,
                                               requested_slots=constraints)
        if not db_matches:
            # no matching entity found -> return inform with primary key=none
            # and other constraints
            filtered_slot_values = self._remove_dontcare_slots(constraints)
            filtered_slots = common.numpy.random.choice(
                list(filtered_slot_values.keys()),
                min(5, len(filtered_slot_values)),
                replace=False)
            for slot in filtered_slots:
                if not slot == 'name':
                    act.add_value(slot, filtered_slot_values[slot])
            act.add_value(self.primary_key, 'none')
        else:
            # match found -> return its name
            # if > 1 match and matches contain last informed entity,
            # stick to this
            match = [
                db_match for db_match in db_matches
                if db_match[self.primary_key] ==
                self.sys_state['lastInformedPrimKeyVal']
            ]
            if not match:
                # none matches last informed venue -> pick first result
                # match = db_matches[0]
                match = common.random.choice(db_matches)
            else:
                assert len(match) == 1
                match = match[0]
            # fill act with values from db
            self._db_results_to_sysact(act, constraints, match)

        return act
コード例 #10
0
ファイル: policy_rl.py プロジェクト: zhiyin121/adviser
    def _expand_informbyname(self, beliefstate: BeliefState):
        """ Expand inform_byname action """
        act = SysAct()
        act.type = SysActionType.InformByName

        # get most probable entity primary key
        if self.primary_key in beliefstate['informs']:
            primkeyval = sorted(
                beliefstate["informs"][self.primary_key].items(),
                key=lambda kv: kv[1],
                reverse=True)[0][0]
        else:
            # try to use previously informed name instead
            primkeyval = self.sys_state['lastInformedPrimKeyVal']
            # TODO change behaviour from here, because primkeyval might be "**NONE**" and this might be an entity in the database
        # find db entry by primary key
        constraints = beliefstate.get_most_probable_inf_beliefs(
            consider_NONE=True, threshold=0.7, max_results=1)

        db_matches = self.domain.find_entities({
            **constraints, self.primary_key:
            primkeyval
        })
        # NOTE usually not needed to give all constraints (shouldn't make a difference)
        if not db_matches:
            # select random entity if none could be found
            primkeyvals = self.domain.get_possible_values(self.primary_key)
            primkeyval = common.random.choice(primkeyvals)
            db_matches = self.domain.find_entities(
                constraints, self.domain.get_requestable_slots())
            # use knowledge from current belief state

        if not db_matches:
            # no results found
            filtered_slot_values = self._remove_dontcare_slots(constraints)
            for slot in common.numpy.random.choice(
                    list(filtered_slot_values.keys()),
                    min(5, len(filtered_slot_values)),
                    replace=False):
                act.add_value(slot, filtered_slot_values[slot])
            act.add_value(self.primary_key, 'none')
            return act

        # select random match
        db_match = common.random.choice(db_matches)
        db_match = self.domain.find_info_about_entity(
            db_match[self.primary_key],
            requested_slots=self.domain.get_requestable_slots())[0]

        # get slots requested by user
        usr_requests = beliefstate.get_requested_slots()
        # remove primary key (to exlude from minimum number) since it is added anyway at the end
        if self.primary_key in usr_requests:
            usr_requests.remove(self.primary_key)
        if usr_requests:
            # add user requested values into system act using db result
            for req_slot in common.numpy.random.choice(usr_requests,
                                                       min(
                                                           4,
                                                           len(usr_requests)),
                                                       replace=False):
                if req_slot in db_match:
                    act.add_value(req_slot, db_match[req_slot])
                else:
                    act.add_value(req_slot, 'none')
        else:
            constraints = self._remove_dontcare_slots(constraints)
            if constraints:
                for inform_slot in common.numpy.random.choice(
                        list(constraints.keys()),
                        min(4, len(constraints)),
                        replace=False):
                    value = db_match[inform_slot]
                    act.add_value(inform_slot, value)
            else:
                # add random slot and value if no user request was detected
                usr_requestable_slots = set(
                    self.domain.get_requestable_slots())
                usr_requestable_slots.remove(self.primary_key)
                random_slot = common.random.choice(list(usr_requestable_slots))
                value = db_match[random_slot]
                act.add_value(random_slot, value)
        # ensure entity primary key is included
        if self.primary_key not in act.slot_values:
            act.add_value(self.primary_key, db_match[self.primary_key])

        return act
コード例 #11
0
    def _expand_informbyalternatives(self, beliefstate: BeliefState):
        """ Expand inform_byalternatives action """
        act = SysAct()
        act.type = SysActionType.InformByAlternatives
        # get set of all previously informed primary key values
        informedPrimKeyValsSinceNone = set(
            beliefstate['system']['informedPrimKeyValsSinceNone'])
        candidates = beliefstate.get_most_probable_inf_beliefs(
            consider_NONE=True, threshold=0.7, max_results=1)
        filtered_slot_values = self._remove_dontcare_slots(candidates)
        # query db by constraints
        db_matches = self.domain.find_entities(candidates)
        if not db_matches:
            # no results found
            for slot in common.numpy.random.choice(
                    list(filtered_slot_values.keys()),
                    min(5, len(filtered_slot_values)),
                    replace=False):
                act.add_value(slot, filtered_slot_values[slot])
            act.add_value(self.primary_key, 'none')
            return act

        # don't inform about already informed entities
        # -> exclude primary key values from informedPrimKeyValsSinceNone
        for db_match in db_matches:
            if db_match[self.primary_key] not in informedPrimKeyValsSinceNone:
                # new entity!
                # TODO move to _db_results_to_sysact method (including maximum inform slot-values)
                # get slots requested by user
                # usr_requests = beliefstate.get_requested_slots(threshold=0.0)
                # if self.primary_key in usr_requests:
                #     usr_requests.remove(self.primary_key)
                # if usr_requests:
                #     # add user requested values into system act using db result
                #     for req_slot in common.numpy.random.choice(usr_requests, min(4,
                #                                                 len(usr_requests)), replace=False):
                #         if req_slot in db_match:
                #             act.add_value(req_slot, db_match[req_slot])
                #         else:
                #             act.add_value(req_slot, 'none')
                # add slots for which the value != 'dontcare'
                for slot in common.numpy.random.choice(
                        list(filtered_slot_values.keys()),
                        min(4 - (len(act.slot_values)),
                            len(filtered_slot_values)),
                        replace=False):
                    act.add_value(slot, filtered_slot_values[slot])
                additional_constraints = {}
                for slot, value in candidates.items():
                    if len(act.slot_values) < 4 and value == 'dontcare':
                        additional_constraints[slot] = value
                db_match = self.domain.find_info_about_entity(
                    db_match[self.primary_key],
                    requested_slots=self.domain.get_informable_slots())[0]
                self._db_results_to_sysact(act, additional_constraints,
                                           db_match)

                return act

        # no alternatives found (that were not already mentioned)
        act.add_value(self.primary_key, 'none')
        return act
コード例 #12
0
ファイル: bst.py プロジェクト: zhiyin121/adviser
 def __init__(self, domain=None, logger=None):
     Service.__init__(self, domain=domain)
     self.logger = logger
     self.bs = BeliefState(domain)
コード例 #13
0
def beliefstate(domain):
    bs = BeliefState(domain)
    return bs
コード例 #14
0
    def forward(self,
                dialog_graph,
                user_utterance: str = "",
                user_acts: List[UserAct] = None,
                beliefstate: BeliefState = None,
                sys_act: List[SysAct] = None,
                **kwargs) -> dict(beliefstate=BeliefState):
        # initialize belief state
        if isinstance(beliefstate, type(None)):
            beliefstate = BeliefState(self.domain)
        else:
            beliefstate.start_new_turn()

        if user_acts is None:
            # this check is required in case the BST is the first called module
            # e.g. usersimulation on semantic level:
            #   dialog acts as outputs -> no NLU
            return {'beliefstate': beliefstate}

        # get all different action types in user inputs
        action_types = self._get_all_usr_action_types(user_acts)
        # update methods
        self._zero_all_scores(beliefstate['beliefs']['method'])
        if len(action_types) == 0:
            # no user actions
            beliefstate['beliefs']['method']['none'] = 1.0
        else:
            if UserActionType.RequestAlternatives in action_types:
                beliefstate['beliefs']['method']['byalternatives'] = 1.0
            elif UserActionType.Bye in action_types:
                beliefstate['beliefs']['method']['finished'] = 1.0
            elif UserActionType.Inform in action_types:
                # check if inform by primary key value or by constraints
                inform_primkeyval = \
                    [primkey_inform_act for primkey_inform_act in user_acts
                     if primkey_inform_act.type == UserActionType.Inform and
                        self.primary_key == primkey_inform_act.slot]
                if len(inform_primkeyval) > 0:
                    # inform by name
                    beliefstate['beliefs']['method']['byprimarykey'] = 1.0
                else:
                    # inform by constraints
                    beliefstate['beliefs']['method']['byconstraints'] = 1.0
            elif (UserActionType.Request in action_types or
                  UserActionType.Confirm in action_types) and \
                 not UserActionType.Inform in action_types and \
                 not UserActionType.Deny in action_types and \
                 not (beliefstate['system']['lastInformedPrimKeyVal'] == '**NONE**' or
                      beliefstate['system']['lastInformedPrimKeyVal'] == ''):
                beliefstate['beliefs']['method']['byprimarykey'] = 1.0
            else:
                beliefstate['beliefs']['method']['none'] = 1.0

        # important to set these to zero since we don't want stale discourseAct
        for act in beliefstate['beliefs']['discourseAct']:
            beliefstate['beliefs']['discourseAct'][act] = 0.0
        beliefstate['beliefs']['discourseAct']['none'] = 1.0

        for act_in in user_acts:
            if act_in.type is None:
                pass
            elif act_in.type is UserActionType.Bad:
                beliefstate['beliefs']['discourseAct']['none'] = 0.0
                beliefstate['beliefs']['discourseAct']['bad'] = 1.0
            elif act_in.type == UserActionType.Hello:
                beliefstate['beliefs']['discourseAct']['none'] = 0.0
                beliefstate['beliefs']['discourseAct']['hello'] = 1.0
            elif act_in.type == UserActionType.Thanks:
                beliefstate['beliefs']['discourseAct']['none'] = 0.0
                beliefstate['beliefs']['discourseAct']['thanks'] = 1.0
            # TODO adapt user actions to dstc action names
            elif act_in.type in [UserActionType.Bye]:
                # nothing to do here, but needed to cirumwent warning
                pass
            elif act.type == UserActionType.RequestAlternatives:
                pass
            elif act_in.type == UserActionType.Ack:
                beliefstate['beliefs']['discourseAct']['none'] = 0.0
                beliefstate['beliefs']['discourseAct']['ack'] = 1.0
            # else:
            #     # unknown Dialog Act
            #     # To be handled:
            #     self.logger.warning("user act not handled by BST: " + str(act_in))

        # track informables and requestables
        utterance = " ".join([
            tok.text for tok in nlp(user_utterance.strip().lower())
        ])  # tokenize
        # convert system utterance to text triples
        sys_utterance = []
        self._sysact_to_list(sys_act, sys_utterance)
        sys_utterance = " ".join(sys_utterance)

        # track informables
        for key in self.inf_trackers:
            output = self.inf_trackers[key].forward(
                sys_utterance,
                utterance,
                first_turn=(dialog_graph.num_turns == 0))
            output = output.squeeze()
            probabilities = F.softmax(output, dim=0)
            # print(probabilities)
            # top_value = self.data_mappings.get_informable_slot_value(key, probabilities.argmax(0).item())
            # top_prob = probabilities.max(0)[0].item()

            # copy beliefstate from network to complete beliefstate for policy
            for val in self.data_mappings.inf_values[key]:
                beliefstate['beliefs'][key][val] = probabilities[
                    self.data_mappings.get_informable_slot_value_index(
                        key, val)].item()

            # print("slot ", key)
            # print("    most probable:", top_value, " with prob", top_prob)

        for key in self.req_trackers:
            output = self.req_trackers[key].forward(
                sys_utterance,
                utterance,
                first_turn=(dialog_graph.num_turns == 0))
            output = output.squeeze()
            probabilities = F.softmax(output, dim=0)
            # print(probabilities)
            # top_value = bool(probabilities.argmax(0).item())
            # top_prob = probabilities.max(0)[0].item()
            # print("req slot ", key)
            # print("    most probable:", top_value, " with prob", top_prob)
            # copy beliefstate from network to complete beliefstate for policy
            beliefstate['beliefs']['requested'][key] = probabilities[1].item(
            )  # true

        return {'beliefstate': beliefstate}
コード例 #15
0
 def start_dialog(self, **kwargs):
     # initialize belief state
     return {'beliefstate': BeliefState(self.domain)}
コード例 #16
0
    def __init__(self,
                 domain: JSONLookupDomain,
                 subgraph=None,
                 buffer_cls=UniformBuffer,
                 buffer_size=6000,
                 batch_size=64,
                 discount_gamma=0.99,
                 include_confreq=False):
        """ 
        Creates state- and action spaces, initializes experience replay 
        buffers.
        
        Arguments:
            domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain
        
        Keyword Arguments:
            subgraph {[type]} -- [see modules.Module] (default: {None})
            buffer_cls {modules.policy.rl.experience_buffer.Buffer} 
            -- [Experience replay buffer *class*, **not** an instance - will be
                initialized by this constructor!] (default: {UniformBuffer})
            buffer_size {int} -- [see modules.policy.rl.experience_buffer.
                                  Buffer] (default: {6000})
            batch_size {int} -- [see modules.policy.rl.experience_buffer.
                                  Buffer] (default: {64})
            discount_gamma {float} -- [Discount factor] (default: {0.99})
            include_confreq {bool} -- [Use confirm_request actions] 
                                       (default: {False})
        """

        super(RLPolicy, self).__init__(domain, subgraph=subgraph)
        # setup evaluator for training
        self.evaluator = ObjectiveReachedEvaluator(domain)

        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.discount_gamma = discount_gamma

        self.writer = None

        # get state size
        self.state_dim = self.beliefstate_dict_to_vector(
                                BeliefState(domain)._init_beliefstate())\
                                                    .size(1)
        logger.info("state space dim: " + str(self.state_dim))

        # get system action list
        self.actions = [
            "inform",
            "inform_byname",  # TODO rename to 'bykey'
            "inform_alternatives",
            "reqmore"
        ]
        # TODO badaction
        # TODO repeat not supported by user simulator
        for req_slot in self.domain.get_system_requestable_slots():
            self.actions.append('request#' + req_slot)
            self.actions.append('confirm#' + req_slot)
            self.actions.append('select#' + req_slot)
            if include_confreq:
                for conf_slot in self.domain.get_system_requestable_slots():
                    if not req_slot == conf_slot:
                        # skip case where confirm slot = request slot
                        self.actions.append('confreq#' + conf_slot + '#' +
                                            req_slot)
        self.action_dim = len(self.actions)
        # don't include closingmsg in learnable actions
        self.actions.append('closingmsg')
        #self.actions.append("closingmsg")
        logger.info("action space dim: " + str(self.action_dim))

        self.primary_key = self.domain.get_primary_key()

        # init replay memory
        self.buffer = buffer_cls(buffer_size,
                                 batch_size,
                                 self.state_dim,
                                 discount_gamma=discount_gamma)

        self.last_sys_act = None
コード例 #17
0
ファイル: policy_rl.py プロジェクト: zhiyin121/adviser
    def __init__(self,
                 domain: JSONLookupDomain,
                 buffer_cls=UniformBuffer,
                 buffer_size=6000,
                 batch_size=64,
                 discount_gamma=0.99,
                 max_turns: int = 25,
                 include_confreq=False,
                 logger: DiasysLogger = DiasysLogger(),
                 include_select: bool = False,
                 device=torch.device('cpu'),
                 obj_evaluator: ObjectiveReachedEvaluator = None):
        """
        Creates state- and action spaces, initializes experience replay
        buffers.

        Arguments:
            domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain

        Keyword Arguments:
            subgraph {[type]} -- [see services.Module] (default: {None})
            buffer_cls {services.policy.rl.experience_buffer.Buffer}
            -- [Experience replay buffer *class*, **not** an instance - will be
                initialized by this constructor!] (default: {UniformBuffer})
            buffer_size {int} -- [see services.policy.rl.experience_buffer.
                                  Buffer] (default: {6000})
            batch_size {int} -- [see services.policy.rl.experience_buffer.
                                  Buffer] (default: {64})
            discount_gamma {float} -- [Discount factor] (default: {0.99})
            include_confreq {bool} -- [Use confirm_request actions]
                                       (default: {False})
        """

        self.device = device
        self.sys_state = {
            "lastInformedPrimKeyVal": None,
            "lastActionInformNone": False,
            "offerHappened": False,
            'informedPrimKeyValsSinceNone': []
        }

        self.max_turns = max_turns
        self.logger = logger
        self.domain = domain
        # setup evaluator for training
        self.evaluator = obj_evaluator  #  ObjectiveReachedEvaluator(domain, logger=logger)

        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.discount_gamma = discount_gamma

        self.writer = None

        # get state size
        self.state_dim = self.beliefstate_dict_to_vector(
            BeliefState(domain)._init_beliefstate()).size(1)
        self.logger.info("state space dim: " + str(self.state_dim))

        # get system action list
        self.actions = [
            "inform_byname",  # TODO rename to 'bykey'
            "inform_alternatives",
            "reqmore"
        ]
        # TODO badaction
        for req_slot in self.domain.get_system_requestable_slots():
            self.actions.append('request#' + req_slot)
            self.actions.append('confirm#' + req_slot)
            if include_select:
                self.actions.append('select#' + req_slot)
            if include_confreq:
                for conf_slot in self.domain.get_system_requestable_slots():
                    if not req_slot == conf_slot:
                        # skip case where confirm slot = request slot
                        self.actions.append('confreq#' + conf_slot + '#' +
                                            req_slot)
        self.action_dim = len(self.actions)
        # don't include closingmsg in learnable actions
        self.actions.append('closingmsg')
        # self.actions.append("closingmsg")
        self.logger.info("action space dim: " + str(self.action_dim))

        self.primary_key = self.domain.get_primary_key()

        # init replay memory
        self.buffer = buffer_cls(buffer_size,
                                 batch_size,
                                 self.state_dim,
                                 discount_gamma=discount_gamma,
                                 device=device)
        self.sys_state = {}

        self.last_sys_act = None
コード例 #18
0
ファイル: bst.py プロジェクト: zhiyin121/adviser
class HandcraftedBST(Service):
    """
    A rule-based approach to belief state tracking.
    """
    def __init__(self, domain=None, logger=None):
        Service.__init__(self, domain=domain)
        self.logger = logger
        self.bs = BeliefState(domain)

    @PublishSubscribe(sub_topics=["user_acts"], pub_topics=["beliefstate"])
    def update_bst(self, user_acts: List[UserAct] = None) \
            -> dict(beliefstate=BeliefState):
        """
            Updates the current dialog belief state (which tracks the system's
            knowledge about what has been said in the dialog) based on the user actions generated
            from the user's utterances

            Args:
                user_acts (list): a list of UserAct objects mapped from the user's last utterance

            Returns:
                (dict): a dictionary with the key "beliefstate" and the value the updated
                        BeliefState object

        """
        # save last turn to memory
        self.bs.start_new_turn()
        if user_acts:
            self._reset_informs(user_acts)
            self._reset_requests()
            self.bs["user_acts"] = self._get_all_usr_action_types(user_acts)

            self._handle_user_acts(user_acts)

            num_entries, discriminable = self.bs.get_num_dbmatches()
            self.bs["num_matches"] = num_entries
            self.bs["discriminable"] = discriminable

        return {'beliefstate': self.bs}

    def dialog_start(self):
        """
            Restets the belief state so it is ready for a new dialog

            Returns:
                (dict): a dictionary with a single entry where the key is 'beliefstate'and
                        the value is a new BeliefState object
        """
        # initialize belief state
        self.bs = BeliefState(self.domain)

    def _reset_informs(self, acts: List[UserAct]):
        """
            If the user specifies a new value for a given slot, delete the old
            entry from the beliefstate
        """

        slots = {act.slot for act in acts if act.type == UserActionType.Inform}
        for slot in [s for s in self.bs['informs']]:
            if slot in slots:
                del self.bs['informs'][slot]

    def _reset_requests(self):
        """
            gets rid of requests from the previous turn
        """
        self.bs['requests'] = {}

    def _get_all_usr_action_types(
            self, user_acts: List[UserAct]) -> Set[UserActionType]:
        """ 
        Returns a set of all different UserActionTypes in user_acts.

        Args:
            user_acts (List[UserAct]): list of UserAct objects

        Returns:
            set of UserActionType objects
        """
        action_type_set = set()
        for act in user_acts:
            action_type_set.add(act.type)
        return action_type_set

    def _handle_user_acts(self, user_acts: List[UserAct]):
        """
            Updates the belief state based on the information contained in the user act(s)

            Args:
                user_acts (list[UserAct]): the list of user acts to use to update the belief state

        """

        # reset any offers if the user informs any new information
        if self.domain.get_primary_key() in self.bs['informs'] \
                and UserActionType.Inform in self.bs["user_acts"]:
            del self.bs['informs'][self.domain.get_primary_key()]

        # We choose to interpret switching as wanting to start a new dialog and do not support
        # resuming an old dialog
        elif UserActionType.SelectDomain in self.bs["user_acts"]:
            self.bs["informs"] = {}
            self.bs["requests"] = {}

        # Handle user acts
        for act in user_acts:
            if act.type == UserActionType.Request:
                self.bs['requests'][act.slot] = act.score
            elif act.type == UserActionType.Inform:
                # add informs and their scores to the beliefstate
                if act.slot in self.bs["informs"]:
                    self.bs['informs'][act.slot][act.value] = act.score
                else:
                    self.bs['informs'][act.slot] = {act.value: act.score}
            elif act.type == UserActionType.NegativeInform:
                # reset mentioned value to zero probability
                if act.slot in self.bs['informs']:
                    if act.value in self.bs['informs'][act.slot]:
                        del self.bs['informs'][act.slot][act.value]
            elif act.type == UserActionType.RequestAlternatives:
                # This way it is clear that the user is no longer asking about that one item
                if self.domain.get_primary_key() in self.bs['informs']:
                    del self.bs['informs'][self.domain.get_primary_key()]