class SrcTgtCorpus(Corpus): def __init__(self, data_dir, data_prefix, min_freq=0, max_vocab_size=None, min_len=0, max_len=100, embed_file=None, share_vocab=False): super(SrcTgtCorpus, self).__init__(data_dir=data_dir, data_prefix=data_prefix, min_freq=min_freq, max_vocab_size=max_vocab_size) self.min_len = min_len self.max_len = max_len self.share_vocab = share_vocab self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file) if self.share_vocab: self.TGT = self.SRC else: self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file) self.fields = {'src': self.SRC, 'tgt': self.TGT} def src_filter_pred(src): return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len def tgt_filter_pred(tgt): return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len self.filter_pred = lambda ex: src_filter_pred(ex[ 'src']) and tgt_filter_pred(ex['tgt']) def read_data(self, data_file, data_type="train"): data = [] filtered = 0 with open(data_file, "r", encoding="utf-8") as f: for line in f: src, tgt = line.strip().split('\t')[:2] data.append({'src': src, 'tgt': tgt}) filtered_num = len(data) if self.filter_pred is not None: data = [ex for ex in data if self.filter_pred(ex)] filtered_num -= len(data) print("Read {} {} examples ({} filtered)".format( len(data), data_type.upper(), filtered_num)) return data
class SrcTgtCorpus(Corpus): """ SrcTgtCorpus """ def __init__(self, config): super(SrcTgtCorpus, self).__init__(config) self.min_len = config.min_len self.max_len = config.max_len self.embed_file = config.embed_file self.share_vocab = config.share_vocab self.SRC = TextField(tokenize_fn=tokenize, embed_file=self.embed_file) if self.share_vocab: self.TGT = self.SRC else: self.TGT = TextField(tokenize_fn=tokenize, embed_file=self.embed_file) self.OUTPUT = self.TGT self.fields = {'src': self.SRC, 'tgt': self.TGT, 'output': self.OUTPUT} def src_filter_pred(src): return self.min_len <= len( self.SRC.tokenize_fn(src)) <= self.max_len def tgt_filter_pred(tgt): return self.min_len <= len( self.TGT.tokenize_fn(tgt)) <= self.max_len # self.filter_pred = lambda ex: src_filter_pred(ex['src']) and tgt_filter_pred(ex['tgt']) self.filter_pred = None def read_data(self, data_file, data_type="train"): data = [] filtered = 0 with open(data_file, "r", encoding="utf-8") as f: for line in f: src = json.loads(line.strip())['src'] tgt = json.loads(line.strip())['tgt'] data.append({'src': src, 'tgt': tgt}) filtered_num = len(data) if self.filter_pred is not None: data = [ex for ex in data if self.filter_pred(ex)] filtered_num -= len(data) print("Read {} {} examples ({} filtered)".format( len(data), data_type.upper(), filtered_num)) return data
class TopicGuide2Corpus(Corpus): """ CueCorpus """ def __init__(self, config): super(TopicGuide2Corpus, self).__init__(config) self.min_len = config.min_len self.max_len = config.max_len self.share_vocab = config.share_vocab self.embed_file = config.embed_file self.topic_words_num = config.topic_words_num self.topic_vocab_file = config.topic_vocab_file self.SRC = TextField(tokenize_fn=tokenize, embed_file=self.embed_file) self.TGT = TextField(tokenize_fn=tokenize, embed_file=self.embed_file) self.LABEL = NumberField(dtype=int) self.OUTPUT = TextField(tokenize_fn=tokenize, embed_file=self.embed_file) self.TOPIC = FixedField(fix_file=self.topic_vocab_file, embed_file=self.embed_file) self.fields = { 'src': self.SRC, 'tgt': self.TGT, 'output': self.OUTPUT, 'topic_src_label': self.LABEL, 'topic_tgt_label': self.LABEL, 'topic': self.TOPIC, } def src_filter_pred(src): return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len def tgt_filter_pred(tgt): return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len # self.filter_pred = lambda ex: src_filter_pred(ex['src']) and tgt_filter_pred(ex['tgt']) self.filter_pred = None def read_data(self, data_file, data_type="train"): """ read_data """ data = [] with open(data_file, "r", encoding="utf-8") as f: for line in f: line = json.loads(line.strip()) src = line['src'] tgt = line['tgt'] topic_src_label = line['topic_src_label'] topic_tgt_label = line['topic_tgt_label'] data.append({ 'src': src, 'tgt': tgt, 'output': tgt, 'topic_src_label': topic_src_label, 'topic_tgt_label': topic_tgt_label, 'topic': src, }) filtered_num = len(data) if self.filter_pred is not None: data = [ex for ex in data if self.filter_pred(ex)] filtered_num -= len(data) print("Read {} {} examples ({} filtered)".format( len(data), data_type.upper(), filtered_num)) return data def build_vocab(self, data): """ Args ---- data: ``List[Dict]`` """ field_data_dict = {} for name in data[0].keys(): field = self.fields.get(name) if isinstance(field, TextField): xs = [x[name] for x in data] if field not in field_data_dict: field_data_dict[field] = xs else: field_data_dict[field] += xs vocab_dict = {} field_dict = {} for name, field in self.fields.items(): if name == 'topic': field.build_vocab() field_dict[name] = field continue if field in field_data_dict: print("Building vocabulary of field {} ...".format( name.upper())) if field.vocab_size == 0: if name != 'output': field.build_vocab(field_data_dict[field], min_freq=self.min_freq, max_size=self.max_vocab_size) field_dict[name] = field field_dict['output'].add_with_other_field(field_dict['tgt']) field_dict['output'].add_with_other_field(field_dict['topic']) if self.embed_file is not None: field_dict['output'].embeddings = field_dict[ 'output'].build_word_embeddings(self.embed_file) for name, field in field_dict.items(): vocab_dict[name] = field.dump_vocab() return vocab_dict def load_vocab(self, prepared_vocab_file): super().load_vocab(prepared_vocab_file) self.topic_bow_vocab_size = self.TOPIC.vocab_size self.Dataset = lambda x: WithBowDataset( data=x, bow_vocab_size=self.topic_bow_vocab_size)
class KnowledgeCorpus(Corpus): def __init__(self, data_dir, data_prefix, min_freq=0, max_vocab_size=None, min_len=0, max_len=100, embed_file=None, share_vocab=False, with_label=False): super(KnowledgeCorpus, self).__init__(data_dir=data_dir, data_prefix=data_prefix, min_freq=min_freq, max_vocab_size=max_vocab_size) self.min_len = min_len self.max_len = max_len self.share_vocab = share_vocab self.with_label = with_label self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file) if self.share_vocab: self.TGT = self.SRC self.CUE = self.SRC else: self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file) self.CUE = TextField(tokenize_fn=tokenize, embed_file=embed_file) if self.with_label: self.INDEX = NumberField() self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE, 'index': self.INDEX} else: self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE} def src_filter_pred(src): return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len def tgt_filter_pred(tgt): return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len self.filter_pred = lambda ex: src_filter_pred( ex['src']) and tgt_filter_pred(ex['tgt']) def read_data(self, data_file, data_type="train"): data = [] with open(data_file, "r", encoding="utf-8") as f: for line in f: if self.with_label: src, tgt, knowledge, label = line.strip().split('\t')[:4] filter_knowledge = [] for sent in knowledge.split(''): filter_knowledge.append(' '.join(sent.split()[:self.max_len])) data.append({'src': src, 'tgt': tgt, 'cue': filter_knowledge, 'index': label}) else: src, tgt, knowledge = line.strip().split('\t')[:3] filter_knowledge = [] for sent in knowledge.split(''): filter_knowledge.append(' '.join(sent.split()[:self.max_len])) data.append({'src': src, 'tgt': tgt, 'cue':filter_knowledge}) filtered_num = len(data) if self.filter_pred is not None: data = [ex for ex in data if self.filter_pred(ex)] filtered_num -= len(data) print( f"Read {len(data)} {data_type.upper()} examples ({filtered_num} filtered)") return data
class KnowledgeCorpus(Corpus): """ KnowledgeCorpus """ def __init__(self, data_dir, data_prefix, min_freq=0, max_vocab_size=None, vocab_file=None, min_len=0, max_len=100, embed_file=None, share_vocab=False, with_label=False): super(KnowledgeCorpus, self).__init__(data_dir=data_dir, data_prefix=data_prefix, min_freq=min_freq, max_vocab_size=max_vocab_size, vocab_file=vocab_file) self.min_len = min_len self.max_len = max_len self.share_vocab = share_vocab self.with_label = with_label self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file) if self.share_vocab: self.TGT = self.SRC self.CUE = self.SRC else: self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file) self.CUE = TextField(tokenize_fn=tokenize, embed_file=embed_file) if self.with_label: self.INDEX = NumberField() self.fields = { 'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE, 'index': self.INDEX } else: self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE} # load vocab if not os.path.exists(self.prepared_vocab_file): self.build() else: self.load_vocab() self.padding_idx = self.TGT.stoi[self.TGT.pad_token] def src_filter_pred(src): """ src_filter_pred """ return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len def tgt_filter_pred(tgt): """ tgt_filter_pred """ return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len self.filter_pred = lambda ex: src_filter_pred(ex[ 'src']) and tgt_filter_pred(ex['tgt']) def read_data(self, data_file, data_type="train"): """ read_data:q """ num = 0 data = [] with open(data_file, "r", encoding="utf-8") as f: for line in f: dialog = json.loads(line, encoding='utf-8') history = dialog["dialog"] uid = [int(i) for i in dialog["uid"]] profile = dialog["profile"] if "responder_profile" in dialog.keys(): responder_profile = dialog["responder_profile"] elif "response_profile" in dialog.keys(): responder_profile = dialog["response_profile"] else: raise ValueError( "No responder_profile or response_profile!") src = "" for idx, sent in zip(uid, history): #tag_list = profile[idx]["tag"][0].split(';') #loc_content = profile[idx]["loc"] #tag_list.append(loc_content) #tag_content = ' '.join(tag_list) sent_content = sent[0] #src += tag_content #src += ' ' src += sent_content src += ' ' src = src.strip() tgt = dialog["golden_response"][0] filter_knowledge = [] if type(responder_profile["tag"]) is list: filter_knowledge.append(' '.join( responder_profile["tag"][0].split(';'))) else: filter_knowledge.append(' '.join( responder_profile["tag"].split(';'))) filter_knowledge.append(responder_profile["loc"]) data.append({'src': src, 'tgt': tgt, 'cue': filter_knowledge}) num += 1 if num < 10: print("src:", src) print("tgt:", tgt) print("cue:", filter_knowledge) print("\n") filtered_num = len(data) if not data_type == "test" and self.filter_pred is not None: data = [ex for ex in data if self.filter_pred(ex)] filtered_num -= len(data) print("Read {} {} examples ({} filtered)".format( len(data), data_type.upper(), filtered_num)) return data
class PersonaCorpus(Corpus): """ PersonaCorpus """ def __init__(self, data_dir, data_prefix, min_freq=0, max_vocab_size=None, min_len=0, max_len=100, embed_file=None, share_vocab=False, with_label=False): super(PersonaCorpus, self).__init__(data_dir=data_dir, data_prefix=data_prefix, min_freq=min_freq, max_vocab_size=max_vocab_size) self.min_len = min_len self.max_len = max_len self.share_vocab = share_vocab self.with_label = with_label self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file) # self.LABEL = NumberField(dtype = float) # self.LABEL = NumberField(sequential=True, dtype = int) if self.share_vocab: self.TGT = self.SRC self.CUE = self.SRC else: self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file) self.CUE = TextField(tokenize_fn=tokenize, embed_file=embed_file) if self.with_label: self.LABEL = NumberField(sequential=False, dtype=int) self.INDEX = NumberField(sequential=True, dtype=int) self.fields = { 'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE, 'label': self.LABEL, 'index': self.INDEX } else: self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE} def src_filter_pred(src): """ src_filter_pred """ for sen in src: if not (min_len <= len(self.SRC.tokenize_fn(sen)) <= max_len): return False else: return True def tgt_filter_pred(tgt): """ tgt_filter_pred """ return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len self.filter_pred = lambda ex: src_filter_pred(ex[ 'src']) and tgt_filter_pred(ex['tgt']) def read_data(self, data_file, data_type="train"): """ read_data """ data = [] with open(data_file, "r", encoding="utf-8") as f: for line in f: # print(self.with_label) if self.with_label: query, response, personas, persona_label, key_index = line.strip( ).split('\t')[:5] filter_personas = [] for sent in personas.split('**'): filter_personas.append(' '.join( sent.split()[:self.max_len])) index = key_index data.append({ 'src': query, 'tgt': response, 'cue': filter_personas, 'label': persona_label, 'index': index }) else: queries, response, persona = line.strip().split('\t')[:3] src = queries.split('**') # filter_persona = ' '.join(persona.split()[:self.max_len]) filter_persona = persona data.append({ 'src': src, 'tgt': response, 'cue': filter_persona }) filtered_num = len(data) if self.filter_pred is not None: data = [ex for ex in data if self.filter_pred(ex)] filtered_num -= len(data) print("Read {} {} examples ({} filtered)".format( len(data), data_type.upper(), filtered_num)) return data def read_data_multitask(self, data_file1, data_file2, data_type="train"): """ read_data """ data1 = [] data2 = [] with open(data_file2, "r", encoding="utf-8") as f: for line in f: # print(self.with_label) query, response, personas, persona_label, key_index = line.strip( ).split('\t')[:5] filter_personas = [] for sent in personas.split('**'): filter_personas.append(' '.join( sent.split()[:self.max_len])) index = key_index data2.append({ 'src': query, 'tgt': response, 'cue': filter_personas, 'label': persona_label, 'index': index }) filtered_num = len(data2) if self.filter_pred is not None: data2 = [ex for ex in data2 if self.filter_pred(ex)] filtered_num -= len(data2) print("Read {} {} examples ({} filtered)".format( len(data2), data_type.upper() + 'task2', filtered_num)) with open(data_file1, "r", encoding="utf-8") as f: for line in f: queries, response, persona = line.strip().split('\t')[:3] src = queries.split('**') # filter_persona = ' '.join(persona.split()[:self.max_len]) filter_persona = persona data1.append({ 'src': src, 'tgt': response, 'cue': filter_persona }) filtered_num = len(data1) if self.filter_pred is not None: data1 = [ex for ex in data1 if self.filter_pred(ex)] filtered_num -= len(data1) print("Read {} {} examples ({} filtered)".format( len(data1), data_type.upper() + 'task1', filtered_num)) return data1, data2