コード例 #1
0
class StateTracker(object):
    def __init__(self, data_dir, config):
        self.time_step = 0
        self.cfg = config
        self.db = DBQuery(data_dir, config)
        self.topic = ''
        self.evaluator = MultiWozEvaluator(data_dir)
        self.lock_evalutor = False

    def set_rollout(self, rollout):
        if rollout:
            self.save_time_step = self.time_step
            self.save_topic = self.topic
            self.lock_evalutor = True
        else:
            self.time_step = self.save_time_step
            self.save_topic = self.topic
            self.lock_evalutor = False

    def get_entities(self, s, domain):
        origin = s['belief_state'][domain].items()
        constraint = []
        for k, v in origin:
            if v != '?' and k in self.cfg.mapping[domain]:
                constraint.append((self.cfg.mapping[domain][k], v))
        entities = self.db.query(domain, constraint)
        random.shuffle(entities)
        return entities

    def update_belief_sys(self, old_s, a):
        """
        update belief/goal state with sys action
        """
        s = deepcopy(old_s)
        a_index = torch.nonzero(a)  # get multiple da indices

        self.time_step += 1
        s['others']['turn'] = self.time_step

        # update sys/user dialog act
        s['sys_action'] = dict()

        # update belief part
        das = [self.cfg.idx2da[idx.item()] for idx in a_index]
        das = [da.split('-') for da in das]
        sorted(das, key=lambda x: x[0])  # sort by domain

        entities = [] if self.topic == '' else self.get_entities(s, self.topic)
        return_flag = False
        for domain, intent, slot, p in das:
            if domain in self.cfg.belief_domains and domain != self.topic:
                self.topic = domain
                entities = self.get_entities(s, domain)

            da = '-'.join((domain, intent, slot, p))
            if intent == 'request':
                s['sys_action'][da] = '?'
            elif intent in ['nooffer', 'nobook'] and self.topic != '':
                return_flag = True
                if slot in s['belief_state'][self.topic] and s['belief_state'][
                        self.topic][slot] != '?':
                    s['sys_action'][da] = s['belief_state'][self.topic][slot]
                else:
                    s['sys_action'][da] = 'none'
            elif slot == 'choice':
                s['sys_action'][da] = str(len(entities))
            elif slot == 'none':
                s['sys_action'][da] = 'none'
            else:
                num = int(p) - 1
                if self.topic and len(
                        entities) > num and slot in self.cfg.mapping[
                            self.topic]:
                    typ = self.cfg.mapping[self.topic][slot]
                    if typ in entities[num]:
                        s['sys_action'][da] = entities[num][typ]
                    else:
                        s['sys_action'][da] = 'none'
                else:
                    s['sys_action'][da] = 'none'

                if not self.topic:
                    continue
                if intent in [
                        'inform', 'recommend', 'offerbook', 'offerbooked',
                        'book'
                ]:
                    discard(s['belief_state'][self.topic], slot, '?')
                    if slot in s['user_goal'][self.topic] and s['user_goal'][
                            self.topic][slot] == '?':
                        s['goal_state'][self.topic][slot] = s['sys_action'][da]

                # booked
                if intent == 'inform' and slot == 'car':  # taxi
                    if 'booked' not in s['belief_state']['taxi']:
                        s['belief_state']['taxi']['booked'] = 'taxi-booked'
                elif intent in ['offerbooked', 'book'
                                ] and slot == 'ref':  # train
                    if self.topic in ['taxi', 'hospital', 'police']:
                        s['belief_state'][
                            self.topic]['booked'] = f'{self.topic}-booked'
                        s['sys_action'][da] = f'{self.topic}-booked'
                    elif entities:
                        book_domain = entities[0]['ref'].split('-')[0]
                        if 'booked' not in s['belief_state'][
                                book_domain] and entities:
                            s['belief_state'][book_domain][
                                'booked'] = entities[0]['ref']
                            s['sys_action'][da] = entities[0]['ref']

        if return_flag:
            for da in s['user_action']:
                d_usr, i_usr, s_usr = da.split('-')
                if i_usr == 'inform' and d_usr == self.topic:
                    discard(s['belief_state'][d_usr], s_usr)
            reload(s['goal_state'], s['user_goal'], self.topic)

        if not self.lock_evalutor:
            self.evaluator.add_sys_da(s['sys_action'])

        return s

    def update_belief_usr(self, old_s, a):
        """
        update belief/goal state with user action
        """
        s = deepcopy(old_s)
        a_index = torch.nonzero(a)  # get multiple da indices

        self.time_step += 1
        s['others']['turn'] = self.time_step
        s['others']['terminal'] = 1 if (self.cfg.a_dim_usr -
                                        1) in a_index else 0

        # update sys/user dialog act
        s['user_action'] = dict()

        # update belief part
        das = [
            self.cfg.idx2da_u[idx.item()] for idx in a_index
            if idx.item() != self.cfg.a_dim_usr - 1
        ]
        das = [da.split('-') for da in das]
        if s['invisible_domains']:
            for da in das:
                if da[0] == s['next_available_domain']:
                    s['next_available_domain'] = s['invisible_domains'][0]
                    s['invisible_domains'].remove(s['next_available_domain'])
                    break
        sorted(das, key=lambda x: x[0])  # sort by domain

        for domain, intent, slot in das:
            if domain in self.cfg.belief_domains and domain != self.topic:
                self.topic = domain

            da = '-'.join((domain, intent, slot))
            if intent == 'request':
                s['user_action'][da] = '?'
                s['belief_state'][self.topic][slot] = '?'
            elif slot == 'none':
                s['user_action'][da] = 'none'
            else:
                if self.topic and slot in s['user_goal'][
                        self.topic] and s['user_goal'][domain][slot] != '?':
                    s['user_action'][da] = s['user_goal'][domain][slot]
                else:
                    s['user_action'][da] = 'dont care'

                if not self.topic:
                    continue
                if intent == 'inform':
                    s['belief_state'][domain][slot] = s['user_action'][da]
                    if slot in s['user_goal'][self.topic] and s['user_goal'][
                            self.topic][slot] != '?':
                        discard(s['goal_state'][self.topic], slot)

        if not self.lock_evalutor:
            self.evaluator.add_usr_da(s['user_action'])

        return s

    def reset(self, random_seed=None):
        """
        Args:
            random_seed (int):
        Returns:
            init_state (dict):
        """
        pass

    def step(self, s, sys_a):
        """
        Args:
            s (dict):
            sys_a (vector):
        Returns:
            next_s (dict):
            terminal (bool):
        """
        pass
