class StateTracker(object): def __init__(self, data_dir, config): self.time_step = 0 self.cfg = config self.db = DBQuery(data_dir, config) self.topic = '' self.evaluator = MultiWozEvaluator(data_dir) self.lock_evalutor = False def set_rollout(self, rollout): if rollout: self.save_time_step = self.time_step self.save_topic = self.topic self.lock_evalutor = True else: self.time_step = self.save_time_step self.save_topic = self.topic self.lock_evalutor = False def get_entities(self, s, domain): origin = s['belief_state'][domain].items() constraint = [] for k, v in origin: if v != '?' and k in self.cfg.mapping[domain]: constraint.append((self.cfg.mapping[domain][k], v)) entities = self.db.query(domain, constraint) random.shuffle(entities) return entities def update_belief_sys(self, old_s, a): """ update belief/goal state with sys action """ s = deepcopy(old_s) a_index = torch.nonzero(a) # get multiple da indices self.time_step += 1 s['others']['turn'] = self.time_step # update sys/user dialog act s['sys_action'] = dict() # update belief part das = [self.cfg.idx2da[idx.item()] for idx in a_index] das = [da.split('-') for da in das] sorted(das, key=lambda x: x[0]) # sort by domain entities = [] if self.topic == '' else self.get_entities(s, self.topic) return_flag = False for domain, intent, slot, p in das: if domain in self.cfg.belief_domains and domain != self.topic: self.topic = domain entities = self.get_entities(s, domain) da = '-'.join((domain, intent, slot, p)) if intent == 'request': s['sys_action'][da] = '?' elif intent in ['nooffer', 'nobook'] and self.topic != '': return_flag = True if slot in s['belief_state'][self.topic] and s['belief_state'][ self.topic][slot] != '?': s['sys_action'][da] = s['belief_state'][self.topic][slot] else: s['sys_action'][da] = 'none' elif slot == 'choice': s['sys_action'][da] = str(len(entities)) elif slot == 'none': s['sys_action'][da] = 'none' else: num = int(p) - 1 if self.topic and len( entities) > num and slot in self.cfg.mapping[ self.topic]: typ = self.cfg.mapping[self.topic][slot] if typ in entities[num]: s['sys_action'][da] = entities[num][typ] else: s['sys_action'][da] = 'none' else: s['sys_action'][da] = 'none' if not self.topic: continue if intent in [ 'inform', 'recommend', 'offerbook', 'offerbooked', 'book' ]: discard(s['belief_state'][self.topic], slot, '?') if slot in s['user_goal'][self.topic] and s['user_goal'][ self.topic][slot] == '?': s['goal_state'][self.topic][slot] = s['sys_action'][da] # booked if intent == 'inform' and slot == 'car': # taxi if 'booked' not in s['belief_state']['taxi']: s['belief_state']['taxi']['booked'] = 'taxi-booked' elif intent in ['offerbooked', 'book' ] and slot == 'ref': # train if self.topic in ['taxi', 'hospital', 'police']: s['belief_state'][ self.topic]['booked'] = f'{self.topic}-booked' s['sys_action'][da] = f'{self.topic}-booked' elif entities: book_domain = entities[0]['ref'].split('-')[0] if 'booked' not in s['belief_state'][ book_domain] and entities: s['belief_state'][book_domain][ 'booked'] = entities[0]['ref'] s['sys_action'][da] = entities[0]['ref'] if return_flag: for da in s['user_action']: d_usr, i_usr, s_usr = da.split('-') if i_usr == 'inform' and d_usr == self.topic: discard(s['belief_state'][d_usr], s_usr) reload(s['goal_state'], s['user_goal'], self.topic) if not self.lock_evalutor: self.evaluator.add_sys_da(s['sys_action']) return s def update_belief_usr(self, old_s, a): """ update belief/goal state with user action """ s = deepcopy(old_s) a_index = torch.nonzero(a) # get multiple da indices self.time_step += 1 s['others']['turn'] = self.time_step s['others']['terminal'] = 1 if (self.cfg.a_dim_usr - 1) in a_index else 0 # update sys/user dialog act s['user_action'] = dict() # update belief part das = [ self.cfg.idx2da_u[idx.item()] for idx in a_index if idx.item() != self.cfg.a_dim_usr - 1 ] das = [da.split('-') for da in das] if s['invisible_domains']: for da in das: if da[0] == s['next_available_domain']: s['next_available_domain'] = s['invisible_domains'][0] s['invisible_domains'].remove(s['next_available_domain']) break sorted(das, key=lambda x: x[0]) # sort by domain for domain, intent, slot in das: if domain in self.cfg.belief_domains and domain != self.topic: self.topic = domain da = '-'.join((domain, intent, slot)) if intent == 'request': s['user_action'][da] = '?' s['belief_state'][self.topic][slot] = '?' elif slot == 'none': s['user_action'][da] = 'none' else: if self.topic and slot in s['user_goal'][ self.topic] and s['user_goal'][domain][slot] != '?': s['user_action'][da] = s['user_goal'][domain][slot] else: s['user_action'][da] = 'dont care' if not self.topic: continue if intent == 'inform': s['belief_state'][domain][slot] = s['user_action'][da] if slot in s['user_goal'][self.topic] and s['user_goal'][ self.topic][slot] != '?': discard(s['goal_state'][self.topic], slot) if not self.lock_evalutor: self.evaluator.add_usr_da(s['user_action']) return s def reset(self, random_seed=None): """ Args: random_seed (int): Returns: init_state (dict): """ pass def step(self, s, sys_a): """ Args: s (dict): sys_a (vector): Returns: next_s (dict): terminal (bool): """ pass
class GoalGenerator: """User goal generator""" def __init__(self, data_dir, goal_model_path, corpus_path=None, boldify=False): """ Args: goal_model_path: path to a goal model corpus_path: path to a dialog corpus to build a goal model """ self.dbquery = DBQuery(data_dir) self.goal_model_path = data_dir + '/' + goal_model_path self.corpus_path = data_dir + '/' + corpus_path if corpus_path is not None else None self.boldify = do_boldify if boldify else null_boldify if os.path.exists(self.goal_model_path): with open(self.goal_model_path, 'rb') as f: self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist = pickle.load(f) logging.info('Loading goal model is done') else: self._build_goal_model() logging.info('Building goal model is done') # remove some slot del self.ind_slot_dist['police']['reqt']['postcode'] del self.ind_slot_value_dist['police']['reqt']['postcode'] del self.ind_slot_dist['hospital']['reqt']['postcode'] del self.ind_slot_value_dist['hospital']['reqt']['postcode'] del self.ind_slot_dist['hospital']['reqt']['address'] del self.ind_slot_value_dist['hospital']['reqt']['address'] def _build_goal_model(self): with open(self.corpus_path) as f: dialogs = json.load(f) # domain ordering def _get_dialog_domains(dialog): return list(filter(lambda x: x in domains and len(dialog['goal'][x]) > 0, dialog['goal'])) domain_orderings = [] for d in dialogs: d_domains = _get_dialog_domains(dialogs[d]) first_index = [] for domain in d_domains: message = [dialogs[d]['goal']['message']] if type(dialogs[d]['goal']['message']) == str else \ dialogs[d]['goal']['message'] for i, m in enumerate(message): if domain_keywords[domain].lower() in m.lower() or domain.lower() in m.lower(): first_index.append(i) break domain_orderings.append(tuple(map(lambda x: x[1], sorted(zip(first_index, d_domains), key=lambda x: x[0])))) domain_ordering_cnt = Counter(domain_orderings) self.domain_ordering_dist = deepcopy(domain_ordering_cnt) for order in domain_ordering_cnt.keys(): self.domain_ordering_dist[order] = domain_ordering_cnt[order] / sum(domain_ordering_cnt.values()) # independent goal slot distribution ind_slot_value_cnt = dict([(domain, {}) for domain in domains]) domain_cnt = Counter() book_cnt = Counter() for d in dialogs: for domain in domains: if dialogs[d]['goal'][domain] != {}: domain_cnt[domain] += 1 if 'info' in dialogs[d]['goal'][domain]: for slot in dialogs[d]['goal'][domain]['info']: if 'invalid' in slot: continue if 'info' not in ind_slot_value_cnt[domain]: ind_slot_value_cnt[domain]['info'] = {} if slot not in ind_slot_value_cnt[domain]['info']: ind_slot_value_cnt[domain]['info'][slot] = Counter() if 'care' in dialogs[d]['goal'][domain]['info'][slot]: continue ind_slot_value_cnt[domain]['info'][slot][dialogs[d]['goal'][domain]['info'][slot]] += 1 if 'reqt' in dialogs[d]['goal'][domain]: for slot in dialogs[d]['goal'][domain]['reqt']: if 'reqt' not in ind_slot_value_cnt[domain]: ind_slot_value_cnt[domain]['reqt'] = Counter() ind_slot_value_cnt[domain]['reqt'][slot] += 1 if 'book' in dialogs[d]['goal'][domain]: book_cnt[domain] += 1 for slot in dialogs[d]['goal'][domain]['book']: if 'invalid' in slot: continue if 'book' not in ind_slot_value_cnt[domain]: ind_slot_value_cnt[domain]['book'] = {} if slot not in ind_slot_value_cnt[domain]['book']: ind_slot_value_cnt[domain]['book'][slot] = Counter() if 'care' in dialogs[d]['goal'][domain]['book'][slot]: continue ind_slot_value_cnt[domain]['book'][slot][dialogs[d]['goal'][domain]['book'][slot]] += 1 self.ind_slot_value_dist = deepcopy(ind_slot_value_cnt) self.ind_slot_dist = dict([(domain, {}) for domain in domains]) self.book_dist = {} for domain in domains: if 'info' in ind_slot_value_cnt[domain]: for slot in ind_slot_value_cnt[domain]['info']: if 'info' not in self.ind_slot_dist[domain]: self.ind_slot_dist[domain]['info'] = {} if slot not in self.ind_slot_dist[domain]['info']: self.ind_slot_dist[domain]['info'][slot] = {} self.ind_slot_dist[domain]['info'][slot] = sum(ind_slot_value_cnt[domain]['info'][slot].values()) / \ domain_cnt[domain] slot_total = sum(ind_slot_value_cnt[domain]['info'][slot].values()) for val in self.ind_slot_value_dist[domain]['info'][slot]: self.ind_slot_value_dist[domain]['info'][slot][val] = ind_slot_value_cnt[domain]['info'][slot][ val] / slot_total if 'reqt' in ind_slot_value_cnt[domain]: for slot in ind_slot_value_cnt[domain]['reqt']: if 'reqt' not in self.ind_slot_dist[domain]: self.ind_slot_dist[domain]['reqt'] = {} self.ind_slot_dist[domain]['reqt'][slot] = ind_slot_value_cnt[domain]['reqt'][slot] / domain_cnt[ domain] self.ind_slot_value_dist[domain]['reqt'][slot] = ind_slot_value_cnt[domain]['reqt'][slot] / \ domain_cnt[domain] if 'book' in ind_slot_value_cnt[domain]: for slot in ind_slot_value_cnt[domain]['book']: if 'book' not in self.ind_slot_dist[domain]: self.ind_slot_dist[domain]['book'] = {} if slot not in self.ind_slot_dist[domain]['book']: self.ind_slot_dist[domain]['book'][slot] = {} self.ind_slot_dist[domain]['book'][slot] = sum(ind_slot_value_cnt[domain]['book'][slot].values()) / \ domain_cnt[domain] slot_total = sum(ind_slot_value_cnt[domain]['book'][slot].values()) for val in self.ind_slot_value_dist[domain]['book'][slot]: self.ind_slot_value_dist[domain]['book'][slot][val] = ind_slot_value_cnt[domain]['book'][slot][ val] / slot_total self.book_dist[domain] = book_cnt[domain] / len(dialogs) with open(self.goal_model_path, 'wb') as f: pickle.dump((self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist), f) def _get_domain_goal(self, domain): cnt_slot = self.ind_slot_dist[domain] cnt_slot_value = self.ind_slot_value_dist[domain] pro_book = self.book_dist[domain] while True: # domain_goal = defaultdict(lambda: {}) # domain_goal = {'info': {}, 'fail_info': {}, 'reqt': {}, 'book': {}, 'fail_book': {}} domain_goal = {'info': {}} # inform if 'info' in cnt_slot: for slot in cnt_slot['info']: if random.random() < cnt_slot['info'][slot] + pro_correction['info']: domain_goal['info'][slot] = nomial_sample(cnt_slot_value['info'][slot]) if domain in ['hotel', 'restaurant', 'attraction'] and 'name' in domain_goal['info'] and len( domain_goal['info']) > 1: if random.random() < cnt_slot['info']['name']: domain_goal['info'] = {'name': domain_goal['info']['name']} else: del domain_goal['info']['name'] if domain in ['taxi', 'train'] and 'arriveBy' in domain_goal['info'] and 'leaveAt' in domain_goal[ 'info']: if random.random() < ( cnt_slot['info']['leaveAt'] / (cnt_slot['info']['arriveBy'] + cnt_slot['info']['leaveAt'])): del domain_goal['info']['arriveBy'] else: del domain_goal['info']['leaveAt'] if domain in ['taxi', 'train'] and 'arriveBy' not in domain_goal['info'] and 'leaveAt' not in \ domain_goal['info']: if random.random() < (cnt_slot['info']['arriveBy'] / ( cnt_slot['info']['arriveBy'] + cnt_slot['info']['leaveAt'])): domain_goal['info']['arriveBy'] = nomial_sample(cnt_slot_value['info']['arriveBy']) else: domain_goal['info']['leaveAt'] = nomial_sample(cnt_slot_value['info']['leaveAt']) if domain in ['taxi', 'train'] and 'departure' not in domain_goal['info']: domain_goal['info']['departure'] = nomial_sample(cnt_slot_value['info']['departure']) if domain in ['taxi', 'train'] and 'destination' not in domain_goal['info']: domain_goal['info']['destination'] = nomial_sample(cnt_slot_value['info']['destination']) if domain in ['taxi', 'train'] and \ 'departure' in domain_goal['info'] and \ 'destination' in domain_goal['info'] and \ domain_goal['info']['departure'] == domain_goal['info']['destination']: if random.random() < (cnt_slot['info']['departure'] / ( cnt_slot['info']['departure'] + cnt_slot['info']['destination'])): domain_goal['info']['departure'] = nomial_sample(cnt_slot_value['info']['departure']) else: domain_goal['info']['destination'] = nomial_sample(cnt_slot_value['info']['destination']) if domain_goal['info'] == {}: continue # request if 'reqt' in cnt_slot: reqt = [slot for slot in cnt_slot['reqt'] if random.random() < cnt_slot['reqt'][slot] + pro_correction['reqt'] and slot not in domain_goal['info']] if len(reqt) > 0: domain_goal['reqt'] = reqt # book if 'book' in cnt_slot and random.random() < pro_book + pro_correction['book']: if 'book' not in domain_goal: domain_goal['book'] = {} for slot in cnt_slot['book']: if random.random() < cnt_slot['book'][slot] + pro_correction['book']: domain_goal['book'][slot] = nomial_sample(cnt_slot_value['book'][slot]) # makes sure that there are all necessary slots for booking if domain == 'restaurant' and 'time' not in domain_goal['book']: domain_goal['book']['time'] = nomial_sample(cnt_slot_value['book']['time']) if domain == 'hotel' and 'stay' not in domain_goal['book']: domain_goal['book']['stay'] = nomial_sample(cnt_slot_value['book']['stay']) if domain in ['hotel', 'restaurant'] and 'day' not in domain_goal['book']: domain_goal['book']['day'] = nomial_sample(cnt_slot_value['book']['day']) if domain in ['hotel', 'restaurant'] and 'people' not in domain_goal['book']: domain_goal['book']['people'] = nomial_sample(cnt_slot_value['book']['people']) if domain == 'train' and len(domain_goal['book']) <= 0: domain_goal['book']['people'] = nomial_sample(cnt_slot_value['book']['people']) # fail_book if 'book' in domain_goal and random.random() < 0.5: if domain == 'hotel': domain_goal['fail_book'] = deepcopy(domain_goal['book']) if 'stay' in domain_goal['book'] and random.random() < 0.5: # increase hotel-stay domain_goal['fail_book']['stay'] = str(int(domain_goal['book']['stay']) + 1) elif 'day' in domain_goal['book']: # push back hotel-day by a day domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) - 1) % 7] elif domain == 'restaurant': domain_goal['fail_book'] = deepcopy(domain_goal['book']) if 'time' in domain_goal['book'] and random.random() < 0.5: hour, minute = domain_goal['book']['time'].split(':') domain_goal['fail_book']['time'] = str((int(hour) + 1) % 24) + ':' + minute elif 'day' in domain_goal['book']: if random.random() < 0.5: domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) - 1) % 7] else: domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) + 1) % 7] # fail_info if 'info' in domain_goal and len(self.dbquery.query(domain, domain_goal['info'].items())) == 0: num_trial = 0 while num_trial < 100: adjusted_info = self._adjust_info(domain, domain_goal['info']) if len(self.dbquery.query(domain, adjusted_info.items())) > 0: if domain == 'train': domain_goal['info'] = adjusted_info else: domain_goal['fail_info'] = domain_goal['info'] domain_goal['info'] = adjusted_info break num_trial += 1 if num_trial >= 100: continue # at least there is one request and book if 'reqt' in domain_goal or 'book' in domain_goal: break return domain_goal def get_user_goal(self, seed=None): # seed the generator to get fixed goal if seed is not None: random.seed(seed) np.random.seed(seed) domain_ordering = () while len(domain_ordering) <= 0: domain_ordering = nomial_sample(self.domain_ordering_dist) #domain_ordering = ('attraction', 'restaurant', 'taxi') user_goal = {dom: self._get_domain_goal(dom) for dom in domain_ordering} assert len(user_goal.keys()) > 0 # using taxi to communte between places, removing destination and departure. if 'taxi' in domain_ordering: places = [dom for dom in domain_ordering[: domain_ordering.index('taxi')] if 'address' in self.ind_slot_dist[dom]['reqt'].keys()] if len(places) >= 1: del user_goal['taxi']['info']['destination'] user_goal[places[-1]]['reqt'] = list(set(user_goal[places[-1]].get('reqt', [])).union({'address'})) if places[-1] == 'restaurant' and 'book' in user_goal['restaurant']: user_goal['taxi']['info']['arriveBy'] = user_goal['restaurant']['book']['time'] if 'leaveAt' in user_goal['taxi']['info']: del user_goal['taxi']['info']['leaveAt'] if len(places) >= 2: del user_goal['taxi']['info']['departure'] user_goal[places[-2]]['reqt'] = list(set(user_goal[places[-2]].get('reqt', [])).union({'address'})) # match area of attraction and restaurant if 'restaurant' in domain_ordering and \ 'attraction' in domain_ordering and \ 'fail_info' not in user_goal['restaurant'] and \ domain_ordering.index('restaurant') > domain_ordering.index('attraction') and \ 'area' in user_goal['restaurant']['info'] and 'area' in user_goal['attraction']['info']: adjusted_restaurant_goal = deepcopy(user_goal['restaurant']['info']) adjusted_restaurant_goal['area'] = user_goal['attraction']['info']['area'] if len(self.dbquery.query('restaurant', adjusted_restaurant_goal.items())) > 0 and random.random() < 0.5: user_goal['restaurant']['info']['area'] = user_goal['attraction']['info']['area'] # match day and people of restaurant and hotel if 'restaurant' in domain_ordering and 'hotel' in domain_ordering and \ 'book' in user_goal['restaurant'] and 'book' in user_goal['hotel']: if random.random() < 0.5: user_goal['restaurant']['book']['people'] = user_goal['hotel']['book']['people'] if 'fail_book' in user_goal['restaurant']: user_goal['restaurant']['fail_book']['people'] = user_goal['hotel']['book']['people'] if random.random() < 1.0: user_goal['restaurant']['book']['day'] = user_goal['hotel']['book']['day'] if 'fail_book' in user_goal['restaurant']: user_goal['restaurant']['fail_book']['day'] = user_goal['hotel']['book']['day'] if user_goal['restaurant']['book']['day'] == user_goal['restaurant']['fail_book']['day'] and \ user_goal['restaurant']['book']['time'] == user_goal['restaurant']['fail_book']['time'] and \ user_goal['restaurant']['book']['people'] == user_goal['restaurant']['fail_book']['people']: del user_goal['restaurant']['fail_book'] # match day and people of hotel and train if 'hotel' in domain_ordering and 'train' in domain_ordering and \ 'book' in user_goal['hotel'] and 'info' in user_goal['train']: if user_goal['train']['info']['destination'] == 'cambridge' and \ 'day' in user_goal['hotel']['book']: user_goal['train']['info']['day'] = user_goal['hotel']['book']['day'] elif user_goal['train']['info']['departure'] == 'cambridge' and \ 'day' in user_goal['hotel']['book'] and 'stay' in user_goal['hotel']['book']: user_goal['train']['info']['day'] = days[ (days.index(user_goal['hotel']['book']['day']) + int( user_goal['hotel']['book']['stay'])) % 7] # In case, we have no query results with adjusted train goal, we simply drop the train goal. if len(self.dbquery.query('train', user_goal['train']['info'].items())) == 0: del user_goal['train'] domain_ordering = tuple(list(domain_ordering).remove('train')) user_goal['domain_ordering'] = domain_ordering return user_goal def _adjust_info(self, domain, info): # adjust one of the slots of the info adjusted_info = deepcopy(info) slot = random.choice(list(info.keys())) adjusted_info[slot] = random.choice(list(self.ind_slot_value_dist[domain]['info'][slot].keys())) return adjusted_info def build_message(self, user_goal, boldify=null_boldify): message = [] state = deepcopy(user_goal) for dom in user_goal['domain_ordering']: state = deepcopy(user_goal[dom]) if not (dom == 'taxi' and len(state['info']) == 1): # intro m = [templates[dom]['intro']] # info def fill_info_template(user_goal, domain, slot, info): if slot != 'area' or not ('restaurant' in user_goal and 'attraction' in user_goal and info in user_goal['restaurant'].keys() and info in user_goal['attraction'].keys() and 'area' in user_goal['restaurant'][info] and 'area' in user_goal['attraction'][info] and user_goal['restaurant'][info]['area'] == user_goal['attraction'][info]['area']): return templates[domain][slot].format(self.boldify(user_goal[domain][info][slot])) else: restaurant_index = user_goal['domain_ordering'].index('restaurant') attraction_index = user_goal['domain_ordering'].index('attraction') if restaurant_index > attraction_index and domain == 'restaurant': return templates[domain][slot].format(self.boldify('same area as the attraction')) elif attraction_index > restaurant_index and domain == 'attraction': return templates[domain][slot].format(self.boldify('same area as the restaurant')) return templates[domain][slot].format(self.boldify(user_goal[domain][info][slot])) info = 'info' if 'fail_info' in user_goal[dom]: info = 'fail_info' if dom == 'taxi' and len(state[info]) == 1: taxi_index = user_goal['domain_ordering'].index('taxi') places = [dom for dom in user_goal['domain_ordering'][: taxi_index] if dom in ['attraction', 'hotel', 'restaurant']] if len(places) >= 2: random.shuffle(places) m.append(templates['taxi']['commute']) if 'arriveBy' in state[info]: m.append('The taxi should arrive at the {} from the {} by {}.'.format(self.boldify(places[0]), self.boldify(places[1]), self.boldify(state[info]['arriveBy']))) elif 'leaveAt' in state[info]: m.append('The taxi should leave from the {} to the {} after {}.'.format(self.boldify(places[0]), self.boldify(places[1]), self.boldify(state[info]['leaveAt']))) message.append(' '.join(m)) else: while len(state[info]) > 0: num_acts = random.randint(1, min(len(state[info]), 3)) slots = random.sample(list(state[info].keys()), num_acts) sents = [fill_info_template(user_goal, dom, slot, info) for slot in slots if slot not in ['parking', 'internet']] if 'parking' in slots: sents.append(templates[dom]['parking ' + state[info]['parking']]) if 'internet' in slots: sents.append(templates[dom]['internet ' + state[info]['internet']]) m.extend(sents) message.append(' '.join(m)) m = [] for slot in slots: del state[info][slot] # fail_info if 'fail_info' in user_goal[dom]: # if 'fail_info' in user_goal[dom]: adjusted_slot = list(filter(lambda x: x[0][1] != x[1][1], zip(user_goal[dom]['info'].items(), user_goal[dom]['fail_info'].items())))[0][0][0] if adjusted_slot in ['internet', 'parking']: message.append(templates[dom]['fail_info ' + adjusted_slot + ' ' + user_goal[dom]['info'][adjusted_slot]]) else: message.append(templates[dom]['fail_info ' + adjusted_slot].format(self.boldify(user_goal[dom]['info'][adjusted_slot]))) # reqt if 'reqt' in state: slot_strings = [] for slot in state['reqt']: if slot in ['internet', 'parking', 'food']: continue slot_strings.append(slot if slot not in request_slot_string_map else request_slot_string_map[slot]) if len(slot_strings) > 0: message.append(templates[dom]['request'].format(self.boldify(', '.join(slot_strings)))) if 'internet' in state['reqt']: message.append('Make sure to ask if the hotel includes free wifi.') if 'parking' in state['reqt']: message.append('Make sure to ask if the hotel includes free parking.') if 'food' in state['reqt']: message.append('Make sure to ask about what food it serves.') def get_same_people_domain(user_goal, domain, slot): if slot not in ['day', 'people']: return None domain_index = user_goal['domain_ordering'].index(domain) previous_domains = user_goal['domain_ordering'][:domain_index] for prev in previous_domains: if prev in ['restaurant', 'hotel', 'train'] and 'book' in user_goal[prev] and \ slot in user_goal[prev]['book'] and user_goal[prev]['book'][slot] == \ user_goal[domain]['book'][slot]: return prev return None # book book = 'book' if 'fail_book' in user_goal[dom]: book = 'fail_book' if 'book' in state: slot_strings = [] for slot in ['people', 'time', 'day', 'stay']: if slot in state[book]: if slot == 'people': same_people_domain = get_same_people_domain(user_goal, dom, slot) if same_people_domain is None: slot_strings.append('for {} people'.format(self.boldify(state[book][slot]))) else: slot_strings.append(self.boldify( 'for the same group of people as the {} booking'.format(same_people_domain))) elif slot == 'time': slot_strings.append('at {}'.format(self.boldify(state[book][slot]))) elif slot == 'day': same_people_domain = get_same_people_domain(user_goal, dom, slot) if same_people_domain is None: slot_strings.append('on {}'.format(self.boldify(state[book][slot]))) else: slot_strings.append( self.boldify('on the same day as the {} booking'.format(same_people_domain))) elif slot == 'stay': slot_strings.append('for {} nights'.format(self.boldify(state[book][slot]))) del state[book][slot] assert len(state[book]) <= 0, state[book] if len(slot_strings) > 0: message.append(templates[dom]['book'].format(' '.join(slot_strings))) # fail_book if 'fail_book' in user_goal[dom]: adjusted_slot = list(filter(lambda x: x[0][1] != x[1][1], zip(user_goal[dom]['book'].items(), user_goal[dom]['fail_book'].items())))[0][0][0] if adjusted_slot in ['internet', 'parking']: message.append( templates[dom]['fail_book ' + adjusted_slot + ' ' + user_goal[dom]['book'][adjusted_slot]]) else: message.append(templates[dom]['fail_book ' + adjusted_slot].format( self.boldify(user_goal[dom]['book'][adjusted_slot]))) if boldify == do_boldify: for i, m in enumerate(message): message[i] = message[i].replace('wifi', "<b>wifi</b>") message[i] = message[i].replace('internet', "<b>internet</b>") message[i] = message[i].replace('parking', "<b>parking</b>") return message
class StateTracker(object): def __init__(self, data_dir, config): self.time_step = 0 self.cfg = config self.db = DBQuery(data_dir) self.topic = 'NONE' def _action_to_dict(self, das): da_dict = {} for da, value in das.items(): domain, intent, slot, p = da.split('-') domint = '-'.join((domain, intent)) if domint not in da_dict: da_dict[domint] = [] da_dict[domint].append([slot, value]) return da_dict def _dict_to_vec(self, das): da_vector = torch.zeros(self.cfg.a_dim_usr, dtype=torch.int32) expand_da(das) for domint in das: pairs = das[domint] for slot, p, value in pairs: da = '-'.join((domint, slot, p)).lower() if da in self.cfg.dau2idx: idx = self.cfg.dau2idx[da] da_vector[idx] = 1 return da_vector def _mask_user_goal(self, goal): domain_ordering = list(goal['domain_ordering']) if 'hospital' in goal: del (goal['hospital']) domain_ordering.remove('hospital') if 'police' in goal: del (goal['police']) domain_ordering.remove('police') goal['domain_ordering'] = tuple(domain_ordering) def get_entities(self, s, domain): origin = s['belief_state']['inform'][domain].items() constraint = [] for k, v in origin: if k in self.cfg.mapping[domain]: constraint.append((self.cfg.mapping[domain][k], v)) entities = self.db.query(domain, constraint) random.shuffle(entities) return entities def update_belief_sys(self, old_s, a): """ update belief state with sys action """ s = deepcopy(old_s) a_index = torch.nonzero(a) # get multiple da indices self.time_step += 1 s['others']['turn'] = self.time_step # update sys/user dialog act s['history']['sys'] = dict(s['history']['sys'], **s['last_sys_action']) del (s['last_sys_action']) s['last_user_action'] = s['user_action'] s['user_action'] = dict() # update belief part das = [self.cfg.idx2da[idx.item()] for idx in a_index] das = [da.split('-') for da in das] sorted(das, key=lambda x: x[0]) # sort by domain entities = [] if self.topic == 'NONE' else self.get_entities( s, self.topic) for domain, intent, slot, p in das: _domain = self.topic if domain == 'booking' else domain if domain in self.cfg.belief_domains and domain != self.topic: self.topic = domain entities = self.get_entities(s, domain) da = '-'.join((domain, intent, slot, p)) if p == 'none': s['sys_action'][da] = 'none' elif p == '?': s['sys_action'][da] = '?' elif intent in ['nooffer', 'nobook']: if slot in s['belief_state']['inform'][_domain]: s['sys_action'][da] = s['belief_state']['inform'][_domain][ slot] else: s['sys_action'][da] = 'none' elif slot == 'choice': s['sys_action'][da] = str(len(entities)) else: num = int(p) - 1 if len(entities) > num and slot in self.cfg.mapping[_domain]: typ = self.cfg.mapping[_domain][slot] s['sys_action'][da] = entities[num][typ] else: s['sys_action'][da] = 'none' if intent == 'inform' and _domain != 'NONE': s['belief_state']['request'][_domain].discard(slot) # booked if intent == 'inform' and slot == 'car': # taxi if not s['belief_state']['booked']['taxi']: s['belief_state']['booked']['taxi'] == 'booked' elif intent == 'offerbooked' and slot == 'ref': # train s['belief_state']['request']['train'].discard('ref') if not s['belief_state']['booked']['train'] and entities: s['belief_state']['booked']['train'] = entities[0]['ref'] elif intent == 'book' and slot == 'ref': # attraction, hotel, restaurant if _domain not in ['attraction', 'hotel', 'restaurant']: continue s['belief_state']['request'][_domain].discard('ref') if not s['belief_state']['booked'][_domain] and entities: # save entity id s['belief_state']['booked'][_domain] = entities[0]['ref'] return s def update_belief_usr(self, old_s, a, terminal): """ update belief state with user action """ s = deepcopy(old_s) a_index = torch.nonzero(a) # get multiple da indices self.time_step += 1 s['others']['turn'] = self.time_step s['others']['terminal'] = terminal # update sys/user dialog act s['history']['user'] = dict(s['history']['user'], **s['last_user_action']) del (s['last_user_action']) s['last_sys_action'] = s['sys_action'] s['sys_action'] = dict() # update belief part das = [self.cfg.idx2dau[idx.item()] for idx in a_index] das = [da.split('-') for da in das] sorted(das, key=lambda x: x[0]) # sort by domain for domain, intent, slot, p in das: if domain in self.cfg.belief_domains and domain != self.topic: self.topic = domain da = '-'.join((domain, intent, slot, p)) if p == 'none': s['user_action'][da] = 'none' elif p == '?': s['user_action'][da] = '?' else: if slot in s['user_goal']['inform'][domain]: s['user_action'][da] = s['user_goal']['inform'][domain][ slot] else: s['user_action'][da] = 'none' if slot != 'none': if intent == 'inform': # update constraints with reasonable value according to user goal if slot in s['user_goal']['inform'][domain]: s['belief_state']['inform'][domain][slot] = s[ 'user_goal']['inform'][domain][slot] # value else: s['belief_state']['inform'][domain][slot] = 'none' elif intent == 'request': s['belief_state']['request'][domain].add(slot) return s def reset(self, random_seed=None): """ Args: random_seed (int): Returns: init_state (dict): """ pass def step(self, s, sys_a): """ Args: s (dict): sys_a (vector): Returns: next_s (dict): terminal (bool): """ pass
class GoalGenerator: """User goal generator""" def __init__(self, data_dir, cfg, goal_model_path, corpus_path=None, boldify=False): """ 生成各种分布: cfg 这里只用于为dbquery提供地址,并没有提供对数据分布进行控制 self.ind_slot_dist: 独立实体分布 self.ind_slot_value_dist: 独立实体值分布 self.domain_ordering_dist: domain请求分布 self.book_dist: 订阅分布 Args: goal_model_path: path to a goal model corpus_path: path to a dialog corpus to build a goal model """ self.cfg = cfg if self.cfg.d: self.domains = [self.cfg.d] else: print(self.cfg) self.domains = { 'attraction', 'hotel', 'restaurant', 'train', 'taxi', 'hospital', 'police' } self.dbquery = DBQuery(data_dir, cfg) self.goal_model_path = data_dir + '/' + goal_model_path self.corpus_path = data_dir + '/' + corpus_path if corpus_path is not None else None self.boldify = do_boldify if boldify else null_boldify if os.path.exists(self.goal_model_path): with open(self.goal_model_path, 'rb') as f: self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist = pickle.load( f) logging.info('Loading goal model is done') else: self._build_goal_model() logging.info('Building goal model is done') # remove some slot if 'police' in self.ind_slot_dist: del self.ind_slot_dist['police']['reqt']['postcode'] del self.ind_slot_value_dist['police']['reqt']['postcode'] if 'hospital' in self.ind_slot_dist: del self.ind_slot_dist['hospital']['reqt']['postcode'] del self.ind_slot_value_dist['hospital']['reqt']['postcode'] del self.ind_slot_dist['hospital']['reqt']['address'] del self.ind_slot_value_dist['hospital']['reqt']['address'] def _build_goal_model(self): with open(self.corpus_path) as f: dialogs = json.load(f) # domain ordering def _get_dialog_domains(dialog): """收集dialog中的有效domains""" return list( filter( lambda x: x in self.domains and len(dialog['goal'][x]) > 0, dialog['goal'])) domain_orderings = [] for d in dialogs: d_domains = _get_dialog_domains(dialogs[d]) if self.cfg.d: # 直接跳过多场景切换问题, 如果需要增加对多场景切换问题的研究,则需要注释掉这段 if len(d_domains) > 1 or len(d_domains) == 0: continue # print("d_domains: ", d_domains) first_index = [] # 找到每个domain的首行message for domain in d_domains: message = [dialogs[d]['goal']['message']] if type(dialogs[d]['goal']['message']) == str else \ dialogs[d]['goal']['message'] for i, m in enumerate(message): if domain_keywords[domain].lower() in m.lower( ) or domain.lower() in m.lower(): first_index.append(i) break # 根据首行message的实际序号,调整domains的出现顺序 domain_orderings.append( tuple( map( lambda x: x[1], sorted(zip(first_index, d_domains), key=lambda x: x[0])))) domain_ordering_cnt = Counter(domain_orderings) self.domain_ordering_dist = deepcopy(domain_ordering_cnt) # 转化为概率值 for order in domain_ordering_cnt.keys(): self.domain_ordering_dist[order] = domain_ordering_cnt[ order] / sum(domain_ordering_cnt.values()) # independent goal slot distribution ind_slot_value_cnt = dict([(domain, {}) for domain in self.domains]) domain_cnt = Counter() book_cnt = Counter() # 对各个用户的用户目标进行统计 for d in dialogs: for domain in self.domains: # 统计domain出现的次数 if dialogs[d]['goal'][domain] != {}: domain_cnt[domain] += 1 # 统计info中各个slot及其slot-value出现的次数 if 'info' in dialogs[d]['goal'][domain]: for slot in dialogs[d]['goal'][domain]['info']: if 'invalid' in slot: continue if 'info' not in ind_slot_value_cnt[domain]: ind_slot_value_cnt[domain]['info'] = {} if slot not in ind_slot_value_cnt[domain]['info']: ind_slot_value_cnt[domain]['info'][slot] = Counter( ) if 'care' in dialogs[d]['goal'][domain]['info'][slot]: continue slot_value = dialogs[d]['goal'][domain]['info'][slot] ind_slot_value_cnt[domain]['info'][slot][ slot_value] += 1 # 统计用户目标汇总reqt的slot出现的次数,由于没有value,因此直接对slot进行统计 if 'reqt' in dialogs[d]['goal'][domain]: for slot in dialogs[d]['goal'][domain]['reqt']: if 'reqt' not in ind_slot_value_cnt[domain]: ind_slot_value_cnt[domain]['reqt'] = Counter() ind_slot_value_cnt[domain]['reqt'][slot] += 1 # 统计用户目标中book的信息 if 'book' in dialogs[d]['goal'][domain]: # 各个domain book的概率 book_cnt[domain] += 1 # book中限制slot的value的次数 for slot in dialogs[d]['goal'][domain]['book']: if 'invalid' in slot: continue if 'book' not in ind_slot_value_cnt[domain]: ind_slot_value_cnt[domain]['book'] = {} if slot not in ind_slot_value_cnt[domain]['book']: ind_slot_value_cnt[domain]['book'][slot] = Counter( ) if 'care' in dialogs[d]['goal'][domain]['book'][slot]: continue ind_slot_value_cnt[domain]['book'][slot][ dialogs[d]['goal'][domain]['book'][slot]] += 1 # 收集完成ind_slot_value_cnt 和 domain_cnt 和 book_cnt 的信息 # 计算得到 ind_slot_dist ind_slot_value_dist 基于不同domain的分布 self.ind_slot_value_dist = deepcopy(ind_slot_value_cnt) self.ind_slot_dist = dict([(domain, {}) for domain in self.domains]) self.book_dist = {} for domain in self.domains: if 'info' in ind_slot_value_cnt[domain]: for slot in ind_slot_value_cnt[domain]['info']: if 'info' not in self.ind_slot_dist[domain]: self.ind_slot_dist[domain]['info'] = {} if slot not in self.ind_slot_dist[domain]['info']: self.ind_slot_dist[domain]['info'][slot] = {} self.ind_slot_dist[domain]['info'][slot] = sum(ind_slot_value_cnt[domain]['info'][slot].values()) / \ domain_cnt[domain] slot_total = sum( ind_slot_value_cnt[domain]['info'][slot].values()) for val in self.ind_slot_value_dist[domain]['info'][slot]: self.ind_slot_value_dist[domain]['info'][slot][ val] = ind_slot_value_cnt[domain]['info'][slot][ val] / slot_total if 'reqt' in ind_slot_value_cnt[domain]: for slot in ind_slot_value_cnt[domain]['reqt']: if 'reqt' not in self.ind_slot_dist[domain]: self.ind_slot_dist[domain]['reqt'] = {} self.ind_slot_dist[domain]['reqt'][ slot] = ind_slot_value_cnt[domain]['reqt'][ slot] / domain_cnt[domain] self.ind_slot_value_dist[domain]['reqt'][slot] = ind_slot_value_cnt[domain]['reqt'][slot] / \ domain_cnt[domain] if 'book' in ind_slot_value_cnt[domain]: for slot in ind_slot_value_cnt[domain]['book']: if 'book' not in self.ind_slot_dist[domain]: self.ind_slot_dist[domain]['book'] = {} if slot not in self.ind_slot_dist[domain]['book']: self.ind_slot_dist[domain]['book'][slot] = {} self.ind_slot_dist[domain]['book'][slot] = sum(ind_slot_value_cnt[domain]['book'][slot].values()) / \ domain_cnt[domain] slot_total = sum( ind_slot_value_cnt[domain]['book'][slot].values()) for val in self.ind_slot_value_dist[domain]['book'][slot]: self.ind_slot_value_dist[domain]['book'][slot][ val] = ind_slot_value_cnt[domain]['book'][slot][ val] / slot_total self.book_dist[domain] = book_cnt[domain] / len(dialogs) # print() # print("ind_slot_dist:", self.ind_slot_dist) # print("ind_slot_value_dist ", self.ind_slot_value_dist) # print("domain_ordering_dist:", self.domain_ordering_dist) # print("book_dist:", self.book_dist) # print() goal_model_path_dir = os.path.dirname(self.goal_model_path) if len(goal_model_path_dir) > 0 and not os.path.exists( goal_model_path_dir): os.makedirs(goal_model_path_dir) with open(self.goal_model_path, 'wb') as f: pickle.dump((self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist), f) def _get_domain_goal(self, domain): """按照domain的设定,生成对应的目标""" cnt_slot = self.ind_slot_dist[domain] cnt_slot_value = self.ind_slot_value_dist[domain] pro_book = self.book_dist[domain] while True: domain_goal = {'info': {}} # inform if 'info' in cnt_slot: for slot in cnt_slot['info']: # 一定概率增加对该slot的限制 if random.random( ) < cnt_slot['info'][slot] + pro_correction['info']: domain_goal['info'][slot] = nomial_sample( cnt_slot_value['info'][slot]) # 对hotel restaurant attraction的限制, if domain in ['hotel', 'restaurant', 'attraction' ] and 'name' in domain_goal['info'] and len( domain_goal['info']) > 1: # 一定概率提供name,但是只需要提供name就可以了,其他的信息已经不重要了 if random.random() < cnt_slot['info']['name']: domain_goal['info'] = { 'name': domain_goal['info']['name'] } else: # 或者删掉name del domain_goal['info']['name'] # 对于taxi 和train, 采用一定的概率删除leaveAt和arriveBy中的一个,不然答案就唯一了 if domain in ['taxi', 'train'] and 'arriveBy' in domain_goal[ 'info'] and 'leaveAt' in domain_goal['info']: if random.random() < (cnt_slot['info']['leaveAt'] / (cnt_slot['info']['arriveBy'] + cnt_slot['info']['leaveAt'])): del domain_goal['info']['arriveBy'] else: del domain_goal['info']['leaveAt'] # 但是arriveBy和leaveAt又必须提供一个 if domain in ['taxi', 'train'] and 'arriveBy' not in domain_goal['info'] and 'leaveAt' not in \ domain_goal['info']: if random.random() < (cnt_slot['info']['arriveBy'] / (cnt_slot['info']['arriveBy'] + cnt_slot['info']['leaveAt'])): domain_goal['info']['arriveBy'] = nomial_sample( cnt_slot_value['info']['arriveBy']) else: domain_goal['info']['leaveAt'] = nomial_sample( cnt_slot_value['info']['leaveAt']) # 需要告知出发地点和目的地 if domain in ['taxi', 'train' ] and 'departure' not in domain_goal['info']: domain_goal['info']['departure'] = nomial_sample( cnt_slot_value['info']['departure']) if domain in ['taxi', 'train' ] and 'destination' not in domain_goal['info']: domain_goal['info']['destination'] = nomial_sample( cnt_slot_value['info']['destination']) # 如果出发地和目的地一样,一定概率重新采样 if domain in ['taxi', 'train'] and \ 'departure' in domain_goal['info'] and \ 'destination' in domain_goal['info']: while domain_goal['info']['departure'] == domain_goal[ 'info']['destination']: if random.random() < ( cnt_slot['info']['departure'] / (cnt_slot['info']['departure'] + cnt_slot['info']['destination'])): domain_goal['info']['departure'] = nomial_sample( cnt_slot_value['info']['departure']) else: domain_goal['info']['destination'] = nomial_sample( cnt_slot_value['info']['destination']) # print("same destination departure") if domain_goal['info'] == {}: # 如果没有用户提供的信息,则重新生成 continue # request # 针对infor的设定,随机选择request的限制 if 'reqt' in cnt_slot: reqt = [ slot for slot in cnt_slot['reqt'] if random.random() < cnt_slot['reqt'][slot] + pro_correction['reqt'] and slot not in domain_goal['info'] ] if len(reqt) > 0: domain_goal['reqt'] = reqt # book # 一定概率决定是否book, 但是对于book的信息并没有和info生成的信息进行协调 if 'book' in cnt_slot and random.random( ) < pro_book + pro_correction['book']: if 'book' not in domain_goal: domain_goal['book'] = {} for slot in cnt_slot['book']: # 随机选取book的限制 if random.random( ) < cnt_slot['book'][slot] + pro_correction['book']: domain_goal['book'][slot] = nomial_sample( cnt_slot_value['book'][slot]) # makes sure that there are all necessary slots for booking # 预定餐馆需要告知时间 if domain == 'restaurant' and 'time' not in domain_goal['book']: domain_goal['book']['time'] = nomial_sample( cnt_slot_value['book']['time']) # 预定旅馆需要告知几天 if domain == 'hotel' and 'stay' not in domain_goal['book']: domain_goal['book']['stay'] = nomial_sample( cnt_slot_value['book']['stay']) # 预定残旅馆需要告知哪天几个人 if domain in ['hotel', 'restaurant' ] and 'day' not in domain_goal['book']: domain_goal['book']['day'] = nomial_sample( cnt_slot_value['book']['day']) if domain in ['hotel', 'restaurant' ] and 'people' not in domain_goal['book']: domain_goal['book']['people'] = nomial_sample( cnt_slot_value['book']['people']) # 如果是train,需要告知几个人 if domain == 'train' and len(domain_goal['book']) <= 0: domain_goal['book']['people'] = nomial_sample( cnt_slot_value['book']['people']) # fail_book # 只有残旅馆会出现fail_book if 'book' in domain_goal and random.random() < 0.5: # 对于旅馆的订购失败设定 if domain == 'hotel': domain_goal['fail_book'] = deepcopy(domain_goal['book']) if 'stay' in domain_goal['book'] and random.random() < 0.5: # increase hotel-stay # 一定概率增加天数 domain_goal['fail_book']['stay'] = str( int(domain_goal['book']['stay']) + 1) elif 'day' in domain_goal['book']: # push back hotel-day by a day # 一定概率提前日期 domain_goal['fail_book']['day'] = days[ (days.index(domain_goal['book']['day']) - 1) % 7] elif domain == 'restaurant': domain_goal['fail_book'] = deepcopy(domain_goal['book']) if 'time' in domain_goal['book'] and random.random() < 0.5: hour, minute = domain_goal['book']['time'].split(':') # 增加1小时 domain_goal['fail_book']['time'] = str( (int(hour) + 1) % 24) + ':' + minute elif 'day' in domain_goal['book']: if random.random() < 0.5: # 提前一天 domain_goal['fail_book']['day'] = days[ (days.index(domain_goal['book']['day']) - 1) % 7] else: # 推迟一天 domain_goal['fail_book']['day'] = days[ (days.index(domain_goal['book']['day']) + 1) % 7] # fail_info if 'info' in domain_goal and len( self.dbquery.query(domain, domain_goal['info'].items())) == 0: # 随机生成的信息查询不到时,设置fail_info num_trial = 0 # 尝试重新生成,尝试次数不超过100次 while num_trial < 100: adjusted_info = self._adjust_info(domain, domain_goal['info']) if len(self.dbquery.query(domain, adjusted_info.items())) > 0: # train不经行fail_info if domain == 'train': domain_goal['info'] = adjusted_info else: domain_goal['fail_info'] = domain_goal['info'] domain_goal['info'] = adjusted_info break num_trial += 1 if num_trial >= 100: # 大于100次尝试失败,重新生成用户目标 continue # at least there is one request and book # 用户至少需要提一个要求 if 'reqt' in domain_goal or 'book' in domain_goal: break return domain_goal def get_user_goal(self, seed=None): # seed the generator to get fixed goal if seed is not None: random.seed(seed) np.random.seed(seed) # 先确定domain_ordering domain_ordering = [] while not domain_ordering: domain_ordering = list(nomial_sample(self.domain_ordering_dist)) np.random.shuffle(domain_ordering) user_goal = { dom: self._get_domain_goal(dom) for dom in domain_ordering } assert len(user_goal.keys()) > 0 # using taxi to communte between places, removing destination and departure. # 如果taxi中有地址出现,则需要跟之前的domain一致, if 'taxi' in domain_ordering: places = [ dom for dom in domain_ordering[:domain_ordering.index('taxi')] if 'address' in self.ind_slot_dist[dom]['reqt'].keys() ] if len(places) >= 1: del user_goal['taxi']['info']['destination'] user_goal[places[-1]]['reqt'] = list( set(user_goal[places[-1]].get('reqt', [])).union({'address'})) # 时间上也要保持一致 if places[-1] == 'restaurant' and 'book' in user_goal[ 'restaurant']: user_goal['taxi']['info']['arriveBy'] = user_goal[ 'restaurant']['book']['time'] if 'leaveAt' in user_goal['taxi']['info']: del user_goal['taxi']['info']['leaveAt'] # 起点和终点 if len(places) >= 2: del user_goal['taxi']['info']['departure'] user_goal[places[-2]]['reqt'] = list( set(user_goal[places[-2]].get('reqt', [])).union({'address'})) # match area of attraction and restaurant # 景点和饭店地点应该在一起 if 'restaurant' in domain_ordering and \ 'attraction' in domain_ordering and \ 'fail_info' not in user_goal['restaurant'] and \ domain_ordering.index('restaurant') > domain_ordering.index('attraction') and \ 'area' in user_goal['restaurant']['info'] and 'area' in user_goal['attraction']['info']: adjusted_restaurant_goal = deepcopy( user_goal['restaurant']['info']) adjusted_restaurant_goal['area'] = user_goal['attraction']['info'][ 'area'] # 数据库有的话,一定程度转移 if len( self.dbquery.query('restaurant', adjusted_restaurant_goal.items()) ) > 0 and random.random() < 0.5: user_goal['restaurant']['info']['area'] = user_goal[ 'attraction']['info']['area'] # match day and people of restaurant and hotel # 餐旅馆的人数和天数应该一致,调整之后检查是不是和fail_book冲突了,如果是删掉fail_book if 'restaurant' in domain_ordering and 'hotel' in domain_ordering and \ 'book' in user_goal['restaurant'] and 'book' in user_goal['hotel']: if random.random() < 0.5: user_goal['restaurant']['book']['people'] = user_goal['hotel'][ 'book']['people'] if 'fail_book' in user_goal['restaurant']: user_goal['restaurant']['fail_book']['people'] = user_goal[ 'hotel']['book']['people'] if random.random() < 1.0: user_goal['restaurant']['book']['day'] = user_goal['hotel'][ 'book']['day'] if 'fail_book' in user_goal['restaurant']: user_goal['restaurant']['fail_book']['day'] = user_goal[ 'hotel']['book']['day'] if user_goal['restaurant']['book']['day'] == user_goal['restaurant']['fail_book']['day'] and \ user_goal['restaurant']['book']['time'] == user_goal['restaurant']['fail_book']['time'] and \ user_goal['restaurant']['book']['people'] == user_goal['restaurant']['fail_book']['people']: del user_goal['restaurant']['fail_book'] # match day and people of hotel and train if 'hotel' in domain_ordering and 'train' in domain_ordering and \ 'book' in user_goal['hotel'] and 'info' in user_goal['train']: if user_goal['train']['info']['destination'] == 'cambridge' and \ 'day' in user_goal['hotel']['book']: user_goal['train']['info']['day'] = user_goal['hotel']['book'][ 'day'] elif user_goal['train']['info']['departure'] == 'cambridge' and \ 'day' in user_goal['hotel']['book'] and 'stay' in user_goal['hotel']['book']: user_goal['train']['info']['day'] = days[ (days.index(user_goal['hotel']['book']['day']) + int(user_goal['hotel']['book']['stay'])) % 7] # In case, we have no query results with adjusted train goal, we simply drop the train goal. if len( self.dbquery.query( 'train', user_goal['train']['info'].items())) == 0: del user_goal['train'] domain_ordering.remove('train') user_goal['domain_ordering'] = domain_ordering return user_goal def _adjust_info(self, domain, info): # adjust one of the slots of the info # 随机选择一个slot进行随机替换 adjusted_info = deepcopy(info) slot = random.choice(list(info.keys())) adjusted_info[slot] = random.choice( list(self.ind_slot_value_dist[domain]['info'][slot].keys())) return adjusted_info def build_message(self, user_goal, boldify=null_boldify): message = [] state = deepcopy(user_goal) for dom in user_goal['domain_ordering']: state = deepcopy(user_goal[dom]) if not (dom == 'taxi' and len(state['info']) == 1): # intro m = [templates[dom]['intro']] # info def fill_info_template(user_goal, domain, slot, info): if slot != 'area' or not ( 'restaurant' in user_goal and 'attraction' in user_goal and info in user_goal['restaurant'].keys() and info in user_goal['attraction'].keys() and 'area' in user_goal['restaurant'][info] and 'area' in user_goal['attraction'][info] and user_goal['restaurant'][info]['area'] == user_goal['attraction'][info]['area']): return templates[domain][slot].format( self.boldify(user_goal[domain][info][slot])) else: restaurant_index = user_goal['domain_ordering'].index( 'restaurant') attraction_index = user_goal['domain_ordering'].index( 'attraction') if restaurant_index > attraction_index and domain == 'restaurant': return templates[domain][slot].format( self.boldify('same area as the attraction')) elif attraction_index > restaurant_index and domain == 'attraction': return templates[domain][slot].format( self.boldify('same area as the restaurant')) return templates[domain][slot].format( self.boldify(user_goal[domain][info][slot])) info = 'info' if 'fail_info' in user_goal[dom]: info = 'fail_info' if dom == 'taxi' and len(state[info]) == 1: taxi_index = user_goal['domain_ordering'].index('taxi') places = [ dom for dom in user_goal['domain_ordering'][:taxi_index] if dom in ['attraction', 'hotel', 'restaurant'] ] if len(places) >= 2: random.shuffle(places) m.append(templates['taxi']['commute']) if 'arriveBy' in state[info]: m.append( 'The taxi should arrive at the {} from the {} by {}.' .format(self.boldify(places[0]), self.boldify(places[1]), self.boldify(state[info]['arriveBy']))) elif 'leaveAt' in state[info]: m.append( 'The taxi should leave from the {} to the {} after {}.' .format(self.boldify(places[0]), self.boldify(places[1]), self.boldify(state[info]['leaveAt']))) message.append(' '.join(m)) else: while len(state[info]) > 0: num_acts = random.randint(1, min(len(state[info]), 3)) slots = random.sample(list(state[info].keys()), num_acts) sents = [ fill_info_template(user_goal, dom, slot, info) for slot in slots if slot not in ['parking', 'internet'] ] if 'parking' in slots: sents.append(templates[dom]['parking ' + state[info]['parking']]) if 'internet' in slots: sents.append(templates[dom]['internet ' + state[info]['internet']]) m.extend(sents) message.append(' '.join(m)) m = [] for slot in slots: del state[info][slot] # fail_info if 'fail_info' in user_goal[dom]: adjusted_slot = list( filter( lambda x: x[0][1] != x[1][1], zip(user_goal[dom]['info'].items(), user_goal[dom]['fail_info'].items())))[0][0][0] if adjusted_slot in ['internet', 'parking']: message.append( templates[dom]['fail_info ' + adjusted_slot + ' ' + user_goal[dom]['info'][adjusted_slot]]) else: message.append( templates[dom]['fail_info ' + adjusted_slot].format( self.boldify( user_goal[dom]['info'][adjusted_slot]))) # reqt if 'reqt' in state: slot_strings = [] for slot in state['reqt']: if slot in ['internet', 'parking', 'food']: continue slot_strings.append(slot if slot not in request_slot_string_map else request_slot_string_map[slot]) if len(slot_strings) > 0: message.append(templates[dom]['request'].format( self.boldify(', '.join(slot_strings)))) if 'internet' in state['reqt']: message.append( 'Make sure to ask if the hotel includes free wifi.') if 'parking' in state['reqt']: message.append( 'Make sure to ask if the hotel includes free parking.') if 'food' in state['reqt']: message.append( 'Make sure to ask about what food it serves.') def get_same_people_domain(user_goal, domain, slot): if slot not in ['day', 'people']: return None domain_index = user_goal['domain_ordering'].index(domain) previous_domains = user_goal['domain_ordering'][:domain_index] for prev in previous_domains: if prev in ['restaurant', 'hotel', 'train'] and 'book' in user_goal[prev] and \ slot in user_goal[prev]['book'] and user_goal[prev]['book'][slot] == \ user_goal[domain]['book'][slot]: return prev return None # book book = 'book' if 'fail_book' in user_goal[dom]: book = 'fail_book' if 'book' in state: slot_strings = [] for slot in ['people', 'time', 'day', 'stay']: if slot in state[book]: if slot == 'people': same_people_domain = get_same_people_domain( user_goal, dom, slot) if same_people_domain is None: slot_strings.append('for {} people'.format( self.boldify(state[book][slot]))) else: slot_strings.append( self.boldify( 'for the same group of people as the {} booking' .format(same_people_domain))) elif slot == 'time': slot_strings.append('at {}'.format( self.boldify(state[book][slot]))) elif slot == 'day': same_people_domain = get_same_people_domain( user_goal, dom, slot) if same_people_domain is None: slot_strings.append('on {}'.format( self.boldify(state[book][slot]))) else: slot_strings.append( self.boldify( 'on the same day as the {} booking'. format(same_people_domain))) elif slot == 'stay': slot_strings.append('for {} nights'.format( self.boldify(state[book][slot]))) del state[book][slot] assert len(state[book]) <= 0, state[book] if len(slot_strings) > 0: message.append(templates[dom]['book'].format( ' '.join(slot_strings))) # fail_book if 'fail_book' in user_goal[dom]: adjusted_slot = list( filter( lambda x: x[0][1] != x[1][1], zip(user_goal[dom]['book'].items(), user_goal[dom]['fail_book'].items())))[0][0][0] if adjusted_slot in ['internet', 'parking']: message.append( templates[dom]['fail_book ' + adjusted_slot + ' ' + user_goal[dom]['book'][adjusted_slot]]) else: message.append( templates[dom]['fail_book ' + adjusted_slot].format( self.boldify( user_goal[dom]['book'][adjusted_slot]))) if boldify == do_boldify: for i, m in enumerate(message): message[i] = message[i].replace('wifi', "<b>wifi</b>") message[i] = message[i].replace('internet', "<b>internet</b>") message[i] = message[i].replace('parking', "<b>parking</b>") return message