class NLG(): def __init__(self, NLG_param_dir, NLG_model_fname, tokenizer, NLU_param_dir=None, NLU_model_fname=None): self.tokenizer = Tokenizer(tokenizer, '../tokenizer/e2e.model') self.tokenizer_mode = tokenizer self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') saved_data = torch.load( NLG_param_dir.rstrip('/') + '/' + NLG_model_fname) self.model_NLG = saved_data['model'] f = open(NLG_param_dir.rstrip('/') + '/dictionary.json', 'r', encoding='utf-8') self.dictionary = json.load(f) f.close() # beam-search settings self.n_beam = 5 # NLU if (NLU_param_dir is not None) and (NLU_model_fname is not None): self.NLU = NLU(NLU_param_dir, NLU_model_fname, tokenizer) else: self.NLU = None def convert_nlg(self, input_mr_obj, search, lex_flag, startword=''): def _shape_txt(input_mr_obj, output_token, lex_flag): if self.tokenizer_mode == 'sentencepiece': output_txt = ''.join(output_token).replace('▁', ' ') output_txt = output_txt.lstrip(' ') else: output_txt = '' for i in range(len(output_token)): if (i > 0) and (output_token[i] != '.') and ( output_token[i] != ',') and (output_token[i][0] != '\''): output_txt += ' ' output_txt += output_token[i] # Lexicalisation if lex_flag is True: output_txt = output_txt.replace('NAME', input_mr_obj['name']) output_txt = output_txt.replace('NEAR', input_mr_obj['near']) return output_txt input_mr_obj_org = copy.deepcopy(input_mr_obj) if lex_flag is True: if input_mr_obj['name'] != '': input_mr_obj['name'] = 'NAME' if input_mr_obj['near'] != '': input_mr_obj['near'] = 'NEAR' input_mr_token = self.tokenizer.mr(input_mr_obj) if search == 'greedy': output_txt_token, attention = self.translate_nlg_greedy_search( input_mr_token, startword) elif search == 'beam': output_txt_token, attention = self.translate_nlg_beam_search( input_mr_token, lex_flag, startword) else: output_txt_token, attention = self.translate_nlg( input_mr_token, lex_flag, startword) output_txt = _shape_txt(input_mr_obj_org, output_txt_token, lex_flag) return output_txt, attention def translate_nlg_encode(self, input_mr_token): mr_indexes = [] for token in input_mr_token: if token in self.dictionary['mr_s2i']: mr_indexes.append(self.dictionary['mr_s2i'][token]) else: mr_indexes.append(self.dictionary['mr_s2i']['<unk>']) mr_tensor = torch.LongTensor(mr_indexes).unsqueeze(0).to(self.device) mr_mask = self.model_NLG.make_mr_mask(mr_tensor) with torch.no_grad(): enc_mr = self.model_NLG.encoder(mr_tensor, mr_mask) return enc_mr, mr_mask def translate_nlg_greedy_search(self, input_mr_token, startword=''): self.model_NLG.eval() ## encode enc_mr, mr_mask = self.translate_nlg_encode(input_mr_token) ## decode # startword token_startword = self.tokenizer.txt(startword) txt_indexes = [self.dictionary['txt_s2i']['<sos>']] for token in token_startword: if token in self.dictionary['txt_s2i']: txt_indexes.append(self.dictionary['txt_s2i'][token]) else: txt_indexes.append(self.dictionary['txt_s2i']['<unk>']) num_token = len(txt_indexes) for i in range(self.dictionary['max_txt_length'] - num_token): txt_tensor = torch.LongTensor(txt_indexes).unsqueeze(0).to( self.device) txt_mask = self.model_NLG.make_txt_mask(txt_tensor) with torch.no_grad(): output, attention = self.model_NLG.decoder( txt_tensor, enc_mr, txt_mask, mr_mask) pred_token = output.argmax(2)[:, -1].item() txt_indexes.append(pred_token) if pred_token == self.dictionary['txt_s2i']['<eos>']: break txt_tokens = [self.dictionary['txt_i2s'][i] for i in txt_indexes] txt_tokens = txt_tokens[1:-1] return txt_tokens, attention def translate_nlg_beam_search(self, input_mr_token, lex_flag, startword=''): self.model_NLG.eval() ## encode enc_mr, mr_mask = self.translate_nlg_encode(input_mr_token) ## decode # startword token_startword = self.tokenizer.txt(startword) offset = len(token_startword) a_cand_prev = [{ 'idx': [self.dictionary['txt_s2i']['<sos>']], 'val': 1.0 }] for token in token_startword: if token in self.dictionary['txt_s2i']: a_cand_prev[0]['idx'].append(self.dictionary['txt_s2i'][token]) else: a_cand_prev[0]['idx'].append( self.dictionary['txt_s2i']['<unk>']) num_token = len(a_cand_prev[0]['idx']) a_out = [] for i in range(self.dictionary['max_txt_length'] - num_token): a_cand = [] for j in range(len(a_cand_prev)): txt_tensor = torch.LongTensor( a_cand_prev[j]['idx']).unsqueeze(0).to(self.device) txt_mask = self.model_NLG.make_txt_mask(txt_tensor) with torch.no_grad(): output, attention = self.model_NLG.decoder( txt_tensor, enc_mr, txt_mask, mr_mask) output = torch.softmax(output, dim=-1) for n in range(self.n_beam): a_cand.append(copy.deepcopy(a_cand_prev[j])) idx = (torch.argsort(output, axis=2)[0, i + offset, -(n + 1)]).item() val = output[0, i + offset, idx].item() a_cand[len(a_cand) - 1]['idx'].append(idx) a_cand[len(a_cand) - 1]['val'] *= val a_cand_sort = sorted(a_cand, key=lambda x: x['val'], reverse=True) a_cand_prev = [] nloop = min(len(a_cand_sort), self.n_beam) for j in range(nloop): if a_cand_sort[j]['idx'][ len(a_cand_sort[j]['idx']) - 1] == self.dictionary['txt_s2i']['<eos>']: a_out.append(a_cand_sort[j]) if len(a_out) == self.n_beam: break else: a_cand_prev.append(a_cand_sort[j]) if len(a_out) == self.n_beam: break if lex_flag is False: ref_mr_token = input_mr_token else: tmp_mr_text = '' for token in input_mr_token: tmp_mr_text += token tmp_mr_list = tmp_mr_text.split('|') if tmp_mr_list[0] != '': tmp_mr_list[0] = 'NAME' if tmp_mr_list[7] != '': tmp_mr_list[7] = 'NEAR' tmp_mr_obj = { 'name': tmp_mr_list[0], 'eatType': tmp_mr_list[1], 'food': tmp_mr_list[2], 'priceRange': tmp_mr_list[3], 'customer rating': tmp_mr_list[4], 'area': tmp_mr_list[5], 'familyFriendly': tmp_mr_list[6], 'near': tmp_mr_list[7] } ref_mr_token = self.tokenizer.mr(tmp_mr_obj) flag = False for n in range(len(a_out)): txt_tokens_tmp = [ self.dictionary['txt_i2s'][idx] for idx in a_out[n]['idx'] ] nlu_output_token, _ = self.NLU.translate_nlu_greedy_search( txt_tokens_tmp[1:-1]) if nlu_output_token == ref_mr_token: txt_tokens = txt_tokens_tmp[1:-1] flag = True break if flag is False: if len(a_out) > 0: txt_tokens = [ self.dictionary['txt_i2s'][idx] for idx in a_out[0]['idx'] ] txt_tokens = txt_tokens[1:-1] else: txt_tokens, attention = self.translate_nlg_greedy_search( input_mr_token, 'single', startword) return txt_tokens, attention def translate_nlg(self, input_mr_token, lex_flag, startword=''): self.model_NLG.eval() ## encode enc_mr, mr_mask = self.translate_nlg_encode(input_mr_token) ## decode # startword token_startword = self.tokenizer.txt(startword) offset = len(token_startword) # greedy search txt_indexes = [self.dictionary['txt_s2i']['<sos>']] for token in token_startword: if token in self.dictionary['txt_s2i']: txt_indexes.append(self.dictionary['txt_s2i'][token]) else: txt_indexes.append(self.dictionary['txt_s2i']['<unk>']) num_token = len(txt_indexes) for i in range(self.dictionary['max_txt_length'] - num_token): txt_tensor = torch.LongTensor(txt_indexes).unsqueeze(0).to( self.device) txt_mask = self.model_NLG.make_txt_mask(txt_tensor) with torch.no_grad(): output, attention = self.model_NLG.decoder( txt_tensor, enc_mr, txt_mask, mr_mask) pred_token = output.argmax(2)[:, -1].item() txt_indexes.append(pred_token) if pred_token == self.dictionary['txt_s2i']['<eos>']: break txt_tokens_greedy = [ self.dictionary['txt_i2s'][i] for i in txt_indexes ] attention_greedy = attention nlu_output_token, _ = self.NLU.translate_nlu_greedy_search( txt_tokens_greedy[1:-1]) if lex_flag is False: ref_mr_token = input_mr_token else: tmp_mr_text = '' for token in input_mr_token: tmp_mr_text += token tmp_mr_list = tmp_mr_text.split('|') if tmp_mr_list[0] != '': tmp_mr_list[0] = 'NAME' if tmp_mr_list[7] != '': tmp_mr_list[7] = 'NEAR' tmp_mr_obj = { 'name': tmp_mr_list[0], 'eatType': tmp_mr_list[1], 'food': tmp_mr_list[2], 'priceRange': tmp_mr_list[3], 'customer rating': tmp_mr_list[4], 'area': tmp_mr_list[5], 'familyFriendly': tmp_mr_list[6], 'near': tmp_mr_list[7] } ref_mr_token = self.tokenizer.mr(tmp_mr_obj) if nlu_output_token == ref_mr_token: txt_tokens = txt_tokens_greedy[1:-1] attention = attention_greedy else: a_cand_prev = [{ 'idx': [self.dictionary['txt_s2i']['<sos>']], 'val': 1.0 }] for token in token_startword: if token in self.dictionary['txt_s2i']: a_cand_prev[0]['idx'].append( self.dictionary['txt_s2i'][token]) else: a_cand_prev[0]['idx'].append( self.dictionary['txt_s2i']['<unk>']) num_token = len(a_cand_prev[0]['idx']) a_out = [] for i in range(self.dictionary['max_txt_length'] - num_token): a_cand = [] for j in range(len(a_cand_prev)): txt_tensor = torch.LongTensor( a_cand_prev[j]['idx']).unsqueeze(0).to(self.device) txt_mask = self.model_NLG.make_txt_mask(txt_tensor) with torch.no_grad(): output, attention = self.model_NLG.decoder( txt_tensor, enc_mr, txt_mask, mr_mask) output = torch.softmax(output, dim=-1) for n in range(self.n_beam): a_cand.append(copy.deepcopy(a_cand_prev[j])) idx = (torch.argsort(output, axis=2)[0, i + offset, -(n + 1)]).item() val = output[0, i + offset, idx].item() a_cand[len(a_cand) - 1]['idx'].append(idx) a_cand[len(a_cand) - 1]['val'] *= val a_cand_sort = sorted(a_cand, key=lambda x: x['val'], reverse=True) a_cand_prev = [] nloop = min(len(a_cand_sort), self.n_beam) for j in range(nloop): if a_cand_sort[j]['idx'][ len(a_cand_sort[j]['idx']) - 1] == self.dictionary['txt_s2i']['<eos>']: a_out.append(a_cand_sort[j]) if len(a_out) == self.n_beam: break else: a_cand_prev.append(a_cand_sort[j]) if len(a_out) == self.n_beam: break flag = False for n in range(len(a_out)): txt_tokens_tmp = [ self.dictionary['txt_i2s'][idx] for idx in a_out[n]['idx'] ] nlu_output_token, _ = self.NLU.translate_nlu_greedy_search( txt_tokens_tmp[1:-1]) if nlu_output_token == ref_mr_token: txt_tokens = txt_tokens_tmp[1:-1] flag = True break if flag is False: txt_tokens = txt_tokens_greedy[1:-1] attention = attention_greedy return txt_tokens, attention
print('TXT: ' + obj_txt['txt']) #print(obj_txt) f = open(args.o, 'w', encoding='utf-8') json.dump(obj_txt, f, ensure_ascii=False, indent=4, sort_keys=False) f.close() print('** done **') tokenizer = Tokenizer('nltk', '../../tokenizer/e2e.model') txt = nlg_txt mr = mr_obj if read_mr_obj['name'] != '': txt = txt.replace(read_mr_obj['name'], 'NAME') mr['name'] = 'NAME' if read_mr_obj['near'] != '': txt = txt.replace(read_mr_obj['near'], 'NEAR') mr['near'] = 'NEAR' mr_token = tokenizer.mr(mr) txt_token = tokenizer.txt(txt) ''' print(mr_token) print(str(len(mr_token))) print(txt_token) print(str(len(txt_token))) print(attention_txt.size()) ''' if args.display is True: display_attention(mr_token, txt_token, attention_txt)
exist_txt = {} for obj in a_obj_in_train: if (obj['txt_lex'] in exist_txt) is False: exist_txt[obj['txt_lex']] = True del a_obj_in_train fi = open('tmp_data.txt', 'r', encoding='utf-8') a_in = fi.readlines() fi.close() a_txt_out = [] for line in a_in: mr = conv_text2mr(line.rstrip('\n').split('\t')[0]) txt = line.rstrip('\n').split('\t')[1] if (txt in exist_txt) is True: continue a_token_txt = tokenizer.txt(txt) if len(a_token_txt) > dictionary['txt_max_num_token'] - 2: continue a_txt_out.append({'mr': mr, 'txt': txt}) del a_in # remove duplicated data a_txt_out = remove_duplication(a_txt_out) # dump augmented data fo = open('e2e_train_aug.json', 'w', encoding='utf-8') json.dump(a_txt_out, fo, ensure_ascii=False, indent=4, sort_keys=False) fo.close() print('** done **')
class NLU(): def __init__(self, param_dir, model_fname, tokenizer): self.tokenizer = Tokenizer(tokenizer, '../tokenizer/e2e.model') self.tokenizer_mode = tokenizer self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') saved_data = torch.load(param_dir.rstrip('/')+'/'+model_fname) self.model = saved_data['model'] f = open(param_dir.rstrip('/')+'/dictionary.json', 'r', encoding='utf-8') self.dictionary = json.load(f) f.close() # beam-search settings self.n_beam = 5 def convert_nlu(self, input_txt, search): def _shape_mr(output_token): if self.tokenizer_mode == 'sentencepiece': output_data = ''.join(output_token).replace('▁', ' ').lstrip(' ').split('|') output_mr = '' for i, data in enumerate(output_data): if i > 0: output_mr += '|' output_mr += data.lstrip(' ') else: output_mr = '' for i in range(len(output_token)): if i > 0 and \ ((output_mr != '') and (output_mr[-1] != '|')) and \ (output_token[i] != '|') and \ (output_token[i] != ''): output_mr += ' ' output_mr += output_token[i] return output_mr input_token = self.tokenizer.txt(input_txt) if search == 'greedy': output_token, attention = self.translate_nlu_greedy_search(input_token) else: output_token, attention = self.translate_nlu_beam_search(input_token) output_mr = _shape_mr(output_token) return output_mr, attention def translate_nlu_encode(self, a_token_txt): txt_indexes = [] for token in a_token_txt: if token in self.dictionary['txt_s2i']: txt_indexes.append(self.dictionary['txt_s2i'][token]) else: txt_indexes.append(self.dictionary['txt_s2i']['<unk>']) txt_tensor = torch.LongTensor(txt_indexes).unsqueeze(0).to(self.device) txt_mask = self.model.make_txt_mask(txt_tensor) with torch.no_grad(): enc_txt = self.model.encoder(txt_tensor, txt_mask) return enc_txt, txt_mask def translate_nlu_greedy_search(self, a_token_txt): self.model.eval() # encode enc_txt, txt_mask = self.translate_nlu_encode(a_token_txt) # decode mr_indexes = [self.dictionary['mr_s2i']['<sos>']] for i in range(self.dictionary['max_mr_length'] - 1): mr_tensor = torch.LongTensor(mr_indexes).unsqueeze(0).to(self.device) mr_mask = self.model.make_mr_mask(mr_tensor) with torch.no_grad(): output, attention = self.model.decoder(mr_tensor, enc_txt, mr_mask, txt_mask) pred_token = output.argmax(2)[:,-1].item() mr_indexes.append(pred_token) if pred_token == self.dictionary['mr_s2i']['<eos>']: break mr_tokens = [self.dictionary['mr_i2s'][i] for i in mr_indexes] return mr_tokens[1:-1], attention def translate_nlu_beam_search(self, a_token_txt): self.model.eval() ## encode enc_txt, txt_mask = self.translate_nlu_encode(a_token_txt) ## decode a_cand_prev = [{'idx': [self.dictionary['mr_s2i']['<sos>']], 'val': 1.0}] a_out = [] for i in range(self.dictionary['max_mr_length'] - 1): a_cand = [] for j in range(len(a_cand_prev)): mr_tensor = torch.LongTensor(a_cand_prev[j]['idx']).unsqueeze(0).to(self.device) mr_mask = self.model.make_mr_mask(mr_tensor) with torch.no_grad(): output, attention = self.model.decoder(mr_tensor, enc_txt, mr_mask, txt_mask) output = torch.softmax(output, dim=-1) for n in range(self.n_beam): a_cand.append(copy.deepcopy(a_cand_prev[j])) idx = (torch.argsort(output, axis=2)[0, i, -(n+1)]).item() val = output[0, i, idx].item() a_cand[len(a_cand)-1]['idx'].append(idx) a_cand[len(a_cand)-1]['val'] *= val a_cand_sort = sorted(a_cand, key=lambda x:x['val'], reverse=True) a_cand_prev = [] nloop = min(len(a_cand_sort), self.n_beam) for j in range(nloop): if a_cand_sort[j]['idx'][len(a_cand_sort[j]['idx'])-1] == self.dictionary['mr_s2i']['<eos>']: if a_cand_sort[j]['idx'].count(self.dictionary['mr_s2i']['|']) == 7: a_out.append(a_cand_sort[j]) if len(a_out) == self.n_beam: break else: a_cand_prev.append(a_cand_sort[j]) if len(a_out) == self.n_beam: break if len(a_out) > 0: mr_tokens = [self.dictionary['mr_i2s'][idx] for idx in a_out[0]['idx']] mr_tokens = mr_tokens[1:-1] else: mr_tokens, attention = self.translate_nlu_greedy_search(input_token) return mr_tokens, attention