コード例 #2
0
class GoalGenerator:
    """User goal generator"""

    def __init__(self, data_dir,
                 goal_model_path,
                 corpus_path=None,
                 boldify=False):
        """
        Args:
            goal_model_path: path to a goal model 
            corpus_path: path to a dialog corpus to build a goal model 
        """
        self.dbquery = DBQuery(data_dir)
        self.goal_model_path = data_dir + '/' + goal_model_path
        self.corpus_path = data_dir + '/' + corpus_path if corpus_path is not None else None
        self.boldify = do_boldify if boldify else null_boldify
        if os.path.exists(self.goal_model_path):
            with open(self.goal_model_path, 'rb') as f:
                self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist = pickle.load(f)
            logging.info('Loading goal model is done')
        else:
            self._build_goal_model()
            logging.info('Building goal model is done')

        # remove some slot
        del self.ind_slot_dist['police']['reqt']['postcode']
        del self.ind_slot_value_dist['police']['reqt']['postcode']
        del self.ind_slot_dist['hospital']['reqt']['postcode']
        del self.ind_slot_value_dist['hospital']['reqt']['postcode']
        del self.ind_slot_dist['hospital']['reqt']['address']
        del self.ind_slot_value_dist['hospital']['reqt']['address']

    def _build_goal_model(self):
        with open(self.corpus_path) as f:
            dialogs = json.load(f)

        # domain ordering
        def _get_dialog_domains(dialog):
            return list(filter(lambda x: x in domains and len(dialog['goal'][x]) > 0, dialog['goal']))

        domain_orderings = []
        for d in dialogs:
            d_domains = _get_dialog_domains(dialogs[d])
            first_index = []
            for domain in d_domains:
                message = [dialogs[d]['goal']['message']] if type(dialogs[d]['goal']['message']) == str else \
                dialogs[d]['goal']['message']
                for i, m in enumerate(message):
                    if domain_keywords[domain].lower() in m.lower() or domain.lower() in m.lower():
                        first_index.append(i)
                        break
            domain_orderings.append(tuple(map(lambda x: x[1], sorted(zip(first_index, d_domains), key=lambda x: x[0]))))
        domain_ordering_cnt = Counter(domain_orderings)
        self.domain_ordering_dist = deepcopy(domain_ordering_cnt)
        for order in domain_ordering_cnt.keys():
            self.domain_ordering_dist[order] = domain_ordering_cnt[order] / sum(domain_ordering_cnt.values())

        # independent goal slot distribution
        ind_slot_value_cnt = dict([(domain, {}) for domain in domains])
        domain_cnt = Counter()
        book_cnt = Counter()

        for d in dialogs:
            for domain in domains:
                if dialogs[d]['goal'][domain] != {}:
                    domain_cnt[domain] += 1
                if 'info' in dialogs[d]['goal'][domain]:
                    for slot in dialogs[d]['goal'][domain]['info']:
                        if 'invalid' in slot:
                            continue
                        if 'info' not in ind_slot_value_cnt[domain]:
                            ind_slot_value_cnt[domain]['info'] = {}
                        if slot not in ind_slot_value_cnt[domain]['info']:
                            ind_slot_value_cnt[domain]['info'][slot] = Counter()
                        if 'care' in dialogs[d]['goal'][domain]['info'][slot]:
                            continue
                        ind_slot_value_cnt[domain]['info'][slot][dialogs[d]['goal'][domain]['info'][slot]] += 1
                if 'reqt' in dialogs[d]['goal'][domain]:
                    for slot in dialogs[d]['goal'][domain]['reqt']:
                        if 'reqt' not in ind_slot_value_cnt[domain]:
                            ind_slot_value_cnt[domain]['reqt'] = Counter()
                        ind_slot_value_cnt[domain]['reqt'][slot] += 1
                if 'book' in dialogs[d]['goal'][domain]:
                    book_cnt[domain] += 1
                    for slot in dialogs[d]['goal'][domain]['book']:
                        if 'invalid' in slot:
                            continue
                        if 'book' not in ind_slot_value_cnt[domain]:
                            ind_slot_value_cnt[domain]['book'] = {}
                        if slot not in ind_slot_value_cnt[domain]['book']:
                            ind_slot_value_cnt[domain]['book'][slot] = Counter()
                        if 'care' in dialogs[d]['goal'][domain]['book'][slot]:
                            continue
                        ind_slot_value_cnt[domain]['book'][slot][dialogs[d]['goal'][domain]['book'][slot]] += 1

        self.ind_slot_value_dist = deepcopy(ind_slot_value_cnt)
        self.ind_slot_dist = dict([(domain, {}) for domain in domains])
        self.book_dist = {}
        for domain in domains:
            if 'info' in ind_slot_value_cnt[domain]:
                for slot in ind_slot_value_cnt[domain]['info']:
                    if 'info' not in self.ind_slot_dist[domain]:
                        self.ind_slot_dist[domain]['info'] = {}
                    if slot not in self.ind_slot_dist[domain]['info']:
                        self.ind_slot_dist[domain]['info'][slot] = {}
                    self.ind_slot_dist[domain]['info'][slot] = sum(ind_slot_value_cnt[domain]['info'][slot].values()) / \
                                                               domain_cnt[domain]
                    slot_total = sum(ind_slot_value_cnt[domain]['info'][slot].values())
                    for val in self.ind_slot_value_dist[domain]['info'][slot]:
                        self.ind_slot_value_dist[domain]['info'][slot][val] = ind_slot_value_cnt[domain]['info'][slot][
                                                                                  val] / slot_total
            if 'reqt' in ind_slot_value_cnt[domain]:
                for slot in ind_slot_value_cnt[domain]['reqt']:
                    if 'reqt' not in self.ind_slot_dist[domain]:
                        self.ind_slot_dist[domain]['reqt'] = {}
                    self.ind_slot_dist[domain]['reqt'][slot] = ind_slot_value_cnt[domain]['reqt'][slot] / domain_cnt[
                        domain]
                    self.ind_slot_value_dist[domain]['reqt'][slot] = ind_slot_value_cnt[domain]['reqt'][slot] / \
                                                                     domain_cnt[domain]
            if 'book' in ind_slot_value_cnt[domain]:
                for slot in ind_slot_value_cnt[domain]['book']:
                    if 'book' not in self.ind_slot_dist[domain]:
                        self.ind_slot_dist[domain]['book'] = {}
                    if slot not in self.ind_slot_dist[domain]['book']:
                        self.ind_slot_dist[domain]['book'][slot] = {}
                    self.ind_slot_dist[domain]['book'][slot] = sum(ind_slot_value_cnt[domain]['book'][slot].values()) / \
                                                               domain_cnt[domain]
                    slot_total = sum(ind_slot_value_cnt[domain]['book'][slot].values())
                    for val in self.ind_slot_value_dist[domain]['book'][slot]:
                        self.ind_slot_value_dist[domain]['book'][slot][val] = ind_slot_value_cnt[domain]['book'][slot][
                                                                                  val] / slot_total
            self.book_dist[domain] = book_cnt[domain] / len(dialogs)

        with open(self.goal_model_path, 'wb') as f:
            pickle.dump((self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist), f)

    def _get_domain_goal(self, domain):
        cnt_slot = self.ind_slot_dist[domain]
        cnt_slot_value = self.ind_slot_value_dist[domain]
        pro_book = self.book_dist[domain]

        while True:
            # domain_goal = defaultdict(lambda: {})
            # domain_goal = {'info': {}, 'fail_info': {}, 'reqt': {}, 'book': {}, 'fail_book': {}}
            domain_goal = {'info': {}}
            # inform
            if 'info' in cnt_slot:
                for slot in cnt_slot['info']:
                    if random.random() < cnt_slot['info'][slot] + pro_correction['info']:
                        domain_goal['info'][slot] = nomial_sample(cnt_slot_value['info'][slot])

                if domain in ['hotel', 'restaurant', 'attraction'] and 'name' in domain_goal['info'] and len(
                        domain_goal['info']) > 1:
                    if random.random() < cnt_slot['info']['name']:
                        domain_goal['info'] = {'name': domain_goal['info']['name']}
                    else:
                        del domain_goal['info']['name']

                if domain in ['taxi', 'train'] and 'arriveBy' in domain_goal['info'] and 'leaveAt' in domain_goal[
                    'info']:
                    if random.random() < (
                            cnt_slot['info']['leaveAt'] / (cnt_slot['info']['arriveBy'] + cnt_slot['info']['leaveAt'])):
                        del domain_goal['info']['arriveBy']
                    else:
                        del domain_goal['info']['leaveAt']

                if domain in ['taxi', 'train'] and 'arriveBy' not in domain_goal['info'] and 'leaveAt' not in \
                        domain_goal['info']:
                    if random.random() < (cnt_slot['info']['arriveBy'] / (
                            cnt_slot['info']['arriveBy'] + cnt_slot['info']['leaveAt'])):
                        domain_goal['info']['arriveBy'] = nomial_sample(cnt_slot_value['info']['arriveBy'])
                    else:
                        domain_goal['info']['leaveAt'] = nomial_sample(cnt_slot_value['info']['leaveAt'])

                if domain in ['taxi', 'train'] and 'departure' not in domain_goal['info']:
                    domain_goal['info']['departure'] = nomial_sample(cnt_slot_value['info']['departure'])

                if domain in ['taxi', 'train'] and 'destination' not in domain_goal['info']:
                    domain_goal['info']['destination'] = nomial_sample(cnt_slot_value['info']['destination'])

                if domain in ['taxi', 'train'] and \
                        'departure' in domain_goal['info'] and \
                        'destination' in domain_goal['info'] and \
                        domain_goal['info']['departure'] == domain_goal['info']['destination']:
                    if random.random() < (cnt_slot['info']['departure'] / (
                            cnt_slot['info']['departure'] + cnt_slot['info']['destination'])):
                        domain_goal['info']['departure'] = nomial_sample(cnt_slot_value['info']['departure'])
                    else:
                        domain_goal['info']['destination'] = nomial_sample(cnt_slot_value['info']['destination'])
                if domain_goal['info'] == {}:
                    continue
            # request
            if 'reqt' in cnt_slot:
                reqt = [slot for slot in cnt_slot['reqt']
                        if random.random() < cnt_slot['reqt'][slot] + pro_correction['reqt'] and slot not in
                        domain_goal['info']]
                if len(reqt) > 0:
                    domain_goal['reqt'] = reqt

            # book
            if 'book' in cnt_slot and random.random() < pro_book + pro_correction['book']:
                if 'book' not in domain_goal:
                    domain_goal['book'] = {}

                for slot in cnt_slot['book']:
                    if random.random() < cnt_slot['book'][slot] + pro_correction['book']:
                        domain_goal['book'][slot] = nomial_sample(cnt_slot_value['book'][slot])

                # makes sure that there are all necessary slots for booking
                if domain == 'restaurant' and 'time' not in domain_goal['book']:
                    domain_goal['book']['time'] = nomial_sample(cnt_slot_value['book']['time'])

                if domain == 'hotel' and 'stay' not in domain_goal['book']:
                    domain_goal['book']['stay'] = nomial_sample(cnt_slot_value['book']['stay'])

                if domain in ['hotel', 'restaurant'] and 'day' not in domain_goal['book']:
                    domain_goal['book']['day'] = nomial_sample(cnt_slot_value['book']['day'])

                if domain in ['hotel', 'restaurant'] and 'people' not in domain_goal['book']:
                    domain_goal['book']['people'] = nomial_sample(cnt_slot_value['book']['people'])

                if domain == 'train' and len(domain_goal['book']) <= 0:
                    domain_goal['book']['people'] = nomial_sample(cnt_slot_value['book']['people'])

            # fail_book
            if 'book' in domain_goal and random.random() < 0.5:
                if domain == 'hotel':
                    domain_goal['fail_book'] = deepcopy(domain_goal['book'])
                    if 'stay' in domain_goal['book'] and random.random() < 0.5:
                        # increase hotel-stay
                        domain_goal['fail_book']['stay'] = str(int(domain_goal['book']['stay']) + 1)
                    elif 'day' in domain_goal['book']:
                        # push back hotel-day by a day
                        domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) - 1) % 7]

                elif domain == 'restaurant':
                    domain_goal['fail_book'] = deepcopy(domain_goal['book'])
                    if 'time' in domain_goal['book'] and random.random() < 0.5:
                        hour, minute = domain_goal['book']['time'].split(':')
                        domain_goal['fail_book']['time'] = str((int(hour) + 1) % 24) + ':' + minute
                    elif 'day' in domain_goal['book']:
                        if random.random() < 0.5:
                            domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) - 1) % 7]
                        else:
                            domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) + 1) % 7]

            # fail_info
            if 'info' in domain_goal and len(self.dbquery.query(domain, domain_goal['info'].items())) == 0:
                num_trial = 0
                while num_trial < 100:
                    adjusted_info = self._adjust_info(domain, domain_goal['info'])
                    if len(self.dbquery.query(domain, adjusted_info.items())) > 0:
                        if domain == 'train':
                            domain_goal['info'] = adjusted_info
                        else:
                            domain_goal['fail_info'] = domain_goal['info']
                            domain_goal['info'] = adjusted_info

                        break
                    num_trial += 1

                if num_trial >= 100:
                    continue

            # at least there is one request and book
            if 'reqt' in domain_goal or 'book' in domain_goal:
                break

        return domain_goal

    def get_user_goal(self, seed=None):
        # seed the generator to get fixed goal
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
        
        domain_ordering = ()
        while len(domain_ordering) <= 0:
            domain_ordering = nomial_sample(self.domain_ordering_dist)
        #domain_ordering = ('attraction', 'restaurant', 'taxi')

        user_goal = {dom: self._get_domain_goal(dom) for dom in domain_ordering}
        assert len(user_goal.keys()) > 0

        # using taxi to communte between places, removing destination and departure.
        if 'taxi' in domain_ordering:
            places = [dom for dom in domain_ordering[: domain_ordering.index('taxi')] if 'address' in self.ind_slot_dist[dom]['reqt'].keys()]
            if len(places) >= 1:
                del user_goal['taxi']['info']['destination']
                user_goal[places[-1]]['reqt'] = list(set(user_goal[places[-1]].get('reqt', [])).union({'address'}))
                if places[-1] == 'restaurant' and 'book' in user_goal['restaurant']:
                    user_goal['taxi']['info']['arriveBy'] = user_goal['restaurant']['book']['time']
                    if 'leaveAt' in user_goal['taxi']['info']:
                        del user_goal['taxi']['info']['leaveAt']
            if len(places) >= 2:
                del user_goal['taxi']['info']['departure']
                user_goal[places[-2]]['reqt'] = list(set(user_goal[places[-2]].get('reqt', [])).union({'address'}))

        # match area of attraction and restaurant
        if 'restaurant' in domain_ordering and \
                'attraction' in domain_ordering and \
                'fail_info' not in user_goal['restaurant'] and \
                domain_ordering.index('restaurant') > domain_ordering.index('attraction') and \
                'area' in user_goal['restaurant']['info'] and 'area' in user_goal['attraction']['info']:
            adjusted_restaurant_goal = deepcopy(user_goal['restaurant']['info'])
            adjusted_restaurant_goal['area'] = user_goal['attraction']['info']['area']
            if len(self.dbquery.query('restaurant', adjusted_restaurant_goal.items())) > 0 and random.random() < 0.5:
                user_goal['restaurant']['info']['area'] = user_goal['attraction']['info']['area']

        # match day and people of restaurant and hotel
        if 'restaurant' in domain_ordering and 'hotel' in domain_ordering and \
                'book' in user_goal['restaurant'] and 'book' in user_goal['hotel']:
            if random.random() < 0.5:
                user_goal['restaurant']['book']['people'] = user_goal['hotel']['book']['people']
                if 'fail_book' in user_goal['restaurant']:
                    user_goal['restaurant']['fail_book']['people'] = user_goal['hotel']['book']['people']
            if random.random() < 1.0:
                user_goal['restaurant']['book']['day'] = user_goal['hotel']['book']['day']
                if 'fail_book' in user_goal['restaurant']:
                    user_goal['restaurant']['fail_book']['day'] = user_goal['hotel']['book']['day']
                    if user_goal['restaurant']['book']['day'] == user_goal['restaurant']['fail_book']['day'] and \
                            user_goal['restaurant']['book']['time'] == user_goal['restaurant']['fail_book']['time'] and \
                            user_goal['restaurant']['book']['people'] == user_goal['restaurant']['fail_book']['people']:
                        del user_goal['restaurant']['fail_book']

        # match day and people of hotel and train
        if 'hotel' in domain_ordering and 'train' in domain_ordering and \
                'book' in user_goal['hotel'] and 'info' in user_goal['train']:
            if user_goal['train']['info']['destination'] == 'cambridge' and \
                'day' in user_goal['hotel']['book']:
                user_goal['train']['info']['day'] = user_goal['hotel']['book']['day']
            elif user_goal['train']['info']['departure'] == 'cambridge' and \
                'day' in user_goal['hotel']['book'] and 'stay' in user_goal['hotel']['book']:
                user_goal['train']['info']['day'] = days[
                    (days.index(user_goal['hotel']['book']['day']) + int(
                        user_goal['hotel']['book']['stay'])) % 7]
            # In case, we have no query results with adjusted train goal, we simply drop the train goal.
            if len(self.dbquery.query('train', user_goal['train']['info'].items())) == 0:
                del user_goal['train']
                domain_ordering = tuple(list(domain_ordering).remove('train'))

        user_goal['domain_ordering'] = domain_ordering

        return user_goal

    def _adjust_info(self, domain, info):
        # adjust one of the slots of the info
        adjusted_info = deepcopy(info)
        slot = random.choice(list(info.keys()))
        adjusted_info[slot] = random.choice(list(self.ind_slot_value_dist[domain]['info'][slot].keys()))
        return adjusted_info

    def build_message(self, user_goal, boldify=null_boldify):
        message = []
        state = deepcopy(user_goal)

        for dom in user_goal['domain_ordering']:
            state = deepcopy(user_goal[dom])

            if not (dom == 'taxi' and len(state['info']) == 1):
                # intro
                m = [templates[dom]['intro']]

            # info
            def fill_info_template(user_goal, domain, slot, info):
                if slot != 'area' or not ('restaurant' in user_goal and
                                          'attraction' in user_goal and
                                          info in user_goal['restaurant'].keys() and
                                          info in user_goal['attraction'].keys() and
                                          'area' in user_goal['restaurant'][info] and
                                          'area' in user_goal['attraction'][info] and
                                          user_goal['restaurant'][info]['area'] == user_goal['attraction'][info]['area']):
                    return templates[domain][slot].format(self.boldify(user_goal[domain][info][slot]))
                else:
                    restaurant_index = user_goal['domain_ordering'].index('restaurant')
                    attraction_index = user_goal['domain_ordering'].index('attraction')
                    if restaurant_index > attraction_index and domain == 'restaurant':
                        return templates[domain][slot].format(self.boldify('same area as the attraction'))
                    elif attraction_index > restaurant_index and domain == 'attraction':
                        return templates[domain][slot].format(self.boldify('same area as the restaurant'))
                return templates[domain][slot].format(self.boldify(user_goal[domain][info][slot]))

            info = 'info'
            if 'fail_info' in user_goal[dom]:
                info = 'fail_info'
            if dom == 'taxi' and len(state[info]) == 1:
                taxi_index = user_goal['domain_ordering'].index('taxi')
                places = [dom for dom in user_goal['domain_ordering'][: taxi_index] if
                          dom in ['attraction', 'hotel', 'restaurant']]
                if len(places) >= 2:
                    random.shuffle(places)
                    m.append(templates['taxi']['commute'])
                    if 'arriveBy' in state[info]:
                        m.append('The taxi should arrive at the {} from the {} by {}.'.format(self.boldify(places[0]),
                                                                                              self.boldify(places[1]),
                                                                                              self.boldify(state[info]['arriveBy'])))
                    elif 'leaveAt' in state[info]:
                        m.append('The taxi should leave from the {} to the {} after {}.'.format(self.boldify(places[0]),
                                                                                                self.boldify(places[1]),
                                                                                                self.boldify(state[info]['leaveAt'])))
                    message.append(' '.join(m))
            else:
                while len(state[info]) > 0:
                    num_acts = random.randint(1, min(len(state[info]), 3))
                    slots = random.sample(list(state[info].keys()), num_acts)
                    sents = [fill_info_template(user_goal, dom, slot, info) for slot in slots if slot not in ['parking', 'internet']]
                    if 'parking' in slots:
                        sents.append(templates[dom]['parking ' + state[info]['parking']])
                    if 'internet' in slots:
                        sents.append(templates[dom]['internet ' + state[info]['internet']])
                    m.extend(sents)
                    message.append(' '.join(m))
                    m = []
                    for slot in slots:
                        del state[info][slot]

            # fail_info
            if 'fail_info' in user_goal[dom]:
            # if 'fail_info' in user_goal[dom]:
                adjusted_slot = list(filter(lambda x: x[0][1] != x[1][1],
                                            zip(user_goal[dom]['info'].items(), user_goal[dom]['fail_info'].items())))[0][0][0]
                if adjusted_slot in ['internet', 'parking']:
                    message.append(templates[dom]['fail_info ' + adjusted_slot + ' ' + user_goal[dom]['info'][adjusted_slot]])
                else:
                    message.append(templates[dom]['fail_info ' + adjusted_slot].format(self.boldify(user_goal[dom]['info'][adjusted_slot])))

            # reqt
            if 'reqt' in state:
                slot_strings = []
                for slot in state['reqt']:
                    if slot in ['internet', 'parking', 'food']:
                        continue
                    slot_strings.append(slot if slot not in request_slot_string_map else request_slot_string_map[slot])
                if len(slot_strings) > 0:
                    message.append(templates[dom]['request'].format(self.boldify(', '.join(slot_strings))))
                if 'internet' in state['reqt']:
                    message.append('Make sure to ask if the hotel includes free wifi.')
                if 'parking' in state['reqt']:
                    message.append('Make sure to ask if the hotel includes free parking.')
                if 'food' in state['reqt']:
                    message.append('Make sure to ask about what food it serves.')

            def get_same_people_domain(user_goal, domain, slot):
                if slot not in ['day', 'people']:
                    return None
                domain_index = user_goal['domain_ordering'].index(domain)
                previous_domains = user_goal['domain_ordering'][:domain_index]
                for prev in previous_domains:
                    if prev in ['restaurant', 'hotel', 'train'] and 'book' in user_goal[prev] and \
                            slot in user_goal[prev]['book'] and user_goal[prev]['book'][slot] == \
                            user_goal[domain]['book'][slot]:
                        return prev
                return None

            # book
            book = 'book'
            if 'fail_book' in user_goal[dom]:
                book = 'fail_book'
            if 'book' in state:
                slot_strings = []
                for slot in ['people', 'time', 'day', 'stay']:
                    if slot in state[book]:
                        if slot == 'people':
                            same_people_domain = get_same_people_domain(user_goal, dom, slot)
                            if same_people_domain is None:
                                slot_strings.append('for {} people'.format(self.boldify(state[book][slot])))
                            else:
                                slot_strings.append(self.boldify(
                                    'for the same group of people as the {} booking'.format(same_people_domain)))
                        elif slot == 'time':
                            slot_strings.append('at {}'.format(self.boldify(state[book][slot])))
                        elif slot == 'day':
                            same_people_domain = get_same_people_domain(user_goal, dom, slot)
                            if same_people_domain is None:
                                slot_strings.append('on {}'.format(self.boldify(state[book][slot])))
                            else:
                                slot_strings.append(
                                    self.boldify('on the same day as the {} booking'.format(same_people_domain)))
                        elif slot == 'stay':
                            slot_strings.append('for {} nights'.format(self.boldify(state[book][slot])))
                        del state[book][slot]

                assert len(state[book]) <= 0, state[book]

                if len(slot_strings) > 0:
                    message.append(templates[dom]['book'].format(' '.join(slot_strings)))

            # fail_book
            if 'fail_book' in user_goal[dom]:
                adjusted_slot = list(filter(lambda x: x[0][1] != x[1][1], zip(user_goal[dom]['book'].items(),
                                                                              user_goal[dom]['fail_book'].items())))[0][0][0]

                if adjusted_slot in ['internet', 'parking']:
                    message.append(
                        templates[dom]['fail_book ' + adjusted_slot + ' ' + user_goal[dom]['book'][adjusted_slot]])
                else:
                    message.append(templates[dom]['fail_book ' + adjusted_slot].format(
                        self.boldify(user_goal[dom]['book'][adjusted_slot])))

        if boldify == do_boldify:
            for i, m in enumerate(message):
                message[i] = message[i].replace('wifi', "<b>wifi</b>")
                message[i] = message[i].replace('internet', "<b>internet</b>")
                message[i] = message[i].replace('parking', "<b>parking</b>")

        return message
