def test_worker(self, _nbatch): ''' For the beam search in testing. ''' start_time = time.time() fout = open(os.path.join(self.args.data_dir, 'nats_results', self.args.task_key+'.txt'), 'w') for batch_id in range(_nbatch): if self.args.oov_explicit: ext_id2oov, src_var, src_var_ex, src_arr, src_msk, sum_arr, ttl_arr \ = process_minibatch_explicit_test( batch_id=batch_id, path_=self.args.data_dir, batch_size=self.args.test_batch_size, vocab2id=self.batch_data['vocab2id'], src_lens=self.args.src_seq_lens) src_msk = src_msk.to(self.args.device) src_var = src_var.to(self.args.device) src_var_ex = src_var_ex.to(self.args.device) else: src_var, src_arr, src_msk, sum_arr, ttl_arr \ = process_minibatch_test( batch_id=batch_id, path_=self.args.data_dir, batch_size=self.args.test_batch_size, vocab2id=self.batch_data['vocab2id'], src_lens=self.args.src_seq_lens) src_msk = src_msk.to(self.args.device) src_var = src_var.to(self.args.device) src_var_ex = src_var.clone() ext_id2oov = {} self.batch_data['ext_id2oov'] = ext_id2oov curr_batch_size = src_var.size(0) src_text_rep = src_var.unsqueeze(1).clone().repeat( 1, self.args.beam_size, 1).view(-1, src_var.size(1)).to(self.args.device) if self.args.oov_explicit: src_text_rep_ex = src_var_ex.unsqueeze(1).clone().repeat( 1, self.args.beam_size, 1).view(-1, src_var_ex.size(1)).to(self.args.device) else: src_text_rep_ex = src_text_rep.clone() models = {} for model_name in self.base_models: models[model_name] = self.base_models[model_name] for model_name in self.train_models: models[model_name] = self.train_models[model_name] beam_seq, beam_prb, beam_attn_ = fast_beam_search( self.args, models, self.batch_data, src_text_rep, src_text_rep_ex, curr_batch_size, self.args.task_key) # copy unknown words if self.args.task_key == 'title': trg_arr = ttl_arr if self.args.task_key == 'summary': trg_arr = sum_arr out_arr = word_copy( self.args, beam_seq, beam_attn_, src_msk, src_arr, curr_batch_size, self.batch_data['id2vocab'], self.batch_data['ext_id2oov']) for k in range(curr_batch_size): fout.write('<sec>'.join([out_arr[k], trg_arr[k]])+'\n') end_time = time.time() show_progress(batch_id, _nbatch, str((end_time-start_time)/3600)[:8]+"h") fout.close()
def app_worker(self): ''' For the beam search in application. ''' files_ = glob.glob(os.path.join(self.args.app_data_dir, '*_in.json')) for curr_file in files_: print("Read {}.".format(curr_file)) fTmp = re.split('\_', curr_file)[0] fp = open(curr_file, 'r') data_input = json.load(fp) fp.close() article = nlp(data_input['content'].lower()) article = ' '.join([wd.text for wd in article]) article = re.split('\s', article) article = list(filter(None, article)) data_input['content_token'] = article self.args.src_seq_lens = len(article) ext_id2oov, src_var, src_var_ex, src_arr, src_msk = \ process_data_app(data_input, self.batch_data['vocab2id'], self.args.src_seq_lens) self.batch_data['ext_id2oov'] = ext_id2oov src_msk = src_msk.to(self.args.device) curr_batch_size = src_var.size(0) src_text_rep = src_var.unsqueeze(1).clone().repeat( 1, self.args.beam_size, 1).view(-1, src_var.size(1)).to(self.args.device) if self.args.oov_explicit: src_text_rep_ex = src_var_ex.unsqueeze(1).clone().repeat( 1, self.args.beam_size, 1).view(-1, src_var_ex.size(1)).to(self.args.device) else: src_text_rep_ex = src_text_rep.clone() self.args.task_key = 'title' beam_seq, beam_prb, beam_attn_ = fast_beam_search( self.args, self.base_models, self.batch_data, src_text_rep, src_text_rep_ex, curr_batch_size, self.args.task_key) beam_out = beam_attn_[:, :, 0].squeeze( )[:, :self.args.src_seq_lens].data.cpu().numpy() beam_out = self.attnWeight2rgbPercent(beam_out) trg_arr = word_copy(self.args, beam_seq, beam_attn_, src_msk, src_arr, curr_batch_size, self.batch_data['id2vocab'], self.batch_data['ext_id2oov']) trg_arr = re.split('\s', trg_arr[0]) out_arr = [] for idx, wd in enumerate(trg_arr): if wd == '<stop>': break if wd != '<s>' and wd != '</s>': out_arr.append({ "key": wd, "attention": beam_out[idx].tolist() }) data_input[self.args.task_key] = out_arr self.args.task_key = 'summary' beam_seq, beam_prb, beam_attn_ = fast_beam_search( self.args, self.base_models, self.batch_data, src_text_rep, src_text_rep_ex, curr_batch_size, self.args.task_key) beam_out = beam_attn_[:, :, 0].squeeze( )[:, :self.args.src_seq_lens].data.cpu().numpy() beam_out = self.attnWeight2rgbPercent(beam_out) trg_arr = word_copy(self.args, beam_seq, beam_attn_, src_msk, src_arr, curr_batch_size, self.batch_data['id2vocab'], self.batch_data['ext_id2oov']) trg_arr = re.split('\s', trg_arr[0]) out_arr = [] for idx, wd in enumerate(trg_arr): if wd == '<stop>': break if wd != '<s>' and wd != '</s>': out_arr.append({ "key": wd, "attention": beam_out[idx].tolist() }) data_input[self.args.task_key] = out_arr print('Write {}.'.format(fTmp + '_out.json')) fout = open(fTmp + '_out.json', 'w') json.dump(data_input, fout) fout.close() os.unlink(curr_file)