class MultiWozVector(Vector): def __init__(self, voc_file, voc_opp_file, character='sys', intent_file=os.path.join( os.path.dirname( os.path.dirname( os.path.dirname( os.path.dirname(os.path.abspath(__file__))))), 'data/multiwoz/trackable_intent.json')): self.belief_domains = [ 'Attraction', 'Restaurant', 'Train', 'Hotel', 'Taxi', 'Hospital', 'Police' ] self.db_domains = ['Attraction', 'Restaurant', 'Train', 'Hotel'] with open(intent_file) as f: intents = json.load(f) self.informable = intents['informable'] self.requestable = intents['requestable'] self.db = Database() with open(voc_file) as f: self.da_voc = f.read().splitlines() with open(voc_opp_file) as f: self.da_voc_opp = f.read().splitlines() self.character = character self.generate_dict() def generate_dict(self): """ init the dict for mapping state/action into vector """ self.act2vec = dict((a, i) for i, a in enumerate(self.da_voc)) self.vec2act = dict((v, k) for k, v in self.act2vec.items()) self.da_dim = len(self.da_voc) self.opp2vec = dict((a, i) for i, a in enumerate(self.da_voc_opp)) self.da_opp_dim = len(self.da_voc_opp) self.belief_state_dim = 0 for domain in self.belief_domains: for slot, value in default_state()['belief_state'][ domain.lower()]['semi'].items(): self.belief_state_dim += 1 self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \ len(self.db_domains) + 6 * len(self.db_domains) + 1 def pointer(self, turn): pointer_vector = np.zeros(6 * len(self.db_domains)) for domain in self.db_domains: constraint = [] for k, v in turn[domain.lower()]['semi'].items(): if k in mapping[domain.lower()]: constraint.append((mapping[domain.lower()][k], v)) entities = self.db.query(domain.lower(), constraint) pointer_vector = self.one_hot_vector(len(entities), domain, pointer_vector) return pointer_vector def one_hot_vector(self, num, domain, vector): """Return number of available entities for particular domain.""" if domain != 'train': idx = self.db_domains.index(domain) if num == 0: vector[idx * 6:idx * 6 + 6] = np.array([1, 0, 0, 0, 0, 0]) elif num == 1: vector[idx * 6:idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0]) elif num == 2: vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0]) elif num == 3: vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0]) elif num == 4: vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0]) elif num >= 5: vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1]) else: idx = self.db_domains.index(domain) if num == 0: vector[idx * 6:idx * 6 + 6] = np.array([1, 0, 0, 0, 0, 0]) elif num <= 2: vector[idx * 6:idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0]) elif num <= 5: vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0]) elif num <= 10: vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0]) elif num <= 40: vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0]) elif num > 40: vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1]) return vector def state_vectorize(self, state): """vectorize a state Args: state (dict): Dialog state action (tuple): Dialog act Returns: state_vec (np.array): Dialog state vector """ self.state = state['belief_state'] action = state['user_action'] if self.character == 'sys' else state[ 'system_action'] opp_action = delexicalize_da(action, self.requestable) opp_action = flat_da(opp_action) opp_act_vec = np.zeros(self.da_opp_dim) for da in opp_action: if da in self.opp2vec: opp_act_vec[self.opp2vec[da]] = 1. action = state['system_action'] if self.character == 'sys' else state[ 'user_action'] action = delexicalize_da(action, self.requestable) action = flat_da(action) last_act_vec = np.zeros(self.da_dim) for da in action: if da in self.act2vec: last_act_vec[self.act2vec[da]] = 1. belief_state = np.zeros(self.belief_state_dim) i = 0 for domain in self.belief_domains: for slot, value in state['belief_state'][ domain.lower()]['semi'].items(): if value: belief_state[i] = 1. i += 1 book = np.zeros(len(self.db_domains)) for i, domain in enumerate(self.db_domains): if state['belief_state'][domain.lower()]['book']['booked']: book[i] = 1. degree = self.pointer(state['belief_state']) final = 1. if state['terminated'] else 0. state_vec = np.r_[opp_act_vec, last_act_vec, belief_state, book, degree, final] assert len(state_vec) == self.state_dim return state_vec def action_devectorize(self, action_vec): """ recover an action Args: action_vec (np.array): Dialog act vector Returns: action (tuple): Dialog act """ act_array = [] for i, idx in enumerate(action_vec): if idx == 1: act_array.append(self.vec2act[i]) action = deflat_da(act_array) entities = {} for domint in action: domain, intent = domint.split('-') if domain not in entities and domain.lower() not in [ 'general', 'booking' ]: constraint = [] for k, v in self.state[domain.lower()]['semi'].items(): if k in mapping[domain.lower()]: constraint.append((mapping[domain.lower()][k], v)) entities[domain] = self.db.query(domain.lower(), constraint) action = lexicalize_da(action, entities, self.state, self.requestable) return action def action_vectorize(self, action): action = delexicalize_da(action, self.requestable) action = flat_da(action) act_vec = np.zeros(self.da_dim) for da in action: if da in self.act2vec: act_vec[self.act2vec[da]] = 1. return act_vec
class GoalGenerator: """User goal generator.""" def __init__(self, goal_model_path=os.path.join( get_root_path(), 'data/multiwoz/goal/new_goal_model.pkl'), corpus_path=None, boldify=False, sample_info_from_trainset=True, sample_reqt_from_trainset=False): """ Args: goal_model_path: path to a goal model corpus_path: path to a dialog corpus to build a goal model boldify: highlight some information in the goal message sample_info_from_trainset: if True, sample info slots combination from train set, else sample each slot independently sample_reqt_from_trainset: if True, sample reqt slots combination from train set, else sample each slot independently """ self.goal_model_path = goal_model_path self.corpus_path = corpus_path self.db = Database() self.boldify = do_boldify if boldify else null_boldify self.sample_info_from_trainset = sample_info_from_trainset self.sample_reqt_from_trainset = sample_reqt_from_trainset self.train_database = self.db.query('train', []) if os.path.exists(self.goal_model_path): self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist, self.slots_num_dist, self.slots_combination_dist = pickle.load( open(self.goal_model_path, 'rb')) print('Loading goal model is done') else: self._build_goal_model() print('Building goal model is done') # remove some slot del self.ind_slot_dist['police']['reqt']['postcode'] del self.ind_slot_value_dist['police']['reqt']['postcode'] del self.ind_slot_dist['hospital']['reqt']['postcode'] del self.ind_slot_value_dist['hospital']['reqt']['postcode'] del self.ind_slot_dist['hospital']['reqt']['address'] del self.ind_slot_value_dist['hospital']['reqt']['address'] # print(self.slots_combination_dist['police']) # print(self.slots_combination_dist['hospital']) # pprint(self.ind_slot_dist) # pprint(self.slots_num_dist) # pprint(self.slots_combination_dist) def _build_goal_model(self): dialogs = json.load(open(self.corpus_path)) # domain ordering def _get_dialog_domains(dialog): return list( filter(lambda x: x in domains and len(dialog['goal'][x]) > 0, dialog['goal'])) domain_orderings = [] for d in dialogs: d_domains = _get_dialog_domains(dialogs[d]) first_index = [] for domain in d_domains: message = [dialogs[d]['goal']['message']] if type(dialogs[d]['goal']['message']) == str else \ dialogs[d]['goal']['message'] for i, m in enumerate(message): if domain_keywords[domain].lower() in m.lower( ) or domain.lower() in m.lower(): first_index.append(i) break domain_orderings.append( tuple( map( lambda x: x[1], sorted(zip(first_index, d_domains), key=lambda x: x[0])))) domain_ordering_cnt = Counter(domain_orderings) self.domain_ordering_dist = deepcopy(domain_ordering_cnt) for order in domain_ordering_cnt.keys(): self.domain_ordering_dist[order] = domain_ordering_cnt[ order] / sum(domain_ordering_cnt.values()) # independent goal slot distribution ind_slot_value_cnt = dict([(domain, {}) for domain in domains]) domain_cnt = Counter() book_cnt = Counter() self.slots_combination_dist = {domain: {} for domain in domains} self.slots_num_dist = {domain: {} for domain in domains} for d in dialogs: for domain in domains: if dialogs[d]['goal'][domain] != {}: domain_cnt[domain] += 1 if 'info' in dialogs[d]['goal'][domain]: if 'info' not in self.slots_combination_dist[domain]: self.slots_combination_dist[domain]['info'] = {} self.slots_num_dist[domain]['info'] = {} slots = sorted( list(dialogs[d]['goal'][domain]['info'].keys())) self.slots_combination_dist[domain]['info'].setdefault( tuple(slots), 0) self.slots_combination_dist[domain]['info'][tuple( slots)] += 1 self.slots_num_dist[domain]['info'].setdefault( len(slots), 0) self.slots_num_dist[domain]['info'][len(slots)] += 1 for slot in dialogs[d]['goal'][domain]['info']: if 'invalid' in slot: continue if 'info' not in ind_slot_value_cnt[domain]: ind_slot_value_cnt[domain]['info'] = {} if slot not in ind_slot_value_cnt[domain]['info']: ind_slot_value_cnt[domain]['info'][slot] = Counter( ) if 'care' in dialogs[d]['goal'][domain]['info'][slot]: continue ind_slot_value_cnt[domain]['info'][slot][ dialogs[d]['goal'][domain]['info'][slot]] += 1 if 'reqt' in dialogs[d]['goal'][domain]: if 'reqt' not in self.slots_combination_dist[domain]: self.slots_combination_dist[domain]['reqt'] = {} self.slots_num_dist[domain]['reqt'] = {} slots = sorted(dialogs[d]['goal'][domain]['reqt']) if domain in ['police', 'hospital' ] and 'postcode' in slots: slots.remove('postcode') else: assert len(slots) > 0, print( sorted(dialogs[d]['goal'][domain]['reqt']), [slots]) if len(slots) > 0: self.slots_combination_dist[domain]['reqt'].setdefault( tuple(slots), 0) self.slots_combination_dist[domain]['reqt'][tuple( slots)] += 1 self.slots_num_dist[domain]['reqt'].setdefault( len(slots), 0) self.slots_num_dist[domain]['reqt'][len(slots)] += 1 for slot in dialogs[d]['goal'][domain]['reqt']: if 'reqt' not in ind_slot_value_cnt[domain]: ind_slot_value_cnt[domain]['reqt'] = Counter() ind_slot_value_cnt[domain]['reqt'][slot] += 1 if 'book' in dialogs[d]['goal'][domain]: book_cnt[domain] += 1 for slot in dialogs[d]['goal'][domain]['book']: if 'invalid' in slot: continue if 'book' not in ind_slot_value_cnt[domain]: ind_slot_value_cnt[domain]['book'] = {} if slot not in ind_slot_value_cnt[domain]['book']: ind_slot_value_cnt[domain]['book'][slot] = Counter( ) if 'care' in dialogs[d]['goal'][domain]['book'][slot]: continue ind_slot_value_cnt[domain]['book'][slot][ dialogs[d]['goal'][domain]['book'][slot]] += 1 # pprint(self.slots_num_dist) # pprint(self.slots_combination_dist) # for domain in domains: # print(domain, len(self.slots_combination_dist[domain]['info'])) self.ind_slot_value_dist = deepcopy(ind_slot_value_cnt) self.ind_slot_dist = dict([(domain, {}) for domain in domains]) self.book_dist = {} for domain in domains: if 'info' in ind_slot_value_cnt[domain]: for slot in ind_slot_value_cnt[domain]['info']: if 'info' not in self.ind_slot_dist[domain]: self.ind_slot_dist[domain]['info'] = {} if slot not in self.ind_slot_dist[domain]['info']: self.ind_slot_dist[domain]['info'][slot] = {} self.ind_slot_dist[domain]['info'][slot] = sum(ind_slot_value_cnt[domain]['info'][slot].values()) / \ domain_cnt[domain] slot_total = sum( ind_slot_value_cnt[domain]['info'][slot].values()) for val in self.ind_slot_value_dist[domain]['info'][slot]: self.ind_slot_value_dist[domain]['info'][slot][ val] = ind_slot_value_cnt[domain]['info'][slot][ val] / slot_total if 'reqt' in ind_slot_value_cnt[domain]: for slot in ind_slot_value_cnt[domain]['reqt']: if 'reqt' not in self.ind_slot_dist[domain]: self.ind_slot_dist[domain]['reqt'] = {} self.ind_slot_dist[domain]['reqt'][ slot] = ind_slot_value_cnt[domain]['reqt'][ slot] / domain_cnt[domain] self.ind_slot_value_dist[domain]['reqt'][slot] = ind_slot_value_cnt[domain]['reqt'][slot] / \ domain_cnt[domain] if 'book' in ind_slot_value_cnt[domain]: for slot in ind_slot_value_cnt[domain]['book']: if 'book' not in self.ind_slot_dist[domain]: self.ind_slot_dist[domain]['book'] = {} if slot not in self.ind_slot_dist[domain]['book']: self.ind_slot_dist[domain]['book'][slot] = {} self.ind_slot_dist[domain]['book'][slot] = sum(ind_slot_value_cnt[domain]['book'][slot].values()) / \ domain_cnt[domain] slot_total = sum( ind_slot_value_cnt[domain]['book'][slot].values()) for val in self.ind_slot_value_dist[domain]['book'][slot]: self.ind_slot_value_dist[domain]['book'][slot][ val] = ind_slot_value_cnt[domain]['book'][slot][ val] / slot_total self.book_dist[domain] = book_cnt[domain] / len(dialogs) pickle.dump((self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist, self.slots_num_dist, self.slots_combination_dist), open(self.goal_model_path, 'wb')) def _get_domain_goal(self, domain): cnt_slot = self.ind_slot_dist[domain] cnt_slot_value = self.ind_slot_value_dist[domain] pro_book = self.book_dist[domain] while True: # domain_goal = defaultdict(lambda: {}) # domain_goal = {'info': {}, 'fail_info': {}, 'reqt': {}, 'book': {}, 'fail_book': {}} domain_goal = {'info': {}} # inform if 'info' in cnt_slot: if self.sample_info_from_trainset: slots = random.choices( list(self.slots_combination_dist[domain] ['info'].keys()), list(self.slots_combination_dist[domain] ['info'].values()))[0] for slot in slots: domain_goal['info'][slot] = nomial_sample( cnt_slot_value['info'][slot]) else: for slot in cnt_slot['info']: if random.random( ) < cnt_slot['info'][slot] + pro_correction['info']: domain_goal['info'][slot] = nomial_sample( cnt_slot_value['info'][slot]) if domain in ['hotel', 'restaurant', 'attraction' ] and 'name' in domain_goal['info'] and len( domain_goal['info']) > 1: if random.random() < cnt_slot['info']['name']: domain_goal['info'] = { 'name': domain_goal['info']['name'] } else: del domain_goal['info']['name'] if domain in ['taxi', 'train'] and 'arriveBy' in domain_goal[ 'info'] and 'leaveAt' in domain_goal['info']: if random.random() < (cnt_slot['info']['leaveAt'] / (cnt_slot['info']['arriveBy'] + cnt_slot['info']['leaveAt'])): del domain_goal['info']['arriveBy'] else: del domain_goal['info']['leaveAt'] if domain in ['taxi', 'train'] and 'arriveBy' not in domain_goal['info'] and 'leaveAt' not in \ domain_goal['info']: if random.random() < (cnt_slot['info']['arriveBy'] / (cnt_slot['info']['arriveBy'] + cnt_slot['info']['leaveAt'])): domain_goal['info']['arriveBy'] = nomial_sample( cnt_slot_value['info']['arriveBy']) else: domain_goal['info']['leaveAt'] = nomial_sample( cnt_slot_value['info']['leaveAt']) if domain in ['train']: random_train = random.choice(self.train_database) domain_goal['info']['departure'] = random_train[ 'departure'] domain_goal['info']['destination'] = random_train[ 'destination'] if domain in ['taxi' ] and 'departure' not in domain_goal['info']: domain_goal['info']['departure'] = nomial_sample( cnt_slot_value['info']['departure']) if domain in ['taxi' ] and 'destination' not in domain_goal['info']: domain_goal['info']['destination'] = nomial_sample( cnt_slot_value['info']['destination']) if domain in ['taxi'] and \ 'departure' in domain_goal['info'] and \ 'destination' in domain_goal['info'] and \ domain_goal['info']['departure'] == domain_goal['info']['destination']: if random.random() < (cnt_slot['info']['departure'] / (cnt_slot['info']['departure'] + cnt_slot['info']['destination'])): domain_goal['info']['departure'] = nomial_sample( cnt_slot_value['info']['departure']) else: domain_goal['info']['destination'] = nomial_sample( cnt_slot_value['info']['destination']) if domain_goal['info'] == {}: continue # request if 'reqt' in cnt_slot: if self.sample_reqt_from_trainset: not_in_info_slots = {} for slots in self.slots_combination_dist[domain]['reqt']: for slot in slots: if slot in domain_goal['info']: break else: not_in_info_slots[ slots] = self.slots_combination_dist[domain][ 'reqt'][slots] pprint(not_in_info_slots) reqt = list( random.choices(list(not_in_info_slots.keys()), list(not_in_info_slots.values()))[0]) else: reqt = [ slot for slot in cnt_slot['reqt'] if random.random() < cnt_slot['reqt'][slot] + pro_correction['reqt'] and slot not in domain_goal['info'] ] if len(reqt) > 0: domain_goal['reqt'] = reqt # book if 'book' in cnt_slot and random.random( ) < pro_book + pro_correction['book']: if 'book' not in domain_goal: domain_goal['book'] = {} for slot in cnt_slot['book']: if random.random( ) < cnt_slot['book'][slot] + pro_correction['book']: domain_goal['book'][slot] = nomial_sample( cnt_slot_value['book'][slot]) # makes sure that there are all necessary slots for booking if domain == 'restaurant' and 'time' not in domain_goal['book']: domain_goal['book']['time'] = nomial_sample( cnt_slot_value['book']['time']) if domain == 'hotel' and 'stay' not in domain_goal['book']: domain_goal['book']['stay'] = nomial_sample( cnt_slot_value['book']['stay']) if domain in ['hotel', 'restaurant' ] and 'day' not in domain_goal['book']: domain_goal['book']['day'] = nomial_sample( cnt_slot_value['book']['day']) if domain in ['hotel', 'restaurant' ] and 'people' not in domain_goal['book']: domain_goal['book']['people'] = nomial_sample( cnt_slot_value['book']['people']) if domain == 'train' and len(domain_goal['book']) <= 0: domain_goal['book']['people'] = nomial_sample( cnt_slot_value['book']['people']) # fail_book: not use any more since 2020.8.18 # if 'book' in domain_goal and random.random() < 0.5: # if domain == 'hotel': # domain_goal['fail_book'] = deepcopy(domain_goal['book']) # if 'stay' in domain_goal['book'] and random.random() < 0.5: # # increase hotel-stay # domain_goal['fail_book']['stay'] = str(int(domain_goal['book']['stay']) + 1) # elif 'day' in domain_goal['book']: # # push back hotel-day by a day # domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) - 1) % 7] # # elif domain == 'restaurant': # domain_goal['fail_book'] = deepcopy(domain_goal['book']) # if 'time' in domain_goal['book'] and random.random() < 0.5: # hour, minute = domain_goal['book']['time'].split(':') # domain_goal['fail_book']['time'] = str((int(hour) + 1) % 24) + ':' + minute # elif 'day' in domain_goal['book']: # if random.random() < 0.5: # domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) - 1) % 7] # else: # domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) + 1) % 7] # fail_info if 'info' in domain_goal and len( self.db.query(domain, domain_goal['info'].items())) == 0: num_trial = 0 while num_trial < 100: adjusted_info = self._adjust_info(domain, domain_goal['info']) if len(self.db.query(domain, adjusted_info.items())) > 0: if domain == 'train': domain_goal['info'] = adjusted_info else: # first ask fail_info which return no result then ask info domain_goal['fail_info'] = domain_goal['info'] domain_goal['info'] = adjusted_info break num_trial += 1 if num_trial >= 100: continue # at least there is one request and book if 'reqt' in domain_goal or 'book' in domain_goal: break return domain_goal def get_user_goal(self): domain_ordering = () while len(domain_ordering) <= 0: domain_ordering = nomial_sample(self.domain_ordering_dist) # domain_ordering = ('restaurant',) user_goal = { dom: self._get_domain_goal(dom) for dom in domain_ordering } assert len(user_goal.keys()) > 0 # using taxi to communte between places, removing destination and departure. if 'taxi' in domain_ordering: places = [ dom for dom in domain_ordering[:domain_ordering.index('taxi')] if dom in ['attraction', 'hotel', 'restaurant', 'police', 'hospital'] ] if len(places) >= 1: del user_goal['taxi']['info']['destination'] # if 'reqt' not in user_goal[places[-1]]: # user_goal[places[-1]]['reqt'] = [] # if 'address' not in user_goal[places[-1]]['reqt']: # user_goal[places[-1]]['reqt'].append('address') # the line below introduce randomness by `union` # user_goal[places[-1]]['reqt'] = list(set(user_goal[places[-1]].get('reqt', [])).union({'address'})) if places[-1] == 'restaurant' and 'book' in user_goal[ 'restaurant']: user_goal['taxi']['info']['arriveBy'] = user_goal[ 'restaurant']['book']['time'] if 'leaveAt' in user_goal['taxi']['info']: del user_goal['taxi']['info']['leaveAt'] if len(places) >= 2: del user_goal['taxi']['info']['departure'] # if 'reqt' not in user_goal[places[-2]]: # user_goal[places[-2]]['reqt'] = [] # if 'address' not in user_goal[places[-2]]['reqt']: # user_goal[places[-2]]['reqt'].append('address') # the line below introduce randomness by `union` # user_goal[places[-2]]['reqt'] = list(set(user_goal[places[-2]].get('reqt', [])).union({'address'})) # match area of attraction and restaurant if 'restaurant' in domain_ordering and \ 'attraction' in domain_ordering and \ 'fail_info' not in user_goal['restaurant'] and \ domain_ordering.index('restaurant') > domain_ordering.index('attraction') and \ 'area' in user_goal['restaurant']['info'] and 'area' in user_goal['attraction']['info']: adjusted_restaurant_goal = deepcopy( user_goal['restaurant']['info']) adjusted_restaurant_goal['area'] = user_goal['attraction']['info'][ 'area'] if len( self.db.query('restaurant', adjusted_restaurant_goal.items( ))) > 0 and random.random() < 0.5: user_goal['restaurant']['info']['area'] = user_goal[ 'attraction']['info']['area'] # match day and people of restaurant and hotel if 'restaurant' in domain_ordering and 'hotel' in domain_ordering and \ 'book' in user_goal['restaurant'] and 'book' in user_goal['hotel']: if random.random() < 0.5: user_goal['restaurant']['book']['people'] = user_goal['hotel'][ 'book']['people'] if 'fail_book' in user_goal['restaurant']: user_goal['restaurant']['fail_book']['people'] = user_goal[ 'hotel']['book']['people'] if random.random() < 1.0: user_goal['restaurant']['book']['day'] = user_goal['hotel'][ 'book']['day'] if 'fail_book' in user_goal['restaurant']: user_goal['restaurant']['fail_book']['day'] = user_goal[ 'hotel']['book']['day'] if user_goal['restaurant']['book']['day'] == user_goal['restaurant']['fail_book']['day'] and \ user_goal['restaurant']['book']['time'] == user_goal['restaurant']['fail_book']['time'] and \ user_goal['restaurant']['book']['people'] == user_goal['restaurant']['fail_book']['people']: del user_goal['restaurant']['fail_book'] # match day and people of hotel and train if 'hotel' in domain_ordering and 'train' in domain_ordering and \ 'book' in user_goal['hotel'] and 'info' in user_goal['train']: if user_goal['train']['info']['destination'] == 'cambridge' and \ 'day' in user_goal['hotel']['book']: user_goal['train']['info']['day'] = user_goal['hotel']['book'][ 'day'] elif user_goal['train']['info']['departure'] == 'cambridge' and \ 'day' in user_goal['hotel']['book'] and 'stay' in user_goal['hotel']['book']: user_goal['train']['info']['day'] = days[ (days.index(user_goal['hotel']['book']['day']) + int(user_goal['hotel']['book']['stay'])) % 7] # In case, we have no query results with adjusted train goal, we simply drop the train goal. if len(self.db.query('train', user_goal['train']['info'].items())) == 0: del user_goal['train'] domain_ordering = tuple(list(domain_ordering).remove('train')) for domain in user_goal: if not user_goal[domain]['info']: user_goal[domain]['info'] = {'none': 'none'} user_goal['domain_ordering'] = domain_ordering return user_goal def _adjust_info(self, domain, info): # adjust one of the slots of the info adjusted_info = deepcopy(info) slot = random.choice(list(info.keys())) adjusted_info[slot] = random.choice( list(self.ind_slot_value_dist[domain]['info'][slot].keys())) return adjusted_info def build_message(self, user_goal, boldify=null_boldify): message = [] message_by_domain = [] mess_ptr4domain = 0 state = deepcopy(user_goal) for dom in user_goal['domain_ordering']: dom_msg = [] state = deepcopy(user_goal[dom]) num_acts_in_unit = 0 if not (dom == 'taxi' and len(state['info']) == 1): # intro m = [templates[dom]['intro']] # info def fill_info_template(user_goal, domain, slot, info): if slot != 'area' or not ( 'restaurant' in user_goal and 'attraction' in user_goal and info in user_goal['restaurant'].keys() and info in user_goal['attraction'].keys() and 'area' in user_goal['restaurant'][info] and 'area' in user_goal['attraction'][info] and user_goal['restaurant'][info]['area'] == user_goal['attraction'][info]['area']): return templates[domain][slot].format( self.boldify(user_goal[domain][info][slot])) else: restaurant_index = user_goal['domain_ordering'].index( 'restaurant') attraction_index = user_goal['domain_ordering'].index( 'attraction') if restaurant_index > attraction_index and domain == 'restaurant': return templates[domain][slot].format( self.boldify('same area as the attraction')) elif attraction_index > restaurant_index and domain == 'attraction': return templates[domain][slot].format( self.boldify('same area as the restaurant')) return templates[domain][slot].format( self.boldify(user_goal[domain][info][slot])) info = 'info' if 'fail_info' in user_goal[dom]: info = 'fail_info' if dom == 'taxi' and len(state[info]) == 1: taxi_index = user_goal['domain_ordering'].index('taxi') places = [ dom for dom in user_goal['domain_ordering'][:taxi_index] if dom in ['attraction', 'hotel', 'restaurant'] ] if len(places) >= 2: random.shuffle(places) m.append(templates['taxi']['commute']) if 'arriveBy' in state[info]: m.append( 'The taxi should arrive at the {} from the {} by {}.' .format(self.boldify(places[0]), self.boldify(places[1]), self.boldify(state[info]['arriveBy']))) elif 'leaveAt' in state[info]: m.append( 'The taxi should leave from the {} to the {} after {}.' .format(self.boldify(places[0]), self.boldify(places[1]), self.boldify(state[info]['leaveAt']))) message.append(' '.join(m)) else: while len(state[info]) > 0: num_acts = random.randint(1, min(len(state[info]), 3)) slots = random.sample(list(state[info].keys()), num_acts) sents = [ fill_info_template(user_goal, dom, slot, info) for slot in slots if slot not in ['parking', 'internet'] ] if 'parking' in slots: sents.append(templates[dom]['parking ' + state[info]['parking']]) if 'internet' in slots: sents.append(templates[dom]['internet ' + state[info]['internet']]) m.extend(sents) message.append(' '.join(m)) m = [] for slot in slots: del state[info][slot] # fail_info if 'fail_info' in user_goal[dom]: # if 'fail_info' in user_goal[dom]: adjusted_slot = list( filter( lambda x: x[0][1] != x[1][1], zip(user_goal[dom]['info'].items(), user_goal[dom]['fail_info'].items())))[0][0][0] if adjusted_slot in ['internet', 'parking']: message.append( templates[dom]['fail_info ' + adjusted_slot + ' ' + user_goal[dom]['info'][adjusted_slot]]) else: message.append( templates[dom]['fail_info ' + adjusted_slot].format( self.boldify( user_goal[dom]['info'][adjusted_slot]))) # reqt if 'reqt' in state: slot_strings = [] for slot in state['reqt']: if slot in ['internet', 'parking', 'food']: continue slot_strings.append(slot if slot not in request_slot_string_map else request_slot_string_map[slot]) if len(slot_strings) > 0: message.append(templates[dom]['request'].format( self.boldify(', '.join(slot_strings)))) if 'internet' in state['reqt']: message.append( 'Make sure to ask if the hotel includes free wifi.') if 'parking' in state['reqt']: message.append( 'Make sure to ask if the hotel includes free parking.') if 'food' in state['reqt']: message.append( 'Make sure to ask about what food it serves.') def get_same_people_domain(user_goal, domain, slot): if slot not in ['day', 'people']: return None domain_index = user_goal['domain_ordering'].index(domain) previous_domains = user_goal['domain_ordering'][:domain_index] for prev in previous_domains: if prev in ['restaurant', 'hotel', 'train'] and 'book' in user_goal[prev] and \ slot in user_goal[prev]['book'] and user_goal[prev]['book'][slot] == \ user_goal[domain]['book'][slot]: return prev return None # book book = 'book' if 'fail_book' in user_goal[dom]: book = 'fail_book' if 'book' in state: slot_strings = [] for slot in ['people', 'time', 'day', 'stay']: if slot in state[book]: if slot == 'people': same_people_domain = get_same_people_domain( user_goal, dom, slot) if same_people_domain is None: slot_strings.append('for {} people'.format( self.boldify(state[book][slot]))) else: slot_strings.append( self.boldify( 'for the same group of people as the {} booking' .format(same_people_domain))) elif slot == 'time': slot_strings.append('at {}'.format( self.boldify(state[book][slot]))) elif slot == 'day': same_people_domain = get_same_people_domain( user_goal, dom, slot) if same_people_domain is None: slot_strings.append('on {}'.format( self.boldify(state[book][slot]))) else: slot_strings.append( self.boldify( 'on the same day as the {} booking'. format(same_people_domain))) elif slot == 'stay': slot_strings.append('for {} nights'.format( self.boldify(state[book][slot]))) del state[book][slot] assert len(state[book]) <= 0, state[book] if len(slot_strings) > 0: message.append(templates[dom]['book'].format( ' '.join(slot_strings))) # fail_book if 'fail_book' in user_goal[dom]: adjusted_slot = list( filter( lambda x: x[0][1] != x[1][1], zip(user_goal[dom]['book'].items(), user_goal[dom]['fail_book'].items())))[0][0][0] if adjusted_slot in ['internet', 'parking']: message.append( templates[dom]['fail_book ' + adjusted_slot + ' ' + user_goal[dom]['book'][adjusted_slot]]) else: message.append( templates[dom]['fail_book ' + adjusted_slot].format( self.boldify( user_goal[dom]['book'][adjusted_slot]))) dm = message[mess_ptr4domain:] mess_ptr4domain = len(message) message_by_domain.append(' '.join(dm)) if boldify == do_boldify: for i, m in enumerate(message): message[i] = message[i].replace('wifi', "<b>wifi</b>") message[i] = message[i].replace('internet', "<b>internet</b>") message[i] = message[i].replace('parking', "<b>parking</b>") return message, message_by_domain
class HDSA_predictor(): def __init__(self, archive_file, model_file=None, use_cuda=False): if not os.path.isfile(archive_file): if not model_file: raise Exception("No model for DA-predictor is specified!") archive_file = cached_path(model_file) model_dir = os.path.dirname(os.path.abspath(__file__)) if not os.path.exists(os.path.join(model_dir, 'checkpoints')): archive = zipfile.ZipFile(archive_file, 'r') archive.extractall(model_dir) load_dir = os.path.join(model_dir, "checkpoints/predictor/save_step_23926") self.db = Database() if not os.path.exists(load_dir): archive = zipfile.ZipFile('{}.zip'.format(load_dir), 'r') archive.extractall(os.path.dirname(load_dir)) self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) self.max_seq_length = 256 self.domain = 'restaurant' self.model = BertForSequenceClassification.from_pretrained( load_dir, cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(-1)), num_labels=44) self.device = 'cuda' if use_cuda else 'cpu' self.model.to(self.device) def gen_example(self, state): file = '' turn = 0 guid = 'infer' act = state['user_action'] for w in act: d = w[1] if Constants.domains.index(d.lower()) < 8: self.domain = d.lower() hierarchical_act_vecs = [0 for _ in range(44)] # fake target meta = state['belief_state'] constraints = [] if self.domain != 'bus': for slot in meta[self.domain]['semi']: if meta[self.domain]['semi'][slot] != "": constraints.append([slot, meta[self.domain]['semi'][slot]]) query_result = self.db.query(self.domain, constraints) if not query_result: kb = {'count': '0'} src = "no information" else: kb = query_result[0] kb['count'] = str(len(query_result)) src = [] for k, v in kb.items(): k = examine(self.domain, k.lower()) if k != 'illegal' and isinstance(v, str): src.extend([k, 'is', v]) src = " ".join(src) usr = state['history'][-1][-1] sys = state['history'][-2][-1] if len(state['history']) > 1 else None example = InputExample(file, turn, guid, src, usr, sys, hierarchical_act_vecs) kb['domain'] = self.domain return example, kb def gen_feature(self, example): tokens_a = self.tokenizer.tokenize(example.text_a) tokens_b = self.tokenizer.tokenize(example.text_b) tokens_m = self.tokenizer.tokenize(example.text_m) # Modifies `tokens_a` and `tokens_b` in place so that the total # length is less than the specified length. # Account for [CLS], [SEP], [SEP] with "- 3" truncate_seq_pair(tokens_a, tokens_b, self.max_seq_length - 3) tokens = ["[CLS]"] + tokens_a + ["[SEP]"] segment_ids = [0] * (len(tokens_a) + 2) assert len(tokens) == len(segment_ids) tokens += tokens_b + ["[SEP]"] segment_ids += [1] * (len(tokens_b) + 1) if len(tokens) < self.max_seq_length: if len(tokens_m) > self.max_seq_length - len(tokens) - 1: tokens_m = tokens_m[:self.max_seq_length - len(tokens) - 1] tokens += tokens_m + ['[SEP]'] segment_ids += [0] * (len(tokens_m) + 1) input_ids = self.tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. padding = [0] * (self.max_seq_length - len(input_ids)) input_ids += padding input_mask += padding segment_ids += padding assert len(input_ids) == self.max_seq_length assert len(input_mask) == self.max_seq_length assert len(segment_ids) == self.max_seq_length feature = InputFeatures(file=example.file, turn=example.turn, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=example.label) return feature def predict(self, state): example, kb = self.gen_example(state) feature = self.gen_feature(example) input_ids = torch.tensor([feature.input_ids], dtype=torch.long).to(self.device) input_masks = torch.tensor([feature.input_mask], dtype=torch.long).to(self.device) segment_ids = torch.tensor([feature.segment_ids], dtype=torch.long).to(self.device) with torch.no_grad(): logits = self.model(input_ids, segment_ids, input_masks, labels=None) logits = torch.sigmoid(logits) preds = (logits > 0.4).float() preds_numpy = preds.cpu().nonzero().squeeze().numpy() # for i in preds_numpy: # if i < 10: # print(Constants.domains[i], end=' ') # elif i < 17: # print(Constants.functions[i-10], end=' ') # else: # print(Constants.arguments[i-17], end=' ') # print() return preds, kb
class RuleBasedMultiwozBot(Policy): ''' Rule-based bot. Implemented for Multiwoz dataset. ''' recommend_flag = -1 choice = "" def __init__(self): Policy.__init__(self) self.last_state = {} self.db = Database() def init_session(self): self.last_state = {} def predict(self, state): """ Args: State, please refer to util/state.py Output: DA(Dialog Act), in the form of {act_type1: [[slot_name_1, value_1], [slot_name_2, value_2], ...], ...} """ # print('policy received state: {}'.format(state)) if self.recommend_flag != -1: self.recommend_flag += 1 self.kb_result = {} DA = {} if 'user_action' in state and (len(state['user_action']) > 0): user_action = {} for da in state['user_action']: i, d, s, v = da k = '-'.join((d, i)) if k not in user_action: user_action[k] = [] user_action[k].append([s, v]) else: user_action = check_diff(self.last_state, state) # Debug info for check_diff function last_state_cpy = copy.deepcopy(self.last_state) state_cpy = copy.deepcopy(state) try: del last_state_cpy['history'] except: pass try: del state_cpy['history'] except: pass ''' if last_state_cpy != state_cpy: print("Last state: ", last_state_cpy) print("State: ", state_cpy) print("Predicted action: ", user_action) ''' self.last_state = state for user_act in user_action: domain, intent_type = user_act.split('-') # Respond to general greetings if domain == 'general': self._update_greeting(user_act, state, DA) # Book taxi for user elif domain == 'Taxi': self._book_taxi(user_act, state, DA) elif domain == 'Booking': self._update_booking(user_act, state, DA) # User's talking about other domain elif domain != "Train": self._update_DA(user_act, user_action, state, DA) # Info about train else: self._update_train(user_act, user_action, state, DA) # Judge if user want to book self._judge_booking(user_act, user_action, DA) if 'Booking-Book' in DA: if random.random() < 0.5: DA['general-reqmore'] = [] user_acts = [] for user_act in DA: if user_act != 'Booking-Book': user_acts.append(user_act) for user_act in user_acts: del DA[user_act] # print("Sys action: ", DA) if DA == {}: DA = {'general-greet': [['none', 'none']]} tuples = [] for domain_intent, svs in DA.items(): domain, intent = domain_intent.split('-') if not svs: tuples.append([intent, domain, 'none', 'none']) else: for slot, value in svs: tuples.append([intent, domain, slot, value]) state['system_action'] = tuples return tuples def _update_greeting(self, user_act, state, DA): """ General request / inform. """ _, intent_type = user_act.split('-') # Respond to goodbye if intent_type == 'bye': if 'general-bye' not in DA: DA['general-bye'] = [] if random.random() < 0.3: if 'general-welcome' not in DA: DA['general-welcome'] = [] elif intent_type == 'thank': DA['general-welcome'] = [] def _book_taxi(self, user_act, state, DA): """ Book a taxi for user. """ blank_info = [] for info in ['departure', 'destination']: if state['belief_state']['taxi']['semi'] == "": info = REF_USR_DA['Taxi'].get(info, info) blank_info.append(info) if state['belief_state']['taxi']['semi']['leaveAt'] == "" and state[ 'belief_state']['taxi']['semi']['arriveBy'] == "": blank_info += ['Leave', 'Arrive'] # Finish booking, tell user car type and phone number if len(blank_info) == 0: if 'Taxi-Inform' not in DA: DA['Taxi-Inform'] = [] car = generate_car() phone_num = generate_phone_num(11) DA['Taxi-Inform'].append(['Car', car]) DA['Taxi-Inform'].append(['Phone', phone_num]) return # Need essential info to finish booking request_num = random.randint(0, 999999) % len(blank_info) + 1 if 'Taxi-Request' not in DA: DA['Taxi-Request'] = [] for i in range(request_num): slot = REF_USR_DA.get(blank_info[i], blank_info[i]) DA['Taxi-Request'].append([slot, '?']) def _update_booking(self, user_act, state, DA): pass def _update_DA(self, user_act, user_action, state, DA): """ Answer user's utterance about any domain other than taxi or train. """ domain, intent_type = user_act.split('-') constraints = [] for slot in state['belief_state'][domain.lower()]['semi']: if state['belief_state'][domain.lower()]['semi'][slot] != "": constraints.append([ slot, state['belief_state'][domain.lower()]['semi'][slot] ]) kb_result = self.db.query(domain.lower(), constraints) self.kb_result[domain] = deepcopy(kb_result) # print("\tConstraint: " + "{}".format(constraints)) # print("\tCandidate Count: " + "{}".format(len(kb_result))) # if len(kb_result) > 0: # print("Candidate: " + "{}".format(kb_result[0])) # print(state['user_action']) # Respond to user's request if intent_type == 'Request': if self.recommend_flag > 1: self.recommend_flag = -1 self.choice = "" elif self.recommend_flag == 1: self.recommend_flag == 0 if (domain + "-Inform") not in DA: DA[domain + "-Inform"] = [] for slot in user_action[user_act]: if len(kb_result) > 0: kb_slot_name = REF_SYS_DA[domain].get(slot[0], slot[0]) if kb_slot_name in kb_result[0]: DA[domain + "-Inform"].append( [slot[0], kb_result[0][kb_slot_name]]) else: DA[domain + "-Inform"].append([slot[0], "unknown"]) # DA[domain + "-Inform"].append([slot_name, state['kb_results_dict'][0][slot[0].lower()]]) else: # There's no result matching user's constraint # if len(state['kb_results_dict']) == 0: if len(kb_result) == 0: if (domain + "-NoOffer") not in DA: DA[domain + "-NoOffer"] = [] for slot in state['belief_state'][domain.lower()]['semi']: if state['belief_state'][domain.lower()]['semi'][slot] != "" and \ state['belief_state'][domain.lower()]['semi'][slot] != "do n't care": slot_name = REF_USR_DA[domain].get(slot, slot) DA[domain + "-NoOffer"].append([ slot_name, state['belief_state'][domain.lower()]['semi'][slot] ]) p = random.random() # Ask user if he wants to change constraint if p < 0.3: req_num = min( random.randint(0, 999999) % len(DA[domain + "-NoOffer"]) + 1, 3) if domain + "-Request" not in DA: DA[domain + "-Request"] = [] for i in range(req_num): slot_name = REF_USR_DA[domain].get( DA[domain + "-NoOffer"][i][0], DA[domain + "-NoOffer"][i][0]) DA[domain + "-Request"].append([slot_name, "?"]) # There's exactly one result matching user's constraint # elif len(state['kb_results_dict']) == 1: elif len(kb_result) == 1: # Inform user about this result if (domain + "-Inform") not in DA: DA[domain + "-Inform"] = [] props = [] for prop in state['belief_state'][domain.lower()]['semi']: props.append(prop) property_num = len(props) if property_num > 0: info_num = random.randint(0, 999999) % property_num + 1 random.shuffle(props) for i in range(info_num): slot_name = REF_USR_DA[domain].get(props[i], props[i]) # DA[domain + "-Inform"].append([slot_name, state['kb_results_dict'][0][props[i]]]) DA[domain + "-Inform"].append( [slot_name, kb_result[0][props[i]]]) # There are multiple resultes matching user's constraint else: p = random.random() # Recommend a choice from kb_list if True: # p < 0.3: if (domain + "-Inform") not in DA: DA[domain + "-Inform"] = [] if (domain + "-Recommend") not in DA: DA[domain + "-Recommend"] = [] DA[domain + "-Inform"].append( ["Choice", str(len(kb_result))]) idx = random.randint(0, 999999) % len(kb_result) # idx = 0 choice = kb_result[idx] if domain in [ "Hotel", "Attraction", "Police", "Restaurant" ]: DA[domain + "-Recommend"].append( ['Name', choice['name']]) self.recommend_flag = 0 self.candidate = choice props = [] for prop in choice: props.append([prop, choice[prop]]) prop_num = min(random.randint(0, 999999) % 3, len(props)) # prop_num = min(2, len(props)) random.shuffle(props) for i in range(prop_num): slot = props[i][0] string = REF_USR_DA[domain].get(slot, slot) if string in INFORMABLE_SLOTS: DA[domain + "-Recommend"].append( [string, str(props[i][1])]) # Ask user to choose a candidate. elif p < 0.5: prop_values = [] props = [] # for prop in state['kb_results_dict'][0]: for prop in kb_result[0]: # for candidate in state['kb_results_dict']: for candidate in kb_result: if prop not in candidate: continue if candidate[prop] not in prop_values: prop_values.append(candidate[prop]) if len(prop_values) > 1: props.append([prop, prop_values]) prop_values = [] random.shuffle(props) idx = 0 while idx < len(props): if props[idx][0] not in SELECTABLE_SLOTS[domain]: props.pop(idx) idx -= 1 idx += 1 if domain + "-Select" not in DA: DA[domain + "-Select"] = [] for i in range(min(len(props[0][1]), 5)): prop_value = REF_USR_DA[domain].get( props[0][0], props[0][0]) DA[domain + "-Select"].append( [prop_value, props[0][1][i]]) # Ask user for more constraint else: reqs = [] for prop in state['belief_state'][domain.lower()]['semi']: if state['belief_state'][ domain.lower()]['semi'][prop] == "": prop_value = REF_USR_DA[domain].get(prop, prop) reqs.append([prop_value, "?"]) i = 0 while i < len(reqs): if reqs[i][0] not in REQUESTABLE_SLOTS: reqs.pop(i) i -= 1 i += 1 random.shuffle(reqs) if len(reqs) == 0: return req_num = min(random.randint(0, 999999) % len(reqs) + 1, 2) if (domain + "-Request") not in DA: DA[domain + "-Request"] = [] for i in range(req_num): req = reqs[i] req[0] = REF_USR_DA[domain].get(req[0], req[0]) DA[domain + "-Request"].append(req) def _update_train(self, user_act, user_action, state, DA): trans = { 'day': 'Day', 'destination': 'Destination', 'departure': 'Departure' } constraints = [] for time in ['leaveAt', 'arriveBy']: if state['belief_state']['train']['semi'][time] != "": constraints.append( [time, state['belief_state']['train']['semi'][time]]) if len(constraints) == 0: p = random.random() if 'Train-Request' not in DA: DA['Train-Request'] = [] if p < 0.33: DA['Train-Request'].append(['Leave', '?']) elif p < 0.66: DA['Train-Request'].append(['Arrive', '?']) else: DA['Train-Request'].append(['Leave', '?']) DA['Train-Request'].append(['Arrive', '?']) if 'Train-Request' not in DA: DA['Train-Request'] = [] for prop in ['day', 'destination', 'departure']: if state['belief_state']['train']['semi'][prop] == "": slot = REF_USR_DA['Train'].get(prop, prop) DA["Train-Request"].append([slot, '?']) else: constraints.append( [prop, state['belief_state']['train']['semi'][prop]]) kb_result = self.db.query('train', constraints) self.kb_result['Train'] = deepcopy(kb_result) # print(constraints) # print(len(kb_result)) if user_act == 'Train-Request': del (DA['Train-Request']) if 'Train-Inform' not in DA: DA['Train-Inform'] = [] for slot in user_action[user_act]: # Train_DA_MAP = {'Duration': "Time", 'Price': 'Ticket', 'TrainID': 'Id'} # slot[0] = Train_DA_MAP.get(slot[0], slot[0]) slot_name = REF_SYS_DA['Train'].get(slot[0], slot[0]) try: DA['Train-Inform'].append( [slot[0], kb_result[0][slot_name]]) except: pass return if len(kb_result) == 0: if 'Train-NoOffer' not in DA: DA['Train-NoOffer'] = [] for prop in constraints: DA['Train-NoOffer'].append( [REF_USR_DA['Train'].get(prop[0], prop[0]), prop[1]]) if 'Train-Request' in DA: del DA['Train-Request'] elif len(kb_result) >= 1: if len(constraints) < 4: return if 'Train-Request' in DA: del DA['Train-Request'] if 'Train-OfferBook' not in DA: DA['Train-OfferBook'] = [] for prop in constraints: DA['Train-OfferBook'].append( [REF_USR_DA['Train'].get(prop[0], prop[0]), prop[1]]) def _judge_booking(self, user_act, user_action, DA): """ If user want to book, return a ref number. """ if self.recommend_flag > 1: self.recommend_flag = -1 self.choice = "" elif self.recommend_flag == 1: self.recommend_flag == 0 domain, _ = user_act.split('-') for slot in user_action[user_act]: if domain in booking_info and slot[0] in booking_info[domain]: if 'Booking-Book' not in DA: if domain in self.kb_result and len( self.kb_result[domain]) > 0: if 'Ref' in self.kb_result[domain][0]: DA['Booking-Book'] = [[ "Ref", self.kb_result[domain][0]['Ref'] ]] else: DA['Booking-Book'] = [["Ref", "N/A"]]
class MultiWozEvaluator(Evaluator): def __init__(self): self.sys_da_array = [] self.usr_da_array = [] self.goal = {} self.cur_domain = '' self.booked = {} self.database = Database() self.dbs = self.database.dbs def _init_dict(self): dic = {} for domain in belief_domains: dic[domain] = {'info': {}, 'book': {}, 'reqt': []} return dic def _init_dict_booked(self): dic = {} for domain in belief_domains: dic[domain] = None return dic def _expand(self, _goal): goal = deepcopy(_goal) for domain in belief_domains: if domain not in goal: goal[domain] = {'info': {}, 'book': {}, 'reqt': []} continue if 'info' not in goal[domain]: goal[domain]['info'] = {} if 'book' not in goal[domain]: goal[domain]['book'] = {} if 'reqt' not in goal[domain]: goal[domain]['reqt'] = [] return goal def add_goal(self, goal): """init goal and array args: goal: dict[domain] dict['info'/'book'/'reqt'] dict/dict/list[slot] """ self.sys_da_array = [] self.usr_da_array = [] self.goal = goal self.cur_domain = '' self.booked = self._init_dict_booked() def add_sys_da(self, da_turn): """add sys_da into array args: da_turn: list[intent, domain, slot, value] """ for intent, domain, slot, value in da_turn: dom_int = '-'.join([domain, intent]) domain = dom_int.split('-')[0].lower() if domain in belief_domains and domain != self.cur_domain: self.cur_domain = domain da = (dom_int + '-' + slot).lower() value = str(value) self.sys_da_array.append(da + '-' + value) if da == 'booking-book-ref' and self.cur_domain in [ 'hotel', 'restaurant', 'train' ]: if not self.booked[self.cur_domain] and re.match(r'^\d{8}$', value) and \ len(self.dbs[self.cur_domain]) > int(value): self.booked[self.cur_domain] = self.dbs[self.cur_domain][ int(value)].copy() self.booked[self.cur_domain]['Ref'] = value elif da == 'train-offerbooked-ref' or da == 'train-inform-ref': if not self.booked['train'] and re.match( r'^\d{8}$', value) and len(self.dbs['train']) > int(value): self.booked['train'] = self.dbs['train'][int(value)].copy() self.booked['train']['Ref'] = value elif da == 'taxi-inform-car': if not self.booked['taxi']: self.booked['taxi'] = 'booked' def add_usr_da(self, da_turn): """add usr_da into array args: da_turn: list[intent, domain, slot, value] """ for intent, domain, slot, value in da_turn: dom_int = '-'.join([domain, intent]) domain = dom_int.split('-')[0].lower() if domain in belief_domains and domain != self.cur_domain: self.cur_domain = domain da = (dom_int + '-' + slot).lower() value = str(value) self.usr_da_array.append(da + '-' + value) def _book_rate_goal(self, goal, booked_entity, domains=None): """ judge if the selected entity meets the constraint """ if domains is None: domains = belief_domains score = [] for domain in domains: if 'book' in goal[domain] and goal[domain]['book']: tot = len(goal[domain]['info'].keys()) if tot == 0: continue entity = booked_entity[domain] if entity is None: score.append(0) continue if domain == 'taxi': score.append(1) continue match = 0 for k, v in goal[domain]['info'].items(): if k in ['destination', 'departure']: tot -= 1 elif k == 'leaveAt': try: v_constraint = int(v.split(':')[0]) * 100 + int( v.split(':')[1]) v_select = int( entity['leaveAt'].split(':')[0]) * 100 + int( entity['leaveAt'].split(':')[1]) if v_constraint <= v_select: match += 1 except (ValueError, IndexError): match += 1 elif k == 'arriveBy': try: v_constraint = int(v.split(':')[0]) * 100 + int( v.split(':')[1]) v_select = int( entity['arriveBy'].split(':')[0]) * 100 + int( entity['arriveBy'].split(':')[1]) if v_constraint >= v_select: match += 1 except (ValueError, IndexError): match += 1 else: if v.strip() == entity[k].strip(): match += 1 if tot != 0: score.append(match / tot) return score def _inform_F1_goal(self, goal, sys_history, domains=None): """ judge if all the requested information is answered """ if domains is None: domains = belief_domains inform_slot = {} for domain in domains: inform_slot[domain] = set() TP, FP, FN = 0, 0, 0 inform_not_reqt = set() reqt_not_inform = set() bad_inform = set() for da in sys_history: domain, intent, slot, value = da.split('-', 3) if intent in ['inform', 'recommend', 'offerbook', 'offerbooked'] and \ domain in domains and slot in mapping[domain] and value.strip() not in NUL_VALUE: key = mapping[domain][slot] if self._check_value(domain, key, value): # print('add key', key) inform_slot[domain].add(key) else: bad_inform.add((intent, domain, key)) FP += 1 for domain in domains: for k in goal[domain]['reqt']: if k in inform_slot[domain]: # print('k: ', k) TP += 1 else: # print('FN + 1') reqt_not_inform.add(('request', domain, k)) FN += 1 for k in inform_slot[domain]: # exclude slots that are informed by users if k not in goal[domain]['reqt'] \ and k not in goal[domain]['info'] \ and k in requestable[domain]: # print('FP + 1 @2', k) inform_not_reqt.add(( 'inform', domain, k, )) FP += 1 return TP, FP, FN, bad_inform, reqt_not_inform, inform_not_reqt def _check_value(self, domain, key, value): if key == "area": return value.lower() in [ "centre", "east", "south", "west", "north" ] elif key == "arriveBy" or key == "leaveAt": return time_re.match(value) elif key == "day": return value.lower() in [ "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday" ] elif key == "duration": return 'minute' in value elif key == "internet" or key == "parking": return value in ["yes", "no", "none"] elif key == "phone": return re.match(r'^\d{11}$', value) or domain == "restaurant" elif key == "price": return 'pound' in value elif key == "pricerange": return value in ["cheap", "expensive", "moderate", "free" ] or domain == "attraction" elif key == "postcode": return re.match(r'^cb\d{1,3}[a-z]{2,3}$', value) or value == 'pe296fl' elif key == "stars": return re.match(r'^\d$', value) elif key == "trainID": return re.match(r'^tr\d{4}$', value.lower()) else: return True def book_rate(self, ref2goal=True, aggregate=True): if ref2goal: goal = self._expand(self.goal) else: goal = self._init_dict() for domain in belief_domains: if domain in self.goal and 'book' in self.goal[domain]: goal[domain]['book'] = self.goal[domain]['book'] for da in self.usr_da_array: d, i, s, v = da.split('-', 3) if i in ['inform', 'recommend', 'offerbook', 'offerbooked' ] and s in mapping[d]: goal[d]['info'][mapping[d][s]] = v score = self._book_rate_goal(goal, self.booked) if aggregate: return np.mean(score) if score else None else: return score def inform_F1(self, ref2goal=True, aggregate=True): if ref2goal: goal = self._expand(self.goal) else: goal = self._init_dict() for da in self.usr_da_array: d, i, s, v = da.split('-', 3) if i in ['inform', 'recommend', 'offerbook', 'offerbooked' ] and s in mapping[d]: goal[d]['info'][mapping[d][s]] = v elif i == 'request': goal[d]['reqt'].append(s) TP, FP, FN, _, _, _ = self._inform_F1_goal(goal, self.sys_da_array) if aggregate: try: rec = TP / (TP + FN) except ZeroDivisionError: return None, None, None try: prec = TP / (TP + FP) F1 = 2 * prec * rec / (prec + rec) except ZeroDivisionError: return 0, rec, 0 return prec, rec, F1 else: return [TP, FP, FN] def task_success(self, ref2goal=True): """ judge if all the domains are successfully completed """ book_sess = self.book_rate(ref2goal) inform_sess = self.inform_F1(ref2goal) goal_sess = self.final_goal_analyze() # book rate == 1 & inform recall == 1 if ((book_sess == 1 and inform_sess[1] == 1) \ or (book_sess == 1 and inform_sess[1] is None) \ or (book_sess is None and inform_sess[1] == 1)) \ and goal_sess == 1: return 1 else: return 0 def domain_reqt_inform_analyze(self, domain, ref2goal=True): if domain not in self.goal: return None if ref2goal: goal = {} goal[domain] = self._expand(self.goal)[domain] else: goal = {} goal[domain] = {'info': {}, 'book': {}, 'reqt': []} if 'book' in self.goal[domain]: goal[domain]['book'] = self.goal[domain]['book'] for da in self.usr_da_array: d, i, s, v = da.split('-', 3) if d != domain: continue if i in ['inform', 'recommend', 'offerbook', 'offerbooked' ] and s in mapping[d]: goal[d]['info'][mapping[d][s]] = v elif i == 'request': goal[d]['reqt'].append(s) inform = self._inform_F1_goal(goal, self.sys_da_array, [domain]) return inform def domain_success(self, domain, ref2goal=True): """ judge if the domain (subtask) is successfully completed """ if domain not in self.goal: return None if ref2goal: goal = {} goal[domain] = self._expand(self.goal)[domain] else: goal = {} goal[domain] = {'info': {}, 'book': {}, 'reqt': []} if 'book' in self.goal[domain]: goal[domain]['book'] = self.goal[domain]['book'] for da in self.usr_da_array: d, i, s, v = da.split('-', 3) if d != domain: continue if i in ['inform', 'recommend', 'offerbook', 'offerbooked' ] and s in mapping[d]: goal[d]['info'][mapping[d][s]] = v elif i == 'request': goal[d]['reqt'].append(s) book_rate = self._book_rate_goal(goal, self.booked, [domain]) book_rate = np.mean(book_rate) if book_rate else None inform = self._inform_F1_goal(goal, self.sys_da_array, [domain]) try: inform_rec = inform[0] / (inform[0] + inform[2]) except ZeroDivisionError: inform_rec = None if (book_rate == 1 and inform_rec == 1) \ or (book_rate == 1 and inform_rec is None) \ or (book_rate is None and inform_rec == 1): return 1 else: return 0 def _final_goal_analyze(self): """whether the final goal satisfies constraints""" match = mismatch = 0 for domain, dom_goal_dict in self.goal.items(): constraints = [] if 'reqt' in dom_goal_dict: reqt_constraints = list(dom_goal_dict['reqt'].items()) constraints += reqt_constraints else: reqt_constraints = [] if 'info' in dom_goal_dict: info_constraints = list(dom_goal_dict['info'].items()) constraints += info_constraints else: info_constraints = [] query_result = self.database.query( domain, info_constraints, soft_contraints=reqt_constraints) if not query_result: mismatch += 1 continue booked = self.booked[domain] if not self.goal[domain].get('book'): match += 1 elif isinstance(booked, dict): ref = booked['Ref'] if any(found['Ref'] == ref for found in query_result): match += 1 else: mismatch += 1 else: match += 1 return match, mismatch def final_goal_analyze(self): """percentage of domains, in which the final goal satisfies the database constraints. If there is no dialog action, returns 1.""" match, mismatch = self._final_goal_analyze() if match == mismatch == 0: return 1 else: return match / (match + mismatch)