def load_data(except_domain, keep_pct=0): filename = 'data/train_dials.json' filtered_domains = [] in_domain_data = [] with open(filename) as fp: data = json.load(fp) for dialogue in data: dialogue['dialogue'].sort(key=lambda x: int(x['turn_idx'])) all_domains = set(dialogue['domains']) # add sometimes missing domains to annotation for turn in dialogue['dialogue']: turn_belief_dict = fix_general_label_error( turn["belief_state"], False, ALL_SLOTS) for slot_key, slot_value in turn_belief_dict.items(): if slot_value == 'none': continue domain, slot_name = slot_key.split('-', maxsplit=1) all_domains.add(domain) dialogue['domains'] = list(all_domains) dialogue['domains'].sort() is_good_domain = True is_in_domain = False for domain in dialogue['domains']: if domain not in EXPERIMENT_DOMAINS: is_good_domain = False if domain == except_domain: is_in_domain = True if not is_good_domain: continue if is_in_domain: in_domain_data.append(dialogue) else: filtered_domains.append(dialogue) if keep_pct > 0: random.shuffle(in_domain_data) to_keep = int(keep_pct * len(in_domain_data)) return filtered_domains + in_domain_data[:to_keep] return filtered_domains
def main(): data = json.load(sys.stdin) new_data = [] for dialogue in data: all_domains = set(dialogue['domains']) # add sometimes missing domains to annotation for turn in dialogue['dialogue']: turn_belief_dict = fix_general_label_error(turn["belief_state"], False, ALL_SLOTS) for slot_key, slot_value in turn_belief_dict: if slot_value == 'none': continue domain, slot_name = slot_key.split('-', maxsplit=1) all_domains.add(domain) if sys.argv[1] in all_domains: continue new_data.append(dialogue) print(len(data), len(new_data), file=sys.stderr) json.dump(new_data, sys.stdout, indent=2)
def compute_templates(data): templates = Counter() distinct = set() for dialogue in data: dialogue['dialogue'].sort(key=lambda x: x['turn_idx']) raw_history = [] history = [] label_dict = dict() for turn in dialogue['dialogue']: turn_idx = turn['turn_idx'] label_dict = fix_general_label_error(turn['belief_state'], False, ALL_SLOTS) if turn_idx > 0: raw_history.append(turn['system_transcript']) system_transcript = replace_with_system_act( turn['system_transcript'], turn['system_acts']) history.append(system_transcript) raw_history.append(turn['transcript']) transcript = replace_with_slots(turn['transcript'], label_dict) history.append(transcript) if len(label_dict) > 0: break raw_history = ' <sep> '.join(raw_history) history = ' <sep> '.join(history) templates[(history, belief_state_to_string(label_dict))] += 1 distinct.add(raw_history) print(len(distinct), file=sys.stderr) return templates
def get_full_belief(turn): dict_of_slots = fix_general_label_error(turn['belief_state'], False, ALL_SLOTS) return ['-'.join((el, dict_of_slots[el])) for el in dict_of_slots], list(dict_of_slots.keys())
def read_langs(file_name, gating_dict, SLOTS, dataset, lang, mem_lang, sequicity, training, max_line=None): print(("Reading from {}".format(file_name))) data = [] max_resp_len, max_value_len = 0, 0 domain_counter = {} with open(file_name) as f: dials = json.load(f) # create vocab first for dial_dict in dials: if (args.all_vocab or dataset == "train") and training: for ti, turn in enumerate(dial_dict["dialogue"]): lang.index_words(turn["system_transcript"], 'utter') lang.index_words(turn["transcript"], 'utter') # determine training data ratio, default is 100% if training and dataset == "train" and args.data_ratio != 100: random.Random(10).shuffle(dials) dials = dials[:int(len(dials) * 0.01 * args.data_ratio)] cnt_lin = 1 for dial_dict in dials: dialog_history = "" last_belief_dict = {} # Filtering and counting domains for domain in dial_dict["domains"]: if domain not in EXPERIMENT_DOMAINS: continue if domain not in domain_counter.keys(): domain_counter[domain] = 0 domain_counter[domain] += 1 # Unseen domain setting if args.only_domain != "" and args.only_domain not in dial_dict["domains"]: continue if args.except_domain != "": if (dataset == "test" and args.except_domain not in dial_dict["domains"]) or \ (dataset != "test" and [args.except_domain] == dial_dict["domains"]): continue # Reading data for ti, turn in enumerate(dial_dict["dialogue"]): turn_domain = turn["domain"] turn_id = turn["turn_idx"] turn_uttr = turn["system_transcript"] + " ; " + turn["transcript"] turn_uttr_strip = turn_uttr.strip() dialog_history += (turn["system_transcript"] + " ; " + turn["transcript"] + " ; ") source_text = dialog_history.strip() turn_belief_dict = fix_general_label_error(turn["belief_state"], False, SLOTS) # Generate domain-dependent slot list slot_temp = SLOTS if dataset == "train" or dataset == "dev": if args.except_domain != "": slot_temp = [k for k in SLOTS if args.except_domain not in k] turn_belief_dict = OrderedDict( [(k, v) for k, v in turn_belief_dict.items() if args.except_domain not in k]) elif args.only_domain != "": slot_temp = [k for k in SLOTS if args.only_domain in k] turn_belief_dict = OrderedDict( [(k, v) for k, v in turn_belief_dict.items() if args.only_domain in k]) else: if args.except_domain != "": slot_temp = [k for k in SLOTS if args.except_domain in k] turn_belief_dict = OrderedDict( [(k, v) for k, v in turn_belief_dict.items() if args.except_domain in k]) elif args.only_domain != "": slot_temp = [k for k in SLOTS if args.only_domain in k] turn_belief_dict = OrderedDict( [(k, v) for k, v in turn_belief_dict.items() if args.only_domain in k]) turn_belief_list = [str(k) + '-' + str(v) for k, v in turn_belief_dict.items()] if (args.all_vocab or dataset == "train") and training: mem_lang.index_words(turn_belief_dict, 'belief') class_label, generate_y, slot_mask, gating_label = [], [], [], [] start_ptr_label, end_ptr_label = [], [] for slot in slot_temp: if slot in turn_belief_dict.keys(): generate_y.append(turn_belief_dict[slot]) if turn_belief_dict[slot] == "dontcare": gating_label.append(gating_dict["dontcare"]) elif turn_belief_dict[slot] == "none": gating_label.append(gating_dict["none"]) else: gating_label.append(gating_dict["ptr"]) if max_value_len < len(turn_belief_dict[slot]): max_value_len = len(turn_belief_dict[slot]) else: generate_y.append("none") gating_label.append(gating_dict["none"]) data_detail = { "ID": dial_dict["dialogue_idx"], "domains": dial_dict["domains"], "turn_domain": turn_domain, "turn_id": turn_id, "dialog_history": source_text, "turn_belief": turn_belief_list, "gating_label": gating_label, "turn_uttr": turn_uttr_strip, 'generate_y': generate_y } data.append(data_detail) if max_resp_len < len(source_text.split()): max_resp_len = len(source_text.split()) cnt_lin += 1 if max_line and cnt_lin >= max_line: break # add t{} to the lang file if "t{}".format(max_value_len - 1) not in mem_lang.word2index.keys() and training: for time_i in range(max_value_len): mem_lang.index_words("t{}".format(time_i), 'utter') print("domain_counter", domain_counter) return data, max_resp_len, slot_temp
def get_data(file_data): print(("Reading from {}".format(file_data))) with open(file_data) as f: dials = json.load(f) # len(dials) = 8420 data = [] # data是全部dialogue的全部turn,逐turn封装的训练数据。 max_value_len, max_input_len, max_value, max_input = 0, 0, '', '' dials = tqdm(enumerate(dials), total=len(dials)) # 13746 for i, dial_dict in dials: #for dial_dict in dials: ''' for domain in dial_dict["domains"]: # 放在外循环内! if domain not in EXPERIMENT_DOMAINS: continue ''' if not set(dial_dict["domains"]) < set(EXPERIMENT_DOMAINS): continue #dialog_history = '' previous_utterances = A_token + " " + U_token previous_generate_y = [NULL_token] * len(ALL_SLOTS) previous_belief = _convert_bs({}).strip() previous_gating_label = [gating_dict["carryover"]] * len(ALL_SLOTS) # 整个对话的开始之前的state B_0,has only NULL as the value of all slots # 所以B_1对应的gate是Carryover #previous_generate_y = ['' for s in ALL_SLOTS] for ti, turn in enumerate(dial_dict["dialogue"]): # 1 基本 turn_domain = turn["domain"] turn_id = turn["turn_idx"] turn_belief_dict = fix_general_label_error( turn["belief_state"], False, ALL_SLOTS) # 修正错误,取value # Generate domain-dependent slot list(target) if turn_belief_dict: turn_belief_list = [ str(k) + '-' + fix(v) for k, v in turn_belief_dict.items() if v != "none" ] turn_belief_dict = { k: fix(v) for k, v in turn_belief_dict.items() } else: turn_belief_list = [] # 2 生成自定义label # 2.1 生成generate_y:按ALL_SLOTS排列的value列表。 if turn_belief_dict: generate_y = [] for slot in ALL_SLOTS: if slot in turn_belief_dict.keys(): value = turn_belief_dict[slot] value = fix(value) value_seq = tokenizer.tokenize(turn_belief_dict[slot]) if max_value_len < len(value_seq): max_value_len = len(value_seq) max_value = value_seq else: value = NULL_token generate_y.append(value) #print('generate_y:\n',generate_y) else: generate_y = [NULL_token] * len(ALL_SLOTS) # 2.2 用相邻轮的generate_y生成gate_label:按ALL_SLOTS排列的operation列表 gating_label = [] for sidx in range(len(ALL_SLOTS)): p_v = previous_generate_y[sidx] n_v = generate_y[sidx] if n_v == p_v: operation = gating_dict["carryover"] # yes/no/dontcare elif n_v in list(sub_gating_dict.keys()): operation = gating_dict["confirm"] else: operation = gating_dict["update"] gating_label.append(operation) #print(gating_label) # 2.3 生成input_seq & previous_belief: #print('previous_utterances:\n', previous_utterances) #print('current_utterances:\n', current_utterances) # input = CLS_token + previous_utterances + SEP_token + current_utterances + previous_belief #tokenizer or split:用tokenizer ''' 三个seq: story:包括D_t-1和D_t,以及两个SEP previous_belief:包括B_t-1 input_seq: story和previous_belief拼接,同时前加一个CLS,尾加一个SEP ''' # 2.3.1 story current_utterances = A_token + " " + turn[ "system_transcript"] + " " + U_token + " " + turn["transcript"] current_utterances = current_utterances.strip() story = previous_utterances + " " + SEP_token + " " + current_utterances + " " + SEP_token #story = tokenizer.tokenize(previous_utterances) + [SEP_token] + tokenizer.tokenize(current_utterances) + [SEP_token] # 2.3.2 previous_belief(应该是30个value都放入input吧,包括[NULL] value) current_belief = _convert_bs(turn_belief_dict).strip() # 2.3.3 input_seq input_seq = [CLS_token ] + tokenizer.tokenize(story + " " + previous_belief) + [EOS_token] if max_input_len < len(input_seq): max_input_len = len(input_seq) max_input = input_seq # 2.3.4 三元组 #SEP_indices = tuple([i for i, x in enumerate(input_seq) if x == SEP_token]) #assert len(SEP_indices) == 2 # 2.4 生成domain_focus:0表示领域完全地不关注;1表示领域中存在至少一个slot是非carryover的 # domain_focus domain_focus = get_domain_focus(generate_y, gating_label) if use2turn: pre_domain_focus = get_domain_focus(previous_generate_y, previous_gating_label) domain_focus = list( map(lambda x: (x[0] | x[1]), zip(domain_focus, pre_domain_focus))) # 3. 用一个容器装起来: data_detail = { "ID": dial_dict["dialogue_idx"], "domains": dial_dict["domains"], # 这个turn所属的对话共涉及哪些领域 "turn_domain": turn_domain, "turn_id": turn_id, "turn_belief": turn_belief_list, 'previous_generate_y': previous_generate_y, 'generate_y': generate_y, "gating_label": gating_label, "previous_utterances": previous_utterances, "current_utterances": current_utterances, 'domain_focus': domain_focus, #'story':story, # D_{t-1}+Dt. 仅当前轮和上一轮的对话原文 (str) #'previous_belief':previous_belief, # B_t (str) 'input_seq': input_seq, # D_{t-1}+Dt+Bt, 已经分词 (list) #'SEP_indices':SEP_indices, # Dt-1,Dt,Bt-1的结束位置 (tuple) } data.append(data_detail) previous_utterances, previous_generate_y, previous_belief, previous_gating_label = current_utterances, generate_y, current_belief, gating_label #previous_turn_domain = turn_domain] #print('max_value:\n', max_value) # london liverpool street=23(char级别),3(tk级别) #print(max_value_len) #break return data, (max_input_len, max_input), (max_value_len, max_value)
def transfer_data(original_data, from_domain, to_domain): new_data = [] transfer_phrases_from = TRANSFER_PHRASES[from_domain] transfer_phrases_to = TRANSFER_PHRASES[to_domain] assert len(transfer_phrases_from) > 0 assert len(transfer_phrases_to) > 0 for dialogue in original_data: new_data.append(dialogue) if not from_domain in dialogue['domains']: continue new_dialogue = copy.deepcopy(dialogue) transfer_replace_bag = ReplaceBag() for phrase in transfer_phrases_from: transfer_replace_bag.add( phrase.split(' '), random.choice(transfer_phrases_to).split(' ')) good_dialogue = True for turn in new_dialogue['dialogue']: turn_idx = int(turn['turn_idx']) if turn_idx > 0: turn['original_system_transcript'] = turn['system_transcript'] turn['system_transcript'] = ' '.join( apply_replacement(turn['system_transcript'], transfer_replace_bag)) turn['original_transcript'] = turn['transcript'] turn['transcript'] = ' '.join( apply_replacement(turn['transcript'], transfer_replace_bag)) found_transfer_phrase = transfer_replace_bag.used > 0 label_dict = fix_general_label_error(turn['belief_state'], False, ALL_SLOTS) label_dict = remove_none_slots(label_dict) found_transfer_slot = False found_bad_slot = False new_label_dict = dict() for slot_key, slot_value in label_dict.items(): domain, slot_name = slot_key.split('-', maxsplit=1) # we have removed all "to_domain" data so if we see this it's a mislabel, ignore as bad if domain == to_domain: good_dialogue = False break if domain != from_domain: new_label_dict[slot_key] = slot_value continue found_transfer_slot = True new_slot_key = to_domain + '-' + slot_name if new_slot_key not in ALL_SLOTS: found_bad_slot = True break new_label_dict[new_slot_key] = slot_value turn['belief_state'] = belief_to_json(new_label_dict) turn['turn_label'] = [ (slot_key, slot_value) for slot_key, slot_value in new_label_dict.items() ] turn['domain'] = to_domain if turn[ 'domain'] == from_domain else turn['domain'] if found_bad_slot or (found_transfer_slot and not found_transfer_phrase): good_dialogue = False break if good_dialogue: new_dialogue['dialogue_idx'] = new_dialogue[ 'dialogue_idx'] + '/' + from_domain + '->' + to_domain new_dialogue['domains'] = [ x for x in new_dialogue['domains'] if x != from_domain ] + [to_domain] # replace the values in the new dialogue with values that make sense for the domain augmenter = Augmenter(only_domain=to_domain) augmenter.augment(new_dialogue['dialogue']) new_data.append(new_dialogue) #print(json.dumps(dialogue['dialogue'], indent=2)) #sys.exit(0) return new_data
dials_v2 = [] for dial_dict in dials: dial_domains = dial_dict["domains"] prev_turn_state = {} for slot in slot_meta: prev_turn_state[slot] = "none" for ti, turn in enumerate(dial_dict["dialogue"]): dial_dict["dialogue"][ti]["system_transcript"] = normalize_text( turn["system_transcript"]) dial_dict["dialogue"][ti]["transcript"] = normalize_text( turn["transcript"]) # state turn_dialog_state = fix_general_label_error( turn["belief_state"], False, slot_meta) for slot in slot_meta: if slot not in turn_dialog_state or slot.split( '-')[0] not in dial_domains: turn_dialog_state[slot] = "none" else: turn_dialog_state[slot] = normalize_label( slot, turn_dialog_state[slot]) if turn_dialog_state[slot] == "dontcare": turn_dialog_state[slot] = "do not care" ontology_modified[slot].append(turn_dialog_state[slot]) dial_dict["dialogue"][ti]["belief_state"] = []