예제 #1
0
    def test_worker(self, _nbatch):
        '''
        For the beam search in testing.
        '''
        start_time = time.time()
        fout = open(os.path.join(self.args.data_dir, 'nats_results', self.args.task_key+'.txt'), 'w')
        for batch_id in range(_nbatch):
            if self.args.oov_explicit:
                ext_id2oov, src_var, src_var_ex, src_arr, src_msk, sum_arr, ttl_arr \
                = process_minibatch_explicit_test(
                    batch_id=batch_id, path_=self.args.data_dir, 
                    batch_size=self.args.test_batch_size, vocab2id=self.batch_data['vocab2id'], 
                    src_lens=self.args.src_seq_lens)
                src_msk = src_msk.to(self.args.device)
                src_var = src_var.to(self.args.device)
                src_var_ex = src_var_ex.to(self.args.device)
            else:
                src_var, src_arr, src_msk, sum_arr, ttl_arr \
                = process_minibatch_test(
                    batch_id=batch_id, path_=self.args.data_dir, 
                    batch_size=self.args.test_batch_size, vocab2id=self.batch_data['vocab2id'], 
                    src_lens=self.args.src_seq_lens)
                src_msk = src_msk.to(self.args.device)
                src_var = src_var.to(self.args.device)
                src_var_ex = src_var.clone()
                ext_id2oov = {}
            self.batch_data['ext_id2oov'] = ext_id2oov
                
            curr_batch_size = src_var.size(0)
            src_text_rep = src_var.unsqueeze(1).clone().repeat(
                1, self.args.beam_size, 1).view(-1, src_var.size(1)).to(self.args.device)
            if self.args.oov_explicit:
                src_text_rep_ex = src_var_ex.unsqueeze(1).clone().repeat(
                    1, self.args.beam_size, 1).view(-1, src_var_ex.size(1)).to(self.args.device)
            else:
                src_text_rep_ex = src_text_rep.clone()
                
            models = {}
            for model_name in self.base_models:
                models[model_name] = self.base_models[model_name]
            for model_name in self.train_models:
                models[model_name] = self.train_models[model_name]
            
            beam_seq, beam_prb, beam_attn_ = fast_beam_search(
                self.args, models, self.batch_data,
                src_text_rep, src_text_rep_ex, curr_batch_size, self.args.task_key)
            # copy unknown words
            if self.args.task_key == 'title':
                trg_arr = ttl_arr
            if self.args.task_key == 'summary':
                trg_arr = sum_arr
            out_arr = word_copy(
                self.args, beam_seq, beam_attn_, src_msk, src_arr, curr_batch_size, 
                self.batch_data['id2vocab'], self.batch_data['ext_id2oov'])
            for k in range(curr_batch_size):
                fout.write('<sec>'.join([out_arr[k], trg_arr[k]])+'\n')

            end_time = time.time()
            show_progress(batch_id, _nbatch, str((end_time-start_time)/3600)[:8]+"h")
        fout.close()        
예제 #2
0
    def app_worker(self):
        '''
        For the beam search in application.
        '''
        files_ = glob.glob(os.path.join(self.args.app_data_dir, '*_in.json'))
        for curr_file in files_:
            print("Read {}.".format(curr_file))
            fTmp = re.split('\_', curr_file)[0]
            fp = open(curr_file, 'r')
            data_input = json.load(fp)
            fp.close()
            article = nlp(data_input['content'].lower())
            article = ' '.join([wd.text for wd in article])
            article = re.split('\s', article)
            article = list(filter(None, article))
            data_input['content_token'] = article

            self.args.src_seq_lens = len(article)
            ext_id2oov, src_var, src_var_ex, src_arr, src_msk = \
            process_data_app(data_input, self.batch_data['vocab2id'], self.args.src_seq_lens)
            self.batch_data['ext_id2oov'] = ext_id2oov
            src_msk = src_msk.to(self.args.device)

            curr_batch_size = src_var.size(0)
            src_text_rep = src_var.unsqueeze(1).clone().repeat(
                1, self.args.beam_size,
                1).view(-1, src_var.size(1)).to(self.args.device)
            if self.args.oov_explicit:
                src_text_rep_ex = src_var_ex.unsqueeze(1).clone().repeat(
                    1, self.args.beam_size,
                    1).view(-1, src_var_ex.size(1)).to(self.args.device)
            else:
                src_text_rep_ex = src_text_rep.clone()

            self.args.task_key = 'title'
            beam_seq, beam_prb, beam_attn_ = fast_beam_search(
                self.args, self.base_models, self.batch_data, src_text_rep,
                src_text_rep_ex, curr_batch_size, self.args.task_key)
            beam_out = beam_attn_[:, :, 0].squeeze(
            )[:, :self.args.src_seq_lens].data.cpu().numpy()
            beam_out = self.attnWeight2rgbPercent(beam_out)
            trg_arr = word_copy(self.args, beam_seq, beam_attn_, src_msk,
                                src_arr, curr_batch_size,
                                self.batch_data['id2vocab'],
                                self.batch_data['ext_id2oov'])
            trg_arr = re.split('\s', trg_arr[0])
            out_arr = []
            for idx, wd in enumerate(trg_arr):
                if wd == '<stop>':
                    break
                if wd != '<s>' and wd != '</s>':
                    out_arr.append({
                        "key": wd,
                        "attention": beam_out[idx].tolist()
                    })
            data_input[self.args.task_key] = out_arr

            self.args.task_key = 'summary'
            beam_seq, beam_prb, beam_attn_ = fast_beam_search(
                self.args, self.base_models, self.batch_data, src_text_rep,
                src_text_rep_ex, curr_batch_size, self.args.task_key)
            beam_out = beam_attn_[:, :, 0].squeeze(
            )[:, :self.args.src_seq_lens].data.cpu().numpy()
            beam_out = self.attnWeight2rgbPercent(beam_out)
            trg_arr = word_copy(self.args, beam_seq, beam_attn_, src_msk,
                                src_arr, curr_batch_size,
                                self.batch_data['id2vocab'],
                                self.batch_data['ext_id2oov'])
            trg_arr = re.split('\s', trg_arr[0])
            out_arr = []
            for idx, wd in enumerate(trg_arr):
                if wd == '<stop>':
                    break
                if wd != '<s>' and wd != '</s>':
                    out_arr.append({
                        "key": wd,
                        "attention": beam_out[idx].tolist()
                    })
            data_input[self.args.task_key] = out_arr

            print('Write {}.'.format(fTmp + '_out.json'))
            fout = open(fTmp + '_out.json', 'w')
            json.dump(data_input, fout)
            fout.close()

            os.unlink(curr_file)