def __init__(self): self.nlp = spacy.load('en_core_web_sm') self.db = MultiWozDB(cfg.dbs) data_path = 'data/multi-woz/annotated_user_da_with_span_full.json' archive = zipfile.ZipFile(data_path + '.zip', 'r') self.convlab_data = json.loads( archive.open(data_path.split('/')[-1], 'r').read().decode('utf-8').lower()) self.delex_sg_valdict_path = 'data/multi-woz-processed/delex_single_valdict.json' self.delex_mt_valdict_path = 'data/multi-woz-processed/delex_multi_valdict.json' self.ambiguous_val_path = 'data/multi-woz-processed/ambiguous_values.json' self.delex_refs_path = 'data/multi-woz-processed/reference_no.json' self.delex_refs = json.loads(open(self.delex_refs_path, 'r').read()) if not os.path.exists(self.delex_sg_valdict_path): self.delex_sg_valdict, self.delex_mt_valdict, self.ambiguous_vals = self.get_delex_valdict( ) else: self.delex_sg_valdict = json.loads( open(self.delex_sg_valdict_path, 'r').read()) self.delex_mt_valdict = json.loads( open(self.delex_mt_valdict_path, 'r').read()) self.ambiguous_vals = json.loads( open(self.ambiguous_val_path, 'r').read()) self.vocab = utils.Vocab(cfg.vocab_size)
def __init__(self): self.nlp = spacy.load('en_core_web_sm') self.db = MultiWozDB(cfg.dbs) data_path = 'data/multi-woz/annotated_user_da_with_span_full.json' archive = zipfile.ZipFile(data_path + '.zip', 'r') self.convlab_data = json.loads(archive.open(data_path.split('/')[-1], 'r').read().lower()) self.delex_sg_valdict_path = 'data/multi-woz-processed/delex_single_valdict.json' self.delex_mt_valdict_path = 'data/multi-woz-processed/delex_multi_valdict.json' self.ambiguous_val_path = 'data/multi-woz-processed/ambiguous_values.json' self.delex_refs_path = 'data/multi-woz-processed/reference_no.json' self.delex_refs = json.loads(open(self.delex_refs_path, 'r').read()) if not os.path.exists(self.delex_sg_valdict_path): self.delex_sg_valdict, self.delex_mt_valdict, self.ambiguous_vals = self.get_delex_valdict() else: self.delex_sg_valdict = json.loads(open(self.delex_sg_valdict_path, 'r').read()) self.delex_mt_valdict = json.loads(open(self.delex_mt_valdict_path, 'r').read()) self.ambiguous_vals = json.loads(open(self.ambiguous_val_path, 'r').read()) self.vocab = utils.Vocab(cfg.vocab_size) self.slot_list = [ 'hotel-pricerange', 'hotel-type', 'hotel-parking', 'hotel-stay', 'hotel-day', 'hotel-people', \ 'hotel-area', 'hotel-stars', 'hotel-internet', 'train-destination', 'train-day', 'train-departure', 'train-arrive', \ 'train-people', 'train-leave', 'attraction-area', 'restaurant-food', 'restaurant-pricerange', 'restaurant-area', \ 'attraction-name', 'restaurant-name', 'attraction-type', 'hotel-name', 'taxi-leave', 'taxi-destination', 'taxi-departure', \ 'restaurant-time', 'restaurant-day', 'restaurant-people', 'taxi-arrive', "hospital-department" ] self.gating_dict = {'ptr': 0, 'dontcare': 1, 'none': 2}
def __init__(self, vocab=None): super().__init__() self.nlp = spacy.load('en_core_web_sm') self.db = MultiWozDB(cfg.dbs) self.domain_files = json.loads(open(cfg.domain_file_path, 'r').read()) self.slot_value_set = json.loads( open(cfg.slot_value_set_path, 'r').read()) if cfg.multi_acts_training: self.multi_acts = json.loads(open(cfg.multi_acts_path, 'r').read()) test_list = [ l.strip().lower() for l in open(cfg.test_list, 'r').readlines() ] dev_list = [ l.strip().lower() for l in open(cfg.dev_list, 'r').readlines() ] self.dev_files, self.test_files = {}, {} for fn in test_list: self.test_files[fn.replace('.json', '')] = 1 for fn in dev_list: self.dev_files[fn.replace('.json', '')] = 1 self.exp_files = {} if 'all' not in cfg.exp_domains: for domain in cfg.exp_domains: fn_list = self.domain_files.get(domain) if not fn_list: raise ValueError('[%s] is an invalid experiment setting' % domain) for fn in fn_list: self.exp_files[fn.replace('.json', '')] = 1 if vocab: self.vocab = vocab self.vocab_size = vocab.size else: self.vocab_size = self._build_vocab() self._load_data() if cfg.limit_bspn_vocab: self.bspn_masks = self._construct_bspn_constraint() if cfg.limit_aspn_vocab: self.aspn_masks = self._construct_aspn_constraint() self.multi_acts_record = None
class DataPreprocessor(object): def __init__(self): self.nlp = spacy.load('en_core_web_sm') self.db = MultiWozDB(cfg.dbs) data_path = 'data/multi-woz/annotated_user_da_with_span_full.json' archive = zipfile.ZipFile(data_path + '.zip', 'r') self.convlab_data = json.loads( archive.open(data_path.split('/')[-1], 'r').read().decode('utf-8').lower()) self.delex_sg_valdict_path = 'data/multi-woz-processed/delex_single_valdict.json' self.delex_mt_valdict_path = 'data/multi-woz-processed/delex_multi_valdict.json' self.ambiguous_val_path = 'data/multi-woz-processed/ambiguous_values.json' self.delex_refs_path = 'data/multi-woz-processed/reference_no.json' self.delex_refs = json.loads(open(self.delex_refs_path, 'r').read()) if not os.path.exists(self.delex_sg_valdict_path): self.delex_sg_valdict, self.delex_mt_valdict, self.ambiguous_vals = self.get_delex_valdict( ) else: self.delex_sg_valdict = json.loads( open(self.delex_sg_valdict_path, 'r').read()) self.delex_mt_valdict = json.loads( open(self.delex_mt_valdict_path, 'r').read()) self.ambiguous_vals = json.loads( open(self.ambiguous_val_path, 'r').read()) self.vocab = utils.Vocab(cfg.vocab_size) def delex_by_annotation(self, dial_turn): u = dial_turn['text'].split() span = dial_turn['span_info'] for s in span: slot = s[1] if slot == 'open': continue if ontology.da_abbr_to_slot_name.get(slot): slot = ontology.da_abbr_to_slot_name[slot] for idx in range(s[3], s[4] + 1): u[idx] = '' try: u[s[3]] = '[value_' + slot + ']' except: u[5] = '[value_' + slot + ']' u_delex = ' '.join([t for t in u if t is not '']) u_delex = u_delex.replace( '[value_address] , [value_address] , [value_address]', '[value_address]') u_delex = u_delex.replace('[value_address] , [value_address]', '[value_address]') u_delex = u_delex.replace('[value_name] [value_name]', '[value_name]') u_delex = u_delex.replace('[value_name]([value_phone] )', '[value_name] ( [value_phone] )') return u_delex def delex_by_valdict(self, text): text = clean_text(text) text = re.sub(r'\d{5}\s?\d{5,7}', '[value_phone]', text) text = re.sub(r'\d[\s-]stars?', '[value_stars]', text) text = re.sub(r'\$\d+|\$?\d+.?(\d+)?\s(pounds?|gbps?)', '[value_price]', text) text = re.sub(r'tr[\d]{4}', '[value_id]', text) text = re.sub( r'([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})', '[value_postcode]', text) for value, slot in self.delex_mt_valdict.items(): text = text.replace(value, '[value_%s]' % slot) for value, slot in self.delex_sg_valdict.items(): tokens = text.split() for idx, tk in enumerate(tokens): if tk == value: tokens[idx] = '[value_%s]' % slot text = ' '.join(tokens) for ambg_ent in self.ambiguous_vals: start_idx = text.find( ' ' + ambg_ent ) # ely is a place, but appears in words like moderately if start_idx == -1: continue front_words = text[:start_idx].split() ent_type = 'time' if ':' in ambg_ent else 'place' for fw in front_words[::-1]: if fw in [ 'arrive', 'arrives', 'arrived', 'arriving', 'arrival', 'destination', 'there', 'reach', 'to', 'by', 'before' ]: slot = '[value_arrive]' if ent_type == 'time' else '[value_destination]' text = re.sub(' ' + ambg_ent, ' ' + slot, text) elif fw in [ 'leave', 'leaves', 'leaving', 'depart', 'departs', 'departing', 'departure', 'from', 'after', 'pulls' ]: slot = '[value_leave]' if ent_type == 'time' else '[value_departure]' text = re.sub(' ' + ambg_ent, ' ' + slot, text) text = text.replace('[value_car] [value_car]', '[value_car]') return text def get_delex_valdict(self, ): skip_entry_type = { 'taxi': ['taxi_phone'], 'police': ['id'], 'hospital': ['id'], 'hotel': [ 'id', 'location', 'internet', 'parking', 'takesbookings', 'stars', 'price', 'n', 'postcode', 'phone' ], 'attraction': [ 'id', 'location', 'pricerange', 'price', 'openhours', 'postcode', 'phone' ], 'train': ['price', 'id'], 'restaurant': [ 'id', 'location', 'introduction', 'signature', 'type', 'postcode', 'phone' ], } entity_value_to_slot = {} ambiguous_entities = [] for domain, db_data in self.db.dbs.items(): print('Processing entity values in [%s]' % domain) if domain != 'taxi': for db_entry in db_data: for slot, value in db_entry.items(): if slot not in skip_entry_type[domain]: if type(value) is not str: raise TypeError( "value '%s' in domain '%s' should be rechecked" % (slot, domain)) else: slot, value = clean_slot_values( domain, slot, value) value = ' '.join([ token.text for token in self.nlp(value) ]).strip() if value in entity_value_to_slot and entity_value_to_slot[ value] != slot: # print(value, ": ",entity_value_to_slot[value], slot) ambiguous_entities.append(value) entity_value_to_slot[value] = slot else: # taxi db specific db_entry = db_data[0] for slot, ent_list in db_entry.items(): if slot not in skip_entry_type[domain]: for ent in ent_list: entity_value_to_slot[ent] = 'car' ambiguous_entities = set(ambiguous_entities) ambiguous_entities.remove('cambridge') ambiguous_entities = list(ambiguous_entities) for amb_ent in ambiguous_entities: # departure or destination? arrive time or leave time? entity_value_to_slot.pop(amb_ent) entity_value_to_slot['parkside'] = 'address' entity_value_to_slot['parkside, cambridge'] = 'address' entity_value_to_slot['cambridge belfry'] = 'name' entity_value_to_slot['hills road'] = 'address' entity_value_to_slot['hills rd'] = 'address' entity_value_to_slot['Parkside Police Station'] = 'name' single_token_values = {} multi_token_values = {} for val, slt in entity_value_to_slot.items(): if val in ['cambridge']: continue if len(val.split()) > 1: multi_token_values[val] = slt else: single_token_values[val] = slt with open(self.delex_sg_valdict_path, 'w') as f: single_token_values = OrderedDict( sorted(single_token_values.items(), key=lambda kv: len(kv[0]), reverse=True)) json.dump(single_token_values, f, indent=2) print('single delex value dict saved!') with open(self.delex_mt_valdict_path, 'w') as f: multi_token_values = OrderedDict( sorted(multi_token_values.items(), key=lambda kv: len(kv[0]), reverse=True)) json.dump(multi_token_values, f, indent=2) print('multi delex value dict saved!') with open(self.ambiguous_val_path, 'w') as f: json.dump(ambiguous_entities, f, indent=2) print('ambiguous value dict saved!') return single_token_values, multi_token_values, ambiguous_entities def preprocess_main(self, save_path=None, is_test=False): """ """ data = {} count = 0 self.unique_da = {} ordered_sysact_dict = {} for fn, raw_dial in tqdm(list(self.convlab_data.items())): count += 1 # if count == 100: # break compressed_goal = {} dial_domains, dial_reqs = [], [] for dom, g in raw_dial['goal'].items(): if dom != 'topic' and dom != 'message' and g: if g.get('reqt'): for i, req_slot in enumerate(g['reqt']): if ontology.normlize_slot_names.get(req_slot): g['reqt'][i] = ontology.normlize_slot_names[ req_slot] dial_reqs.append(g['reqt'][i]) compressed_goal[dom] = g if dom in ontology.all_domains: dial_domains.append(dom) dial_reqs = list(set(dial_reqs)) dial = {'goal': compressed_goal, 'log': []} single_turn = {} constraint_dict = OrderedDict() prev_constraint_dict = {} prev_turn_domain = ['general'] ordered_sysact_dict[fn] = {} for turn_num, dial_turn in enumerate(raw_dial['log']): dial_state = dial_turn['metadata'] if not dial_state: # user u = ' '.join(clean_text(dial_turn['text']).split()) if dial_turn['span_info']: u_delex = clean_text( self.delex_by_annotation(dial_turn)) else: u_delex = self.delex_by_valdict(dial_turn['text']) single_turn['user'] = u single_turn['user_delex'] = u_delex else: #system if dial_turn['span_info']: s_delex = clean_text( self.delex_by_annotation(dial_turn)) else: if not dial_turn['text']: print(fn) s_delex = self.delex_by_valdict(dial_turn['text']) single_turn['resp'] = s_delex # get belief state for domain in dial_domains: if not constraint_dict.get(domain): constraint_dict[domain] = OrderedDict() info_sv = dial_state[domain]['semi'] for s, v in info_sv.items(): s, v = clean_slot_values(domain, s, v) if len(v.split()) > 1: v = ' '.join([ token.text for token in self.nlp(v) ]).strip() if v != '': constraint_dict[domain][s] = v book_sv = dial_state[domain]['book'] for s, v in book_sv.items(): if s == 'booked': continue s, v = clean_slot_values(domain, s, v) if len(v.split()) > 1: v = ' '.join([ token.text for token in self.nlp(v) ]).strip() if v != '': constraint_dict[domain][s] = v constraints = [] cons_delex = [] turn_dom_bs = [] for domain, info_slots in constraint_dict.items(): if info_slots: constraints.append('[' + domain + ']') cons_delex.append('[' + domain + ']') for slot, value in info_slots.items(): constraints.append(slot) constraints.extend(value.split()) cons_delex.append(slot) if domain not in prev_constraint_dict: turn_dom_bs.append(domain) elif prev_constraint_dict[ domain] != constraint_dict[domain]: turn_dom_bs.append(domain) sys_act_dict = {} turn_dom_da = set() for act in dial_turn['dialog_act']: d, a = act.split('-') turn_dom_da.add(d) turn_dom_da = list(turn_dom_da) if len(turn_dom_da) != 1 and 'general' in turn_dom_da: turn_dom_da.remove('general') if len(turn_dom_da) != 1 and 'booking' in turn_dom_da: turn_dom_da.remove('booking') # get turn domain turn_domain = turn_dom_bs for dom in turn_dom_da: if dom != 'booking' and dom not in turn_domain: turn_domain.append(dom) if not turn_domain: turn_domain = prev_turn_domain if len(turn_domain) == 2 and 'general' in turn_domain: turn_domain.remove('general') if len(turn_domain) == 2: if len(prev_turn_domain) == 1 and prev_turn_domain[ 0] == turn_domain[1]: turn_domain = turn_domain[::-1] # get system action for dom in turn_domain: sys_act_dict[dom] = {} add_to_last_collect = [] booking_act_map = { 'inform': 'offerbook', 'book': 'offerbooked' } for act, params in dial_turn['dialog_act'].items(): if act == 'general-greet': continue d, a = act.split('-') if d == 'general' and d not in sys_act_dict: sys_act_dict[d] = {} if d == 'booking': d = turn_domain[0] a = booking_act_map.get(a, a) add_p = [] for param in params: p = param[0] if p == 'none': continue elif ontology.da_abbr_to_slot_name.get(p): p = ontology.da_abbr_to_slot_name[p] if p not in add_p: add_p.append(p) add_to_last = True if a in [ 'request', 'reqmore', 'bye', 'offerbook' ] else False if add_to_last: add_to_last_collect.append((d, a, add_p)) else: sys_act_dict[d][a] = add_p for d, a, add_p in add_to_last_collect: sys_act_dict[d][a] = add_p for d in copy.copy(sys_act_dict): acts = sys_act_dict[d] if not acts: del sys_act_dict[d] if 'inform' in acts and 'offerbooked' in acts: for s in sys_act_dict[d]['inform']: sys_act_dict[d]['offerbooked'].append(s) del sys_act_dict[d]['inform'] ordered_sysact_dict[fn][len(dial['log'])] = sys_act_dict sys_act = [] if 'general-greet' in dial_turn['dialog_act']: sys_act.extend(['[general]', '[greet]']) for d, acts in sys_act_dict.items(): sys_act += ['[' + d + ']'] for a, slots in acts.items(): self.unique_da[d + '-' + a] = 1 sys_act += ['[' + a + ']'] sys_act += slots # get db pointers matnums = self.db.get_match_num(constraint_dict) match_dom = turn_domain[0] if len( turn_domain) == 1 else turn_domain[1] match = matnums[match_dom] dbvec = self.db.addDBPointer(match_dom, match) bkvec = self.db.addBookingPointer(dial_turn['dialog_act']) single_turn['pointer'] = ','.join( [str(d) for d in dbvec + bkvec]) single_turn['match'] = str(match) single_turn['constraint'] = ' '.join(constraints) single_turn['cons_delex'] = ' '.join(cons_delex) single_turn['sys_act'] = ' '.join(sys_act) single_turn['turn_num'] = len(dial['log']) single_turn['turn_domain'] = ' '.join( ['[' + d + ']' for d in turn_domain]) prev_turn_domain = copy.deepcopy(turn_domain) prev_constraint_dict = copy.deepcopy(constraint_dict) if 'user' in single_turn: dial['log'].append(single_turn) for t in single_turn['user'].split() + single_turn[ 'resp'].split() + constraints + sys_act: self.vocab.add_word(t) for t in single_turn['user_delex'].split(): if '[' in t and ']' in t and not t.startswith( '[') and not t.endswith(']'): single_turn['user_delex'].replace( t, t[t.index('['):t.index(']') + 1]) elif not self.vocab.has_word(t): self.vocab.add_word(t) single_turn = {} data[fn] = dial # pprint(dial) # if count == 20: # break self.vocab.construct() self.vocab.save_vocab('data/multi-woz-processed/vocab') with open('data/multi-woz-analysis/dialog_acts.json', 'w') as f: json.dump(ordered_sysact_dict, f, indent=2) with open('data/multi-woz-analysis/dialog_act_type.json', 'w') as f: json.dump(self.unique_da, f, indent=2) return data
class MultiWozReader(_ReaderBase): def __init__(self): super().__init__() self.nlp = spacy.load('en_core_web_sm') self.db = MultiWozDB(cfg.dbs) self.vocab_size = self._build_vocab() self.domain_files = json.loads(open(cfg.domain_file_path, 'r').read()) self.slot_value_set = json.loads( open(cfg.slot_value_set_path, 'r').read()) if cfg.multi_acts_training: self.multi_acts = json.loads(open(cfg.multi_acts_path, 'r').read()) self.gating_dict = {'ptr': 0, 'dontcare': 1, 'none': 2} test_list = [ l.strip().lower() for l in open(cfg.test_list, 'r').readlines() ] dev_list = [ l.strip().lower() for l in open(cfg.dev_list, 'r').readlines() ] self.dev_files, self.test_files = {}, {} for fn in test_list: self.test_files[fn.replace('.json', '')] = 1 for fn in dev_list: self.dev_files[fn.replace('.json', '')] = 1 self.exp_files = {} if 'all' not in cfg.exp_domains: for domain in cfg.exp_domains: fn_list = self.domain_files.get(domain) if not fn_list: raise ValueError('[%s] is an invalid experiment setting' % domain) for fn in fn_list: self.exp_files[fn.replace('.json', '')] = 1 self._load_data() if cfg.limit_bspn_vocab: self.bspn_masks = self._construct_bspn_constraint() if cfg.limit_aspn_vocab: self.aspn_masks = self._construct_aspn_constraint() self.multi_acts_record = None def _build_vocab(self): self.vocab = utils.Vocab(cfg.vocab_size) vp = cfg.vocab_path_train if cfg.mode == 'train' or cfg.vocab_path_eval is None else cfg.vocab_path_eval # vp = cfg.vocab_path+'.json.freq.json' self.vocab.load_vocab(vp) return self.vocab.vocab_size def _construct_bspn_constraint(self): bspn_masks = {} valid_domains = [ 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital' ] all_dom_codes = [ self.vocab.encode('[' + d + ']') for d in valid_domains ] all_slot_codes = [self.vocab.encode(s) for s in ontology.all_slots] bspn_masks[self.vocab.encode( '<go_b>')] = all_dom_codes + [self.vocab.encode('<eos_b>'), 0] bspn_masks[self.vocab.encode('<eos_b>')] = [self.vocab.encode('<pad>')] bspn_masks[self.vocab.encode('<pad>')] = [self.vocab.encode('<pad>')] for domain, slot_values in self.slot_value_set.items(): if domain == 'police': continue dom_code = self.vocab.encode('[' + domain + ']') bspn_masks[dom_code] = [] for slot, values in slot_values.items(): slot_code = self.vocab.encode(slot) if slot_code not in bspn_masks: bspn_masks[slot_code] = [] if slot_code not in bspn_masks[dom_code]: bspn_masks[dom_code].append(slot_code) for value in values: for idx, v in enumerate(value.split()): if not self.vocab.has_word(v): continue v_code = self.vocab.encode(v) if v_code not in bspn_masks: # print(self.vocab._word2idx) bspn_masks[v_code] = [] if idx == 0 and v_code not in bspn_masks[slot_code]: bspn_masks[slot_code].append(v_code) if idx == (len(value.split()) - 1): for w in all_dom_codes + all_slot_codes: if self.vocab.encode( '<eos_b>') not in bspn_masks[v_code]: bspn_masks[v_code].append( self.vocab.encode('<eos_b>')) if w not in bspn_masks[v_code]: bspn_masks[v_code].append(w) break if not self.vocab.has_word(value.split()[idx + 1]): continue next_v_code = self.vocab.encode(value.split()[idx + 1]) if next_v_code not in bspn_masks[v_code]: bspn_masks[v_code].append(next_v_code) bspn_masks[self.vocab.encode('<unk>')] = list(bspn_masks.keys()) with open('data/multi-woz-processed/bspn_masks.txt', 'w') as f: for i, j in bspn_masks.items(): f.write( self.vocab.decode(i) + ': ' + ' '.join([self.vocab.decode(int(m)) for m in j]) + '\n') return bspn_masks def _construct_aspn_constraint(self): aspn_masks = {} aspn_masks = {} all_dom_codes = [ self.vocab.encode('[' + d + ']') for d in ontology.dialog_acts.keys() ] all_act_codes = [ self.vocab.encode('[' + a + ']') for a in ontology.dialog_act_params ] all_slot_codes = [ self.vocab.encode(s) for s in ontology.dialog_act_all_slots ] aspn_masks[self.vocab.encode( '<go_a>')] = all_dom_codes + [self.vocab.encode('<eos_a>'), 0] aspn_masks[self.vocab.encode('<eos_a>')] = [self.vocab.encode('<pad>')] aspn_masks[self.vocab.encode('<pad>')] = [self.vocab.encode('<pad>')] # for d in all_dom_codes: # aspn_masks[d] = all_act_codes for a in all_act_codes: aspn_masks[a] = all_dom_codes + all_slot_codes + [ self.vocab.encode('<eos_a>') ] for domain, acts in ontology.dialog_acts.items(): dom_code = self.vocab.encode('[' + domain + ']') aspn_masks[dom_code] = [] for a in acts: act_code = self.vocab.encode('[' + a + ']') if act_code not in aspn_masks[dom_code]: aspn_masks[dom_code].append(act_code) # for a, slots in ontology.dialog_act_params.items(): # act_code = self.vocab.encode('['+a+']') # slot_codes = [self.vocab.encode(s) for s in slots] # aspn_masks[act_code] = all_dom_codes + slot_codes + [self.vocab.encode('<eos_a>')] for s in all_slot_codes: aspn_masks[s] = all_dom_codes + all_slot_codes + [ self.vocab.encode('<eos_a>') ] aspn_masks[self.vocab.encode('<unk>')] = list(aspn_masks.keys()) with open('data/multi-woz-processed/aspn_masks.txt', 'w') as f: for i, j in aspn_masks.items(): f.write( self.vocab.decode(i) + ': ' + ' '.join([self.vocab.decode(int(m)) for m in j]) + '\n') return aspn_masks def _load_data(self, save_temp=False): self.data = json.loads( open(cfg.data_path + cfg.data_file, 'r', encoding='utf-8').read().lower()) self.train, self.dev, self.test = [], [], [] for fn, dial in self.data.items(): if 'all' in cfg.exp_domains or self.exp_files.get(fn): if self.dev_files.get(fn): self.dev.append(self._get_encoded_data(fn, dial)) elif self.test_files.get(fn): self.test.append(self._get_encoded_data(fn, dial)) else: self.train.append(self._get_encoded_data(fn, dial)) if save_temp: json.dump(self.test, open('data/multi-woz-analysis/test.encoded.json', 'w'), indent=2) self.vocab.save_vocab('data/multi-woz-analysis/vocab_temp') random.shuffle(self.train) random.shuffle(self.dev) random.shuffle(self.test) def _get_encoded_data(self, fn, dial): encoded_dial = [] for idx, t in enumerate(dial['log']): enc = {} enc['dial_id'] = fn enc['user'] = self.vocab.sentence_encode(t['user'].split() + ['<eos_u>']) enc['usdx'] = self.vocab.sentence_encode(t['user_delex'].split() + ['<eos_u>']) enc['resp'] = self.vocab.sentence_encode(t['resp'].split() + ['<eos_r>']) enc['bspn'] = self.vocab.sentence_encode(t['constraint'].split() + ['<eos_b>']) enc['bsdx'] = self.vocab.sentence_encode(t['cons_delex'].split() + ['<eos_b>']) enc['aspn'] = self.vocab.sentence_encode(t['sys_act'].split() + ['<eos_a>']) enc['dspn'] = self.vocab.sentence_encode(t['turn_domain'].split() + ['<eos_d>']) enc['pointer'] = [int(i) for i in t['pointer'].split(',')] enc['turn_domain'] = t['turn_domain'].split() enc['turn_num'] = t['turn_num'] # TRADE labels enc["gating_label"] = [ self.gating_dict[label] for label in t["gating_label"].split(",") ] enc["ptr_label"] = [] for labels in t["ptr_label"].split(","): temp = [] for label in labels.split(): label_idx = self.vocab.encode(label) if label_idx >= self.vocab.vocab_size: temp.append(self.vocab.encode("<unk>")) else: temp.append(label_idx) temp.append(self.vocab.encode("<eos_b>")) enc["ptr_label"].append(temp) if cfg.multi_acts_training: enc['aspn_aug'] = [] if fn in self.multi_acts: turn_ma = self.multi_acts[fn].get(str(idx), {}) for act_type, act_spans in turn_ma.items(): enc['aspn_aug'].append([ self.vocab.sentence_encode(a.split() + ['<eos_a>']) for a in act_spans ]) encoded_dial.append(enc) return encoded_dial def bspan_to_constraint_dict(self, bspan, bspn_mode='bspn'): bspan = bspan.split() if isinstance(bspan, str) else bspan constraint_dict = {} domain = None conslen = len(bspan) for idx, cons in enumerate(bspan): cons = self.vocab.decode(cons) if type(cons) is not str else cons if cons == '<eos_b>': break if '[' in cons: if cons[1:-1] not in ontology.all_domains: continue domain = cons[1:-1] elif cons in ontology.get_slot: if domain is None: continue if cons == 'people': # handle confusion of value name "people's portraits..." and slot people try: ns = bspan[idx + 1] ns = self.vocab.decode(ns) if type( ns) is not str else ns if ns == "'s": continue except: continue if not constraint_dict.get(domain): constraint_dict[domain] = {} if bspn_mode == 'bsdx': constraint_dict[domain][cons] = 1 continue vidx = idx + 1 if vidx == conslen: break vt_collect = [] vt = bspan[vidx] vt = self.vocab.decode(vt) if type(vt) is not str else vt while vidx < conslen and vt != '<eos_b>' and '[' not in vt and vt not in ontology.get_slot: vt_collect.append(vt) vidx += 1 if vidx == conslen: break vt = bspan[vidx] vt = self.vocab.decode(vt) if type(vt) is not str else vt if vt_collect: constraint_dict[domain][cons] = ' '.join(vt_collect) return constraint_dict def bspan_to_DBpointer(self, bspan, turn_domain): constraint_dict = self.bspan_to_constraint_dict(bspan) # print(constraint_dict) matnums = self.db.get_match_num(constraint_dict) match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1] match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom match = matnums[match_dom] vector = self.db.addDBPointer(match_dom, match) return vector def aspan_to_act_list(self, aspan): aspan = aspan.split() if isinstance(aspan, str) else aspan acts = [] domain = None conslen = len(aspan) for idx, cons in enumerate(aspan): cons = self.vocab.decode(cons) if type(cons) is not str else cons if cons == '<eos_a>': break if '[' in cons and cons[1:-1] in ontology.dialog_acts: domain = cons[1:-1] elif '[' in cons and cons[1:-1] in ontology.dialog_act_params: if domain is None: continue vidx = idx + 1 if vidx == conslen: acts.append(domain + '-' + cons[1:-1] + '-none') break vt = aspan[vidx] vt = self.vocab.decode(vt) if type(vt) is not str else vt no_param_act = True while vidx < conslen and vt != '<eos_a>' and '[' not in vt: no_param_act = False acts.append(domain + '-' + cons[1:-1] + '-' + vt) vidx += 1 if vidx == conslen: break vt = aspan[vidx] vt = self.vocab.decode(vt) if type(vt) is not str else vt if no_param_act: acts.append(domain + '-' + cons[1:-1] + '-none') return acts def dspan_to_domain(self, dspan): domains = {} dspan = dspan.split() if isinstance(dspan, str) else dspan for d in dspan: dom = self.vocab.decode(d) if type(d) is not str else d if dom != '<eos_d>': domains[dom] = 1 else: break return domains def convert_batch(self, py_batch, py_prev, first_turn=False): inputs = {} if first_turn: for item, py_list in py_prev.items(): batch_size = len(py_batch['user']) inputs[item + '_np'] = np.array([[1]] * batch_size) inputs[item + '_unk_np'] = np.array([[1]] * batch_size) else: for item, py_list in py_prev.items(): if py_list is None: continue if not cfg.enable_aspn and 'aspn' in item: continue if not cfg.enable_bspn and 'bspn' in item: continue if not cfg.enable_dspn and 'dspn' in item: continue prev_np = utils.padSeqs(py_list, truncated=cfg.truncated, trunc_method='pre') inputs[item + '_np'] = prev_np if item in ['pv_resp', 'pv_bspn', "pv_aspn"]: inputs[item + '_unk_np'] = deepcopy(inputs[item + '_np']) inputs[item + '_unk_np'][inputs[item + '_unk_np'] >= self.vocab_size] = 2 # <unk> else: inputs[item + '_unk_np'] = inputs[item + '_np'] for item in ['user', 'usdx', 'resp', 'bspn', 'aspn', 'bsdx', 'dspn']: if not cfg.enable_aspn and item == 'aspn': continue if not cfg.enable_bspn and item == 'bspn': continue if not cfg.enable_dspn and item == 'dspn': continue py_list = py_batch[item] trunc_method = 'post' if item == 'resp' else 'pre' # max_length = cfg.max_nl_length if item in ['user', 'usdx', 'resp'] else cfg.max_span_length inputs[item + '_np'] = utils.padSeqs(py_list, truncated=cfg.truncated, trunc_method=trunc_method) if item in ['user', 'usdx', 'resp', 'bspn', "aspn"]: inputs[item + '_unk_np'] = deepcopy(inputs[item + '_np']) inputs[item + '_unk_np'][ inputs[item + '_unk_np'] >= self.vocab_size] = 2 # <unk> else: inputs[item + '_unk_np'] = inputs[item + '_np'] if cfg.multi_acts_training and cfg.mode == 'train': inputs['aspn_bidx'], multi_aspn = [], [] for bidx, aspn_type_list in enumerate(py_batch['aspn_aug']): if aspn_type_list: for aspn_list in aspn_type_list: random.shuffle(aspn_list) aspn = aspn_list[ 0] #choose one random act span in each act type multi_aspn.append(aspn) inputs['aspn_bidx'].append(bidx) if cfg.multi_act_sampling_num > 1: for i in range(cfg.multi_act_sampling_num): if len(aspn_list) >= i + 2: aspn = aspn_list[ i + 1] #choose one random act span in each act type multi_aspn.append(aspn) inputs['aspn_bidx'].append(bidx) if multi_aspn: inputs['aspn_aug_np'] = utils.padSeqs(multi_aspn, truncated=cfg.truncated, trunc_method='pre') inputs['aspn_aug_unk_np'] = inputs[ 'aspn_aug_np'] # [all available aspn num in the batch, T] inputs['db_np'] = np.array(py_batch['pointer']) inputs['turn_domain'] = py_batch['turn_domain'] # TRADE inputs["gating_label_np"] = np.array(py_batch["gating_label"]) batch_size = len(py_batch["ptr_label"]) slot_num = inputs["gating_label_np"].shape[1] ptr_max_len = max( [len(label) for batch in py_batch["ptr_label"] for label in batch]) inputs["ptr_label_np"] = np.zeros((batch_size, slot_num, ptr_max_len)) for bidx, batch in enumerate(py_batch["ptr_label"]): for idx, label in enumerate(batch): label_len = len(label) inputs["ptr_label_np"][bidx, idx, :label_len] = np.array(label) return inputs def wrap_result(self, result_dict, eos_syntax=None): decode_fn = self.vocab.sentence_decode results = [] eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax gating_dict = {'ptr': 0, 'dontcare': 1, 'none': 2} slot_list = [ 'hotel-pricerange', 'hotel-type', 'hotel-parking', 'hotel-stay', 'hotel-day', 'hotel-people', \ 'hotel-area', 'hotel-stars', 'hotel-internet', 'train-destination', 'train-day', 'train-departure', 'train-arriveby', \ 'train-people', 'train-leaveat', 'attraction-area', 'restaurant-food', 'restaurant-pricerange', 'restaurant-area', \ 'attraction-name', 'restaurant-name', 'attraction-type', 'hotel-name', 'taxi-leaveat', 'taxi-destination', 'taxi-departure', \ 'restaurant-time', 'restaurant-day', 'restaurant-people', 'taxi-arriveby', "hospital-department" ] if cfg.enable_trade: field = [ 'dial_id', 'turn_num', 'user', 'bspn_gen', 'bspn', 'resp_gen', 'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'pointer', "gating_label", "trade_gate", "ptr_label", "trade_ptr" ] elif cfg.bspn_mode == 'bspn': field = [ 'dial_id', 'turn_num', 'user', 'bspn_gen', 'bspn', 'resp_gen', 'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'pointer' ] elif not cfg.enable_dst: field = [ 'dial_id', 'turn_num', 'user', 'bsdx_gen', 'bsdx', 'resp_gen', 'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn', 'pointer' ] else: field = [ 'dial_id', 'turn_num', 'user', 'bsdx_gen', 'bsdx', 'resp_gen', 'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn_gen', 'bspn', 'pointer' ] if self.multi_acts_record is not None: field.insert(7, 'multi_act_gen') for dial_id, turns in result_dict.items(): entry = {'dial_id': dial_id, 'turn_num': len(turns)} for prop in field[2:]: entry[prop] = '' results.append(entry) # info of dialogue for turn_no, turn in enumerate(turns): entry = {'dial_id': dial_id} trade_gate = [] for key in field: if key in ['dial_id']: continue v = turn.get(key, '') if key == "trade_gate": trade_gate = v entry[key] = v continue elif key == "trade_ptr": trade_ptr = [] for idx, slot in enumerate(v): if trade_gate[idx].item() == 0: trade_ptr.append( decode_fn(slot.tolist(), eos=eos_syntax["trade"])) elif trade_gate[idx].item() == 1: trade_ptr.append("do n't care") else: trade_ptr.append("none") entry[key] = trade_ptr continue elif key == "ptr_label": ptr_label = [] for slot in v: ptr_label.append( decode_fn(slot, eos=eos_syntax["trade"])) entry[key] = ptr_label continue if key == 'turn_domain': v = ' '.join(v) entry[key] = decode_fn( v, eos=eos_syntax[key] ) if key in eos_syntax and v != '' else v results.append(entry) return results, field def restore(self, resp, domain, constraint_dict, mat_ents): restored = resp restored = restored.replace('[value_reference]', '53022') restored = restored.replace('[value_car]', 'BMW') # restored.replace('[value_phone]', '830-430-6666') for d in domain: constraint = constraint_dict.get(d, None) if constraint: if 'stay' in constraint: restored = restored.replace('[value_stay]', constraint['stay']) if 'day' in constraint: restored = restored.replace('[value_day]', constraint['day']) if 'people' in constraint: restored = restored.replace('[value_people]', constraint['people']) if 'time' in constraint: restored = restored.replace('[value_time]', constraint['time']) if 'type' in constraint: restored = restored.replace('[value_type]', constraint['type']) if d in mat_ents and len(mat_ents[d]) == 0: for s in constraint: if s == 'pricerange' and d in [ 'hotel', 'restaurant' ] and 'price]' in restored: restored = restored.replace( '[value_price]', constraint['pricerange']) if s + ']' in restored: restored = restored.replace( '[value_%s]' % s, constraint[s]) if '[value_choice' in restored and mat_ents.get(d): restored = restored.replace('[value_choice]', str(len(mat_ents[d]))) if '[value_choice' in restored: restored = restored.replace('[value_choice]', '3') # restored.replace('[value_car]', 'BMW') try: ent = mat_ents.get(domain[-1], []) if ent: ent = ent[0] for t in restored.split(): if '[value' in t: slot = t[7:-1] if ent.get(slot): if domain[-1] == 'hotel' and slot == 'price': slot = 'pricerange' restored = restored.replace(t, ent[slot]) elif slot == 'price': if ent.get('pricerange'): restored = restored.replace( t, ent['pricerange']) else: print(restored, domain) except: print(resp) print(restored) quit() restored = restored.replace('[value_phone]', '62781111') restored = restored.replace('[value_postcode]', 'CG9566') restored = restored.replace('[value_address]', 'Parkside, Cambridge') # if '[value_' in restored: # print(domain) # # print(mat_ents) # print(resp) # print(restored) return restored def record_utterance(self, result_dict): decode_fn = self.vocab.sentence_decode ordered_dial = {} for dial_id, turns in result_dict.items(): diverse = 0 turn_count = 0 for turn_no, turn in enumerate(turns): act_collect = {} act_type_collect = {} slot_score = 0 for i in range(cfg.nbest): aspn = decode_fn(turn['multi_act'][i], eos=ontology.eos_tokens['aspn']) pred_acts = self.aspan_to_act_list(' '.join(aspn)) act_type = '' for act in pred_acts: d, a, s = act.split('-') if d + '-' + a not in act_collect: act_collect[d + '-' + a] = {s: 1} slot_score += 1 act_type += d + '-' + a + ';' elif s not in act_collect: act_collect[d + '-' + a][s] = 1 slot_score += 1 act_type_collect[act_type] = 1 turn_count += 1 diverse += len(act_collect) * 3 + slot_score ordered_dial[dial_id] = diverse / turn_count ordered_dial = sorted(ordered_dial.keys(), key=lambda x: -ordered_dial[x]) dialog_record = {} with open(cfg.eval_load_path + '/dialogue_record.csv', 'w') as rf: writer = csv.writer(rf) for dial_id in ordered_dial: dialog_record[dial_id] = [] turns = result_dict[dial_id] writer.writerow([dial_id]) for turn_no, turn in enumerate(turns): user = decode_fn(turn['user'], eos=ontology.eos_tokens['user']) bspn = decode_fn(turn['bspn'], eos=ontology.eos_tokens['bspn']) aspn = decode_fn(turn['aspn'], eos=ontology.eos_tokens['aspn']) resp = decode_fn(turn['resp'], eos=ontology.eos_tokens['resp']) constraint_dict = self.bspan_to_constraint_dict(bspn) # print(constraint_dict) mat_ents = self.db.get_match_num(constraint_dict, True) domain = [ i[1:-1] for i in self.dspan_to_domain(turn['dspn']).keys() ] restored = self.restore(resp, domain, constraint_dict, mat_ents) writer.writerow([ turn_no, user, turn['pointer'], domain, restored, resp ]) turn_record = { 'user': user, 'bspn': bspn, 'aspn': aspn, 'dom': domain, 'resp': resp, 'resp_res': restored } resp_col = [] aspn_col = [] resp_restore_col = [] for i in range(cfg.nbest): aspn = decode_fn(turn['multi_act'][i], eos=ontology.eos_tokens['aspn']) resp = decode_fn(turn['multi_resp'][i], eos=ontology.eos_tokens['resp']) restored = self.restore(resp, domain, constraint_dict, mat_ents) resp_col.append(resp) resp_restore_col.append(restored) aspn_col.append(aspn) zipped = list(zip(resp_restore_col, resp_col, aspn_col)) zipped.sort(key=lambda s: len(s[0])) resp_restore_col = list(list(zip(*zipped))[0]) aspn_col = list(list(zip(*zipped))[2]) resp_col = list(list(zip(*zipped))[1]) turn_record['aspn_col'] = aspn_col turn_record['resp_col'] = resp_col turn_record['resp_res_col'] = resp_restore_col for i in range(cfg.nbest): # aspn = decode_fn(turn['multi_act'][i], eos=ontology.eos_tokens['aspn']) resp = resp_col[i] aspn = aspn_col[i] resp_restore = resp_restore_col[i] writer.writerow(['', resp_restore, resp, aspn]) dialog_record[dial_id].append(turn_record)
class MultiWozReader(_ReaderBase): def __init__(self, vocab=None): super().__init__() self.nlp = spacy.load('en_core_web_sm') self.db = MultiWozDB(cfg.dbs) self.domain_files = json.loads(open(cfg.domain_file_path, 'r').read()) self.slot_value_set = json.loads( open(cfg.slot_value_set_path, 'r').read()) if cfg.multi_acts_training: self.multi_acts = json.loads(open(cfg.multi_acts_path, 'r').read()) test_list = [ l.strip().lower() for l in open(cfg.test_list, 'r').readlines() ] dev_list = [ l.strip().lower() for l in open(cfg.dev_list, 'r').readlines() ] self.dev_files, self.test_files = {}, {} for fn in test_list: self.test_files[fn.replace('.json', '')] = 1 for fn in dev_list: self.dev_files[fn.replace('.json', '')] = 1 self.exp_files = {} if 'all' not in cfg.exp_domains: for domain in cfg.exp_domains: fn_list = self.domain_files.get(domain) if not fn_list: raise ValueError('[%s] is an invalid experiment setting' % domain) for fn in fn_list: self.exp_files[fn.replace('.json', '')] = 1 if vocab: self.vocab = vocab self.vocab_size = vocab.size else: self.vocab_size = self._build_vocab() self._load_data() if cfg.limit_bspn_vocab: self.bspn_masks = self._construct_bspn_constraint() if cfg.limit_aspn_vocab: self.aspn_masks = self._construct_aspn_constraint() self.multi_acts_record = None def _build_vocab(self): self.vocab = utils.Vocab(cfg.vocab_size) vp = cfg.vocab_path_train if cfg.mode == 'train' or cfg.vocab_path_eval is None else cfg.vocab_path_eval # vp = cfg.vocab_path+'.json.freq.json' self.vocab.load_vocab(vp) return self.vocab.vocab_size def _construct_bspn_constraint(self): bspn_masks = {} valid_domains = [ 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital' ] all_dom_codes = [ self.vocab.encode('[' + d + ']') for d in valid_domains ] all_slot_codes = [self.vocab.encode(s) for s in ontology.all_slots] bspn_masks[self.vocab.encode( '<go_b>')] = all_dom_codes + [self.vocab.encode('<eos_b>'), 0] bspn_masks[self.vocab.encode('<eos_b>')] = [self.vocab.encode('<pad>')] bspn_masks[self.vocab.encode('<pad>')] = [self.vocab.encode('<pad>')] for domain, slot_values in self.slot_value_set.items(): if domain == 'police': continue dom_code = self.vocab.encode('[' + domain + ']') bspn_masks[dom_code] = [] for slot, values in slot_values.items(): slot_code = self.vocab.encode(slot) if slot_code not in bspn_masks: bspn_masks[slot_code] = [] if slot_code not in bspn_masks[dom_code]: bspn_masks[dom_code].append(slot_code) for value in values: for idx, v in enumerate(value.split()): if not self.vocab.has_word(v): continue v_code = self.vocab.encode(v) if v_code not in bspn_masks: # print(self.vocab._word2idx) bspn_masks[v_code] = [] if idx == 0 and v_code not in bspn_masks[slot_code]: bspn_masks[slot_code].append(v_code) if idx == (len(value.split()) - 1): for w in all_dom_codes + all_slot_codes: if self.vocab.encode( '<eos_b>') not in bspn_masks[v_code]: bspn_masks[v_code].append( self.vocab.encode('<eos_b>')) if w not in bspn_masks[v_code]: bspn_masks[v_code].append(w) break if not self.vocab.has_word(value.split()[idx + 1]): continue next_v_code = self.vocab.encode(value.split()[idx + 1]) if next_v_code not in bspn_masks[v_code]: bspn_masks[v_code].append(next_v_code) bspn_masks[self.vocab.encode('<unk>')] = list(bspn_masks.keys()) with open('data/multi-woz-processed/bspn_masks.txt', 'w') as f: for i, j in bspn_masks.items(): f.write( self.vocab.decode(i) + ': ' + ' '.join([self.vocab.decode(int(m)) for m in j]) + '\n') return bspn_masks def _construct_aspn_constraint(self): aspn_masks = {} aspn_masks = {} all_dom_codes = [ self.vocab.encode('[' + d + ']') for d in ontology.dialog_acts.keys() ] all_act_codes = [ self.vocab.encode('[' + a + ']') for a in ontology.dialog_act_params ] all_slot_codes = [ self.vocab.encode(s) for s in ontology.dialog_act_all_slots ] aspn_masks[self.vocab.encode( '<go_a>')] = all_dom_codes + [self.vocab.encode('<eos_a>'), 0] aspn_masks[self.vocab.encode('<eos_a>')] = [self.vocab.encode('<pad>')] aspn_masks[self.vocab.encode('<pad>')] = [self.vocab.encode('<pad>')] # for d in all_dom_codes: # aspn_masks[d] = all_act_codes for a in all_act_codes: aspn_masks[a] = all_dom_codes + all_slot_codes + [ self.vocab.encode('<eos_a>') ] for domain, acts in ontology.dialog_acts.items(): dom_code = self.vocab.encode('[' + domain + ']') aspn_masks[dom_code] = [] for a in acts: act_code = self.vocab.encode('[' + a + ']') if act_code not in aspn_masks[dom_code]: aspn_masks[dom_code].append(act_code) # for a, slots in ontology.dialog_act_params.items(): # act_code = self.vocab.encode('['+a+']') # slot_codes = [self.vocab.encode(s) for s in slots] # aspn_masks[act_code] = all_dom_codes + slot_codes + [self.vocab.encode('<eos_a>')] for s in all_slot_codes: aspn_masks[s] = all_dom_codes + all_slot_codes + [ self.vocab.encode('<eos_a>') ] aspn_masks[self.vocab.encode('<unk>')] = list(aspn_masks.keys()) with open('data/multi-woz-processed/aspn_masks.txt', 'w') as f: for i, j in aspn_masks.items(): f.write( self.vocab.decode(i) + ': ' + ' '.join([self.vocab.decode(int(m)) for m in j]) + '\n') return aspn_masks def _load_data(self, save_temp=False): self.data = json.loads( open(cfg.data_path + cfg.data_file, 'r', encoding='utf-8').read().lower()) self.train, self.dev, self.test = [], [], [] # data_fraction = 0.05 # train_count = 0 for fn, dial in self.data.items(): #print(fn) if 'all' in cfg.exp_domains or self.exp_files.get(fn): if self.dev_files.get(fn): self.dev.append(self._get_encoded_data(fn, dial)) elif self.test_files.get(fn): self.test.append(self._get_encoded_data(fn, dial)) else: # if train_count>round(data_fraction*8438): # continue self.train.append(self._get_encoded_data(fn, dial)) # train_count+=1 if save_temp: json.dump(self.test, open('data/multi-woz-analysis/test.encoded.json', 'w'), indent=2) self.vocab.save_vocab('data/multi-woz-analysis/vocab_temp') random.shuffle(self.train) random.shuffle(self.dev) random.shuffle(self.test) def _get_encoded_data(self, fn, dial): encoded_dial = [] for idx, t in enumerate(dial['log']): enc = {} enc['dial_id'] = fn enc['user'] = self.vocab.sentence_encode(t['user'].split() + ['<eos_u>']) enc['usdx'] = self.vocab.sentence_encode(t['user_delex'].split() + ['<eos_u>']) enc['resp'] = self.vocab.sentence_encode(t['resp'].split() + ['<eos_r>']) enc['bspn'] = self.vocab.sentence_encode(t['constraint'].split() + ['<eos_b>']) enc['bsdx'] = self.vocab.sentence_encode(t['cons_delex'].split() + ['<eos_b>']) enc['aspn'] = self.vocab.sentence_encode(t['sys_act'].split() + ['<eos_a>']) enc['dspn'] = self.vocab.sentence_encode(t['turn_domain'].split() + ['<eos_d>']) enc['pointer'] = [int(i) for i in t['pointer'].split(',')] enc['input_pointer'] = enc['pointer'] enc['turn_domain'] = t['turn_domain'].split() enc['turn_num'] = t['turn_num'] if cfg.multi_acts_training: enc['aspn_aug'] = [] if fn in self.multi_acts: turn_ma = self.multi_acts[fn].get(str(idx), {}) for act_type, act_spans in turn_ma.items(): enc['aspn_aug'].append([ self.vocab.sentence_encode(a.split() + ['<eos_a>']) for a in act_spans ]) encoded_dial.append(enc) return encoded_dial def bspan_to_constraint_dict(self, bspan, bspn_mode='bspn'): bspan = bspan.split() if isinstance(bspan, str) else bspan constraint_dict = {} domain = None conslen = len(bspan) for idx, cons in enumerate(bspan): cons = self.vocab.decode(cons) if type(cons) is not str else cons if cons == '<eos_b>': break if '[' in cons: if cons[1:-1] not in ontology.all_domains: continue domain = cons[1:-1] elif cons in ontology.get_slot: if domain is None: continue if cons == 'people': # handle confusion of value name "people's portraits..." and slot people try: ns = bspan[idx + 1] ns = self.vocab.decode(ns) if type( ns) is not str else ns if ns == "'s": continue except: continue if not constraint_dict.get(domain): constraint_dict[domain] = {} if bspn_mode == 'bsdx': constraint_dict[domain][cons] = 1 continue vidx = idx + 1 if vidx == conslen: break vt_collect = [] vt = bspan[vidx] vt = self.vocab.decode(vt) if type(vt) is not str else vt while vidx < conslen and vt != '<eos_b>' and '[' not in vt and vt not in ontology.get_slot: vt_collect.append(vt) vidx += 1 if vidx == conslen: break vt = bspan[vidx] vt = self.vocab.decode(vt) if type(vt) is not str else vt if vt_collect: constraint_dict[domain][cons] = ' '.join(vt_collect) return constraint_dict def bspan_to_DBpointer(self, bspan, turn_domain): constraint_dict = self.bspan_to_constraint_dict(bspan) # print(constraint_dict) matnums = self.db.get_match_num(constraint_dict) match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1] match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom match = matnums[match_dom] vector = self.db.addDBPointer(match_dom, match) return vector def aspan_to_act_list(self, aspan): aspan = aspan.split() if isinstance(aspan, str) else aspan acts = [] domain = None conslen = len(aspan) for idx, cons in enumerate(aspan): cons = self.vocab.decode(cons) if type(cons) is not str else cons if cons == '<eos_a>': break if '[' in cons and cons[1:-1] in ontology.dialog_acts: domain = cons[1:-1] elif '[' in cons and cons[1:-1] in ontology.dialog_act_params: if domain is None: continue vidx = idx + 1 if vidx == conslen: acts.append(domain + '-' + cons[1:-1] + '-none') break vt = aspan[vidx] vt = self.vocab.decode(vt) if type(vt) is not str else vt no_param_act = True while vidx < conslen and vt != '<eos_a>' and '[' not in vt: no_param_act = False acts.append(domain + '-' + cons[1:-1] + '-' + vt) vidx += 1 if vidx == conslen: break vt = aspan[vidx] vt = self.vocab.decode(vt) if type(vt) is not str else vt if no_param_act: acts.append(domain + '-' + cons[1:-1] + '-none') return acts def dspan_to_domain(self, dspan): domains = {} dspan = dspan.split() if isinstance(dspan, str) else dspan for d in dspan: dom = self.vocab.decode(d) if type(d) is not str else d if dom != '<eos_d>': domains[dom] = 1 else: break return domains def convert_batch(self, py_batch, py_prev, first_turn=False): inputs = {} if first_turn: for item, py_list in py_prev.items(): batch_size = len(py_batch['user']) inputs[item + '_np'] = np.array([[1]] * batch_size) inputs[item + '_unk_np'] = np.array([[1]] * batch_size) else: for item, py_list in py_prev.items(): if py_list is None: continue if not cfg.enable_aspn and 'aspn' in item: continue if not cfg.enable_bspn and 'bspn' in item: continue if not cfg.enable_dspn and 'dspn' in item: continue prev_np = utils.padSeqs(py_list, truncated=cfg.truncated, trunc_method='pre') inputs[item + '_np'] = prev_np if item in ['pv_resp', 'pv_bspn']: inputs[item + '_unk_np'] = deepcopy(inputs[item + '_np']) inputs[item + '_unk_np'][inputs[item + '_unk_np'] >= self.vocab_size] = 2 # <unk> else: inputs[item + '_unk_np'] = inputs[item + '_np'] for item in ['user', 'usdx', 'resp', 'bspn', 'aspn', 'bsdx', 'dspn']: if not cfg.enable_aspn and item == 'aspn': continue if not cfg.enable_bspn and item == 'bspn': continue if not cfg.enable_dspn and item == 'dspn': continue py_list = py_batch[item] trunc_method = 'post' if item == 'resp' else 'pre' # max_length = cfg.max_nl_length if item in ['user', 'usdx', 'resp'] else cfg.max_span_length inputs[item + '_np'] = utils.padSeqs(py_list, truncated=cfg.truncated, trunc_method=trunc_method) if item in ['user', 'usdx', 'resp', 'bspn']: inputs[item + '_unk_np'] = deepcopy(inputs[item + '_np']) inputs[item + '_unk_np'][ inputs[item + '_unk_np'] >= self.vocab_size] = 2 # <unk> else: inputs[item + '_unk_np'] = inputs[item + '_np'] if cfg.multi_acts_training and cfg.mode == 'train': inputs['aspn_bidx'], multi_aspn = [], [] for bidx, aspn_type_list in enumerate(py_batch['aspn_aug']): if aspn_type_list: for aspn_list in aspn_type_list: random.shuffle(aspn_list) aspn = aspn_list[ 0] #choose one random act span in each act type multi_aspn.append(aspn) inputs['aspn_bidx'].append(bidx) if cfg.multi_act_sampling_num > 1: for i in range(cfg.multi_act_sampling_num): if len(aspn_list) >= i + 2: aspn = aspn_list[ i + 1] #choose one random act span in each act type multi_aspn.append(aspn) inputs['aspn_bidx'].append(bidx) if multi_aspn: inputs['aspn_aug_np'] = utils.padSeqs(multi_aspn, truncated=cfg.truncated, trunc_method='pre') inputs['aspn_aug_unk_np'] = inputs[ 'aspn_aug_np'] # [all available aspn num in the batch, T] inputs['db_np'] = np.array(py_batch['input_pointer']) inputs['turn_domain'] = py_batch['turn_domain'] return inputs def wrap_result(self, result_dict, eos_syntax=None): decode_fn = self.vocab.sentence_decode results = [] eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax if cfg.bspn_mode == 'bspn': field = [ 'dial_id', 'turn_num', 'user', 'bspn_gen', 'bspn', 'resp_gen', 'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'pointer' ] elif not cfg.enable_dst: field = [ 'dial_id', 'turn_num', 'user', 'bsdx_gen', 'bsdx', 'resp_gen', 'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn', 'pointer' ] else: field = [ 'dial_id', 'turn_num', 'user', 'bsdx_gen', 'bsdx', 'resp_gen', 'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn_gen', 'bspn', 'pointer' ] if self.multi_acts_record is not None: field.insert(7, 'multi_act_gen') for dial_id, turns in result_dict.items(): entry = {'dial_id': dial_id, 'turn_num': len(turns)} for prop in field[2:]: entry[prop] = '' results.append(entry) for turn_no, turn in enumerate(turns): entry = {'dial_id': dial_id} for key in field: if key in ['dial_id']: continue v = turn.get(key, '') if key == 'turn_domain': v = ' '.join(v) entry[key] = decode_fn( v, eos=eos_syntax[key] ) if key in eos_syntax and v != '' else v results.append(entry) return results, field def restore(self, resp, domain, constraint_dict): restored = resp restored = restored.capitalize() restored = restored.replace(' -s', 's') restored = restored.replace(' -ly', 'ly') restored = restored.replace(' -er', 'er') mat_ents = self.db.get_match_num(constraint_dict, True) ref = random.choice(self.delex_refs) restored = restored.replace('[value_reference]', ref.upper()) restored = restored.replace('[value_car]', 'BMW') # restored.replace('[value_phone]', '830-430-6666') for d in domain: constraint = constraint_dict.get(d, None) if constraint: if 'stay' in constraint: restored = restored.replace('[value_stay]', constraint['stay']) if 'day' in constraint: restored = restored.replace('[value_day]', constraint['day']) if 'people' in constraint: restored = restored.replace('[value_people]', constraint['people']) if 'time' in constraint: restored = restored.replace('[value_time]', constraint['time']) if 'type' in constraint: restored = restored.replace('[value_type]', constraint['type']) if d in mat_ents and len(mat_ents[d]) == 0: for s in constraint: if s == 'pricerange' and d in [ 'hotel', 'restaurant' ] and 'price]' in restored: restored = restored.replace( '[value_price]', constraint['pricerange']) if s + ']' in restored: restored = restored.replace( '[value_%s]' % s, constraint[s]) if '[value_choice' in restored and mat_ents.get(d): restored = restored.replace('[value_choice]', str(len(mat_ents[d]))) if '[value_choice' in restored: restored = restored.replace('[value_choice]', str(random.choice([1, 2, 3, 4, 5]))) stopwords = [ "i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", "yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its", "itself", "they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that", "these", "those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having", "do", "does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while", "of", "at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before", "after", "above", "below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", "further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", "each", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s", "t", "can", "will", "just", "don", "should", "now" ] # restored.replace('[value_car]', 'BMW') ent = mat_ents.get(domain[-1], []) if ent: # handle multiple [value_xxx] tokens first restored_split = restored.split() token_count = Counter(restored_split) for idx, t in enumerate(restored_split): if '[value' in t and token_count[t] > 1 and token_count[ t] <= len(ent): slot = t[7:-1] pattern = r'\[' + t[1:-1] + r'\]' for e in ent: if e.get(slot): if domain[-1] == 'hotel' and slot == 'price': slot = 'pricerange' if slot in ['name', 'address']: rep = ' '.join([ i.capitalize() if i not in stopwords else i for i in e[slot].split() ]) elif slot in ['id', 'postcode']: rep = e[slot].upper() else: rep = e[slot] restored = re.sub(pattern, rep, restored, 1) elif slot == 'price' and e.get('pricerange'): restored = re.sub(pattern, e['pricerange'], restored, 1) # handle normal 1 entity case ent = ent[0] for t in restored.split(): if '[value' in t: slot = t[7:-1] if ent.get(slot): if domain[-1] == 'hotel' and slot == 'price': slot = 'pricerange' if slot in ['name', 'address']: rep = ' '.join([ i.capitalize() if i not in stopwords else i for i in ent[slot].split() ]) elif slot in ['id', 'postcode']: rep = ent[slot].upper() else: rep = ent[slot] # rep = ent[slot] restored = restored.replace(t, rep) # restored = restored.replace(t, ent[slot]) elif slot == 'price' and ent.get('pricerange'): restored = restored.replace(t, ent['pricerange']) # else: # print(restored, domain) for t in restored.split(): if '[value' in t: restored = restored.replace(t, 'UNKNOWN') restored = restored.split() for idx, w in enumerate(restored): if idx > 0 and restored[idx - 1] in ['.', '?', '!']: restored[idx] = restored[idx].capitalize() restored = ' '.join(restored) return restored # def restore(self, resp, domain, constraint_dict, mat_ents): # restored = resp # restored = restored.replace('[value_reference]', '53022') # restored = restored.replace('[value_car]', 'BMW') # # restored.replace('[value_phone]', '830-430-6666') # for d in domain: # constraint = constraint_dict.get(d,None) # if constraint: # if 'stay' in constraint: # restored = restored.replace('[value_stay]', constraint['stay']) # if 'day' in constraint: # restored = restored.replace('[value_day]', constraint['day']) # if 'people' in constraint: # restored = restored.replace('[value_people]', constraint['people']) # if 'time' in constraint: # restored = restored.replace('[value_time]', constraint['time']) # if 'type' in constraint: # restored = restored.replace('[value_type]', constraint['type']) # if d in mat_ents and len(mat_ents[d])==0: # for s in constraint: # if s == 'pricerange' and d in ['hotel', 'restaurant'] and 'price]' in restored: # restored = restored.replace('[value_price]', constraint['pricerange']) # if s+']' in restored: # restored = restored.replace('[value_%s]'%s, constraint[s]) # if '[value_choice' in restored and mat_ents.get(d): # restored = restored.replace('[value_choice]', str(len(mat_ents[d]))) # if '[value_choice' in restored: # restored = restored.replace('[value_choice]', '3') # # restored.replace('[value_car]', 'BMW') # try: # ent = mat_ents.get(domain[-1], []) # if ent: # ent = ent[0] # for t in restored.split(): # if '[value' in t: # slot = t[7:-1] # if ent.get(slot): # if domain[-1] == 'hotel' and slot == 'price': # slot = 'pricerange' # restored = restored.replace(t, ent[slot]) # elif slot == 'price': # if ent.get('pricerange'): # restored = restored.replace(t, ent['pricerange']) # else: # print(restored, domain) # except: # print(resp) # print(restored) # quit() # restored = restored.replace('[value_phone]', '62781111') # restored = restored.replace('[value_postcode]', 'CG9566') # restored = restored.replace('[value_address]', 'Parkside, Cambridge') # return restored def record_utterance(self, result_dict): decode_fn = self.vocab.sentence_decode ordered_dial = {} for dial_id, turns in result_dict.items(): diverse = 0 turn_count = 0 for turn_no, turn in enumerate(turns): act_collect = {} act_type_collect = {} slot_score = 0 for i in range(cfg.nbest): aspn = decode_fn(turn['multi_act'][i], eos=ontology.eos_tokens['aspn']) pred_acts = self.aspan_to_act_list(' '.join(aspn)) act_type = '' for act in pred_acts: d, a, s = act.split('-') if d + '-' + a not in act_collect: act_collect[d + '-' + a] = {s: 1} slot_score += 1 act_type += d + '-' + a + ';' elif s not in act_collect: act_collect[d + '-' + a][s] = 1 slot_score += 1 act_type_collect[act_type] = 1 turn_count += 1 diverse += len(act_collect) * 3 + slot_score ordered_dial[dial_id] = diverse / turn_count ordered_dial = sorted(ordered_dial.keys(), key=lambda x: -ordered_dial[x]) dialog_record = {} with open(cfg.eval_load_path + '/dialogue_record.csv', 'w') as rf: writer = csv.writer(rf) for dial_id in ordered_dial: dialog_record[dial_id] = [] turns = result_dict[dial_id] writer.writerow([dial_id]) for turn_no, turn in enumerate(turns): user = decode_fn(turn['user'], eos=ontology.eos_tokens['user']) bspn = decode_fn(turn['bspn'], eos=ontology.eos_tokens['bspn']) aspn = decode_fn(turn['aspn'], eos=ontology.eos_tokens['aspn']) resp = decode_fn(turn['resp'], eos=ontology.eos_tokens['resp']) constraint_dict = self.bspan_to_constraint_dict(bspn) # print(constraint_dict) mat_ents = self.db.get_match_num(constraint_dict, True) domain = [ i[1:-1] for i in self.dspan_to_domain(turn['dspn']).keys() ] restored = self.restore(resp, domain, constraint_dict, mat_ents) writer.writerow([ turn_no, user, turn['pointer'], domain, restored, resp ]) turn_record = { 'user': user, 'bspn': bspn, 'aspn': aspn, 'dom': domain, 'resp': resp, 'resp_res': restored } resp_col = [] aspn_col = [] resp_restore_col = [] for i in range(cfg.nbest): aspn = decode_fn(turn['multi_act'][i], eos=ontology.eos_tokens['aspn']) resp = decode_fn(turn['multi_resp'][i], eos=ontology.eos_tokens['resp']) restored = self.restore(resp, domain, constraint_dict, mat_ents) resp_col.append(resp) resp_restore_col.append(restored) aspn_col.append(aspn) zipped = list(zip(resp_restore_col, resp_col, aspn_col)) zipped.sort(key=lambda s: len(s[0])) resp_restore_col = list(list(zip(*zipped))[0]) aspn_col = list(list(zip(*zipped))[2]) resp_col = list(list(zip(*zipped))[1]) turn_record['aspn_col'] = aspn_col turn_record['resp_col'] = resp_col turn_record['resp_res_col'] = resp_restore_col for i in range(cfg.nbest): # aspn = decode_fn(turn['multi_act'][i], eos=ontology.eos_tokens['aspn']) resp = resp_col[i] aspn = aspn_col[i] resp_restore = resp_restore_col[i] writer.writerow(['', resp_restore, resp, aspn]) dialog_record[dial_id].append(turn_record)