def log(self, sent_number): """ Log translation. """ msg = ['\nSENT {}: {}\n'.format(sent_number, self.src_raw)] best_pred = self.pred_sents[0] best_score = self.pred_scores[0] pred_sent = ' '.join(best_pred) msg.append('PRED {}: {}\n'.format(sent_number, pred_sent)) msg.append("PRED SCORE: {:.4f}\n".format(best_score)) if self.word_aligns is not None: pred_align = self.word_aligns[0] pred_align_pharaoh = build_align_pharaoh(pred_align) pred_align_sent = ' '.join(pred_align_pharaoh) msg.append("ALIGN: {}\n".format(pred_align_sent)) if self.gold_sent is not None: tgt_sent = ' '.join(self.gold_sent) msg.append('GOLD {}: {}\n'.format(sent_number, tgt_sent)) msg.append(("GOLD SCORE: {:.4f}\n".format(self.gold_score))) if len(self.pred_sents) > 1: msg.append('\nBEST HYP:\n') for score, sent in zip(self.pred_scores, self.pred_sents): msg.append("[{:.4f}] {}\n".format(score, sent)) return "".join(msg)
def translate(self, src, tgt=None, src_dir=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table=""): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_dir: See :func:`self.src_reader.read()` (only relevant for certain types of data). batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging align_debug (bool): enables the word alignment logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ if batch_size is None: raise ValueError("batch_size must be set") src_data = {"reader": self.src_reader, "data": src, "dir": src_dir} tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None} _readers, _data, _dir = inputters.Dataset.config([('src', src_data), ('tgt', tgt_data)]) data = inputters.Dataset( self.fields, readers=_readers, data=_data, dirs=_dir, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, batch_size_fn=max_tok_len if batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, self.phrase_table) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] start_time = time.time() for batch in data_iter: batch_data = self.translate_batch(batch, data.src_vocabs, attn_debug) translations = xlation_builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] if self.report_align: align_pharaohs = [ build_align_pharaoh(align) for align in trans.word_aligns[:self.n_best] ] n_best_preds_align = [ " ".join(align) for align in align_pharaohs ] n_best_preds = [ pred + " ||| " + align for pred, align in zip( n_best_preds, n_best_preds_align) ] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] output = report_matrix(srcs, preds, attns) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if align_debug: if trans.gold_sent is not None: tgts = trans.gold_sent else: tgts = trans.pred_sents[0] align = trans.word_aligns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(align[0]))] output = report_matrix(srcs, tgts, align) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) end_time = time.time() if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) self._log(msg) if tgt is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % (total_time / len(all_predictions))) self._log("Tokens per second: %f" % (pred_words_total / total_time)) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def translate(self, src, tgt=None, src_dir=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table=""): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_dir: See :func:`self.src_reader.read()` (only relevant for certain types of data). batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging align_debug (bool): enables the word alignment logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ if batch_size is None: raise ValueError("batch_size must be set") src_data = {"reader": self.src_reader, "data": src, "dir": src_dir} tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None} _readers, _data, _dir = inputters.Dataset.config([('src', src_data), ('tgt', tgt_data)]) data = inputters.Dataset( self.fields, readers=_readers, data=_data, dirs=_dir, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, batch_size_fn=max_tok_len if batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, self.phrase_table) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 # compute accuracy like we do during training across the entire test set (or single word) total_correct_words, total_num_words = 0, 0 all_scores = [] all_predictions = [] start_time = time.time() skipped = 0 total_num_utts = 0 batch_index = 0 for batch in data_iter: batch_data = self.translate_batch(batch, data.src_vocabs, attn_debug) translations = xlation_builder.from_batch(batch_data) trans_index = 0 for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 # ################################################################## # # IOHAVOC - compute accuracy # # data_iter.batches[batch][trans_i].tgt[0] # # target_sentence = trans.gold_sent # data_iter.batches[0][0].tgt[0] ### <<----- this only works for once # # target_sentence = data_iter.batches[batch_index][trans_index].tgt[0] # target_sentence = trans.gold_sent # total_num_utts += 1 # # # if "<unk>" in target_sentence: # # self._log("<UNK> in target_sentence .. skipping") # # skipped += 1 # # continue # # # by summing here you count UNKS, when you want to discount unks # # if True: # style 2 # # stopwords = ['<unk>'] # # target_sentence_for_not_counting_unks = [word for word in trans.gold_sent if word not in stopwords] # # total_num_words += len(target_sentence_for_not_counting_unks) # # else: # style 1 # # assert(len(trans.pred_sents) == 1) # n_best_preds = trans.pred_sents[0] # # if len(target_sentence) != len(n_best_preds): # make sure we predicted the same num words # # IOHAVOC FEB9 EVAL # self._log("ERROR why??? ") # if "<unk>" in target_sentence: # self._log("<UNK> in PHRASE, test set vocab mismatch is messing up predictions") # elif len(target_sentence) > 0 and len(n_best_preds) > 0: # self._log("UNEVEN LENGTHS => " + str(target_sentence) + " + =>" + str(n_best_preds)) # else: # self._log("ANOTHER REASON") # skipped += 1 # continue # # current_trans_correct = 0 # for i in range(len(target_sentence)): # if target_sentence[i] == "<unk>": # continue # if target_sentence[i] == n_best_preds[i]: # total_correct_words += 1 # current_trans_correct += 1 # else: # if False: # IOHAVOC FEB9 EVAL # print("error words: " + target_sentence[i] + " " + n_best_preds[i]) # # total_num_words += len(target_sentence) # # # IOHAVOC FEB9 EVAL # # # # self._log("") # # self._log("") # # self._log("-----------------------------------") # # self._log("num_correct_words: " + str(current_trans_correct)) # # # self._log("num_words: " + str(len(target_sentence_for_not_counting_unks))) # # self._log("num_words: " + str(len(target_sentence))) # # # self._log("num_words: " + str(len(trans.gold_sent))) # # ################################################################## n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] if self.report_align: align_pharaohs = [ build_align_pharaoh(align) for align in trans.word_aligns[:self.n_best] ] n_best_preds_align = [ " ".join(align) for align in align_pharaohs ] n_best_preds = [ pred + " ||| " + align for pred, align in zip( n_best_preds, n_best_preds_align) ] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() # IOHAVOC FEB9 EVAL # IOHAVOC UNCOMMENT THIS FOR ERROR ANALYSES & TO SEE EACH UTTERANCES RESULTS # if self.verbose: # sent_number = next(counter) # output = trans.log(sent_number) # if self.logger: # self.logger.info(output) # else: # os.write(1, output.encode('utf-8')) # print("TARGET: " + " ".join(target_sentence)) if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] output = report_matrix(srcs, preds, attns) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if align_debug: if trans.gold_sent is not None: tgts = trans.gold_sent else: tgts = trans.pred_sents[0] align = trans.word_aligns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(align[0]))] output = report_matrix(srcs, tgts, align) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) end_time = time.time() if self.report_score: # ################################################################## # self._log("-----------------------------------") # self._log("Running total correct_words: " + str(total_correct_words)) # self._log("Running total num_words: " + str(total_num_words)) # self._log("Accuracy (%): " + str(100 * (total_correct_words / total_num_words))) # self._log("-----------------------------------") # ################################################################## print("\nskipped: " + str(skipped)) print("total_num_utts: " + str(total_num_utts)) msg = self._report_score('PRED', pred_score_total, pred_words_total) self._log(msg) if tgt is not None: # GOLD SCORE IS ANNOYING # msg = self._report_score('GOLD', gold_score_total, gold_words_total) # self._log(msg) self.report_bleu = True if self.report_bleu: msg = self._report_bleu(tgt) self._log(msg) if self.report_rouge: msg = self._report_rouge(tgt) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % (total_time / len(all_predictions))) self._log("Tokens per second: %f" % (pred_words_total / total_time)) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def attention_analysis(self, direction, src, tgt, batch_type="sents", phrase_table=""): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_dir: See :func:`self.src_reader.read()` (only relevant for certain types of data). batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging align_debug (bool): enables the word alignment logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert direction in ['x2y', 'y2x'] self.model.encoder = (self.model.encoder_x2y if direction == 'x2y' else self.model.encoder_y2x) self.model.decoder = (self.model.decoder_x2y if direction == 'x2y' else self.model.decoder_y2x) self.model.generator = (self.model.generator_x2y if direction == 'x2y' else self.model.generator_y2x) self.direction = direction batch_size = len(src) src_data = {"reader": self.src_reader, "data": src, "dir": None} tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None} _readers, _data, _dir = inputters.Dataset.config([('src', src_data), ('tgt', tgt_data)]) # corpus_id field is useless here if self.fields.get("corpus_id", None) is not None: self.fields.pop('corpus_id') data = inputters.Dataset( self.fields, readers=_readers, data=_data, dirs=_dir, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, batch_size_fn=max_tok_len if batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, self.phrase_table) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_gold_scores = [] all_predictions = [] all_attentions = [] start_time = time.time() for batch in tqdm(data_iter): batch_data = self.translate_batch(batch, data.src_vocabs, attn_debug=True, only_gold_score=False) translations = xlation_builder.from_batch(batch_data) for trans in translations: n_best_scores = trans.pred_scores[:self.n_best] all_scores += [n_best_scores] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) n_best_gold_scores = [trans.gold_score] all_gold_scores += [n_best_gold_scores] gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] if self.report_align: align_pharaohs = [ build_align_pharaoh(align) for align in trans.word_aligns[:self.n_best] ] n_best_preds_align = [ " ".join(align) for align in align_pharaohs ] n_best_preds = [ pred + " ||| " + align for pred, align in zip( n_best_preds, n_best_preds_align) ] all_predictions += [n_best_preds] if self.out_file: if self.log_score: # in BWD translation(tgt=product, n_best==1), # we use gold score if self.n_best == 1 and tgt is not None: n_best_scores = n_best_gold_scores n_best_preds_scores = [ pred + ',' + str(score.item()) for pred, score in zip(n_best_preds, n_best_scores) ] self.out_file.write('\n'.join(n_best_preds_scores) + '\n') self.out_file.flush() else: self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) os.write(1, output.encode('utf-8')) for i in range(self.beam_size): preds = trans.pred_sents[i] preds.append('</s>') attns = trans.attns[i].tolist() srcs = trans.src_raw print(srcs, len(srcs), len(attns), len(attns[0])) output = report_matrix(srcs, preds, attns) os.write(1, output.encode('utf-8')) all_attentions.append(trans.attns[i].cpu().numpy()) end_time = time.time() if self.report_score: msg = self._report_score('GOLD', gold_score_total, gold_words_total) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % (total_time / len(all_predictions))) self._log("Tokens per second: %f" % (pred_words_total / total_time)) return all_scores, all_predictions, all_attentions
def translate(self, src, tgt=None, src_dir=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table="", partial=None, dymax_len=None): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_dir: See :func:`self.src_reader.read()` (only relevant for certain types of data). batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging align_debug (bool): enables the word alignment logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions * attns is a list of attention scores for translation having highest cumilative log likelihood """ self.dymax_len = dymax_len self.partialf = None # To check with partial words partialfcheck = True # To check with editdistance, put True. To check with just startswith which will be prone to errors due to spelling mistakes, put False. partialfedit = False # partialopt = True # Logic for partial and partialf if partial and partial != '': partials = partial.split() print(partials, '~~~~partials~~~') vocabdict = dict(self.fields)["tgt"].base_field.vocab # if vocabdict.stoi[partials[-1]] == 0: if partialfcheck: # if partialfedit: # parlen = len(partials[-1]) # f = lambda x: 1 + editdistance.eval(x[:parlen], partials[-1]) * 20 # else: # f = lambda x: float('inf') if not x.startswith(partials[-1]) else float('1.0') # editarr = [(f(k) , v) for k, v in vocabdict.stoi.items() if v] # self.partialf = [20.0] + [i[0] for i in sorted(editarr, key=lambda x: x[1])] self.partial = [vocabdict.stoi[x] for x in partials[:-1]] print("#########vocabdict.stoi########") print(self.partial) print("##################################") self.partialf = [ v for k, v in vocabdict.stoi.items() if k.startswith(partials[-1]) and v ] else: self.partial = [vocabdict.stoi[x] for x in partials] # else: # self.partialf = None # self.partial = [vocabdict.stoi[x] for x in partials] else: self.partial = None # self.partialf = None if batch_size is None: raise ValueError("batch_size must be set") src_data = {"reader": self.src_reader, "data": src, "dir": src_dir} tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None} _readers, _data, _dir = inputters.Dataset.config([('src', src_data), ('tgt', tgt_data)]) # corpus_id field is useless here if self.fields.get("corpus_id", None) is not None: self.fields.pop('corpus_id') data = inputters.Dataset( self.fields, readers=_readers, data=_data, dirs=_dir, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, batch_size_fn=max_tok_len if batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, self.phrase_table) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [ ] # I guess this is the cumilative log likelihood score of each sentence all_predictions = [] start_time = time.time() for batch in data_iter: batch_data = self.translate_batch(batch, data.src_vocabs, attn_debug) translations = xlation_builder.from_batch(batch_data) for trans in translations: print("Loop") print(trans, trans.pred_sents) all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] print("############n_best_preds###############") print(n_best_preds) print("############n_best_preds###############") if self.report_align: align_pharaohs = [ build_align_pharaoh(align) for align in trans.word_aligns[:self.n_best] ] n_best_preds_align = [ " ".join(align) for align in align_pharaohs ] n_best_preds = [ pred + " ||| " + align for pred, align in zip( n_best_preds, n_best_preds_align) ] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] output = report_matrix( srcs, preds, attns ) # This prints attentions in output for the sentence having highest cumilative log likelihood score if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if align_debug: if trans.gold_sent is not None: tgts = trans.gold_sent else: tgts = trans.pred_sents[0] align = trans.word_aligns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(align[0]))] output = report_matrix(srcs, tgts, align) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) end_time = time.time() if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) self._log(msg) if tgt is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % (total_time / len(all_predictions))) self._log("Tokens per second: %f" % (pred_words_total / total_time)) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) if attn_debug: return all_scores, all_predictions, attns, pred_score_total, pred_words_total else: return all_scores, all_predictions, pred_score_total, pred_words_total
def translate(self, src, tgt=None, src_dir=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table="", opt=None): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_dir: See :func:`self.src_reader.read()` (only relevant for certain types of data). batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging align_debug (bool): enables the word alignment logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ if batch_size is None: raise ValueError("batch_size must be set") # modified by @memray to accommodate keyphrase src_data = {"reader": self.src_reader, "data": src, "dir": src_dir} tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None} _readers, _data, _dir = inputters.Dataset.config([('src', src_data), ('tgt', tgt_data)]) data = inputters.str2dataset[self.data_type]( self.fields, readers=_readers, data=_data, dirs=_dir, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) # @memray, as Dataset is only instantiated here, having to use this plugin setter if isinstance(data, KeyphraseDataset): data.tgt_type = self.tgt_type data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, batch_size_fn=max_tok_len if batch_type == "tokens" else None, train=False, sort=False, # sort_within_batch=True, sort_within_batch=False, #@memray: to keep the original order shuffle=False) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, self.phrase_table) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] start_time = time.time() num_examples = 0 for batch in data_iter: num_examples += batch_size print("Translating %d/%d" % (num_examples, len(src))) batch_data = self.translate_batch(batch, data.src_vocabs, attn_debug) translations = xlation_builder.from_batch(batch_data) # @memray if self.data_type == "keyphrase": # post-process for one2seq outputs, split seq into individual phrases if self.model_tgt_type != 'one2one': translations = self.segment_one2seq_trans(translations) # add statistics of kps(pred_num, beamstep_num etc.) translations = self.add_trans_stats(translations, self.model_tgt_type) # add copied flag vocab_size = len(self.fields['src'].base_field.vocab.itos) for t in translations: t.add_copied_flags(vocab_size) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] if self.report_align: align_pharaohs = [ build_align_pharaoh(align) for align in trans.word_aligns[:self.n_best] ] n_best_preds_align = [ " ".join(align) for align in align_pharaohs ] n_best_preds = [ pred + " ||| " + align for pred, align in zip( n_best_preds, n_best_preds_align) ] all_predictions += [n_best_preds] if self.out_file: import json if self.data_type == "keyphrase": self.out_file.write( json.dumps(trans.__dict__()) + '\n') self.out_file.flush() else: self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) if self.data_type == "keyphrase": output = trans.log_kp(sent_number) else: output = trans.log(sent_number) if self.verbose: if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] output = report_matrix(srcs, preds, attns) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if align_debug: if trans.gold_sent is not None: tgts = trans.gold_sent else: tgts = trans.pred_sents[0] align = trans.word_aligns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(align[0]))] output = report_matrix(srcs, tgts, align) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) end_time = time.time() if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) self._log(msg) if tgt is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) self._log(msg) if self.report_bleu: msg = self._report_bleu(tgt) self._log(msg) if self.report_rouge: msg = self._report_rouge(tgt) self._log(msg) if self.report_kpeval: # don't run eval here. because in opt.tgt rare words are replaced by <unk> pass # msg = self._report_kpeval(opt.src, opt.tgt, opt.output) # self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % (total_time / len(all_predictions))) self._log("Tokens per second: %f" % (pred_words_total / total_time)) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def _translate(self, data, tgt=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table="", dynamic=False, transform=None): data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, batch_size_fn=max_tok_len if batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False, ) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, self.phrase_table, ) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] start_time = time.time() for batch in data_iter: batch_data = self.translate_batch(batch, data.src_vocabs, attn_debug) translations = xlation_builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] if self.report_align: align_pharaohs = [ build_align_pharaoh(align) for align in trans.word_aligns[:self.n_best] ] n_best_preds_align = [ " ".join(align) for align in align_pharaohs ] n_best_preds = [ pred + DefaultTokens.ALIGNMENT_SEPARATOR + align for pred, align in zip(n_best_preds, n_best_preds_align) ] if dynamic: n_best_preds = [ transform.apply_reverse(x) for x in n_best_preds ] all_predictions += [n_best_preds] self.out_file.write("\n".join(n_best_preds) + "\n") self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode("utf-8")) if attn_debug: preds = trans.pred_sents[0] preds.append(DefaultTokens.EOS) attns = trans.attns[0].tolist() if self.data_type == "text": srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] output = report_matrix(srcs, preds, attns) if self.logger: self.logger.info(output) else: os.write(1, output.encode("utf-8")) if align_debug: tgts = trans.pred_sents[0] align = trans.word_aligns[0].tolist() if self.data_type == "text": srcs = trans.src_raw else: srcs = [str(item) for item in range(len(align[0]))] output = report_matrix(srcs, tgts, align) if self.logger: self.logger.info(output) else: os.write(1, output.encode("utf-8")) end_time = time.time() if self.report_score: msg = self._report_score("PRED", pred_score_total, pred_words_total) self._log(msg) if tgt is not None: msg = self._report_score("GOLD", gold_score_total, gold_words_total) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % (total_time / len(all_predictions))) self._log("Tokens per second: %f" % (pred_words_total / total_time)) if self.dump_beam: import json json.dump( self.translator.beam_accum, codecs.open(self.dump_beam, "w", "utf-8"), ) return all_scores, all_predictions