コード例 #3
0
class StateTracker(object):
    def __init__(self, data_dir, config):
        self.time_step = 0
        self.cfg = config
        self.db = DBQuery(data_dir)
        self.topic = 'NONE'

    def _action_to_dict(self, das):
        da_dict = {}
        for da, value in das.items():
            domain, intent, slot, p = da.split('-')
            domint = '-'.join((domain, intent))
            if domint not in da_dict:
                da_dict[domint] = []
            da_dict[domint].append([slot, value])
        return da_dict

    def _dict_to_vec(self, das):
        da_vector = torch.zeros(self.cfg.a_dim_usr, dtype=torch.int32)
        expand_da(das)
        for domint in das:
            pairs = das[domint]
            for slot, p, value in pairs:
                da = '-'.join((domint, slot, p)).lower()
                if da in self.cfg.dau2idx:
                    idx = self.cfg.dau2idx[da]
                    da_vector[idx] = 1
        return da_vector

    def _mask_user_goal(self, goal):
        domain_ordering = list(goal['domain_ordering'])
        if 'hospital' in goal:
            del (goal['hospital'])
            domain_ordering.remove('hospital')
        if 'police' in goal:
            del (goal['police'])
            domain_ordering.remove('police')
        goal['domain_ordering'] = tuple(domain_ordering)

    def get_entities(self, s, domain):
        origin = s['belief_state']['inform'][domain].items()
        constraint = []
        for k, v in origin:
            if k in self.cfg.mapping[domain]:
                constraint.append((self.cfg.mapping[domain][k], v))
        entities = self.db.query(domain, constraint)
        random.shuffle(entities)
        return entities

    def update_belief_sys(self, old_s, a):
        """
        update belief state with sys action
        """
        s = deepcopy(old_s)
        a_index = torch.nonzero(a)  # get multiple da indices

        self.time_step += 1
        s['others']['turn'] = self.time_step

        # update sys/user dialog act
        s['history']['sys'] = dict(s['history']['sys'], **s['last_sys_action'])
        del (s['last_sys_action'])
        s['last_user_action'] = s['user_action']
        s['user_action'] = dict()

        # update belief part
        das = [self.cfg.idx2da[idx.item()] for idx in a_index]
        das = [da.split('-') for da in das]
        sorted(das, key=lambda x: x[0])  # sort by domain

        entities = [] if self.topic == 'NONE' else self.get_entities(
            s, self.topic)
        for domain, intent, slot, p in das:
            _domain = self.topic if domain == 'booking' else domain
            if domain in self.cfg.belief_domains and domain != self.topic:
                self.topic = domain
                entities = self.get_entities(s, domain)

            da = '-'.join((domain, intent, slot, p))
            if p == 'none':
                s['sys_action'][da] = 'none'
            elif p == '?':
                s['sys_action'][da] = '?'
            elif intent in ['nooffer', 'nobook']:
                if slot in s['belief_state']['inform'][_domain]:
                    s['sys_action'][da] = s['belief_state']['inform'][_domain][
                        slot]
                else:
                    s['sys_action'][da] = 'none'
            elif slot == 'choice':
                s['sys_action'][da] = str(len(entities))
            else:
                num = int(p) - 1
                if len(entities) > num and slot in self.cfg.mapping[_domain]:
                    typ = self.cfg.mapping[_domain][slot]
                    s['sys_action'][da] = entities[num][typ]
                else:
                    s['sys_action'][da] = 'none'

            if intent == 'inform' and _domain != 'NONE':
                s['belief_state']['request'][_domain].discard(slot)

            # booked
            if intent == 'inform' and slot == 'car':  # taxi
                if not s['belief_state']['booked']['taxi']:
                    s['belief_state']['booked']['taxi'] == 'booked'
            elif intent == 'offerbooked' and slot == 'ref':  # train
                s['belief_state']['request']['train'].discard('ref')
                if not s['belief_state']['booked']['train'] and entities:
                    s['belief_state']['booked']['train'] = entities[0]['ref']
            elif intent == 'book' and slot == 'ref':  # attraction, hotel, restaurant
                if _domain not in ['attraction', 'hotel', 'restaurant']:
                    continue
                s['belief_state']['request'][_domain].discard('ref')
                if not s['belief_state']['booked'][_domain] and entities:
                    # save entity id
                    s['belief_state']['booked'][_domain] = entities[0]['ref']

        return s

    def update_belief_usr(self, old_s, a, terminal):
        """
        update belief state with user action
        """
        s = deepcopy(old_s)
        a_index = torch.nonzero(a)  # get multiple da indices

        self.time_step += 1
        s['others']['turn'] = self.time_step
        s['others']['terminal'] = terminal

        # update sys/user dialog act
        s['history']['user'] = dict(s['history']['user'],
                                    **s['last_user_action'])
        del (s['last_user_action'])
        s['last_sys_action'] = s['sys_action']
        s['sys_action'] = dict()

        # update belief part
        das = [self.cfg.idx2dau[idx.item()] for idx in a_index]
        das = [da.split('-') for da in das]
        sorted(das, key=lambda x: x[0])  # sort by domain

        for domain, intent, slot, p in das:
            if domain in self.cfg.belief_domains and domain != self.topic:
                self.topic = domain

            da = '-'.join((domain, intent, slot, p))
            if p == 'none':
                s['user_action'][da] = 'none'
            elif p == '?':
                s['user_action'][da] = '?'
            else:
                if slot in s['user_goal']['inform'][domain]:
                    s['user_action'][da] = s['user_goal']['inform'][domain][
                        slot]
                else:
                    s['user_action'][da] = 'none'

            if slot != 'none':
                if intent == 'inform':
                    # update constraints with reasonable value according to user goal
                    if slot in s['user_goal']['inform'][domain]:
                        s['belief_state']['inform'][domain][slot] = s[
                            'user_goal']['inform'][domain][slot]  # value
                    else:
                        s['belief_state']['inform'][domain][slot] = 'none'

                elif intent == 'request':
                    s['belief_state']['request'][domain].add(slot)

        return s

    def reset(self, random_seed=None):
        """
        Args:
            random_seed (int):
        Returns:
            init_state (dict):
        """
        pass

    def step(self, s, sys_a):
        """
        Args:
            s (dict):
            sys_a (vector):
        Returns:
            next_s (dict):
            terminal (bool):
        """
        pass
