Ejemplo n.º 1
0
class RuleDST(DST):
    """Rule based DST which trivially updates new values from NLU result to states.

    Attributes:
        state(dict):
            Dialog state. Function ``tatk.util.crosswoz.state.default_state`` returns a default state.
    """
    def __init__(self):
        super().__init__()
        self.state = default_state()
        self.database = Database()

    def init_session(self, state=None):
        """Initialize ``self.state`` with a default state, which ``tatk.util.crosswoz.state.default_state`` returns."""
        self.state = default_state() if not state else deepcopy(state)

    def update(self, usr_da=None):
        """
        update belief_state, cur_domain, request_slot
        :param usr_da:
        :return:
        """
        self.state['user_action'] = usr_da
        sys_da = self.state['system_action']

        select_domains = Counter([x[1] for x in usr_da if x[0] == 'Select'])
        request_domains = Counter([x[1] for x in usr_da if x[0] == 'Request'])
        inform_domains = Counter([x[1] for x in usr_da if x[0] == 'Inform'])
        sys_domains = Counter(
            [x[1] for x in sys_da if x[0] in ['Inform', 'Recommend']])
        if len(select_domains) > 0:
            self.state['cur_domain'] = select_domains.most_common(1)[0][0]
        elif len(request_domains) > 0:
            self.state['cur_domain'] = request_domains.most_common(1)[0][0]
        elif len(inform_domains) > 0:
            self.state['cur_domain'] = inform_domains.most_common(1)[0][0]
        elif len(sys_domains) > 0:
            self.state['cur_domain'] = sys_domains.most_common(1)[0][0]
        else:
            self.state['cur_domain'] = None

        # print('cur_domain', self.cur_domain)

        NoOffer = 'NoOffer' in [x[0] for x in sys_da
                                ] and 'Inform' not in [x[0] for x in sys_da]
        # DONE: clean cur domain constraints because nooffer

        if NoOffer:
            if self.state['cur_domain']:
                self.state['belief_state'][
                    self.state['cur_domain']] = deepcopy(
                        default_state()['belief_state'][
                            self.state['cur_domain']])

        # DONE: clean request slot
        for domain, slot in deepcopy(self.state['request_slots']):
            if [domain, slot] in [
                    x[1:3] for x in sys_da if x[0] in ['Inform', 'Recommend']
            ]:
                self.state['request_slots'].remove([domain, slot])

        # DONE: domain switch
        for intent, domain, slot, value in usr_da:
            if intent == 'Select':
                from_domain = value
                name = self.state['belief_state'][from_domain]['名称']
                if name:
                    if domain == from_domain:
                        self.state['belief_state'][domain] = deepcopy(
                            default_state()['belief_state'][domain])
                    self.state['belief_state'][domain]['周边{}'.format(
                        from_domain)] = name

        for intent, domain, slot, value in usr_da:
            if intent == 'Inform':
                if slot in [
                        '名称', '游玩时间', '酒店类型', '出发地', '目的地', '评分', '门票', '价格',
                        '人均消费'
                ]:
                    self.state['belief_state'][domain][slot] = value
                elif slot == '推荐菜':
                    if not self.state['belief_state'][domain][slot]:
                        self.state['belief_state'][domain][slot] = value
                    else:
                        self.state['belief_state'][domain][slot] += ' ' + value
                elif '酒店设施' in slot:
                    if value == '是':
                        faci = slot.split('-')[1]
                        if not self.state['belief_state'][domain]['酒店设施']:
                            self.state['belief_state'][domain]['酒店设施'] = faci
                        else:
                            self.state['belief_state'][domain][
                                '酒店设施'] += ' ' + faci
            elif intent == 'Request':
                self.state['request_slots'].append([domain, slot])

        return self.state

    def query(self):
        return self.database.query(self.state['belief_state'],
                                   self.state['cur_domain'])
