def _get_iterator(self, path): """ Creates an iterator object from a text file. Args: path(str): path to text file to process Returns: data_iter(inputters.OrderedIterator): iterator object """ # Create dataset object data = inputters.build_dataset(fields=self.fields, data_type='text', src_path=path, tgt_path=None, src_dir='', use_filter_pred=False) data_iter = inputters.OrderedIterator( dataset=data, device=self.gpu, batch_size=self.similar_pairs.batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) return data_iter
def build_save_dataset(corpus_type, fields, opt): """ Building and saving the dataset """ assert corpus_type in ['train', 'valid'] if corpus_type == 'train': corpus = opt.train_dir else: corpus = opt.valid_dir dataset = inputters.build_dataset( fields, data_path=corpus, data_type=opt.data_type, total_token_length=opt.total_token_length, src_seq_length=opt.src_seq_length, src_sent_length=opt.src_sent_length, seq_length_trunc=opt.seq_length_trunc) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] pt_file = "{:s}.{:s}.pt".format(opt.save_data, corpus_type) logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file)) torch.save(dataset, pt_file) return pt_file
def build_save_dataset(corpus_type, fields, src_corpus, tgt_corpus, savepath, args): """ Building and saving the dataset """ assert corpus_type in ['train', 'dev', 'test'] dataset = inputters.build_dataset(fields, data_type='text', src_path=src_corpus, tgt_path=tgt_corpus, src_dir='', src_seq_length=args.max_src_len, tgt_seq_length=args.max_tgt_len, src_seq_length_trunc=0, tgt_seq_length_trunc=0, dynamic_dict=True) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] for i in range(len(dataset)): if i % 500 == 0: print(i) setattr(dataset.examples[i], 'graph', myutils.str2graph(dataset.examples[i].src)) pt_file = "{:s}/{:s}.pt".format(savepath, corpus_type) # torch.save(dataset, pt_file) with open(pt_file, 'wb') as f: pickle.dump(dataset, f) return [pt_file]
def generate_vectors(self, list_of_sentences, batch_size=1, cuda=False): """ list_of_sentences: [str] batch_size: int :return [np.array] numpy vectors of sentences in the same order """ unique_filename = str(uuid.uuid4()) # delete repeating tmp files tmp_files = os.listdir(pjoin(self.temp_dir, "l2e")) if len(tmp_files) > 10: for f_n in tmp_files: os.remove(pjoin(self.temp_dir, "l2e", f_n)) with open( pjoin(self.temp_dir, "l2e", '{}.txt'.format(unique_filename)), 'w') as f: for s in list_of_sentences: f.write(s.strip() + '\n') data = inputters.build_dataset( self.fields, src_path=pjoin(self.temp_dir, "l2e", '{}.txt'.format(unique_filename)), data_type='text', use_filter_pred=False) # src_seq_length=50, dynamic_dict=False) if cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) all_vecs = [] for batch in data_iter: # translation model just translates, here we generate vectors instead src, enc_states, memory_bank, src_lengths = _run_encoder( self.model, batch, 'text') # enc_states[0]: (layer_size, batch_size, hid) reshaped_hid_states = enc_states[0].reshape( batch_size, self.model_opt.enc_layers, self.model_opt.enc_rnn_size) # reshaped_hid_states: (batch_size, layer_size, hid) # we only append the 2nd layer all_vecs.append(reshaped_hid_states[:, -1, :].data.cpu().numpy()) all_vecs = np.vstack(all_vecs) return all_vecs
def build_save_dataset(corpus_type, fields, opt): assert corpus_type in ['train', 'valid', 'comp'] if corpus_type == 'train': src = opt.train_src tgt = opt.train_tgt elif corpus_type == 'valid': src = opt.valid_src tgt = opt.valid_tgt else: src = opt.comp_train_src tgt = opt.comp_train_tgt logger.info("Reading source and target files: %s %s." % (src, tgt)) src_len = _write_temp_shard_files(src, fields, corpus_type, opt.shard_size) tgt_len = _write_temp_shard_files(tgt, fields, corpus_type, opt.shard_size) assert src_len == tgt_len, "Source and target should be the same length" src_shards = sorted(glob.glob(src + '.*.txt')) tgt_shards = sorted(glob.glob(tgt + '.*.txt')) shard_pairs = zip(src_shards, tgt_shards) dataset_paths = [] for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Building shard %d." % i) dataset = inputters.build_dataset( fields, opt.data_type, src_path=src_shard, tgt_path=tgt_shard, src_dir=opt.src_dir, src_seq_len=opt.src_seq_length, tgt_seq_len=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc, dynamic_dict=opt.dynamic_dict, sample_rate=opt.sample_rate, window_size=opt.window_size, window_stride=opt.window_stride, window=opt.window, image_channel_size=opt.image_channel_size ) data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i) dataset_paths.append(data_path) logger.info(" * saving %sth %s data shard to %s." % (i, corpus_type, data_path)) dataset.save(data_path) os.remove(src_shard) os.remove(tgt_shard) del dataset.examples gc.collect() del dataset gc.collect() return dataset_paths
def build_save_dataset(corpus_type, fields, opt): """ Building and saving the dataset """ assert corpus_type in ['train', 'valid', 'monitor'] if corpus_type == 'train': src_corpus = [opt.train_src] tgt_corpus = [opt.train_tgt] elif corpus_type == 'valid': src_corpus = [opt.valid_src] tgt_corpus = [opt.valid_tgt] else: assert len(opt.monitor_src) == len(opt.monitor_tgt) src_corpus = opt.monitor_src tgt_corpus = opt.monitor_tgt pt_files = [] for i, (src, tgt) in enumerate(zip(src_corpus, tgt_corpus)): if "monitor" in corpus_type: fname = src.split("/" if "/" in src else "\\")[-1].split( ".")[0].replace("_src", "") corpus_type = "monitor_{}".format(fname) if (opt.shard_size > 0): pt_file = build_save_in_shards_using_shards_size( src, tgt, fields, corpus_type, opt) pt_files.extend(pt_file) else: # For data_type == 'img' or 'audio', currently we don't do # preprocess sharding. We only build a monolithic dataset. # But since the interfaces are uniform, it would be not hard # to do this should users need this feature. dataset = inputters.build_dataset( fields, opt.data_type, src_path=src, tgt_path=tgt, src_dir=opt.src_dir, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc, dynamic_dict=opt.dynamic_dict, sample_rate=opt.sample_rate, window_size=opt.window_size, window_stride=opt.window_stride, window=opt.window, image_channel_size=opt.image_channel_size) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] pt_file = "{:s}.{:s}.pt".format(opt.save_data, corpus_type) logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file)) torch.save(dataset, pt_file) pt_files.append(pt_file) return pt_files
def build_save_dataset(corpus_type, fields, opt): assert corpus_type in ['train', 'valid'] if corpus_type == 'train': src = opt.train_src tgt = opt.train_tgt ans = opt.train_ans else: src = opt.valid_src tgt = opt.valid_tgt ans = opt.valid_ans logger.info("Reading source answer and target files: %s %s %s." % (src, ans, tgt)) src_shards = split_corpus(src, opt.shard_size) tgt_shards = split_corpus(tgt, opt.shard_size) ans_shards = split_corpus(ans, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards, ans_shards) dataset_paths = [] for i, (src_shard, tgt_shard, ans_shard) in enumerate(shard_pairs): assert len(src_shard) == len(tgt_shard) == len(ans_shard) logger.info("Building shard %d." % i) dataset = inputters.build_dataset( fields, opt.data_type, src=src_shard, tgt=tgt_shard, ans=ans_shard, src_dir=opt.src_dir, src_seq_len=opt.src_seq_length, tgt_seq_len=opt.tgt_seq_length, ans_seq_len=opt.ans_seq_length, sample_rate=opt.sample_rate, window_size=opt.window_size, window_stride=opt.window_stride, window=opt.window, use_filter_pred=corpus_type == 'train' or opt.filter_valid) data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i) dataset_paths.append(data_path) logger.info(" * saving %sth %s data shard to %s." % (i, corpus_type, data_path)) dataset.save(data_path) del dataset.examples gc.collect() del dataset gc.collect() return dataset_paths
def build_save_dataset(corpus_type, fields, opt): assert corpus_type in ['train', 'valid'] if corpus_type == 'train': src = opt.train_src tgt = opt.train_tgt else: src = opt.valid_src tgt = opt.valid_tgt logger.info("Reading source and target files: %s %s." % (src, tgt)) src_shards = split_corpus(src, opt.shard_size) tgt_shards = split_corpus(tgt, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards) dataset_paths = [] total_valid_ex_num = 0 for i, (src_shard, tgt_shard) in enumerate(shard_pairs): assert len(src_shard) == len(tgt_shard) logger.info("Building shard %d." % i) dataset = inputters.build_dataset( fields, opt.data_type, src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, src_seq_len=opt.src_seq_length, tgt_seq_len=opt.tgt_seq_length, sample_rate=opt.sample_rate, window_size=opt.window_size, window_stride=opt.window_stride, window=opt.window, image_channel_size=opt.image_channel_size, use_filter_pred=corpus_type == 'train' or opt.filter_valid, src_seq_min_length=opt.src_seq_min_length, tgt_seq_min_length=opt.tgt_seq_min_length) data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i) dataset_paths.append(data_path) logger.info(" * saving %sth %s data shard to %s. Example number: %d" % (i, corpus_type, data_path, len(dataset.examples))) total_valid_ex_num += len(dataset.examples) dataset.save(data_path) del dataset.examples gc.collect() del dataset gc.collect() logger.info(" * Total Example number: %d" % (total_valid_ex_num)) return dataset_paths
def build_save_dataset(corpus_type, fields, opt): """ Building and saving the dataset """ assert corpus_type in ['train', 'valid'] if corpus_type == 'train': src_corpus = opt.train_src tgt_corpus = opt.train_tgt src_ref_corpus = opt.train_ref_src tgt_ref_corpus = opt.train_ref_tgt else: src_corpus = opt.valid_src tgt_corpus = opt.valid_tgt src_ref_corpus = opt.valid_ref_src tgt_ref_corpus = opt.valid_ref_tgt if (opt.shard_size > 0): return build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, src_ref_corpus, tgt_ref_corpus, fields, corpus_type, opt) # For data_type == 'img' or 'audio', currently we don't do # preprocess sharding. We only build a monolithic dataset. # But since the interfaces are uniform, it would be not hard # to do this should users need this feature. dataset = inputters.build_dataset( fields, opt.data_type, src_path=src_corpus, tgt_path=tgt_corpus, src_ref_path=src_ref_corpus, tgt_ref_path=tgt_ref_corpus, src_dir=opt.src_dir, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc, dynamic_dict=opt.dynamic_dict, sample_rate=opt.sample_rate, window_size=opt.window_size, window_stride=opt.window_stride, window=opt.window, image_channel_size=opt.image_channel_size) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] pt_file = "{:s}.{:s}.pt".format(opt.save_data, corpus_type) logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file)) torch.save(dataset, pt_file) return [pt_file]
def get_encodings(self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, attn_debug=False): assert src_data_iter is not None or src_path is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset( self.fields, self.data_type, src_path=src_path, #PATH TO INPUT FILE src_data_iter=src_data_iter, #NONE tgt_path=tgt_path, #NONE tgt_data_iter=tgt_data_iter, #NONE src_dir=src_dir, # empty string "" sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) all_encodings = [] for batch in data_iter: batch_data = self.Encode(batch, data) all_encodings.append(batch_data) print(batch_data) return all_encodings
def build_save_vectors(src_corpus, tgt_corpus, fields, corpus_type, opt): """ Divide src_corpus and tgt_corpus into smaller multiples src_copus and tgt corpus files, then build shards, each shard will have opt.shard_size samples except last shard. The reason we do this is to avoid taking up too much memory due to sucking in a huge corpus file. """ ret_list = [] dataset = inputters.build_dataset( fields, opt.data_type, src_path=src_corpus, tgt_path=tgt_corpus, src_dir=opt.src_dir, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc, dynamic_dict=opt.dynamic_dict, sample_rate=opt.sample_rate, window_size=opt.window_size, window_stride=opt.window_stride, window=opt.window, image_channel_size=opt.image_channel_size) pt_file = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, 0) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] logger.info(" * saving %sth %s data shard to %s." % (0, corpus_type, pt_file)) torch.save(dataset, pt_file) ret_list.append(pt_file) del dataset.examples gc.collect() del dataset gc.collect() return ret_list
def run_one(param): index, src, opt, fields, tgt_list, condition_corpus, corpus_type = param dataset = inputters.build_dataset( fields, opt.data_type, src_path=src, tgt_path=tgt_list[index], src_dir=opt.src_dir, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc, dynamic_dict=opt.dynamic_dict, sample_rate=opt.sample_rate, window_size=opt.window_size, window_stride=opt.window_stride, window=opt.window, image_channel_size=opt.image_channel_size) pt_file = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, index) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] if condition_corpus: # 加载条件 with open(condition_corpus) as f: target_condition = [int(s.rstrip()) for s in f.readlines()] tmp_example = [] _ = [parrel_func(e, opt.with_3d_confomer) for e in dataset.examples] for cond, result in zip(target_condition, dataset.examples): if getattr(result, 'graph') is not None: if condition_corpus: setattr(result, 'condition_target', cond) tmp_example.append(result) dataset.examples = tmp_example with open(pt_file, 'wb') as f: pickle.dump(dataset, f) os.remove(src) os.remove(tgt_list[index]) return pt_file
def encode_seq(self, src, tgt=None, src_dir=None, batch_size=None): assert src is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset( self.fields, self.data_type, src=src, tgt=tgt, src_dir=src_dir) cur_device = "cuda" if self.cuda else "cpu" data_iter = inputters.OrderedIterator( dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False ) all_sent_vecs = [] with torch.no_grad(): for i, batch in enumerate(data_iter): batch_size = batch.batch_size # Encoder forward. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch, data.data_type) # memory_bank (seq_lengths, batch_size, hidden_size) sent_vec_batch = memory_bank.mean(dim=0).cpu().numpy() np.savetxt(self.outfile, sent_vec_batch, fmt='%.10e') if (i + 1) % 10 == 0: print(".", end="", flush=True) if (i + 1) % 100 == 0: print((i + 1)*batch_size, end="", flush=True)
def translate(self, src_path=None, src_data_iter=None, src_length=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, attn_debug=False, search_mode=0, threshold=0, ref_path=None): assert src_data_iter is not None or src_path is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset( self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, src_seq_length_trunc=src_length, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred, ref_path=['%s.%d' % (ref_path, r) for r in range(self.refer)] if self.refer else None, ref_seq_length_trunc=self.max_sent_length, ignore_unk=False) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" if self.refer: for i in range(self.refer): data.fields['ref%d' % i].vocab = data.fields['src'].vocab data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) if search_mode == 2: all_predictions = self.search(data_iter, data, src_path, train=False, threshold=threshold) for i in all_predictions: self.out_file.write(i) self.out_file.flush() return builder = onmt.translate.TranslationBuilder(data, self.fields, self.n_best, self.replace_unk, tgt_path) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] for batch in data_iter: batch_data = self.translate_batch(batch, data, fast=True, attn_debug=False) translations = builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt_path is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] all_predictions += [n_best_preds] # self.out_file.write('\n'.join(n_best_preds) + '\n') # self.out_file.flush() if search_mode == 1: sim_predictions = self.search(data_iter, data, src_path, threshold) for i in range(len(sim_predictions)): if not sim_predictions[i]: self.out_file.write('\n'.join(all_predictions[i]) + '\n') self.out_file.flush() else: self.out_file.write(sim_predictions[i]) self.out_file.flush() else: for i in all_predictions: self.out_file.write('\n'.join(i) + '\n') self.out_file.flush() return all_scores, all_predictions
def index_documents( self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, ): data = inputters.build_dataset( self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, src_seq_length_trunc=self.max_sent_length, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred, ignore_unk=True) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) doc_feats = [] shard = 1 for batch in data_iter: # Encoder forward. src = inputters.make_features(batch, 'src', data.data_type) _, src_lengths = batch.src enc_states, memory_bank, _ = self.model.encoder(src, src_lengths) feature = torch.max(memory_bank, 0)[0] _, recover_indices = torch.sort(batch.indices, descending=False) feature = feature[recover_indices] doc_feats.append(feature) if len(doc_feats) % 1250 == 0: print('saving shard %d' % shard) doc_feats = torch.cat(doc_feats) torch.save( doc_feats, '{}/indexes/codev{}.pt'.format( '/'.join(src_path.split('/')[:2]), shard)) doc_feats = [] shard += 1 if doc_feats: doc_feats = torch.cat(doc_feats) torch.save( doc_feats, '{}/indexes/codev{}.pt'.format( '/'.join(src_path.split('/')[:2]), shard)) print('done.')
def translate( self, src, tgt=None, src_dir=None, batch_size=None, attn_debug=False, data_iter=None ): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data tgt_path (str): filepath of target data or None src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset( self.fields, self.data_type, src=src, src_reader=self.src_reader, tgt=tgt, tgt_reader=self.tgt_reader, src_dir=src_dir, use_filter_pred=self.use_filter_pred, bert=self.opt.bert, morph=self.opt.korean_morphs ) cur_device = "cuda" if self.cuda else "cpu" # data_iter = inputters.OrderedIterator( # dataset=data, # device=cur_device, # batch_size=batch_size, # train=False, # sort=False, # sort_within_batch=True, # shuffle=False # ) builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt ) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] start_time = time.time() for batch_ in data_iter: batch = bbbb(batch_) # batch_data = self.translate_batch( # batch, data.src_vocabs, attn_debug, fast=self.fast # ) batch_data = self.translate_batch( batch, batch.dataset.src_vocabs, attn_debug, fast=self.fast ) # batch_data = self.translate_batch( # batch, data.src_vocabs, attn_debug, fast=self.fast # ) translations = builder.from_batch(batch_data) return translations for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [" ".join(pred) for pred in trans.pred_sents[:self.n_best]] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) print(list(trans.attns[0].max(1)[1].cpu().detach().numpy())) if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *srcs) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) end_time = time.time() if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) self._log(msg) if tgt is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) self._log(msg) if self.report_bleu: msg = self._report_bleu(tgt) self._log(msg) if self.report_rouge: msg = self._report_rouge(tgt) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % ( total_time / len(all_predictions))) self._log("Tokens per second: %f" % ( pred_words_total / total_time)) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def translate(self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, attn_debug=False, intervention=None, out_file=None): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data src_data_iter (iterator): an interator generating source data e.g. it may be a list or an openned file tgt_path (str): filepath of target data tgt_data_iter (iterator): an interator generating target data src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src_data_iter is not None or src_path is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset(self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator( dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt_path) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] all_dumped_layers = [] for batch in data_iter: if self.dump_layers != '': batch_data, dumped_layers = self.translate_batch(batch, data, intervention=intervention) # Get the correct order of sentences so that we can dump in # the same order as input occurred. inds, perm = torch.sort(batch_data['batch'].indices.data) # At this point dumped_layers is going to be an array of # (num_layers) packed sequences, each of which has (len) x (batch) # shape. We would like to transpose this, so that # we have an array of "sentences", each of which is # an array of "tokens", each of which is an array of "layers", # each of which is an array of "neurons". dumped_layers = [unpack(layer) for layer in dumped_layers] # Tuples of (tensor, lengths) dumped_layers = [ [ [ # Array of layers dumped_layers[i][0][t][idx] for i in range(len(dumped_layers)) ] # Array of tokens; dumped_layers[0][1] is the list of # sentence lengths for the batch, so we can look up # number of tokens here for t in range(dumped_layers[0][1][idx]) ] # Array of sentences for idx in perm ] # Accumulate all the dumped layers into one big list of sentences. all_dumped_layers.extend(dumped_layers) else: batch_data = self.translate_batch(batch, data, intervention=intervention) translations = builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt_path is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [" ".join(pred) for pred in trans.pred_sents[:self.n_best]] all_predictions += [n_best_preds] if out_file is None: self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() else: out_file.write('\n'.join(n_best_preds) + '\n') out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) # Debug attention. if attn_debug: srcs = trans.src_raw preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *trans.src_raw) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) if self.logger: self.logger.info(msg) else: print(msg) if tgt_path is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) if self.logger: self.logger.info(msg) else: print(msg) if self.report_bleu: msg = self._report_bleu(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.report_rouge: msg = self._report_rouge(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) if self.dump_layers and self.dump_layers != -1: torch.save(all_dumped_layers, self.dump_layers) elif self.dump_layers == -1: return all_dumped_layers, all_scores, all_predictions return all_scores, all_predictions
def build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields, corpus_type, opt): """ Divide src_corpus and tgt_corpus into smaller multiples src_copus and tgt corpus files, then build shards, each shard will have opt.shard_size samples except last shard. The reason we do this is to avoid taking up too much memory due to sucking in a huge corpus file. """ src_data = open(src_corpus, "r", encoding="utf-8").readlines() tgt_data = open(tgt_corpus, "r", encoding="utf-8").readlines() src_corpus = "".join(src_corpus.split(".")[:-1]) tgt_corpus = "".join(tgt_corpus.split(".")[:-1]) for x in range(int(len(src_data) / opt.shard_size)): open(src_corpus + ".{0}.txt".format(x), "w", encoding="utf-8").writelines(src_data[x * opt.shard_size:(x + 1) * opt.shard_size]) open(tgt_corpus + ".{0}.txt".format(x), "w", encoding="utf-8").writelines(tgt_data[x * opt.shard_size:(x + 1) * opt.shard_size]) src_list = sorted(glob.glob(src_corpus + '.*.txt')) tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt')) ret_list = [] for index, src in enumerate(src_list): dataset = inputters.build_dataset( fields, opt.data_type, src_path=src, tgt_path=tgt_list[index], src_dir=opt.src_dir, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc, dynamic_dict=opt.dynamic_dict, sample_rate=opt.sample_rate, window_size=opt.window_size, window_stride=opt.window_stride, window=opt.window, image_channel_size=opt.image_channel_size) pt_file = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, index) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] logger.info(" * saving %sth %s data image shard to %s." % (index, corpus_type, pt_file)) torch.save(dataset, pt_file) ret_list.append(pt_file) del dataset.examples gc.collect() del dataset gc.collect() return ret_list
def translate(self, src, tgt=None, src_dir=None, batch_size=None, attn_debug=False): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data tgt_path (str): filepath of target data or None src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset( self.fields, self.data_type, src=src, tgt=tgt, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred, image_channel_size=self.image_channel_size, ) cur_device = "cuda" if self.cuda else "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder(data, self.fields, self.n_best, self.replace_unk, tgt) # Statistics counter = count(1) all_scores = [] all_predictions = [] results = [] # TODO(daphne): Figure out why putting import at top of the file fails. import json # Iterating over batches. for num, batch in enumerate(data_iter): ## Reinitialize previous hypotheses self.prev_hyps = [] inputs = ["" for i in range(batch.batch_size)] preds = [[] for i in range(batch.batch_size)] scores = [[] for i in range(batch.batch_size)] # If doing iterative beam search, may run beam search multiple times. for i in range(self.beam_iters): batch_data = self.translate_batch( batch, data, attn_debug, builder, fast=self.fast, ) translations = builder.from_batch(batch_data) # Iterate over examples in the batch. for j, trans in enumerate(translations): pred_scores = list( float(s) for s in trans.pred_scores[:self.n_best]) pred_sents = trans.pred_sents[:self.n_best] all_scores += [pred_scores] if 0 in [len(l) for l in pred_sents]: print( 'Warning: (batch=%d, translation=%d) generated an empty sequence' % (num, j)) if tgt is not None: #TODO(dei): Add back support for this. raise ValueError('tgt not currently supported.') n_best_preds = [" ".join(pred) for pred in pred_sents] all_predictions += [n_best_preds] ## Saves predictions and scores into dictionary ## to be added to final results later inputs[j] = trans.src_raw if self.beam_iters == 1: preds[j] = pred_sents scores[j] = pred_scores else: ## Checks if top candidate is empty (TODO: why is this happening?) k = 0 while not trans.pred_sents[k]: k += 1 preds[j] += [trans.pred_sents[k]] scores[j] += [float(trans.pred_scores[k])] if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: preds[j] = trans.pred_sents[0] preds[j].append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *srcs) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len( srcs) os.write(1, output.encode('utf-8')) assert len(inputs) == len(preds) == len(scores) for j in range(len(inputs)): results.append({ 'input': inputs[j], 'pred': preds[j], 'scores': scores[j] }) # Compute overall per-token perplexity. pred_score_total = 0 pred_token_total = 0 for result in results: pred_score_total += sum(result['scores']) pred_token_total += sum(len(s) for s in result['pred']) try: score = pred_score_total / pred_token_total ppl = math.exp(-pred_score_total / pred_token_total) except Exception as e: print(e) print( 'WARNING: SCORE AND PPL WERE COMPUTED DUE TO NUMERICAL ERRORS') score = np.nan ppl = np.nan # Save the results to json. json_dump = {'results': results, 'score': score, 'ppl': ppl} json.dump(json_dump, self.out_file) self.out_file.flush() if self.report_score: msg = self._report_score('PRED', json_dump['score'], 1) if self.logger: self.logger.info(msg) else: print(msg) if tgt is not None: msg = self._report_score('GOLD', json_dump['score'], 1) if self.logger: self.logger.info(msg) else: print(msg) if self.report_bleu: msg = self._report_bleu(tgt) if self.logger: self.logger.info(msg) else: print(msg) if self.report_rouge: msg = self._report_rouge(tgt) if self.logger: self.logger.info(msg) else: print(msg) if self.dump_beam: raise ValueError('This code path seems broken.') import json json.dump(self.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def decode_sentences(self, sents, cuda=False): """ Takes in a list of sentences and returns a list of sentences decode_sentences(['this is fun !', "this is not fun"]) [('this is fun !', 'I 'm not a this .', -12.412576675415039), ('this is not fun', 'I 'm not sure .', -10.160457611083984)] :param sents: [str] :return: [(src, tgt, log-likelihood-score)] """ unique_filename = str(uuid.uuid4()) # delete repeating tmp files tmp_files = os.listdir(pjoin(self.temp_dir, "l2e")) if len(tmp_files) > 10: for f_n in tmp_files: os.remove(pjoin(self.temp_dir, "l2e", f_n)) with open( pjoin(self.temp_dir, "l2e", '{}.txt'.format(unique_filename)), 'w') as f: for s in sents: f.write(s.strip() + '\n') data = inputters.build_dataset( self.fields, src_path=pjoin(self.temp_dir, "l2e", '{}.txt'.format(unique_filename)), data_type='text', use_filter_pred=False, dynamic_dict=False) # src_seq_length=50, dynamic_dict=False) if cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=1, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder(data, self.fields, n_best=1, replace_unk=True, has_tgt=False) # this is not really beam-search... decoded_sents = [] # (src, tgt, score) # we don't keep statistics / scores or anything for batch in data_iter: batch_data = self.translator.translate_batch(batch, data, fast=False) translations = builder.from_batch(batch_data) # going through each sentence in a batch for trans in translations: n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.translator.n_best] ] for i in range(len(n_best_preds)): decoded_sents.append( (' '.join(trans.src_raw), n_best_preds[i], trans.pred_scores[i].item())) return decoded_sents
def build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields, corpus_type, opt): """ Divide src_corpus and tgt_corpus into smaller multiples src_copus and tgt corpus files, then build shards, each shard will have opt.shard_size samples except last shard. The reason we do this is to avoid taking up too much memory due to sucking in a huge corpus file. """ with codecs.open(src_corpus, "r", encoding="utf-8") as fsrc: with codecs.open(tgt_corpus, "r", encoding="utf-8") as ftgt: logger.info("Reading source and target files: %s %s." % (src_corpus, tgt_corpus)) src_data = fsrc.readlines() tgt_data = ftgt.readlines() num_shards = int(len(src_data) / opt.shard_size) for x in range(num_shards): logger.info("Splitting shard %d." % x) f = codecs.open(src_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(src_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(x), "w", encoding="utf-8") f.writelines(tgt_data[x * opt.shard_size:(x + 1) * opt.shard_size]) f.close() num_written = num_shards * opt.shard_size if len(src_data) > num_written: logger.info("Splitting shard %d." % num_shards) f = codecs.open(src_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(src_data[num_shards * opt.shard_size:]) f.close() f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards), 'w', encoding="utf-8") f.writelines(tgt_data[num_shards * opt.shard_size:]) f.close() src_list = sorted(glob.glob(src_corpus + '.*.txt')) tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt')) ret_list = [] for index, src in enumerate(src_list): logger.info("Building shard %d." % index) dataset = inputters.build_dataset( fields, opt.data_type, src_path=src, tgt_path=tgt_list[index], src_dir=opt.src_dir, src_seq_length=opt.src_seq_length, tgt_seq_length=opt.tgt_seq_length, src_seq_length_trunc=opt.src_seq_length_trunc, tgt_seq_length_trunc=opt.tgt_seq_length_trunc, dynamic_dict=opt.dynamic_dict, sample_rate=opt.sample_rate, window_size=opt.window_size, window_stride=opt.window_stride, window=opt.window, image_channel_size=opt.image_channel_size, use_filter_pred=False) pt_file = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, index) # We save fields in vocab.pt seperately, so make it empty. dataset.fields = [] print("!!!!dataset examples " + str(len(dataset.examples))) logger.info(" * saving %sth %s data shard to %s." % (index, corpus_type, pt_file)) torch.save(dataset, pt_file) ret_list.append(pt_file) os.remove(src) os.remove(tgt_list[index]) del dataset.examples gc.collect() del dataset gc.collect() return ret_list
def translate(self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, ans_path=None, ans_data_iter=None, ): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data src_data_iter (iterator): an interator generating source data e.g. it may be a list or an openned file tgt_path (str): filepath of target data tgt_data_iter (iterator): an interator generating target data src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src_data_iter is not None or src_path is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset(self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred, ans_data_iter=ans_data_iter, ans_path=ans_path) print(data) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator( dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) for batch in data_iter: stats = self.translate_batch(batch, data) logger.info(stats)
def translate(self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, attn_debug=False): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data src_data_iter (iterator): an interator generating source data e.g. it may be a list or an openned file tgt_path (str): filepath of target data tgt_data_iter (iterator): an interator generating target data src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ # assert src_data_iter is not None or src_path is not None if batch_size is None: raise ValueError("batch_size must be set") resp_vocab = self.fields["tgt"].vocab while True: post = input("Type in a post:") if post == "exit": break keyword = input("Type in a keyword:") keyword_index = resp_vocab.stoi[keyword] seg_lst = jieba.cut(post) post = ' '.join(seg_lst) src_path = [post] data = inputters.build_dataset(self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator( dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt_path) # Statistics # counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] for batch in data_iter: # backward反向生成front_seq batch_data = self.translate_batch(self.bk_model, batch, data, keyword=keyword_index, fast=self.fast) translations = builder.from_batch(batch_data) for trans in translations: post = trans.src_raw resps = trans.pred_sents # print(resps) scores = [float(np.exp(s)) for s in trans.pred_scores[:self.n_best]] best_index = np.argmax(scores) # resp_front = resps[best_index][::-1][:-1] # resp_front.append(keyword) best_forward_score = 0 for resp in resps: resp_front = resp[::-1][:-1] resp_front.append(keyword) resp_front_indexs = [resp_vocab.stoi[w] for w in resp_front] # 将最后一个词替换回已知的keyword # 依据生成的front_seq生成back_seq batch_data = self.translate_batch(self.model, batch, data, front_seq=resp_front_indexs, fast=self.fast) translations = builder.from_batch(batch_data) for trans in translations: resps = trans.pred_sents # print(resps) scores = [float(np.exp(s)) for s in trans.pred_scores[:self.n_best]] if max(scores) > best_forward_score: best_forward_score = max(scores) best_resp_front = resp_front best_index = np.argmax(scores) best_resp_back = resps[best_index][len(resp_front):] resp = ''.join(best_resp_front + best_resp_back) print('response: ', resp) all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt_path is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [" ".join(pred) for pred in trans.pred_sents[:self.n_best]] all_predictions += [n_best_preds] ''' self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) ''' # Debug attention. if attn_debug: srcs = trans.src_raw preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *trans.src_raw) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) if self.logger: self.logger.info(msg) else: print(msg) if tgt_path is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) if self.logger: self.logger.info(msg) else: print(msg) if self.report_bleu: msg = self._report_bleu(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.report_rouge: msg = self._report_rouge(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def translate(self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, attn_debug=False): """ Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None """ if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset(self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred) data_iter = inputters.OrderedIterator( dataset=data, device=self.gpu, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt_path) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] for batch in data_iter: batch_data = self.translate_batch(batch, data) translations = builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[0]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt_path is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [" ".join(pred) for pred in trans.pred_sents[:self.n_best]] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) # Debug attention. if attn_debug: srcs = trans.src_raw preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *trans.src_raw) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) if self.logger: self.logger.info(msg) else: print(msg) if tgt_path is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) if self.logger: self.logger.info(msg) else: print(msg) if self.report_bleu: msg = self._report_bleu(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.report_rouge: msg = self._report_rouge(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores
def scoring(self, src_data_path=None, src_data_iter=None, tgt_data_path=None, tgt_data_iter=None, batch_size=32): if src_data_iter is not None: batch_size = len(src_data_iter) assert batch_size != 0 data = inputters.build_dataset(self.fields, 'text', src_path=src_data_path, src_data_iter=src_data_iter, tgt_path=tgt_data_path, tgt_data_iter=tgt_data_iter, use_filter_pred=False, dynamic_dict=False) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) scored_triplets = [] for batch in data_iter: src = inputters.make_features( batch, 'src', 'text') # [src_len, batch_size, num_features] _, src_lengths = batch.src tgt = inputters.make_features( batch, 'tgt', 'text') # [tgt_len, batch_size, num_features] _, tgt_lengths = batch.tgt logits, probs = self.model(src, tgt, src_lengths, tgt_lengths) # Sorting inds, perm = torch.sort(batch.indices.data) # orig_src = batch.src[0].data.index_select(1, perm) # orig_tgt = batch.tgt[0].data.index_select(1, perm) orig_probs = probs.index_select(0, perm) for b in range(batch.batch_size): src_raw = data.examples[inds[b]].src tgt_raw = data.examples[inds[b]].tgt final_score = orig_probs[b].data.item() scored_triplets.append({ 'src': src_raw, 'tgt': tgt_raw, 'score': final_score }) # if final_score > 0.5: # print('=' * 30) # print('src: {}'.format(' '.join(src_raw))) # print('tgt: {}; score: {}'.format(' '.join(tgt_raw), final_score)) # print('=' * 30) return scored_triplets
def translate(self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, attn_debug=False): """ Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None """ if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset(self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred) data_iter = inputters.OrderedIterator(dataset=data, device=self.gpu, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder(data, self.fields, self.n_best, self.replace_unk, tgt_path) # ADDED -------------------------------------------------------------- # Load the translation pieces list #home_path = "/home/pmlf/Documents/github/OpenNMT-py-fork/" home_path = "/home/ubuntu/OpenNMT-py-fork/" tp_path = home_path + "extra_data/translation_pieces_md_10-th0pt5.pickle" translation_pieces = pickle.load(open(tp_path, 'rb')) tot_time = 0 # END ---------------------------------------------------------------- # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] for ix, batch in enumerate(data_iter): # ADDED -------------------------------------------------------------- start_time = time.time() # END ---------------------------------------------------------------- batch_data = self.translate_batch(batch, data, translation_pieces) translations = builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[0]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt_path is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) # Debug attention. if attn_debug: srcs = trans.src_raw preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *trans.src_raw) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) # ADDED -------------------------------------------------------------- duration = time.time() - start_time tot_time += duration tot_time_print = str( time.strftime("%H:%M:%S", time.gmtime(tot_time))) print("Batch {} - Duration: {:.2f} - Total: {}".format( ix, duration, tot_time_print)) # END ---------------------------------------------------------------- if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) if self.logger: self.logger.info(msg) else: print(msg) if tgt_path is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) if self.logger: self.logger.info(msg) else: print(msg) if self.report_bleu: msg = self._report_bleu(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.report_rouge: msg = self._report_rouge(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores
def translate(self, knl, src, tgt=None, src_dir=None, batch_size=None, attn_debug=False): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data tgt_path (str): filepath of target data or None src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src is not None assert knl is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset( self.fields, self.data_type, knl=knl, src=src, tgt=tgt, knl_seq_length_trunc=200, src_seq_length_trunc=50, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred, image_channel_size=self.image_channel_size, dynamic_dict=self.copy_attn) cur_device = "cuda" if self.cuda else "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder(data, self.fields, self.n_best, self.replace_unk, tgt) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] for batch in data_iter: batch_data = self.translate_batch(batch, data, attn_debug, fast=self.fast) translations = builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *srcs) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) if self.logger: self.logger.info(msg) else: print(msg) if tgt is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) if self.logger: self.logger.info(msg) else: print(msg) if self.report_bleu: msg = self._report_bleu(tgt) if self.logger: self.logger.info(msg) else: print(msg) if self.report_rouge: msg = self._report_rouge(tgt) if self.logger: self.logger.info(msg) else: print(msg) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def translate(opt): out_file = codecs.open(opt.output, 'w+', 'utf-8') if opt.gpu > -1: torch.cuda.set_device(opt.gpu) dummy_parser = argparse.ArgumentParser(description='train.py') opts.model_opts(dummy_parser) dummy_opt = dummy_parser.parse_known_args([])[0] fields, model, model_opt = \ onmt.model_builder.load_test_model(opt, dummy_opt.__dict__) data = inputters.build_dataset(fields, 'text', src_path=opt.src, src_data_iter=None, tgt_path=opt.tgt, tgt_data_iter=None, src_dir=opt.src_dir, sample_rate='16000', window_size=.02, window_stride=.01, window='hamming', use_filter_pred=False) device = torch.device('cuda' if opt.gpu > -1 else 'cpu') batch_size = 1 data_iter = inputters.OrderedIterator(dataset=data, device=device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) pair_size = model_opt.wpe_pair_size s_id = fields["tgt"].vocab.stoi['<s>'] if '<sgo>' in fields["tgt"].vocab.stoi: ss_id = fields["tgt"].vocab.stoi['<sgo>'] else: ss_id = fields['tgt'].vocab.stoi['<unk>'] if '<seos>' in fields['tgt'].vocab.stoi: eos_id = fields['tgt'].vocab.stoi['<seos>'] else: eos_id = fields['tgt'].vocab.stoi['</s>'] for i, batch in enumerate(data_iter): tgt = torch.LongTensor([s_id] * batch_size + [ss_id] * ((pair_size - 1) * batch_size)).view( pair_size, batch_size).unsqueeze(2).to(device) dec_state = None src = inputters.make_features(batch, 'src', 'text') _, src_lengths = batch.src result = None for _ in range(opt.max_length): outputs, _, dec_state = model(src, tgt, src_lengths, dec_state) scores = model.generator(outputs.view(-1, outputs.size(2))) indices = scores.argmax(dim=1) tgt = indices.view(pair_size, batch_size, 1) # (pair_size x batch x feat) assert batch_size == 1 if tgt[0][0][0].item() == eos_id: break if result is None: result = indices.view(pair_size, batch_size) else: result = torch.cat( [result, indices.view(pair_size, batch_size)], 0) result = result.transpose(0, 1).tolist() for sent in result: sent = [fields["tgt"].vocab.itos[_] for _ in sent] sent = [_ for _ in sent if _ not in ['<blank>', '<seos>', '</s>']] sent = ' '.join(sent) out_file.write(sent + '\n') print('Translated {} batches'.format(i)) out_file.close()
def translate(self, src_path=None, src_data_iter=None, rk_path=None, rk_data_iter=None, key_indicator_path=None, key_indicator_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, attn_debug=False): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data src_data_iter (iterator): an interator generating source data e.g. it may be a list or an openned file rk_path (str): filepath of retrieved keyphrases rk_data_iter (iterator): an interator generating retrieved keyphrases e.g. it may be a list or an openned file key_indicator_path (str): filepath of src keyword indicators key_indicator_iter (iterator): an interator generating src keyword indicators e.g. it may be a list or an openned file tgt_path (str): filepath of target data tgt_data_iter (iterator): an interator generating target data src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src_data_iter is not None or src_path is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset(self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, rk_path=rk_path, rk_data_iter=rk_data_iter, key_indicator_path=key_indicator_path, key_indicator_iter=key_indicator_iter, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator( dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt_path) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] for batch in data_iter: batch_data = self.translate_batch(batch, data, fast=self.fast) translations = builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt_path is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [" ".join(pred) for pred in trans.pred_sents[:self.n_best]] n_best_preds_scores = [round(sc.exp().item(), 5) for sc in trans.pred_scores[:self.n_best]] all_predictions += [n_best_preds] self.out_file.write(' ; '.join(n_best_preds) + '\n') self.out_file.flush() self.scores_out_file.write(' ; '.join([str(sc) for sc in n_best_preds_scores]) + '\n') if trans.selector_probs is not None: selector_probs = trans.selector_probs.tolist() selector_probs = [round(sp, 5) for sp in selector_probs if sp != 0.0] self.sel_probs_out_file.write(' ; '.join([str(sp) for sp in selector_probs]) + '\n') if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) # Debug attention. if attn_debug: srcs = trans.src_raw preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *trans.src_raw) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) if self.logger: self.logger.info(msg) else: print(msg) if tgt_path is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) if self.logger: self.logger.info(msg) else: print(msg) if self.report_bleu: msg = self._report_bleu(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.report_rouge: msg = self._report_rouge(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) if self.opt is not None: evaluate_func(opts=self.opt, do_stem=True) return all_scores, all_predictions
def translate(self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, attn_debug=False, node_type_seq=None, atc=None): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data src_data_iter (iterator): an interator generating source data e.g. it may be a list or an openned file tgt_path (str): filepath of target data tgt_data_iter (iterator): an interator generating target data src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src_data_iter is not None or src_path is not None assert node_type_seq is not None, 'Node Types must be provided' node_type_scores = node_type_seq[1] node_type_seq = node_type_seq[0] if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset(self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder(data, self.fields, self.n_best, self.replace_unk, tgt_path) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] #debug(self.option.tree_count) def check_correctness(preds, gold): for p in preds: if p.strip() == gold.strip(): return 1 return 0 total_correct = 0 for bidx, batch in enumerate(data_iter): # if bidx == 100: # break example_idx = batch.indices.item( ) # Only 1 item in this batch, guaranteed # if bidx % 20 == 0: if bidx % 20 == 0: debug('Current Example : ', example_idx) nt_sequences = node_type_seq[example_idx] nt_scores = node_type_scores[example_idx] if atc is not None: atc_item = atc[example_idx] else: atc_item = None scores = [] predictions = [] tree_count = self.option.tree_count for type_sequence, type_score in zip(nt_sequences[:tree_count], nt_scores[:tree_count]): batch_data = self.translate_batch(batch, data, node_type_str=type_sequence, fast=self.fast, atc=atc_item) translations = builder.from_batch(batch_data) already_found = False for trans in translations: pred_scores = [ score + type_score for score in trans.pred_scores[:self.n_best] ] # debug(len(pred_scores)) scores += pred_scores pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt_path is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] gold_sent = ' '.join(trans.gold_sent) correct = check_correctness(n_best_preds, gold_sent) # debug(correct == 1) if not already_found: total_correct += correct already_found = True # debug(len(n_best_preds)) predictions += n_best_preds if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: srcs = trans.src_raw preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", * trans.src_raw) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len( srcs) os.write(1, output.encode('utf-8')) all_scores += [scores] all_predictions += [predictions] if self.report_score: if tgt_path is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) if self.logger: self.logger.info(msg) else: print(msg) if self.report_bleu: msg = self._report_bleu(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.report_rouge: msg = self._report_rouge(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) #debug(total_correct) return all_scores, all_predictions