コード例 #4
0
class GoalGenerator:
    """User goal generator"""
    def __init__(self,
                 data_dir,
                 cfg,
                 goal_model_path,
                 corpus_path=None,
                 boldify=False):
        """
        生成各种分布:
        cfg 这里只用于为dbquery提供地址,并没有提供对数据分布进行控制
        self.ind_slot_dist: 独立实体分布
        self.ind_slot_value_dist: 独立实体值分布
        self.domain_ordering_dist: domain请求分布
        self.book_dist: 订阅分布
        Args:
            goal_model_path: path to a goal model
            corpus_path: path to a dialog corpus to build a goal model
        """
        self.cfg = cfg
        if self.cfg.d:
            self.domains = [self.cfg.d]
        else:
            print(self.cfg)
            self.domains = {
                'attraction', 'hotel', 'restaurant', 'train', 'taxi',
                'hospital', 'police'
            }
        self.dbquery = DBQuery(data_dir, cfg)
        self.goal_model_path = data_dir + '/' + goal_model_path
        self.corpus_path = data_dir + '/' + corpus_path if corpus_path is not None else None
        self.boldify = do_boldify if boldify else null_boldify
        if os.path.exists(self.goal_model_path):
            with open(self.goal_model_path, 'rb') as f:
                self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist = pickle.load(
                    f)
            logging.info('Loading goal model is done')
        else:
            self._build_goal_model()
            logging.info('Building goal model is done')

        # remove some slot
        if 'police' in self.ind_slot_dist:
            del self.ind_slot_dist['police']['reqt']['postcode']
            del self.ind_slot_value_dist['police']['reqt']['postcode']
        if 'hospital' in self.ind_slot_dist:
            del self.ind_slot_dist['hospital']['reqt']['postcode']
            del self.ind_slot_value_dist['hospital']['reqt']['postcode']
            del self.ind_slot_dist['hospital']['reqt']['address']
            del self.ind_slot_value_dist['hospital']['reqt']['address']

    def _build_goal_model(self):
        with open(self.corpus_path) as f:
            dialogs = json.load(f)

        # domain ordering
        def _get_dialog_domains(dialog):
            """收集dialog中的有效domains"""
            return list(
                filter(
                    lambda x: x in self.domains and len(dialog['goal'][x]) > 0,
                    dialog['goal']))

        domain_orderings = []
        for d in dialogs:
            d_domains = _get_dialog_domains(dialogs[d])

            if self.cfg.d:
                # 直接跳过多场景切换问题, 如果需要增加对多场景切换问题的研究,则需要注释掉这段
                if len(d_domains) > 1 or len(d_domains) == 0:
                    continue

            # print("d_domains: ", d_domains)

            first_index = []
            # 找到每个domain的首行message
            for domain in d_domains:
                message = [dialogs[d]['goal']['message']] if type(dialogs[d]['goal']['message']) == str else \
                    dialogs[d]['goal']['message']
                for i, m in enumerate(message):
                    if domain_keywords[domain].lower() in m.lower(
                    ) or domain.lower() in m.lower():
                        first_index.append(i)
                        break
            # 根据首行message的实际序号,调整domains的出现顺序
            domain_orderings.append(
                tuple(
                    map(
                        lambda x: x[1],
                        sorted(zip(first_index, d_domains),
                               key=lambda x: x[0]))))
        domain_ordering_cnt = Counter(domain_orderings)
        self.domain_ordering_dist = deepcopy(domain_ordering_cnt)
        # 转化为概率值
        for order in domain_ordering_cnt.keys():
            self.domain_ordering_dist[order] = domain_ordering_cnt[
                order] / sum(domain_ordering_cnt.values())

        # independent goal slot distribution
        ind_slot_value_cnt = dict([(domain, {}) for domain in self.domains])
        domain_cnt = Counter()
        book_cnt = Counter()

        # 对各个用户的用户目标进行统计
        for d in dialogs:
            for domain in self.domains:
                # 统计domain出现的次数
                if dialogs[d]['goal'][domain] != {}:
                    domain_cnt[domain] += 1
                # 统计info中各个slot及其slot-value出现的次数
                if 'info' in dialogs[d]['goal'][domain]:
                    for slot in dialogs[d]['goal'][domain]['info']:
                        if 'invalid' in slot:
                            continue
                        if 'info' not in ind_slot_value_cnt[domain]:
                            ind_slot_value_cnt[domain]['info'] = {}
                        if slot not in ind_slot_value_cnt[domain]['info']:
                            ind_slot_value_cnt[domain]['info'][slot] = Counter(
                            )
                        if 'care' in dialogs[d]['goal'][domain]['info'][slot]:
                            continue
                        slot_value = dialogs[d]['goal'][domain]['info'][slot]
                        ind_slot_value_cnt[domain]['info'][slot][
                            slot_value] += 1
                # 统计用户目标汇总reqt的slot出现的次数,由于没有value,因此直接对slot进行统计
                if 'reqt' in dialogs[d]['goal'][domain]:
                    for slot in dialogs[d]['goal'][domain]['reqt']:
                        if 'reqt' not in ind_slot_value_cnt[domain]:
                            ind_slot_value_cnt[domain]['reqt'] = Counter()
                        ind_slot_value_cnt[domain]['reqt'][slot] += 1
                # 统计用户目标中book的信息
                if 'book' in dialogs[d]['goal'][domain]:
                    # 各个domain book的概率
                    book_cnt[domain] += 1
                    # book中限制slot的value的次数
                    for slot in dialogs[d]['goal'][domain]['book']:
                        if 'invalid' in slot:
                            continue
                        if 'book' not in ind_slot_value_cnt[domain]:
                            ind_slot_value_cnt[domain]['book'] = {}
                        if slot not in ind_slot_value_cnt[domain]['book']:
                            ind_slot_value_cnt[domain]['book'][slot] = Counter(
                            )
                        if 'care' in dialogs[d]['goal'][domain]['book'][slot]:
                            continue
                        ind_slot_value_cnt[domain]['book'][slot][
                            dialogs[d]['goal'][domain]['book'][slot]] += 1
        # 收集完成ind_slot_value_cnt 和 domain_cnt 和 book_cnt 的信息
        # 计算得到 ind_slot_dist ind_slot_value_dist 基于不同domain的分布
        self.ind_slot_value_dist = deepcopy(ind_slot_value_cnt)
        self.ind_slot_dist = dict([(domain, {}) for domain in self.domains])
        self.book_dist = {}
        for domain in self.domains:
            if 'info' in ind_slot_value_cnt[domain]:
                for slot in ind_slot_value_cnt[domain]['info']:
                    if 'info' not in self.ind_slot_dist[domain]:
                        self.ind_slot_dist[domain]['info'] = {}
                    if slot not in self.ind_slot_dist[domain]['info']:
                        self.ind_slot_dist[domain]['info'][slot] = {}
                    self.ind_slot_dist[domain]['info'][slot] = sum(ind_slot_value_cnt[domain]['info'][slot].values()) / \
                                                               domain_cnt[domain]
                    slot_total = sum(
                        ind_slot_value_cnt[domain]['info'][slot].values())
                    for val in self.ind_slot_value_dist[domain]['info'][slot]:
                        self.ind_slot_value_dist[domain]['info'][slot][
                            val] = ind_slot_value_cnt[domain]['info'][slot][
                                val] / slot_total
            if 'reqt' in ind_slot_value_cnt[domain]:
                for slot in ind_slot_value_cnt[domain]['reqt']:
                    if 'reqt' not in self.ind_slot_dist[domain]:
                        self.ind_slot_dist[domain]['reqt'] = {}
                    self.ind_slot_dist[domain]['reqt'][
                        slot] = ind_slot_value_cnt[domain]['reqt'][
                            slot] / domain_cnt[domain]
                    self.ind_slot_value_dist[domain]['reqt'][slot] = ind_slot_value_cnt[domain]['reqt'][slot] / \
                                                                     domain_cnt[domain]
            if 'book' in ind_slot_value_cnt[domain]:
                for slot in ind_slot_value_cnt[domain]['book']:
                    if 'book' not in self.ind_slot_dist[domain]:
                        self.ind_slot_dist[domain]['book'] = {}
                    if slot not in self.ind_slot_dist[domain]['book']:
                        self.ind_slot_dist[domain]['book'][slot] = {}
                    self.ind_slot_dist[domain]['book'][slot] = sum(ind_slot_value_cnt[domain]['book'][slot].values()) / \
                                                               domain_cnt[domain]
                    slot_total = sum(
                        ind_slot_value_cnt[domain]['book'][slot].values())
                    for val in self.ind_slot_value_dist[domain]['book'][slot]:
                        self.ind_slot_value_dist[domain]['book'][slot][
                            val] = ind_slot_value_cnt[domain]['book'][slot][
                                val] / slot_total
            self.book_dist[domain] = book_cnt[domain] / len(dialogs)

        # print()
        # print("ind_slot_dist:", self.ind_slot_dist)
        # print("ind_slot_value_dist ", self.ind_slot_value_dist)
        # print("domain_ordering_dist:", self.domain_ordering_dist)
        # print("book_dist:", self.book_dist)
        # print()

        goal_model_path_dir = os.path.dirname(self.goal_model_path)
        if len(goal_model_path_dir) > 0 and not os.path.exists(
                goal_model_path_dir):
            os.makedirs(goal_model_path_dir)

        with open(self.goal_model_path, 'wb') as f:
            pickle.dump((self.ind_slot_dist, self.ind_slot_value_dist,
                         self.domain_ordering_dist, self.book_dist), f)

    def _get_domain_goal(self, domain):
        """按照domain的设定,生成对应的目标"""
        cnt_slot = self.ind_slot_dist[domain]
        cnt_slot_value = self.ind_slot_value_dist[domain]
        pro_book = self.book_dist[domain]

        while True:
            domain_goal = {'info': {}}
            # inform
            if 'info' in cnt_slot:
                for slot in cnt_slot['info']:
                    # 一定概率增加对该slot的限制
                    if random.random(
                    ) < cnt_slot['info'][slot] + pro_correction['info']:
                        domain_goal['info'][slot] = nomial_sample(
                            cnt_slot_value['info'][slot])
                # 对hotel restaurant attraction的限制,
                if domain in ['hotel', 'restaurant', 'attraction'
                              ] and 'name' in domain_goal['info'] and len(
                                  domain_goal['info']) > 1:
                    # 一定概率提供name,但是只需要提供name就可以了,其他的信息已经不重要了
                    if random.random() < cnt_slot['info']['name']:
                        domain_goal['info'] = {
                            'name': domain_goal['info']['name']
                        }
                    else:
                        # 或者删掉name
                        del domain_goal['info']['name']

                # 对于taxi 和train, 采用一定的概率删除leaveAt和arriveBy中的一个,不然答案就唯一了
                if domain in ['taxi', 'train'] and 'arriveBy' in domain_goal[
                        'info'] and 'leaveAt' in domain_goal['info']:
                    if random.random() < (cnt_slot['info']['leaveAt'] /
                                          (cnt_slot['info']['arriveBy'] +
                                           cnt_slot['info']['leaveAt'])):
                        del domain_goal['info']['arriveBy']
                    else:
                        del domain_goal['info']['leaveAt']

                # 但是arriveBy和leaveAt又必须提供一个
                if domain in ['taxi', 'train'] and 'arriveBy' not in domain_goal['info'] and 'leaveAt' not in \
                        domain_goal['info']:
                    if random.random() < (cnt_slot['info']['arriveBy'] /
                                          (cnt_slot['info']['arriveBy'] +
                                           cnt_slot['info']['leaveAt'])):
                        domain_goal['info']['arriveBy'] = nomial_sample(
                            cnt_slot_value['info']['arriveBy'])
                    else:
                        domain_goal['info']['leaveAt'] = nomial_sample(
                            cnt_slot_value['info']['leaveAt'])

                # 需要告知出发地点和目的地
                if domain in ['taxi', 'train'
                              ] and 'departure' not in domain_goal['info']:
                    domain_goal['info']['departure'] = nomial_sample(
                        cnt_slot_value['info']['departure'])
                if domain in ['taxi', 'train'
                              ] and 'destination' not in domain_goal['info']:
                    domain_goal['info']['destination'] = nomial_sample(
                        cnt_slot_value['info']['destination'])

                # 如果出发地和目的地一样,一定概率重新采样
                if domain in ['taxi', 'train'] and \
                        'departure' in domain_goal['info'] and \
                        'destination' in domain_goal['info']:
                    while domain_goal['info']['departure'] == domain_goal[
                            'info']['destination']:
                        if random.random() < (
                                cnt_slot['info']['departure'] /
                            (cnt_slot['info']['departure'] +
                             cnt_slot['info']['destination'])):
                            domain_goal['info']['departure'] = nomial_sample(
                                cnt_slot_value['info']['departure'])
                        else:
                            domain_goal['info']['destination'] = nomial_sample(
                                cnt_slot_value['info']['destination'])
