Exemplo n.º 1
0
    def generate_dict(self):
        self.sys_da2id = dict((a, i) for i, a in enumerate(self.sys_da_voc))
        self.id2sys_da = dict((i, a) for i, a in enumerate(self.sys_da_voc))

        # 155
        self.sys_da_dim = len(self.sys_da_voc)

        self.usr_da2id = dict((a, i) for i, a in enumerate(self.usr_da_voc))
        self.id2usr_da = dict((i, a) for i, a in enumerate(self.usr_da_voc))

        # 142
        self.usr_da_dim = len(self.usr_da_voc)

        # 26
        self.belief_state_dim = 0
        for domain, svs in default_state()['belief_state'].items():
            self.belief_state_dim += len(svs)

        self.db_res_dim = 4

        self.state_dim = self.sys_da_dim + self.usr_da_dim + self.belief_state_dim + self.db_res_dim + 1  # terminated
Exemplo n.º 2
0
Arquivo: dst.py Projeto: zqwerty/tatk
    def update(self, usr_da=None):
        """
        update belief_state, cur_domain, request_slot
        :param usr_da:
        :return:
        """
        self.state['user_action'] = usr_da
        sys_da = self.state['system_action']

        select_domains = Counter([x[1] for x in usr_da if x[0] == 'Select'])
        request_domains = Counter([x[1] for x in usr_da if x[0] == 'Request'])
        inform_domains = Counter([x[1] for x in usr_da if x[0] == 'Inform'])
        sys_domains = Counter(
            [x[1] for x in sys_da if x[0] in ['Inform', 'Recommend']])
        if len(select_domains) > 0:
            self.state['cur_domain'] = select_domains.most_common(1)[0][0]
        elif len(request_domains) > 0:
            self.state['cur_domain'] = request_domains.most_common(1)[0][0]
        elif len(inform_domains) > 0:
            self.state['cur_domain'] = inform_domains.most_common(1)[0][0]
        elif len(sys_domains) > 0:
            self.state['cur_domain'] = sys_domains.most_common(1)[0][0]
        else:
            self.state['cur_domain'] = None

        # print('cur_domain', self.cur_domain)

        NoOffer = 'NoOffer' in [x[0] for x in sys_da
                                ] and 'Inform' not in [x[0] for x in sys_da]
        # DONE: clean cur domain constraints because nooffer

        if NoOffer:
            if self.state['cur_domain']:
                self.state['belief_state'][
                    self.state['cur_domain']] = deepcopy(
                        default_state()['belief_state'][
                            self.state['cur_domain']])

        # DONE: clean request slot
        for domain, slot in deepcopy(self.state['request_slots']):
            if [domain, slot] in [
                    x[1:3] for x in sys_da if x[0] in ['Inform', 'Recommend']
            ]:
                self.state['request_slots'].remove([domain, slot])

        # DONE: domain switch
        for intent, domain, slot, value in usr_da:
            if intent == 'Select':
                from_domain = value
                name = self.state['belief_state'][from_domain]['名称']
                if name:
                    if domain == from_domain:
                        self.state['belief_state'][domain] = deepcopy(
                            default_state()['belief_state'][domain])
                    self.state['belief_state'][domain]['周边{}'.format(
                        from_domain)] = name

        for intent, domain, slot, value in usr_da:
            if intent == 'Inform':
                if slot in [
                        '名称', '游玩时间', '酒店类型', '出发地', '目的地', '评分', '门票', '价格',
                        '人均消费'
                ]:
                    self.state['belief_state'][domain][slot] = value
                elif slot == '推荐菜':
                    if not self.state['belief_state'][domain][slot]:
                        self.state['belief_state'][domain][slot] = value
                    else:
                        self.state['belief_state'][domain][slot] += ' ' + value
                elif '酒店设施' in slot:
                    if value == '是':
                        faci = slot.split('-')[1]
                        if not self.state['belief_state'][domain]['酒店设施']:
                            self.state['belief_state'][domain]['酒店设施'] = faci
                        else:
                            self.state['belief_state'][domain][
                                '酒店设施'] += ' ' + faci
            elif intent == 'Request':
                self.state['request_slots'].append([domain, slot])

        return self.state
Exemplo n.º 3
0
Arquivo: dst.py Projeto: zqwerty/tatk
 def init_session(self, state=None):
     """Initialize ``self.state`` with a default state, which ``tatk.util.crosswoz.state.default_state`` returns."""
     self.state = default_state() if not state else deepcopy(state)
Exemplo n.º 4
0
Arquivo: dst.py Projeto: zqwerty/tatk
 def __init__(self):
     super().__init__()
     self.state = default_state()
     self.database = Database()
Exemplo n.º 5
0
                            return False
                elif options.get('type') == 'multiple_in':
                    s = options['params']
                    if not s is None:
                        if absence:
                            return False
                        sarr = list(filter(lambda t: bool(t), s.split(' ')))
                        if len(list(filter(lambda t: contains(val, t), sarr))):
                            return False
                else:
                    s = options['params']
                    if not s is None:
                        if absence:
                            return False
                        if val.find(s) < 0:
                            return False
            return True

        return list(filter(func3, db))


if __name__ == '__main__':
    db = Database()
    state = default_state()
    dishes = {}
    for n, v in db.query(state['belief_state'], '餐馆'):
        for dish in v['推荐菜']:
            dishes.setdefault(dish, 0)
            dishes[dish] += 1
    pprint(Counter(dishes))