Example #1
0
class NLG():
    def __init__(self,
                 NLG_param_dir,
                 NLG_model_fname,
                 tokenizer,
                 NLU_param_dir=None,
                 NLU_model_fname=None):
        self.tokenizer = Tokenizer(tokenizer, '../tokenizer/e2e.model')
        self.tokenizer_mode = tokenizer
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        saved_data = torch.load(
            NLG_param_dir.rstrip('/') + '/' + NLG_model_fname)
        self.model_NLG = saved_data['model']
        f = open(NLG_param_dir.rstrip('/') + '/dictionary.json',
                 'r',
                 encoding='utf-8')
        self.dictionary = json.load(f)
        f.close()

        # beam-search settings
        self.n_beam = 5

        # NLU
        if (NLU_param_dir is not None) and (NLU_model_fname is not None):
            self.NLU = NLU(NLU_param_dir, NLU_model_fname, tokenizer)
        else:
            self.NLU = None

    def convert_nlg(self, input_mr_obj, search, lex_flag, startword=''):
        def _shape_txt(input_mr_obj, output_token, lex_flag):
            if self.tokenizer_mode == 'sentencepiece':
                output_txt = ''.join(output_token).replace('▁', ' ')
                output_txt = output_txt.lstrip(' ')
            else:
                output_txt = ''
                for i in range(len(output_token)):
                    if (i > 0) and (output_token[i] != '.') and (
                            output_token[i] != ',') and (output_token[i][0] !=
                                                         '\''):
                        output_txt += ' '
                    output_txt += output_token[i]
            # Lexicalisation
            if lex_flag is True:
                output_txt = output_txt.replace('NAME', input_mr_obj['name'])
                output_txt = output_txt.replace('NEAR', input_mr_obj['near'])
            return output_txt

        input_mr_obj_org = copy.deepcopy(input_mr_obj)
        if lex_flag is True:
            if input_mr_obj['name'] != '':
                input_mr_obj['name'] = 'NAME'
            if input_mr_obj['near'] != '':
                input_mr_obj['near'] = 'NEAR'
        input_mr_token = self.tokenizer.mr(input_mr_obj)
        if search == 'greedy':
            output_txt_token, attention = self.translate_nlg_greedy_search(
                input_mr_token, startword)
        elif search == 'beam':
            output_txt_token, attention = self.translate_nlg_beam_search(
                input_mr_token, lex_flag, startword)
        else:
            output_txt_token, attention = self.translate_nlg(
                input_mr_token, lex_flag, startword)
        output_txt = _shape_txt(input_mr_obj_org, output_txt_token, lex_flag)
        return output_txt, attention

    def translate_nlg_encode(self, input_mr_token):
        mr_indexes = []
        for token in input_mr_token:
            if token in self.dictionary['mr_s2i']:
                mr_indexes.append(self.dictionary['mr_s2i'][token])
            else:
                mr_indexes.append(self.dictionary['mr_s2i']['<unk>'])
        mr_tensor = torch.LongTensor(mr_indexes).unsqueeze(0).to(self.device)
        mr_mask = self.model_NLG.make_mr_mask(mr_tensor)
        with torch.no_grad():
            enc_mr = self.model_NLG.encoder(mr_tensor, mr_mask)
        return enc_mr, mr_mask

    def translate_nlg_greedy_search(self, input_mr_token, startword=''):
        self.model_NLG.eval()

        ## encode
        enc_mr, mr_mask = self.translate_nlg_encode(input_mr_token)

        ## decode
        # startword
        token_startword = self.tokenizer.txt(startword)

        txt_indexes = [self.dictionary['txt_s2i']['<sos>']]
        for token in token_startword:
            if token in self.dictionary['txt_s2i']:
                txt_indexes.append(self.dictionary['txt_s2i'][token])
            else:
                txt_indexes.append(self.dictionary['txt_s2i']['<unk>'])

        num_token = len(txt_indexes)
        for i in range(self.dictionary['max_txt_length'] - num_token):
            txt_tensor = torch.LongTensor(txt_indexes).unsqueeze(0).to(
                self.device)
            txt_mask = self.model_NLG.make_txt_mask(txt_tensor)
            with torch.no_grad():
                output, attention = self.model_NLG.decoder(
                    txt_tensor, enc_mr, txt_mask, mr_mask)

            pred_token = output.argmax(2)[:, -1].item()
            txt_indexes.append(pred_token)

            if pred_token == self.dictionary['txt_s2i']['<eos>']:
                break
        txt_tokens = [self.dictionary['txt_i2s'][i] for i in txt_indexes]
        txt_tokens = txt_tokens[1:-1]

        return txt_tokens, attention

    def translate_nlg_beam_search(self,
                                  input_mr_token,
                                  lex_flag,
                                  startword=''):
        self.model_NLG.eval()

        ## encode
        enc_mr, mr_mask = self.translate_nlg_encode(input_mr_token)

        ## decode
        # startword
        token_startword = self.tokenizer.txt(startword)
        offset = len(token_startword)

        a_cand_prev = [{
            'idx': [self.dictionary['txt_s2i']['<sos>']],
            'val': 1.0
        }]
        for token in token_startword:
            if token in self.dictionary['txt_s2i']:
                a_cand_prev[0]['idx'].append(self.dictionary['txt_s2i'][token])
            else:
                a_cand_prev[0]['idx'].append(
                    self.dictionary['txt_s2i']['<unk>'])
        num_token = len(a_cand_prev[0]['idx'])
        a_out = []
        for i in range(self.dictionary['max_txt_length'] - num_token):
            a_cand = []
            for j in range(len(a_cand_prev)):
                txt_tensor = torch.LongTensor(
                    a_cand_prev[j]['idx']).unsqueeze(0).to(self.device)
                txt_mask = self.model_NLG.make_txt_mask(txt_tensor)
                with torch.no_grad():
                    output, attention = self.model_NLG.decoder(
                        txt_tensor, enc_mr, txt_mask, mr_mask)
                    output = torch.softmax(output, dim=-1)
                for n in range(self.n_beam):
                    a_cand.append(copy.deepcopy(a_cand_prev[j]))
                    idx = (torch.argsort(output, axis=2)[0, i + offset,
                                                         -(n + 1)]).item()
                    val = output[0, i + offset, idx].item()
                    a_cand[len(a_cand) - 1]['idx'].append(idx)
                    a_cand[len(a_cand) - 1]['val'] *= val

            a_cand_sort = sorted(a_cand, key=lambda x: x['val'], reverse=True)
            a_cand_prev = []
            nloop = min(len(a_cand_sort), self.n_beam)
            for j in range(nloop):
                if a_cand_sort[j]['idx'][
                        len(a_cand_sort[j]['idx']) -
                        1] == self.dictionary['txt_s2i']['<eos>']:
                    a_out.append(a_cand_sort[j])
                    if len(a_out) == self.n_beam:
                        break
                else:
                    a_cand_prev.append(a_cand_sort[j])
            if len(a_out) == self.n_beam:
                break

        if lex_flag is False:
            ref_mr_token = input_mr_token
        else:
            tmp_mr_text = ''
            for token in input_mr_token:
                tmp_mr_text += token
            tmp_mr_list = tmp_mr_text.split('|')
            if tmp_mr_list[0] != '':
                tmp_mr_list[0] = 'NAME'
            if tmp_mr_list[7] != '':
                tmp_mr_list[7] = 'NEAR'
            tmp_mr_obj = {
                'name': tmp_mr_list[0],
                'eatType': tmp_mr_list[1],
                'food': tmp_mr_list[2],
                'priceRange': tmp_mr_list[3],
                'customer rating': tmp_mr_list[4],
                'area': tmp_mr_list[5],
                'familyFriendly': tmp_mr_list[6],
                'near': tmp_mr_list[7]
            }
            ref_mr_token = self.tokenizer.mr(tmp_mr_obj)

        flag = False
        for n in range(len(a_out)):
            txt_tokens_tmp = [
                self.dictionary['txt_i2s'][idx] for idx in a_out[n]['idx']
            ]
            nlu_output_token, _ = self.NLU.translate_nlu_greedy_search(
                txt_tokens_tmp[1:-1])
            if nlu_output_token == ref_mr_token:
                txt_tokens = txt_tokens_tmp[1:-1]
                flag = True
                break
        if flag is False:
            if len(a_out) > 0:
                txt_tokens = [
                    self.dictionary['txt_i2s'][idx] for idx in a_out[0]['idx']
                ]
                txt_tokens = txt_tokens[1:-1]
            else:
                txt_tokens, attention = self.translate_nlg_greedy_search(
                    input_mr_token, 'single', startword)
        return txt_tokens, attention

    def translate_nlg(self, input_mr_token, lex_flag, startword=''):
        self.model_NLG.eval()

        ## encode
        enc_mr, mr_mask = self.translate_nlg_encode(input_mr_token)

        ## decode
        # startword
        token_startword = self.tokenizer.txt(startword)
        offset = len(token_startword)

        # greedy search
        txt_indexes = [self.dictionary['txt_s2i']['<sos>']]
        for token in token_startword:
            if token in self.dictionary['txt_s2i']:
                txt_indexes.append(self.dictionary['txt_s2i'][token])
            else:
                txt_indexes.append(self.dictionary['txt_s2i']['<unk>'])

        num_token = len(txt_indexes)
        for i in range(self.dictionary['max_txt_length'] - num_token):
            txt_tensor = torch.LongTensor(txt_indexes).unsqueeze(0).to(
                self.device)
            txt_mask = self.model_NLG.make_txt_mask(txt_tensor)

            with torch.no_grad():
                output, attention = self.model_NLG.decoder(
                    txt_tensor, enc_mr, txt_mask, mr_mask)

            pred_token = output.argmax(2)[:, -1].item()
            txt_indexes.append(pred_token)

            if pred_token == self.dictionary['txt_s2i']['<eos>']:
                break
        txt_tokens_greedy = [
            self.dictionary['txt_i2s'][i] for i in txt_indexes
        ]
        attention_greedy = attention

        nlu_output_token, _ = self.NLU.translate_nlu_greedy_search(
            txt_tokens_greedy[1:-1])
        if lex_flag is False:
            ref_mr_token = input_mr_token
        else:
            tmp_mr_text = ''
            for token in input_mr_token:
                tmp_mr_text += token
            tmp_mr_list = tmp_mr_text.split('|')
            if tmp_mr_list[0] != '':
                tmp_mr_list[0] = 'NAME'
            if tmp_mr_list[7] != '':
                tmp_mr_list[7] = 'NEAR'
            tmp_mr_obj = {
                'name': tmp_mr_list[0],
                'eatType': tmp_mr_list[1],
                'food': tmp_mr_list[2],
                'priceRange': tmp_mr_list[3],
                'customer rating': tmp_mr_list[4],
                'area': tmp_mr_list[5],
                'familyFriendly': tmp_mr_list[6],
                'near': tmp_mr_list[7]
            }
            ref_mr_token = self.tokenizer.mr(tmp_mr_obj)

        if nlu_output_token == ref_mr_token:
            txt_tokens = txt_tokens_greedy[1:-1]
            attention = attention_greedy
        else:
            a_cand_prev = [{
                'idx': [self.dictionary['txt_s2i']['<sos>']],
                'val': 1.0
            }]
            for token in token_startword:
                if token in self.dictionary['txt_s2i']:
                    a_cand_prev[0]['idx'].append(
                        self.dictionary['txt_s2i'][token])
                else:
                    a_cand_prev[0]['idx'].append(
                        self.dictionary['txt_s2i']['<unk>'])
            num_token = len(a_cand_prev[0]['idx'])
            a_out = []
            for i in range(self.dictionary['max_txt_length'] - num_token):
                a_cand = []
                for j in range(len(a_cand_prev)):
                    txt_tensor = torch.LongTensor(
                        a_cand_prev[j]['idx']).unsqueeze(0).to(self.device)
                    txt_mask = self.model_NLG.make_txt_mask(txt_tensor)
                    with torch.no_grad():
                        output, attention = self.model_NLG.decoder(
                            txt_tensor, enc_mr, txt_mask, mr_mask)
                        output = torch.softmax(output, dim=-1)
                    for n in range(self.n_beam):
                        a_cand.append(copy.deepcopy(a_cand_prev[j]))
                        idx = (torch.argsort(output, axis=2)[0, i + offset,
                                                             -(n + 1)]).item()
                        val = output[0, i + offset, idx].item()
                        a_cand[len(a_cand) - 1]['idx'].append(idx)
                        a_cand[len(a_cand) - 1]['val'] *= val

                a_cand_sort = sorted(a_cand,
                                     key=lambda x: x['val'],
                                     reverse=True)
                a_cand_prev = []
                nloop = min(len(a_cand_sort), self.n_beam)
                for j in range(nloop):
                    if a_cand_sort[j]['idx'][
                            len(a_cand_sort[j]['idx']) -
                            1] == self.dictionary['txt_s2i']['<eos>']:
                        a_out.append(a_cand_sort[j])
                        if len(a_out) == self.n_beam:
                            break
                    else:
                        a_cand_prev.append(a_cand_sort[j])
                if len(a_out) == self.n_beam:
                    break

            flag = False
            for n in range(len(a_out)):
                txt_tokens_tmp = [
                    self.dictionary['txt_i2s'][idx] for idx in a_out[n]['idx']
                ]
                nlu_output_token, _ = self.NLU.translate_nlu_greedy_search(
                    txt_tokens_tmp[1:-1])
                if nlu_output_token == ref_mr_token:
                    txt_tokens = txt_tokens_tmp[1:-1]
                    flag = True
                    break

            if flag is False:
                txt_tokens = txt_tokens_greedy[1:-1]
                attention = attention_greedy

        return txt_tokens, attention
Example #2
0
                         for wordG in mr_list['near']:
                             mr = {
                                 'name': 'NAME',
                                 'eatType': wordA,
                                 'food': wordB,
                                 'priceRange': wordC,
                                 'customer rating': wordD,
                                 'area': wordE,
                                 'familyFriendly': wordF,
                                 'near': wordG
                             }
                             mr_text = conv_mr2text(mr)
                             if (mr_text in exist_mr) is False:
                                 a_mr_out.append(mr)
                                 exist_mr[mr_text] = ''
                             mr_token = tokenizer.mr(mr)
                             if mr_lex_max_num_token < len(mr_token):
                                 mr_lex_max_num_token = len(mr_token)
 del exist_mr
 fo = open('e2e_mr_lex_max_num_token.json', 'w', encoding='utf-8')
 json.dump(mr_lex_max_num_token,
           fo,
           ensure_ascii=False,
           indent=4,
           sort_keys=False)
 fo.close()
 fo = open('e2e_train_mr_aug.json', 'w', encoding='utf-8')
 json.dump(a_mr_out, fo, ensure_ascii=False, indent=4, sort_keys=False)
 fo.close()
 del a_mr_out
 print('** done(MR) **')