#                        print("same destination departure")
                if domain_goal['info'] == {}:
                    # 如果没有用户提供的信息,则重新生成
                    continue
            # request
            # 针对infor的设定,随机选择request的限制
            if 'reqt' in cnt_slot:
                reqt = [
                    slot for slot in cnt_slot['reqt']
                    if random.random() < cnt_slot['reqt'][slot] +
                    pro_correction['reqt'] and slot not in domain_goal['info']
                ]
                if len(reqt) > 0:
                    domain_goal['reqt'] = reqt

            # book
            # 一定概率决定是否book, 但是对于book的信息并没有和info生成的信息进行协调
            if 'book' in cnt_slot and random.random(
            ) < pro_book + pro_correction['book']:
                if 'book' not in domain_goal:
                    domain_goal['book'] = {}

                for slot in cnt_slot['book']:
                    # 随机选取book的限制
                    if random.random(
                    ) < cnt_slot['book'][slot] + pro_correction['book']:
                        domain_goal['book'][slot] = nomial_sample(
                            cnt_slot_value['book'][slot])

                # makes sure that there are all necessary slots for booking
                # 预定餐馆需要告知时间
                if domain == 'restaurant' and 'time' not in domain_goal['book']:
                    domain_goal['book']['time'] = nomial_sample(
                        cnt_slot_value['book']['time'])
                # 预定旅馆需要告知几天
                if domain == 'hotel' and 'stay' not in domain_goal['book']:
                    domain_goal['book']['stay'] = nomial_sample(
                        cnt_slot_value['book']['stay'])
                # 预定残旅馆需要告知哪天几个人
                if domain in ['hotel', 'restaurant'
                              ] and 'day' not in domain_goal['book']:
                    domain_goal['book']['day'] = nomial_sample(
                        cnt_slot_value['book']['day'])
                if domain in ['hotel', 'restaurant'
                              ] and 'people' not in domain_goal['book']:
                    domain_goal['book']['people'] = nomial_sample(
                        cnt_slot_value['book']['people'])

                # 如果是train,需要告知几个人
                if domain == 'train' and len(domain_goal['book']) <= 0:
                    domain_goal['book']['people'] = nomial_sample(
                        cnt_slot_value['book']['people'])

            # fail_book
            # 只有残旅馆会出现fail_book
            if 'book' in domain_goal and random.random() < 0.5:
                # 对于旅馆的订购失败设定
                if domain == 'hotel':
                    domain_goal['fail_book'] = deepcopy(domain_goal['book'])
                    if 'stay' in domain_goal['book'] and random.random() < 0.5:
                        # increase hotel-stay
                        # 一定概率增加天数
                        domain_goal['fail_book']['stay'] = str(
                            int(domain_goal['book']['stay']) + 1)
                    elif 'day' in domain_goal['book']:
                        # push back hotel-day by a day
                        # 一定概率提前日期
                        domain_goal['fail_book']['day'] = days[
                            (days.index(domain_goal['book']['day']) - 1) % 7]

                elif domain == 'restaurant':
                    domain_goal['fail_book'] = deepcopy(domain_goal['book'])
                    if 'time' in domain_goal['book'] and random.random() < 0.5:
                        hour, minute = domain_goal['book']['time'].split(':')
                        # 增加1小时
                        domain_goal['fail_book']['time'] = str(
                            (int(hour) + 1) % 24) + ':' + minute
                    elif 'day' in domain_goal['book']:
                        if random.random() < 0.5:
                            # 提前一天
                            domain_goal['fail_book']['day'] = days[
                                (days.index(domain_goal['book']['day']) - 1) %
                                7]
                        else:
                            # 推迟一天
                            domain_goal['fail_book']['day'] = days[
                                (days.index(domain_goal['book']['day']) + 1) %
                                7]

            # fail_info
            if 'info' in domain_goal and len(
                    self.dbquery.query(domain,
                                       domain_goal['info'].items())) == 0:
                # 随机生成的信息查询不到时,设置fail_info
                num_trial = 0
                # 尝试重新生成,尝试次数不超过100次
                while num_trial < 100:
                    adjusted_info = self._adjust_info(domain,
                                                      domain_goal['info'])
                    if len(self.dbquery.query(domain,
                                              adjusted_info.items())) > 0:
                        # train不经行fail_info
                        if domain == 'train':
                            domain_goal['info'] = adjusted_info
                        else:
                            domain_goal['fail_info'] = domain_goal['info']
                            domain_goal['info'] = adjusted_info

                        break
                    num_trial += 1

                if num_trial >= 100:
                    # 大于100次尝试失败,重新生成用户目标
                    continue

            # at least there is one request and book
            # 用户至少需要提一个要求
            if 'reqt' in domain_goal or 'book' in domain_goal:
                break

        return domain_goal

    def get_user_goal(self, seed=None):
        # seed the generator to get fixed goal
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)

        # 先确定domain_ordering
        domain_ordering = []
        while not domain_ordering:
            domain_ordering = list(nomial_sample(self.domain_ordering_dist))
        np.random.shuffle(domain_ordering)

        user_goal = {
            dom: self._get_domain_goal(dom)
            for dom in domain_ordering
        }
        assert len(user_goal.keys()) > 0

        # using taxi to communte between places, removing destination and departure.
        # 如果taxi中有地址出现,则需要跟之前的domain一致,
        if 'taxi' in domain_ordering:
            places = [
                dom for dom in domain_ordering[:domain_ordering.index('taxi')]
                if 'address' in self.ind_slot_dist[dom]['reqt'].keys()
            ]
            if len(places) >= 1:
                del user_goal['taxi']['info']['destination']
                user_goal[places[-1]]['reqt'] = list(
                    set(user_goal[places[-1]].get('reqt',
                                                  [])).union({'address'}))
                # 时间上也要保持一致
                if places[-1] == 'restaurant' and 'book' in user_goal[
                        'restaurant']:
                    user_goal['taxi']['info']['arriveBy'] = user_goal[
                        'restaurant']['book']['time']
                    if 'leaveAt' in user_goal['taxi']['info']:
                        del user_goal['taxi']['info']['leaveAt']
            # 起点和终点
            if len(places) >= 2:
                del user_goal['taxi']['info']['departure']
                user_goal[places[-2]]['reqt'] = list(
                    set(user_goal[places[-2]].get('reqt',
                                                  [])).union({'address'}))

        # match area of attraction and restaurant
        # 景点和饭店地点应该在一起
        if 'restaurant' in domain_ordering and \
                'attraction' in domain_ordering and \
                'fail_info' not in user_goal['restaurant'] and \
                domain_ordering.index('restaurant') > domain_ordering.index('attraction') and \
                'area' in user_goal['restaurant']['info'] and 'area' in user_goal['attraction']['info']:
            adjusted_restaurant_goal = deepcopy(
                user_goal['restaurant']['info'])
            adjusted_restaurant_goal['area'] = user_goal['attraction']['info'][
                'area']
            # 数据库有的话,一定程度转移
            if len(
                    self.dbquery.query('restaurant',
                                       adjusted_restaurant_goal.items())
            ) > 0 and random.random() < 0.5:
                user_goal['restaurant']['info']['area'] = user_goal[
                    'attraction']['info']['area']

        # match day and people of restaurant and hotel
        # 餐旅馆的人数和天数应该一致,调整之后检查是不是和fail_book冲突了,如果是删掉fail_book
        if 'restaurant' in domain_ordering and 'hotel' in domain_ordering and \
                'book' in user_goal['restaurant'] and 'book' in user_goal['hotel']:
            if random.random() < 0.5:
                user_goal['restaurant']['book']['people'] = user_goal['hotel'][
                    'book']['people']
                if 'fail_book' in user_goal['restaurant']:
                    user_goal['restaurant']['fail_book']['people'] = user_goal[
                        'hotel']['book']['people']
            if random.random() < 1.0:
                user_goal['restaurant']['book']['day'] = user_goal['hotel'][
                    'book']['day']
                if 'fail_book' in user_goal['restaurant']:
                    user_goal['restaurant']['fail_book']['day'] = user_goal[
                        'hotel']['book']['day']
                    if user_goal['restaurant']['book']['day'] == user_goal['restaurant']['fail_book']['day'] and \
                            user_goal['restaurant']['book']['time'] == user_goal['restaurant']['fail_book']['time'] and \
                            user_goal['restaurant']['book']['people'] == user_goal['restaurant']['fail_book']['people']:
                        del user_goal['restaurant']['fail_book']

        # match day and people of hotel and train
        if 'hotel' in domain_ordering and 'train' in domain_ordering and \
                'book' in user_goal['hotel'] and 'info' in user_goal['train']:
            if user_goal['train']['info']['destination'] == 'cambridge' and \
                    'day' in user_goal['hotel']['book']:
                user_goal['train']['info']['day'] = user_goal['hotel']['book'][
                    'day']
            elif user_goal['train']['info']['departure'] == 'cambridge' and \
                    'day' in user_goal['hotel']['book'] and 'stay' in user_goal['hotel']['book']:
                user_goal['train']['info']['day'] = days[
                    (days.index(user_goal['hotel']['book']['day']) +
                     int(user_goal['hotel']['book']['stay'])) % 7]
            # In case, we have no query results with adjusted train goal, we simply drop the train goal.
            if len(
                    self.dbquery.query(
                        'train', user_goal['train']['info'].items())) == 0:
                del user_goal['train']
                domain_ordering.remove('train')

        user_goal['domain_ordering'] = domain_ordering

        return user_goal

    def _adjust_info(self, domain, info):
        # adjust one of the slots of the info
        # 随机选择一个slot进行随机替换
        adjusted_info = deepcopy(info)
        slot = random.choice(list(info.keys()))
        adjusted_info[slot] = random.choice(
            list(self.ind_slot_value_dist[domain]['info'][slot].keys()))
        return adjusted_info

    def build_message(self, user_goal, boldify=null_boldify):
        message = []
        state = deepcopy(user_goal)

        for dom in user_goal['domain_ordering']:
            state = deepcopy(user_goal[dom])

            if not (dom == 'taxi' and len(state['info']) == 1):
                # intro
                m = [templates[dom]['intro']]

            # info
            def fill_info_template(user_goal, domain, slot, info):
                if slot != 'area' or not (
                        'restaurant' in user_goal and 'attraction' in user_goal
                        and info in user_goal['restaurant'].keys()
                        and info in user_goal['attraction'].keys()
                        and 'area' in user_goal['restaurant'][info]
                        and 'area' in user_goal['attraction'][info]
                        and user_goal['restaurant'][info]['area']
                        == user_goal['attraction'][info]['area']):
                    return templates[domain][slot].format(
                        self.boldify(user_goal[domain][info][slot]))
                else:
                    restaurant_index = user_goal['domain_ordering'].index(
                        'restaurant')
                    attraction_index = user_goal['domain_ordering'].index(
                        'attraction')
                    if restaurant_index > attraction_index and domain == 'restaurant':
                        return templates[domain][slot].format(
                            self.boldify('same area as the attraction'))
                    elif attraction_index > restaurant_index and domain == 'attraction':
                        return templates[domain][slot].format(
                            self.boldify('same area as the restaurant'))
                return templates[domain][slot].format(
                    self.boldify(user_goal[domain][info][slot]))

            info = 'info'
            if 'fail_info' in user_goal[dom]:
                info = 'fail_info'
            if dom == 'taxi' and len(state[info]) == 1:
                taxi_index = user_goal['domain_ordering'].index('taxi')
                places = [
                    dom for dom in user_goal['domain_ordering'][:taxi_index]
                    if dom in ['attraction', 'hotel', 'restaurant']
                ]
                if len(places) >= 2:
                    random.shuffle(places)
                    m.append(templates['taxi']['commute'])
                    if 'arriveBy' in state[info]:
                        m.append(
                            'The taxi should arrive at the {} from the {} by {}.'
                            .format(self.boldify(places[0]),
                                    self.boldify(places[1]),
                                    self.boldify(state[info]['arriveBy'])))
                    elif 'leaveAt' in state[info]:
                        m.append(
                            'The taxi should leave from the {} to the {} after {}.'
                            .format(self.boldify(places[0]),
                                    self.boldify(places[1]),
                                    self.boldify(state[info]['leaveAt'])))
                    message.append(' '.join(m))
            else:
                while len(state[info]) > 0:
                    num_acts = random.randint(1, min(len(state[info]), 3))
                    slots = random.sample(list(state[info].keys()), num_acts)
                    sents = [
                        fill_info_template(user_goal, dom, slot, info)
                        for slot in slots
                        if slot not in ['parking', 'internet']
                    ]
                    if 'parking' in slots:
                        sents.append(templates[dom]['parking ' +
                                                    state[info]['parking']])
                    if 'internet' in slots:
                        sents.append(templates[dom]['internet ' +
                                                    state[info]['internet']])
                    m.extend(sents)
                    message.append(' '.join(m))
                    m = []
                    for slot in slots:
                        del state[info][slot]

            # fail_info
            if 'fail_info' in user_goal[dom]:
                adjusted_slot = list(
                    filter(
                        lambda x: x[0][1] != x[1][1],
                        zip(user_goal[dom]['info'].items(),
                            user_goal[dom]['fail_info'].items())))[0][0][0]
                if adjusted_slot in ['internet', 'parking']:
                    message.append(
                        templates[dom]['fail_info ' + adjusted_slot + ' ' +
                                       user_goal[dom]['info'][adjusted_slot]])
                else:
                    message.append(
                        templates[dom]['fail_info ' + adjusted_slot].format(
                            self.boldify(
                                user_goal[dom]['info'][adjusted_slot])))

            # reqt
            if 'reqt' in state:
                slot_strings = []
                for slot in state['reqt']:
                    if slot in ['internet', 'parking', 'food']:
                        continue
                    slot_strings.append(slot if slot not in
                                        request_slot_string_map else
                                        request_slot_string_map[slot])
                if len(slot_strings) > 0:
                    message.append(templates[dom]['request'].format(
                        self.boldify(', '.join(slot_strings))))
                if 'internet' in state['reqt']:
                    message.append(
                        'Make sure to ask if the hotel includes free wifi.')
                if 'parking' in state['reqt']:
                    message.append(
                        'Make sure to ask if the hotel includes free parking.')
                if 'food' in state['reqt']:
                    message.append(
                        'Make sure to ask about what food it serves.')

            def get_same_people_domain(user_goal, domain, slot):
                if slot not in ['day', 'people']:
                    return None
                domain_index = user_goal['domain_ordering'].index(domain)
                previous_domains = user_goal['domain_ordering'][:domain_index]
                for prev in previous_domains:
                    if prev in ['restaurant', 'hotel', 'train'] and 'book' in user_goal[prev] and \
                            slot in user_goal[prev]['book'] and user_goal[prev]['book'][slot] == \
                            user_goal[domain]['book'][slot]:
                        return prev
                return None

            # book
            book = 'book'
            if 'fail_book' in user_goal[dom]:
                book = 'fail_book'
            if 'book' in state:
                slot_strings = []
                for slot in ['people', 'time', 'day', 'stay']:
                    if slot in state[book]:
                        if slot == 'people':
                            same_people_domain = get_same_people_domain(
                                user_goal, dom, slot)
                            if same_people_domain is None:
                                slot_strings.append('for {} people'.format(
                                    self.boldify(state[book][slot])))
                            else:
                                slot_strings.append(
                                    self.boldify(
                                        'for the same group of people as the {} booking'
                                        .format(same_people_domain)))
                        elif slot == 'time':
                            slot_strings.append('at {}'.format(
                                self.boldify(state[book][slot])))
                        elif slot == 'day':
                            same_people_domain = get_same_people_domain(
                                user_goal, dom, slot)
                            if same_people_domain is None:
                                slot_strings.append('on {}'.format(
                                    self.boldify(state[book][slot])))
                            else:
                                slot_strings.append(
                                    self.boldify(
                                        'on the same day as the {} booking'.
                                        format(same_people_domain)))
                        elif slot == 'stay':
                            slot_strings.append('for {} nights'.format(
                                self.boldify(state[book][slot])))
                        del state[book][slot]

                assert len(state[book]) <= 0, state[book]

                if len(slot_strings) > 0:
                    message.append(templates[dom]['book'].format(
                        ' '.join(slot_strings)))

            # fail_book
            if 'fail_book' in user_goal[dom]:
                adjusted_slot = list(
                    filter(
                        lambda x: x[0][1] != x[1][1],
                        zip(user_goal[dom]['book'].items(),
                            user_goal[dom]['fail_book'].items())))[0][0][0]

                if adjusted_slot in ['internet', 'parking']:
                    message.append(
                        templates[dom]['fail_book ' + adjusted_slot + ' ' +
                                       user_goal[dom]['book'][adjusted_slot]])
                else:
                    message.append(
                        templates[dom]['fail_book ' + adjusted_slot].format(
                            self.boldify(
                                user_goal[dom]['book'][adjusted_slot])))

        if boldify == do_boldify:
            for i, m in enumerate(message):
                message[i] = message[i].replace('wifi', "<b>wifi</b>")
                message[i] = message[i].replace('internet', "<b>internet</b>")
                message[i] = message[i].replace('parking', "<b>parking</b>")

        return message