Ejemplo n.º 2
0
class CrossWozVector(Vector):
    def __init__(self, sys_da_voc_json, usr_da_voc_json):
        self.sys_da_voc = json.load(open(sys_da_voc_json))
        self.usr_da_voc = json.load(open(usr_da_voc_json))
        self.database = Database()
        
        self.generate_dict()
        
    def generate_dict(self):
        self.sys_da2id = dict((a, i) for i, a in enumerate(self.sys_da_voc))
        self.id2sys_da = dict((i, a) for i, a in enumerate(self.sys_da_voc))
        
        # 155
        self.sys_da_dim = len(self.sys_da_voc)
        
        
        self.usr_da2id = dict((a, i) for i, a in enumerate(self.usr_da_voc))
        self.id2usr_da = dict((i, a) for i, a in enumerate(self.usr_da_voc))
        
        # 142
        self.usr_da_dim = len(self.usr_da_voc)

        # 26
        self.belief_state_dim = 0
        for domain, svs in default_state()['belief_state'].items():
            self.belief_state_dim += len(svs)

        self.db_res_dim = 4

        self.state_dim = self.sys_da_dim + self.usr_da_dim + self.belief_state_dim + self.db_res_dim + 1 # terminated

    def state_vectorize(self, state):
        self.belief_state = state['belief_state']
        self.cur_domain = state['cur_domain']

        da = state['user_action']
        da = delexicalize_da(da)
        usr_act_vec = np.zeros(self.usr_da_dim)
        for a in da:
            if a in self.usr_da2id:
                usr_act_vec[self.usr_da2id[a]] = 1.

        da = state['system_action']
        da = delexicalize_da(da)
        sys_act_vec = np.zeros(self.sys_da_dim)
        for a in da:
            if a in self.sys_da2id:
                sys_act_vec[self.sys_da2id[a]] = 1.
                
        belief_state_vec = np.zeros(self.belief_state_dim)
        i = 0
        for domain, svs in state['belief_state'].items():
            for slot, value in svs.items():                
                if value:
                    belief_state_vec[i] = 1.
                i += 1

        self.db_res = self.database.query(state['belief_state'], state['cur_domain'])
        db_res_num = len(self.db_res)
        db_res_vec = np.zeros(4)
        if db_res_num == 0:
            db_res_vec[0] = 1.
        elif db_res_num == 1:
            db_res_vec[1] = 1.
        elif 1 < db_res_num < 5:
            db_res_vec[2] = 1.
        else:
            db_res_vec[3] = 1.
            
        terminated = 1. if state['terminated'] else 0.
            
        # print('state dim', self.state_dim)
        state_vec = np.r_[usr_act_vec, sys_act_vec, belief_state_vec, db_res_vec, terminated]
        # print('actual state vec dim', len(state_vec))
        return state_vec
    
    def action_devectorize(self, action_vec):
        """
        must call state_vectorize func before
        :param action_vec:
        :return:
        """
        da = []
        for i, idx in enumerate(action_vec):
            if idx == 1:
                da.append(self.id2sys_da[i])
        lexicalized_da = lexicalize_da(da=da, cur_domain=self.cur_domain, entities=self.db_res)
        return lexicalized_da
    
    def action_vectorize(self, da):
        da = delexicalize_da(da)
        sys_act_vec = np.zeros(self.sys_da_dim)
        for a in da:
            if a in self.sys_da2id:
                sys_act_vec[self.sys_da2id[a]] = 1.
        return sys_act_vec
Ejemplo n.º 3
0
 def __init__(self):
     super().__init__()
     self.state = default_state()
     self.database = Database()
Ejemplo n.º 4
0
 def __init__(self, sys_da_voc_json, usr_da_voc_json):
     self.sys_da_voc = json.load(open(sys_da_voc_json))
     self.usr_da_voc = json.load(open(usr_da_voc_json))
     self.database = Database()
     
     self.generate_dict()