def translate( self, src, tgt=None, src_dir=None, batch_size=None, attn_debug=False, phrase_table=""): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_dir: See :func:`self.src_reader.read()` (only relevant for certain types of data). batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ if batch_size is None: raise ValueError("batch_size must be set") data = inputters.Dataset( self.fields, readers=([self.src_reader, self.tgt_reader] if tgt else [self.src_reader]), data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)], #dirs=[src_dir, None] if tgt else [src_dir], sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred ) data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False ) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, self.phrase_table ) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] start_time = time.time() for batch in data_iter: batch_data = self.translate_batch( batch, data.src_vocabs, attn_debug ) translations = xlation_builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [" ".join(pred) for pred in trans.pred_sents[:self.n_best]] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *srcs) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) end_time = time.time() if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) self._log(msg) if tgt is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) self._log(msg) if self.report_bleu: msg = self._report_bleu(tgt) self._log(msg) if self.report_rouge: msg = self._report_rouge(tgt) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % ( total_time / len(all_predictions))) self._log("Tokens per second: %f" % ( pred_words_total / total_time)) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def translate( self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, attn_debug=False, ): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data src_data_iter (iterator): an interator generating source data e.g. it may be a list or an openned file tgt_path (str): filepath of target data tgt_data_iter (iterator): an interator generating target data src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src_data_iter is not None or src_path is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset( self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred, image_channel_size=self.image_channel_size, ) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator( dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False, ) builder = onmt.translate.TranslationBuilder(data, self.fields, self.n_best, self.replace_unk, tgt_path) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] for batch in data_iter: batch_data = self.translate_batch(batch, data, fast=self.fast) translations = builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt_path is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] all_predictions += [n_best_preds] # Debug attention. if attn_debug: preds = trans.pred_sents[0] preds.append("</s>") attns = trans.attns[0].tolist() if self.data_type == "text": srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *srcs) + "\n" for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + "\n" row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode("utf-8")) return all_scores, all_predictions
def translate(self, src, src_pos, tgt=None, src_dir=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table=""): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_dir: See :func:`self.src_reader.read()` (only relevant for certain types of data). batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging align_debug (bool): enables the word alignment logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ if batch_size is None: raise ValueError("batch_size must be set") src_data = {"reader": self.src_reader, "data": src, "dir": src_dir} # src_dir here is useless. src_pos_data = { "reader": self.src_pos_reader, "data": src_pos, "dir": src_dir } tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None} _readers, _data, _dir = inputters.Dataset.config([('src', src_data), ('src_pos', src_pos_data), ('tgt', tgt_data)]) data = inputters.Dataset( self.fields, readers=_readers, data=_data, dirs=_dir, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, batch_size_fn=max_tok_len if batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, self.phrase_table) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] start_time = time.time() for batch in data_iter: batch_data = self.translate_batch(batch, data.src_vocabs, attn_debug) translations = xlation_builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] if self.report_align: align_pharaohs = [ build_align_pharaoh(align) for align in trans.word_aligns[:self.n_best] ] n_best_preds_align = [ " ".join(align) for align in align_pharaohs ] n_best_preds = [ pred + " ||| " + align for pred, align in zip( n_best_preds, n_best_preds_align) ] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] output = report_matrix(srcs, preds, attns) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if align_debug: if trans.gold_sent is not None: tgts = trans.gold_sent else: tgts = trans.pred_sents[0] align = trans.word_aligns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(align[0]))] output = report_matrix(srcs, tgts, align) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) end_time = time.time() if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) self._log(msg) if tgt is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) self._log(msg) if self.report_bleu: msg = self._report_bleu(tgt) self._log(msg) if self.report_rouge: msg = self._report_rouge(tgt) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % (total_time / len(all_predictions))) self._log("Tokens per second: %f" % (pred_words_total / total_time)) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def translate(self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, attn_debug=False): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data src_data_iter (iterator): an interator generating source data e.g. it may be a list or an openned file tgt_path (str): filepath of target data tgt_data_iter (iterator): an interator generating target data src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src_data_iter is not None or src_path is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset(self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder(data, self.fields, self.n_best, self.replace_unk, tgt_path) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] for batch in data_iter: batch_data = self.translate_batch(batch, data, fast=self.fast) translations = builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt_path is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) # Debug attention. if attn_debug: srcs = trans.src_raw preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *trans.src_raw) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) if self.logger: self.logger.info(msg) else: print(msg) if tgt_path is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) if self.logger: self.logger.info(msg) else: print(msg) if self.report_bleu: msg = self._report_bleu(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.report_rouge: msg = self._report_rouge(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def translate(self, src, tgt=None, src_dir=None, batch_size=None, attn_debug=False, phrase_table="", opt=None): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_dir: See :func:`self.src_reader.read()` (only relevant for certain types of data). batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ if batch_size is None: raise ValueError("batch_size must be set") # modified by @memray to accommodate keyphrase data = inputters.str2dataset[self.data_type]( self.fields, readers=([self.src_reader, self.tgt_reader] if tgt else [self.src_reader]), data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)], dirs=[src_dir, None] if tgt else [src_dir], sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) # @memray, as Dataset is only instantiated here, having to use this plugin setter if isinstance(data, KeyphraseDataset): data.tgt_type = self.tgt_type data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, train=False, sort=False, # sort_within_batch=True, sort_within_batch=False, #@memray: to keep the original order shuffle=False) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, self.phrase_table) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] start_time = time.time() num_examples = 0 for batch in data_iter: num_examples += batch_size print("Translating %d/%d" % (num_examples, len(src))) batch_data = self.translate_batch(batch, data.src_vocabs, attn_debug) translations = xlation_builder.from_batch(batch_data) # @memray if self.data_type == "keyphrase": # post-process for one2seq outputs, split seq into individual phrases if self.model_tgt_type != 'one2one': translations = self.segment_one2seq_trans(translations) # add statistics of kps(pred_num, beamstep_num etc.) translations = self.add_trans_stats(translations, self.model_tgt_type) # add copied flag vocab_size = len(self.fields['src'].base_field.vocab.itos) for t in translations: t.add_copied_flags(vocab_size) for tran in translations: all_scores += [tran.pred_scores[:self.n_best]] pred_score_total += tran.pred_scores[0] pred_words_total += len(tran.pred_sents[0]) if tgt is not None: gold_score_total += tran.gold_score gold_words_total += len(tran.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in tran.pred_sents[:self.n_best] ] all_predictions += [n_best_preds] if self.data_type == "keyphrase": self.out_file.write(json.dumps(tran.__dict__()) + '\n') self.out_file.flush() else: self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) if self.data_type == "keyphrase": output = tran.log_kp(sent_number) else: output = tran.log(sent_number) if self.verbose: if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: preds = tran.pred_sents[0] preds.append('</s>') attns = tran.attns[0].tolist() if self.data_type == 'text': srcs = tran.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *srcs) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) end_time = time.time() if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) self._log(msg) if tgt is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) self._log(msg) if self.report_bleu: msg = self._report_bleu(tgt) self._log(msg) if self.report_rouge: msg = self._report_rouge(tgt) self._log(msg) if self.report_kpeval: # don't run eval here. because in opt.tgt rare words are replaced by <unk> pass # msg = self._report_kpeval(opt.src, opt.tgt, opt.output) # self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % (total_time / len(all_predictions))) self._log("Tokens per second: %f" % (pred_words_total / total_time)) if self.dump_beam: json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def translate(self, src, tgt=None, src_dir=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table="", partial=None, partialfcheck=True, dymax_len=None, n_best=None): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_dir: See :func:`self.src_reader.read()` (only relevant for certain types of data). batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging align_debug (bool): enables the word alignment logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions * attns is a list of attention scores for translation having highest cumilative log likelihood """ self.dymax_len = dymax_len if n_best and n_best <= 5: self.n_best = n_best self.partialf = None self.partialfcheck = partialfcheck # To check with partial words # partialfcheck = True #Now taking from the input function # To check with editdistance, put True. To check with just startswith which will be prone to errors due to spelling mistakes, put False. partialfedit = False # partialopt = True # Logic for partial and partialf if partial and partial != '': partials = partial.split() print(partials, '~~~~partials~~~') vocabdict = dict(self.fields)["tgt"].base_field.vocab # if vocabdict.stoi[partials[-1]] == 0: if partialfcheck: # if partialfedit: # parlen = len(partials[-1]) # f = lambda x: 1 + editdistance.eval(x[:parlen], partials[-1]) * 20 # else: # f = lambda x: float('inf') if not x.startswith(partials[-1]) else float('1.0') # editarr = [(f(k) , v) for k, v in vocabdict.stoi.items() if v] # self.partialf = [20.0] + [i[0] for i in sorted(editarr, key=lambda x: x[1])] self.partial = [vocabdict.stoi[x] for x in partials[:-1]] print("#########vocabdict.stoi########") print(self.partial) print("##################################") self.partialf = [ v for k, v in vocabdict.stoi.items() if k.startswith(partials[-1]) and v ] else: self.partial = [vocabdict.stoi[x] for x in partials] # else: # self.partialf = None # self.partial = [vocabdict.stoi[x] for x in partials] else: self.partial = None # self.partialf = None if batch_size is None: raise ValueError("batch_size must be set") src_data = {"reader": self.src_reader, "data": src, "dir": src_dir} tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None} _readers, _data, _dir = inputters.Dataset.config([('src', src_data), ('tgt', tgt_data)]) # corpus_id field is useless here if self.fields.get("corpus_id", None) is not None: self.fields.pop('corpus_id') data = inputters.Dataset( self.fields, readers=_readers, data=_data, dirs=_dir, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, batch_size_fn=max_tok_len if batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, self.phrase_table) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [ ] # I guess this is the cumilative log likelihood score of each sentence all_predictions = [] start_time = time.time() for batch in data_iter: batch_data = self.translate_batch(batch, data.src_vocabs, attn_debug) translations = xlation_builder.from_batch(batch_data) for trans in translations: print("Loop") print(trans, trans.pred_sents) all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] print("############n_best_preds###############") print(n_best_preds) print("############n_best_preds###############") if self.report_align: align_pharaohs = [ build_align_pharaoh(align) for align in trans.word_aligns[:self.n_best] ] n_best_preds_align = [ " ".join(align) for align in align_pharaohs ] n_best_preds = [ pred + " ||| " + align for pred, align in zip( n_best_preds, n_best_preds_align) ] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] output = report_matrix( srcs, preds, attns ) # This prints attentions in output for the sentence having highest cumilative log likelihood score if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if align_debug: if trans.gold_sent is not None: tgts = trans.gold_sent else: tgts = trans.pred_sents[0] align = trans.word_aligns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(align[0]))] output = report_matrix(srcs, tgts, align) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) end_time = time.time() if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) self._log(msg) if tgt is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % (total_time / len(all_predictions))) self._log("Tokens per second: %f" % (pred_words_total / total_time)) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) if attn_debug: return all_scores, all_predictions, attns, pred_score_total, pred_words_total else: return all_scores, all_predictions, pred_score_total, pred_words_total
def translate(self, src, tgt=None, agenda=None, src_dir=None, batch_size=None, attn_debug=False, tag_shard=None): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_dir: See :func:`self.src_reader.read()` (only relevant for certain types of data). batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ if batch_size is None: raise ValueError("batch_size must be set") readers, data, dirs = [], [], [] if self.src_reader: readers += [self.src_reader] data += [("src", src)] dirs += [src_dir] if tgt: readers += [self.tgt_reader] data += [("tgt", tgt)] dirs += [None] if agenda: readers += [self.agenda_reader] data += [("agenda", agenda)] dirs += [None] data = inputters.Dataset( self.fields, readers=readers, data=data, dirs=dirs, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) data_iter = inputters.OrderedIterator(dataset=data, device=self._dev, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] start_time = time.time() for batch in data_iter: #tagss = [tag_shard[i] for i in batch.indices] #print('----------------------') #print(batch.indices[4]) #print(words[4][:5]) #print(batch.src[0][:5, 4, 0].cpu().numpy()) #print(len(words[4]), batch.src[0].shape) if tag_shard is not None: max_len = max([len(tag_shard[i]) for i in batch.indices]) cur_tags = [ torch.tensor(tag_shard[i], device=batch.src[0].device, dtype=torch.float) for i in batch.indices ] cur_tags_padded = torch.nn.utils.rnn.pad_sequence( cur_tags, padding_value=0) if batch.src[0].shape[0] != batch.src[1].max().item( ) or cur_tags_padded.shape[ 0] != max_len or max_len != batch.src[1].max().item(): print(batch.src[0].shape, batch.src[1].max().item(), cur_tags_padded.shape, max_len) raise ValueError else: cur_tags_padded = None batch_data = self.translate_batch(batch, data.src_vocabs, attn_debug, tags=cur_tags_padded) translations = xlation_builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *srcs) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) end_time = time.time() if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) self._log(msg) if tgt is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) self._log(msg) if self.report_bleu: msg = self._report_bleu(tgt) self._log(msg) if self.report_rouge: msg = self._report_rouge(tgt) self._log(msg) if agenda: msg = self._report_agenda_accuray(agenda) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % (total_time / len(all_predictions))) self._log("Tokens per second: %f" % (pred_words_total / total_time)) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def translate(self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, src_dir=None, batch_size=None, attn_debug=False, node_type_seq=None, atc=None): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data src_data_iter (iterator): an interator generating source data e.g. it may be a list or an openned file tgt_path (str): filepath of target data tgt_data_iter (iterator): an interator generating target data src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src_data_iter is not None or src_path is not None assert node_type_seq is not None, 'Node Types must be provided' node_type_scores = node_type_seq[1] node_type_seq = node_type_seq[0] if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset(self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder(data, self.fields, self.n_best, self.replace_unk, tgt_path) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] #debug(self.option.tree_count) def check_correctness(preds, gold): for p in preds: if p.strip() == gold.strip(): return 1 return 0 total_correct = 0 for bidx, batch in enumerate(data_iter): example_idx = batch.indices.item( ) # Only 1 item in this batch, guaranteed if bidx % 50 == 0: debug('Current Example : ', example_idx) nt_sequences = node_type_seq[example_idx] nt_scores = node_type_scores[example_idx] if atc is not None: atc_item = atc[example_idx] else: atc_item = None scores = [] predictions = [] tree_count = self.option.tree_count for type_sequence, type_score in zip(nt_sequences[:tree_count], nt_scores[:tree_count]): batch_data = self.translate_batch(batch, data, node_type_str=type_sequence, atc=atc_item) translations = builder.from_batch(batch_data) already_found = False for trans in translations: pred_scores = [ score + type_score for score in trans.pred_scores[:self.n_best] ] # debug(len(pred_scores)) scores += pred_scores pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt_path is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] gold_sent = ' '.join(trans.gold_sent) correct = check_correctness(n_best_preds, gold_sent) # debug(correct == 1) if not already_found: total_correct += correct already_found = True # debug(len(n_best_preds)) predictions += n_best_preds all_scores += [scores] all_predictions += [predictions] if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) debug(total_correct) return all_scores, all_predictions
def translate(self, src, tgt=None, src_dir=None, batch_size=None, attn_debug=False): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data tgt_path (str): filepath of target data or None src_dir (str): source directory path (used for Audio and Image datasets) batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters.build_dataset( self.fields, self.data_type, src=src, tgt=tgt, src_dir=src_dir, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred, image_channel_size=self.image_channel_size, dynamic_dict=self.copy_attn) cur_device = "cuda" if self.cuda else "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder(data, self.fields, self.n_best, self.replace_unk, tgt) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] for batch_count, batch in enumerate(data_iter): batch_data = self.translate_batch(batch, data, attn_debug, fast=self.fast) translations = builder.from_batch(batch_data) for trans_count, trans in enumerate(translations): all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] all_predictions += [n_best_preds] if (self.out_file is not None): self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *srcs) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) if self.attn_output is not None: fig, ax = plt.subplots(figsize=(20, 20)) for n_best_index in range(self.n_best): attn_matrix = trans.attns[n_best_index].data.numpy().T (row, col) = attn_matrix.shape below_threshold = True for row_index in range(row): for value in attn_matrix[row_index, :]: if (value > self.attn_min_threshold): below_threshold = False start_index = row_index break if (not below_threshold): break below_threshold = True for row_index in reversed(range(row)): for value in attn_matrix[row_index, :]: if (value > self.attn_min_threshold): below_threshold = False end_index = row_index if (not below_threshold): break if (start_index is None or end_index is None): os.write( 1, "Cannot find attention values more than " + str(self.attn_min_threshold)) os.write(1, "Please lower attn_min_threshold") else: attn_matrix = attn_matrix[start_index:end_index + 1, :] src_raw = trans.src_raw[start_index:end_index + 1] if (len(src_raw) > self.attn_max_src_length): continue else: draw(attn_matrix, trans.pred_sents[n_best_index], src_raw, ax) plt.savefig( os.path.join( self.attn_output, str(batch_count * batch_size + trans_count + 1) + "_" + str(n_best_index) + ".png")) plt.cla() plt.close(fig) if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) if self.logger: self.logger.info(msg) else: print(msg) if tgt is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) if self.logger: self.logger.info(msg) else: print(msg) if self.report_bleu: msg = self._report_bleu(tgt) if self.logger: self.logger.info(msg) else: print(msg) if self.report_rouge: msg = self._report_rouge(tgt) if self.logger: self.logger.info(msg) else: print(msg) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def save_hidden_states(opt, args): OnmtArgumentParser.update_model_opts(opt) OnmtArgumentParser.validate_model_opts(opt) # load model model_path = os.path.join(opt.save_model, args['onmt']['translate']['model']) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_opt = OnmtArgumentParser.ckpt_model_opts(checkpoint["opt"]) OnmtArgumentParser.update_model_opts(model_opt) OnmtArgumentParser.validate_model_opts(model_opt) vocab = checkpoint['vocab'] model = build_model(model_opt, opt, vocab, checkpoint) cache = [] cache_idxs = [ ] # stores index into the training data and length of sentence for split in ('train', 'valid', 'test'): if split == 'test': test_file = args['data']['test_path'] test_data = open(test_file, "r").readlines() test_src, test_tgt = zip(*[line.split("\t") for line in test_data]) src_reader = inputters.str2reader["text"].from_opt(opt) tgt_reader = inputters.str2reader["text"].from_opt(opt) src_data = {"reader": src_reader, "data": test_src, "dir": None} tgt_data = {"reader": tgt_reader, "data": test_tgt, "dir": None} _readers, _data, _dir = inputters.Dataset.config([ ('src', src_data), ('tgt', tgt_data) ]) data = inputters.Dataset(vocab, readers=_readers, data=_data, dirs=_dir, sort_key=inputters.str2sortkey["text"], filter_pred=None) batch_iter = inputters.OrderedIterator(dataset=data, device=args['device'], batch_size=64, batch_size_fn=None, train=False, sort=False, sort_within_batch=True, shuffle=False) else: train_dataset_paths = get_dataset_paths(opt, split, eos=args['lm']['use_eos']) batch_iter = DatasetLazyIter( train_dataset_paths, vocab, # vocab 64, # batch size None, # "batch_fn" 1, # "batch_size_multiple" args['device'], # device True, # is train 8192, # pool factor repeat=False, num_batches_multiple=1, yield_raw_example=False) tgt_field = vocab["tgt"].base_field tgt_pad_idx = tgt_field.vocab.stoi[tgt_field.pad_token] for batch_i, batch in tqdm(enumerate(batch_iter), desc=f"[{split}]"): # run through model if batch_i > 10000: break src, src_lengths = batch.src tgt = batch.tgt with torch.no_grad(): hidden_states, attn = model(src, tgt, src_lengths, bptt=False, with_align=False) # save src idxs and hidden states pad_masks = (tgt[1:] != tgt_pad_idx).squeeze(2) cache.extend(hidden_states[pad_masks].cpu().numpy()) cache_idxs.extend([( batch.indices[i].item(), pad_masks[:, i].sum().item(), ) for i in range(pad_masks.size(1))]) cache = np.vstack(cache) # save the cache and the cache indices save_path = args['reporter']['results_path'] print(save_path) print(cache.shape) np.save(os.path.join(save_path, f"cache.{split}.npy"), cache) with open(os.path.join(save_path, f"cache_idxs.{split}.csv"), "w") as csvfile: csvfile.write("\n".join( [f"{idx},{length}" for idx, length in cache_idxs])) cache_idxs = [] cache = []
def translate( self, src, tgt=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table="", ): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging align_debug (bool): enables the word alignment logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ if batch_size is None: raise ValueError("batch_size must be set") if self.tgt_prefix and tgt is None: raise ValueError("Prefix should be feed to tgt if -tgt_prefix.") src_data = {"reader": self.src_reader, "data": src} tgt_data = {"reader": self.tgt_reader, "data": tgt} _readers, _data = inputters.Dataset.config([("src", src_data), ("tgt", tgt_data)]) data = inputters.Dataset( self.fields, readers=_readers, data=_data, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred, ) data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, batch_size_fn=max_tok_len if batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False, ) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, self.phrase_table, ) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] start_time = time.time() for batch in data_iter: batch_data = self.translate_batch(batch, data.src_vocabs, attn_debug) translations = xlation_builder.from_batch(batch_data) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] if self.report_align: align_pharaohs = [ build_align_pharaoh(align) for align in trans.word_aligns[:self.n_best] ] n_best_preds_align = [ " ".join(align) for align in align_pharaohs ] n_best_preds = [ pred + DefaultTokens.ALIGNMENT_SEPARATOR + align for pred, align in zip(n_best_preds, n_best_preds_align) ] all_predictions += [n_best_preds] self.out_file.write("\n".join(n_best_preds) + "\n") self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode("utf-8")) if attn_debug: preds = trans.pred_sents[0] preds.append(DefaultTokens.EOS) attns = trans.attns[0].tolist() if self.data_type == "text": srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] output = report_matrix(srcs, preds, attns) if self.logger: self.logger.info(output) else: os.write(1, output.encode("utf-8")) if align_debug: tgts = trans.pred_sents[0] align = trans.word_aligns[0].tolist() if self.data_type == "text": srcs = trans.src_raw else: srcs = [str(item) for item in range(len(align[0]))] output = report_matrix(srcs, tgts, align) if self.logger: self.logger.info(output) else: os.write(1, output.encode("utf-8")) end_time = time.time() if self.report_score: msg = self._report_score("PRED", pred_score_total, pred_words_total) self._log(msg) if tgt is not None: msg = self._report_score("GOLD", gold_score_total, gold_words_total) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % (total_time / len(all_predictions))) self._log("Tokens per second: %f" % (pred_words_total / total_time)) if self.dump_beam: import json json.dump( self.translator.beam_accum, codecs.open(self.dump_beam, "w", "utf-8"), ) return all_scores, all_predictions
def translate(self, src_path=None, src_data_iter=None, tgt_path=None, tgt_data_iter=None, batch_size=None, attn_debug=False): """ Translate content of `src_data_iter` (if not None) or `src_path` and get gold scores if one of `tgt_data_iter` or `tgt_path` is set. Note: batch_size must not be None Note: one of ('src_path', 'src_data_iter') must not be None Args: src_path (str): filepath of source data src_data_iter (iterator): an interator generating source data e.g. it may be a list or an openned file tgt_path (str): filepath of target data tgt_data_iter (iterator): an interator generating target data batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ assert src_data_iter is not None or src_path is not None if batch_size is None: raise ValueError("batch_size must be set") data = inputters. \ build_dataset(self.fields, self.data_type, src_path=src_path, src_data_iter=src_data_iter, tgt_path=tgt_path, tgt_data_iter=tgt_data_iter, sample_rate=self.sample_rate, window_size=self.window_size, window_stride=self.window_stride, window=self.window, use_filter_pred=self.use_filter_pred, topk_keywords=self.topk_keywords) if self.cuda: cur_device = "cuda" else: cur_device = "cpu" data_iter = inputters.OrderedIterator(dataset=data, device=cur_device, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) builder = onmt.translate.TranslationBuilder(data, self.fields, self.n_best, self.replace_unk, tgt_path) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] _recall_1_result = [] _mrr_1_result = [] _recall_5_result = [] _mrr_5_result = [] _recall_10_result = [] _mrr_10_result = [] _recall_15_result = [] _mrr_15_result = [] _recall_20_result = [] _mrr_20_result = [] pred_sents_list = [] gold_sents_list = [] for batch in data_iter: batch_data = self.translate_batch( batch, data, fast=self.fast, skip_keyphrase=self.skip_keyphrase) # for click prediction scores_sku = batch_data["score_sku_prob"] # [b X 20] sorted_index_t = np.argsort(-scores_sku.detach())[:, :20].to( scores_sku.device) batch_data["pred_sku"] = sorted_index_t target_sku = batch_data["batch"].tgt_sku[0].squeeze() _preds = sorted_index_t.tolist() _golds = target_sku.tolist() for (_pred, _gold) in zip(_preds, _golds): # Recall@20 if _gold in _pred: _recall_20_result.append(1) _mrr_20_result.append(1.0 / (_pred.index(_gold) + 1)) else: _recall_20_result.append(0) _mrr_20_result.append(0) # Recall@15 _tmp_pred = _pred[:15] if _gold in _tmp_pred: _recall_15_result.append(1) _mrr_15_result.append(1.0 / (_tmp_pred.index(_gold) + 1)) else: _recall_15_result.append(0) _mrr_15_result.append(0) # Recall@10 _tmp_pred = _pred[:10] if _gold in _tmp_pred: _recall_10_result.append(1) _mrr_10_result.append(1.0 / (_tmp_pred.index(_gold) + 1)) else: _recall_10_result.append(0) _mrr_10_result.append(0) # Recall@5 _tmp_pred = _pred[:5] if _gold in _tmp_pred: _recall_5_result.append(1) _mrr_5_result.append(1.0 / (_tmp_pred.index(_gold) + 1)) else: _recall_5_result.append(0) _mrr_5_result.append(0) # Recall@1 _tmp_pred = _pred[:1] if _gold in _tmp_pred: _recall_1_result.append(1) _mrr_1_result.append(1.0 / (_tmp_pred.index(_gold) + 1)) else: _recall_1_result.append(0) _mrr_1_result.append(0) # if batch_data only contains batch data and score_sku_prob, skip the translation of decoder if len(batch_data.keys()) == 3: continue translations = builder.from_batch(batch_data) pred_sents_list.extend( [' '.join(_tran.pred_sents[0]) for _tran in translations]) gold_sents_list.extend( [' '.join(_tran.gold_sent) for _tran in translations]) for trans in translations: all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt_path is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) # Debug attention. if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *srcs) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) # print click results self.logger.info('Recall@1: {}\t\tMrr@1: {}'.format( np.array(_recall_1_result).mean(), np.array(_mrr_1_result).mean())) self.logger.info('Recall@5: {}\t\tMrr@5: {}'.format( np.array(_recall_5_result).mean(), np.array(_mrr_5_result).mean())) self.logger.info('Recall@10: {}\t\tMrr@10: {}'.format( np.array(_recall_10_result).mean(), np.array(_mrr_10_result).mean())) self.logger.info('Recall@15: {}\t\tMrr@15: {}'.format( np.array(_recall_15_result).mean(), np.array(_mrr_15_result).mean())) self.logger.info('Recall@20: {}\t\tMrr@20: {}'.format( np.array(_recall_20_result).mean(), np.array(_mrr_20_result).mean())) def save_pred(_results_dirname, _name, _results, _cut_off): with open( '{}/pred_{}@{}'.format(_results_dirname, _name, _cut_off), 'w') as f: f.write('\n'.join(map(str, _results))) self.logger.info('Saving prediction results.') _date = self.pred_save_path.split('/')[1] _model = self.pred_save_path.split('/')[2].split('.')[0] if len(gold_sents_list) > 0: rouge = Rouge() rouge_scores = rouge.get_scores(pred_sents_list, gold_sents_list, avg=True) self.logger.info('Rouge-1:\tP:{}\t\tR:{}\t\tF1:{}'.format( rouge_scores['rouge-1']['p'], rouge_scores['rouge-1']['r'], rouge_scores['rouge-1']['f'])) self.logger.info('Rouge-2:\tP:{}\t\tR:{}\t\tF1:{}'.format( rouge_scores['rouge-2']['p'], rouge_scores['rouge-2']['r'], rouge_scores['rouge-2']['f'])) self.logger.info('Rouge-L:\tP:{}\t\tR:{}\t\tF1:{}'.format( rouge_scores['rouge-l']['p'], rouge_scores['rouge-l']['r'], rouge_scores['rouge-l']['f'])) if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) if self.logger: self.logger.info(msg) else: print(msg) if tgt_path is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) if self.logger: self.logger.info(msg) else: print(msg) if self.report_bleu: msg = self._report_bleu(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.report_rouge: msg = self._report_rouge(tgt_path) if self.logger: self.logger.info(msg) else: print(msg) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions
def translate(opt): """ Returns source and baseline embeddings""" ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) baseline_shards = split_corpus(opt.baseline, opt.shard_size) print("\nEmbedding source and baseline...\n") # Loop for src_embedding for i, src_shard in enumerate(src_shards): src_data = { "reader": translator.src_reader, "data": src_shard, "dir": opt.src_dir } _readers, _data, _dir = inputters.Dataset.config([('src', src_data)]) data = inputters.Dataset( translator.fields, readers=_readers, data=_data, dirs=_dir, sort_key=inputters.str2sortkey[translator.data_type], filter_pred=translator._filter_pred) data_iter = inputters.OrderedIterator( dataset=data, device=translator._dev, batch_size=opt.batch_size, batch_size_fn=translator.max_tok_len if opt.batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False) for batch in data_iter: src, src_lengths = batch.src src_embed = translator.model.encoder.embed(src, src_lengths) # Loop for baseline_embedding for i, src_shard in enumerate(baseline_shards): src_data = { "reader": translator.src_reader, "data": src_shard, "dir": opt.src_dir } _readers, _data, _dir = inputters.Dataset.config([('src', src_data)]) data = inputters.Dataset( translator.fields, readers=_readers, data=_data, dirs=_dir, sort_key=inputters.str2sortkey[translator.data_type], filter_pred=translator._filter_pred) data_iter = inputters.OrderedIterator( dataset=data, device=translator._dev, batch_size=opt.batch_size, batch_size_fn=translator.max_tok_len if opt.batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False) for batch in data_iter: src, src_lengths = batch.src bline_embed = translator.model.encoder.embed(src, src_lengths) return src_embed, bline_embed
def translate_gold_diff(self, src, tgt=None, tgt2=None, src_dir=None, batch_size=None, batch_type="sents", attn_debug=False, align_debug=False, phrase_table="", src_embed=None, hidden_state=None, unlearn=False): """Translate content of ``src`` and get gold score difference of ``tgt`` and "tgt2". Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_dir: See :func:`self.src_reader.read()` (only relevant for certain types of data). batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging align_debug (bool): enables the word alignment logging src_embed: embeddings of the source hidden_state (torch tensor) output of the encoder layers Returns: gold_score1 = gold_score2 (torch tensor) """ if batch_size is None: raise ValueError("batch_size must be set") src_data = {"reader": self.src_reader, "data": src, "dir": src_dir} tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None} _readers, _data, _dir = inputters.Dataset.config([('src', src_data), ('tgt', tgt_data)]) data = inputters.Dataset( self.fields, readers=_readers, data=_data, dirs=_dir, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, batch_size_fn=max_tok_len if batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False) if unlearn: with torch.no_grad(): for i, batch in enumerate(data_iter): gold_scores_1, src, enc_states, memory_bank, src_lengths = self.translate_batch( batch, data.src_vocabs, attn_debug, src_embed=src_embed, hidden_state=hidden_state) else: for i, batch in enumerate(data_iter): gold_scores_1, src, enc_states, memory_bank, src_lengths = self.translate_batch( batch, data.src_vocabs, attn_debug, src_embed=src_embed, hidden_state=hidden_state) gold_scores_1 = torch.exp(gold_scores_1) if tgt2 is not None: tgt2_data = {"reader": self.tgt2_reader, "data": tgt2, "dir": None} _readers, _data, _dir = inputters.Dataset.config([ ('src', src_data), ('tgt', tgt2_data) ]) data = inputters.Dataset( self.fields, readers=_readers, data=_data, dirs=_dir, sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) data_iter = inputters.OrderedIterator( dataset=data, device=self._dev, batch_size=batch_size, batch_size_fn=max_tok_len if batch_type == "tokens" else None, train=False, sort=False, sort_within_batch=True, shuffle=False) for batch in data_iter: gold_scores_2 = self.translate_batch(batch, data.src_vocabs, attn_debug, src, enc_states, memory_bank, src_lengths, src_embed, tgt2=True, hidden_state=hidden_state) gold_scores_2 = torch.exp(gold_scores_2) else: gold_scores_2 = 0 return gold_scores_1 - gold_scores_2
def translate(self, src, tgt=None, src_dir=None, batch_size=None, attn_debug=False, attn_vis=False): """Translate content of ``src`` and get gold scores from ``tgt``. Args: src: See :func:`self.src_reader.read()`. tgt: See :func:`self.tgt_reader.read()`. src_dir: See :func:`self.src_reader.read()` (only relevant for certain types of data). batch_size (int): size of examples per mini-batch attn_debug (bool): enables the attention logging Returns: (`list`, `list`) * all_scores is a list of `batch_size` lists of `n_best` scores * all_predictions is a list of `batch_size` lists of `n_best` predictions """ if batch_size is None: raise ValueError("batch_size must be set") data = inputters.Dataset( self.fields, readers=([self.src_reader, self.tgt_reader] if tgt else [self.src_reader]), data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)], dirs=[src_dir, None] if tgt else [src_dir], sort_key=inputters.str2sortkey[self.data_type], filter_pred=self._filter_pred) data_iter = inputters.OrderedIterator(dataset=data, device=self._dev, batch_size=batch_size, train=False, sort=False, sort_within_batch=True, shuffle=False) xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt) # Statistics counter = count(1) pred_score_total, pred_words_total = 0, 0 gold_score_total, gold_words_total = 0, 0 all_scores = [] all_predictions = [] start_time = time.time() for batch in data_iter: batch_data = self.translate_batch(batch, data.src_vocabs, attn_debug) translations = xlation_builder.from_batch(batch_data) for i, trans in enumerate(translations): all_scores += [trans.pred_scores[:self.n_best]] pred_score_total += trans.pred_scores[0] pred_words_total += len(trans.pred_sents[0]) if tgt is not None: gold_score_total += trans.gold_score gold_words_total += len(trans.gold_sent) + 1 n_best_preds = [ " ".join(pred) for pred in trans.pred_sents[:self.n_best] ] all_predictions += [n_best_preds] self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.flush() if attn_vis: src = trans.src_raw pred = trans.pred_sents[0] pred.append("'</s>'") attns = trans.attns[0] attns = attns[:len(pred), :len(src)].data.cpu() def draw(data, x, y, ax, title): ax.set_title(title) seaborn.heatmap(data, xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0, cbar=False, ax=ax) # plots attention over inputs seaborn.heatmap(attns, xticklabels=src, square=True, yticklabels=pred, vmin=0.0, vmax=1.0, cbar=False) plt.show() for layer in range(0, 6, 2): fig, axs = plt.subplots(1, 4, figsize=(12, 6)) for h in range(4): data = list(self.model.encoder.children() )[1][layer].self_attn.attn.data.cpu() data = data[i, h] data = data[:len(src), :len(src)] draw(data, src, src if h == 0 else [], ax=axs[h], title="Head {}".format(h)) # plt.suptitle("Attention Distribution for Layer {}".format(layer), y=.8) plt.show() if self.verbose: sent_number = next(counter) output = trans.log(sent_number) if self.logger: self.logger.info(output) else: os.write(1, output.encode('utf-8')) if attn_debug: preds = trans.pred_sents[0] preds.append('</s>') attns = trans.attns[0].tolist() if self.data_type == 'text': srcs = trans.src_raw else: srcs = [str(item) for item in range(len(attns[0]))] header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = header_format.format("", *srcs) + '\n' for word, row in zip(preds, attns): max_index = row.index(max(row)) row_format = row_format.replace( "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) row_format = row_format.replace( "{:*>10.7f} ", "{:>10.7f} ", max_index) output += row_format.format(word, *row) + '\n' row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) os.write(1, output.encode('utf-8')) end_time = time.time() if self.report_score: msg = self._report_score('PRED', pred_score_total, pred_words_total) self._log(msg) if tgt is not None: msg = self._report_score('GOLD', gold_score_total, gold_words_total) self._log(msg) if self.report_bleu: msg = self._report_bleu(tgt) self._log(msg) if self.report_rouge: msg = self._report_rouge(tgt) self._log(msg) if self.report_time: total_time = end_time - start_time self._log("Total translation time (s): %f" % total_time) self._log("Average translation time (s): %f" % (total_time / len(all_predictions))) self._log("Tokens per second: %f" % (pred_words_total / total_time)) if self.dump_beam: import json json.dump(self.translator.beam_accum, codecs.open(self.dump_beam, 'w', 'utf-8')) return all_scores, all_predictions