Ejemplo n.º 1
0
    def translate(
            self,
            src,
            tgt=None,
            src_dir=None,
            batch_size=None,
            attn_debug=False,
            phrase_table=""):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """

        if batch_size is None:
            raise ValueError("batch_size must be set")

        data = inputters.Dataset(
            self.fields,
            readers=([self.src_reader, self.tgt_reader]
                     if tgt else [self.src_reader]),
            data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)],
            #dirs=[src_dir, None] if tgt else [src_dir],
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred
        )

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False
        )

        xlation_builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt,
            self.phrase_table
        )

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        start_time = time.time()

        for batch in data_iter:
            batch_data = self.translate_batch(
                batch, data.src_vocabs, attn_debug
            )
            translations = xlation_builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [" ".join(pred)
                                for pred in trans.pred_sents[:self.n_best]]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *srcs) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

        end_time = time.time()

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                self._log(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt)
                    self._log(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt)
                    self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" % (
                total_time / len(all_predictions)))
            self._log("Tokens per second: %f" % (
                pred_words_total / total_time))

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions
Ejemplo n.º 2
0
    def translate(
        self,
        src_path=None,
        src_data_iter=None,
        tgt_path=None,
        tgt_data_iter=None,
        src_dir=None,
        batch_size=None,
        attn_debug=False,
    ):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            src_data_iter (iterator): an interator generating source data
                e.g. it may be a list or an openned file
            tgt_path (str): filepath of target data
            tgt_data_iter (iterator): an interator generating target data
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src_data_iter is not None or src_path is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters.build_dataset(
            self.fields,
            self.data_type,
            src_path=src_path,
            src_data_iter=src_data_iter,
            tgt_path=tgt_path,
            tgt_data_iter=tgt_data_iter,
            src_dir=src_dir,
            sample_rate=self.sample_rate,
            window_size=self.window_size,
            window_stride=self.window_stride,
            window=self.window,
            use_filter_pred=self.use_filter_pred,
            image_channel_size=self.image_channel_size,
        )

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=cur_device,
            batch_size=batch_size,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False,
        )

        builder = onmt.translate.TranslationBuilder(data, self.fields,
                                                    self.n_best,
                                                    self.replace_unk, tgt_path)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        for batch in data_iter:
            batch_data = self.translate_batch(batch, data, fast=self.fast)
            translations = builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt_path is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                all_predictions += [n_best_preds]

                # Debug attention.
                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append("</s>")
                    attns = trans.attns[0].tolist()
                    if self.data_type == "text":
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *srcs) + "\n"
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + "\n"
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode("utf-8"))

        return all_scores, all_predictions
    def translate(self,
                  src,
                  src_pos,
                  tgt=None,
                  src_dir=None,
                  batch_size=None,
                  batch_type="sents",
                  attn_debug=False,
                  align_debug=False,
                  phrase_table=""):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging
            align_debug (bool): enables the word alignment logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """

        if batch_size is None:
            raise ValueError("batch_size must be set")

        src_data = {"reader": self.src_reader, "data": src, "dir": src_dir}
        # src_dir here is useless.
        src_pos_data = {
            "reader": self.src_pos_reader,
            "data": src_pos,
            "dir": src_dir
        }
        tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None}
        _readers, _data, _dir = inputters.Dataset.config([('src', src_data),
                                                          ('src_pos',
                                                           src_pos_data),
                                                          ('tgt', tgt_data)])

        data = inputters.Dataset(
            self.fields,
            readers=_readers,
            data=_data,
            dirs=_dir,
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred)

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            batch_size_fn=max_tok_len if batch_type == "tokens" else None,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False)

        xlation_builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt,
            self.phrase_table)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        start_time = time.time()

        for batch in data_iter:
            batch_data = self.translate_batch(batch, data.src_vocabs,
                                              attn_debug)
            translations = xlation_builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                if self.report_align:
                    align_pharaohs = [
                        build_align_pharaoh(align)
                        for align in trans.word_aligns[:self.n_best]
                    ]
                    n_best_preds_align = [
                        " ".join(align) for align in align_pharaohs
                    ]
                    n_best_preds = [
                        pred + " ||| " + align for pred, align in zip(
                            n_best_preds, n_best_preds_align)
                    ]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    output = report_matrix(srcs, preds, attns)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if align_debug:
                    if trans.gold_sent is not None:
                        tgts = trans.gold_sent
                    else:
                        tgts = trans.pred_sents[0]
                    align = trans.word_aligns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(align[0]))]
                    output = report_matrix(srcs, tgts, align)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

        end_time = time.time()

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                self._log(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt)
                    self._log(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt)
                    self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" %
                      (total_time / len(all_predictions)))
            self._log("Tokens per second: %f" %
                      (pred_words_total / total_time))

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions
Ejemplo n.º 4
0
    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            src_data_iter (iterator): an interator generating source data
                e.g. it may be a list or an openned file
            tgt_path (str): filepath of target data
            tgt_data_iter (iterator): an interator generating target data
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src_data_iter is not None or src_path is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters.build_dataset(self.fields,
                                       self.data_type,
                                       src_path=src_path,
                                       src_data_iter=src_data_iter,
                                       tgt_path=tgt_path,
                                       tgt_data_iter=tgt_data_iter,
                                       src_dir=src_dir,
                                       sample_rate=self.sample_rate,
                                       window_size=self.window_size,
                                       window_stride=self.window_stride,
                                       window=self.window,
                                       use_filter_pred=self.use_filter_pred)

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        builder = onmt.translate.TranslationBuilder(data, self.fields,
                                                    self.n_best,
                                                    self.replace_unk, tgt_path)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        for batch in data_iter:
            batch_data = self.translate_batch(batch, data, fast=self.fast)
            translations = builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt_path is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                # Debug attention.
                if attn_debug:
                    srcs = trans.src_raw
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *trans.src_raw) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            if self.logger:
                self.logger.info(msg)
            else:
                print(msg)
            if tgt_path is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions
Ejemplo n.º 5
0
    def translate(self,
                  src,
                  tgt=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False,
                  phrase_table="",
                  opt=None):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """

        if batch_size is None:
            raise ValueError("batch_size must be set")

        # modified by @memray to accommodate keyphrase
        data = inputters.str2dataset[self.data_type](
            self.fields,
            readers=([self.src_reader, self.tgt_reader]
                     if tgt else [self.src_reader]),
            data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)],
            dirs=[src_dir, None] if tgt else [src_dir],
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred)
        # @memray, as Dataset is only instantiated here, having to use this plugin setter
        if isinstance(data, KeyphraseDataset):
            data.tgt_type = self.tgt_type

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            train=False,
            sort=False,
            # sort_within_batch=True,
            sort_within_batch=False,  #@memray: to keep the original order
            shuffle=False)

        xlation_builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt,
            self.phrase_table)
        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        start_time = time.time()

        num_examples = 0
        for batch in data_iter:
            num_examples += batch_size
            print("Translating %d/%d" % (num_examples, len(src)))

            batch_data = self.translate_batch(batch, data.src_vocabs,
                                              attn_debug)
            translations = xlation_builder.from_batch(batch_data)

            # @memray
            if self.data_type == "keyphrase":
                # post-process for one2seq outputs, split seq into individual phrases
                if self.model_tgt_type != 'one2one':
                    translations = self.segment_one2seq_trans(translations)
                # add statistics of kps(pred_num, beamstep_num etc.)
                translations = self.add_trans_stats(translations,
                                                    self.model_tgt_type)

                # add copied flag
                vocab_size = len(self.fields['src'].base_field.vocab.itos)
                for t in translations:
                    t.add_copied_flags(vocab_size)

            for tran in translations:
                all_scores += [tran.pred_scores[:self.n_best]]
                pred_score_total += tran.pred_scores[0]
                pred_words_total += len(tran.pred_sents[0])
                if tgt is not None:
                    gold_score_total += tran.gold_score
                    gold_words_total += len(tran.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in tran.pred_sents[:self.n_best]
                ]
                all_predictions += [n_best_preds]
                if self.data_type == "keyphrase":
                    self.out_file.write(json.dumps(tran.__dict__()) + '\n')
                    self.out_file.flush()
                else:
                    self.out_file.write('\n'.join(n_best_preds) + '\n')
                    self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    if self.data_type == "keyphrase":
                        output = tran.log_kp(sent_number)
                    else:
                        output = tran.log(sent_number)

                    if self.verbose:
                        if self.logger:
                            self.logger.info(output)
                        else:
                            os.write(1, output.encode('utf-8'))

                if attn_debug:
                    preds = tran.pred_sents[0]
                    preds.append('</s>')
                    attns = tran.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = tran.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *srcs) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

        end_time = time.time()

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                self._log(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt)
                    self._log(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt)
                    self._log(msg)
                if self.report_kpeval:
                    # don't run eval here. because in opt.tgt rare words are replaced by <unk>
                    pass
                    # msg = self._report_kpeval(opt.src, opt.tgt, opt.output)
                    # self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" %
                      (total_time / len(all_predictions)))
            self._log("Tokens per second: %f" %
                      (pred_words_total / total_time))

        if self.dump_beam:
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))

        return all_scores, all_predictions
Ejemplo n.º 6
0
    def translate(self,
                  src,
                  tgt=None,
                  src_dir=None,
                  batch_size=None,
                  batch_type="sents",
                  attn_debug=False,
                  align_debug=False,
                  phrase_table="",
                  partial=None,
                  partialfcheck=True,
                  dymax_len=None,
                  n_best=None):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging
            align_debug (bool): enables the word alignment logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
            * attns is a list of attention scores for translation having highest cumilative log likelihood
        """
        self.dymax_len = dymax_len
        if n_best and n_best <= 5:
            self.n_best = n_best
        self.partialf = None
        self.partialfcheck = partialfcheck

        # To check with partial words
        # partialfcheck = True #Now taking from the input function

        # To check with editdistance, put True. To check with just startswith which will be prone to errors due to spelling mistakes, put False.
        partialfedit = False

        # partialopt = True

        # Logic for partial and partialf
        if partial and partial != '':
            partials = partial.split()
            print(partials, '~~~~partials~~~')
            vocabdict = dict(self.fields)["tgt"].base_field.vocab
            # if vocabdict.stoi[partials[-1]] == 0:
            if partialfcheck:
                # if partialfedit:
                #     parlen = len(partials[-1])
                #     f = lambda x: 1 + editdistance.eval(x[:parlen], partials[-1]) * 20
                # else:
                #     f = lambda x: float('inf') if not x.startswith(partials[-1]) else float('1.0')

                # editarr = [(f(k) , v) for k, v in vocabdict.stoi.items() if v]
                # self.partialf = [20.0] + [i[0] for i in sorted(editarr, key=lambda x: x[1])]

                self.partial = [vocabdict.stoi[x] for x in partials[:-1]]
                print("#########vocabdict.stoi########")
                print(self.partial)
                print("##################################")

                self.partialf = [
                    v for k, v in vocabdict.stoi.items()
                    if k.startswith(partials[-1]) and v
                ]
            else:
                self.partial = [vocabdict.stoi[x] for x in partials]
            # else:
            #     self.partialf = None
            #     self.partial = [vocabdict.stoi[x] for x in partials]
        else:
            self.partial = None
            # self.partialf = None

        if batch_size is None:
            raise ValueError("batch_size must be set")

        src_data = {"reader": self.src_reader, "data": src, "dir": src_dir}
        tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None}
        _readers, _data, _dir = inputters.Dataset.config([('src', src_data),
                                                          ('tgt', tgt_data)])

        # corpus_id field is useless here
        if self.fields.get("corpus_id", None) is not None:
            self.fields.pop('corpus_id')
        data = inputters.Dataset(
            self.fields,
            readers=_readers,
            data=_data,
            dirs=_dir,
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred)

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            batch_size_fn=max_tok_len if batch_type == "tokens" else None,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False)

        xlation_builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt,
            self.phrase_table)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = [
        ]  # I guess this is the cumilative log likelihood score of each sentence
        all_predictions = []

        start_time = time.time()

        for batch in data_iter:
            batch_data = self.translate_batch(batch, data.src_vocabs,
                                              attn_debug)
            translations = xlation_builder.from_batch(batch_data)

            for trans in translations:
                print("Loop")
                print(trans, trans.pred_sents)
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]

                print("############n_best_preds###############")
                print(n_best_preds)
                print("############n_best_preds###############")

                if self.report_align:
                    align_pharaohs = [
                        build_align_pharaoh(align)
                        for align in trans.word_aligns[:self.n_best]
                    ]
                    n_best_preds_align = [
                        " ".join(align) for align in align_pharaohs
                    ]
                    n_best_preds = [
                        pred + " ||| " + align for pred, align in zip(
                            n_best_preds, n_best_preds_align)
                    ]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    output = report_matrix(
                        srcs, preds, attns
                    )  # This prints attentions in output for the sentence having highest cumilative log likelihood score

                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if align_debug:
                    if trans.gold_sent is not None:
                        tgts = trans.gold_sent
                    else:
                        tgts = trans.pred_sents[0]
                    align = trans.word_aligns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(align[0]))]
                    output = report_matrix(srcs, tgts, align)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

        end_time = time.time()

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" %
                      (total_time / len(all_predictions)))
            self._log("Tokens per second: %f" %
                      (pred_words_total / total_time))

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))

        if attn_debug:
            return all_scores, all_predictions, attns, pred_score_total, pred_words_total
        else:
            return all_scores, all_predictions, pred_score_total, pred_words_total
    def translate(self,
                  src,
                  tgt=None,
                  agenda=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False,
                  tag_shard=None):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """

        if batch_size is None:
            raise ValueError("batch_size must be set")

        readers, data, dirs = [], [], []
        if self.src_reader:
            readers += [self.src_reader]
            data += [("src", src)]
            dirs += [src_dir]
        if tgt:
            readers += [self.tgt_reader]
            data += [("tgt", tgt)]
            dirs += [None]
        if agenda:
            readers += [self.agenda_reader]
            data += [("agenda", agenda)]
            dirs += [None]

        data = inputters.Dataset(
            self.fields,
            readers=readers,
            data=data,
            dirs=dirs,
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred)

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=self._dev,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        xlation_builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        start_time = time.time()

        for batch in data_iter:
            #tagss = [tag_shard[i] for i in batch.indices]

            #print('----------------------')
            #print(batch.indices[4])
            #print(words[4][:5])
            #print(batch.src[0][:5, 4, 0].cpu().numpy())
            #print(len(words[4]), batch.src[0].shape)

            if tag_shard is not None:
                max_len = max([len(tag_shard[i]) for i in batch.indices])
                cur_tags = [
                    torch.tensor(tag_shard[i],
                                 device=batch.src[0].device,
                                 dtype=torch.float) for i in batch.indices
                ]
                cur_tags_padded = torch.nn.utils.rnn.pad_sequence(
                    cur_tags, padding_value=0)
                if batch.src[0].shape[0] != batch.src[1].max().item(
                ) or cur_tags_padded.shape[
                        0] != max_len or max_len != batch.src[1].max().item():
                    print(batch.src[0].shape, batch.src[1].max().item(),
                          cur_tags_padded.shape, max_len)
                    raise ValueError
            else:
                cur_tags_padded = None

            batch_data = self.translate_batch(batch,
                                              data.src_vocabs,
                                              attn_debug,
                                              tags=cur_tags_padded)
            translations = xlation_builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *srcs) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))

        end_time = time.time()

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                self._log(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt)
                    self._log(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt)
                    self._log(msg)
            if agenda:
                msg = self._report_agenda_accuray(agenda)
                self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" %
                      (total_time / len(all_predictions)))
            self._log("Tokens per second: %f" %
                      (pred_words_total / total_time))

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions
Ejemplo n.º 8
0
    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False,
                  node_type_seq=None,
                  atc=None):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            src_data_iter (iterator): an interator generating source data
                e.g. it may be a list or an openned file
            tgt_path (str): filepath of target data
            tgt_data_iter (iterator): an interator generating target data
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src_data_iter is not None or src_path is not None
        assert node_type_seq is not None, 'Node Types must be provided'
        node_type_scores = node_type_seq[1]
        node_type_seq = node_type_seq[0]
        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters.build_dataset(self.fields,
                                       self.data_type,
                                       src_path=src_path,
                                       src_data_iter=src_data_iter,
                                       tgt_path=tgt_path,
                                       tgt_data_iter=tgt_data_iter,
                                       src_dir=src_dir,
                                       sample_rate=self.sample_rate,
                                       window_size=self.window_size,
                                       window_stride=self.window_stride,
                                       window=self.window,
                                       use_filter_pred=self.use_filter_pred)

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        builder = onmt.translate.TranslationBuilder(data, self.fields,
                                                    self.n_best,
                                                    self.replace_unk, tgt_path)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        #debug(self.option.tree_count)

        def check_correctness(preds, gold):
            for p in preds:
                if p.strip() == gold.strip():
                    return 1
            return 0

        total_correct = 0

        for bidx, batch in enumerate(data_iter):
            example_idx = batch.indices.item(
            )  # Only 1 item in this batch, guaranteed
            if bidx % 50 == 0:
                debug('Current Example : ', example_idx)
            nt_sequences = node_type_seq[example_idx]
            nt_scores = node_type_scores[example_idx]
            if atc is not None:
                atc_item = atc[example_idx]
            else:
                atc_item = None
            scores = []
            predictions = []
            tree_count = self.option.tree_count
            for type_sequence, type_score in zip(nt_sequences[:tree_count],
                                                 nt_scores[:tree_count]):
                batch_data = self.translate_batch(batch,
                                                  data,
                                                  node_type_str=type_sequence,
                                                  atc=atc_item)
                translations = builder.from_batch(batch_data)
                already_found = False
                for trans in translations:
                    pred_scores = [
                        score + type_score
                        for score in trans.pred_scores[:self.n_best]
                    ]
                    # debug(len(pred_scores))
                    scores += pred_scores
                    pred_score_total += trans.pred_scores[0]
                    pred_words_total += len(trans.pred_sents[0])
                    if tgt_path is not None:
                        gold_score_total += trans.gold_score
                        gold_words_total += len(trans.gold_sent) + 1
                    n_best_preds = [
                        " ".join(pred)
                        for pred in trans.pred_sents[:self.n_best]
                    ]
                    gold_sent = ' '.join(trans.gold_sent)
                    correct = check_correctness(n_best_preds, gold_sent)
                    # debug(correct == 1)
                    if not already_found:
                        total_correct += correct
                        already_found = True
                    # debug(len(n_best_preds))
                    predictions += n_best_preds
            all_scores += [scores]
            all_predictions += [predictions]

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        debug(total_correct)
        return all_scores, all_predictions
Ejemplo n.º 9
0
    def translate(self,
                  src,
                  tgt=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            tgt_path (str): filepath of target data or None
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")

        data = inputters.build_dataset(
            self.fields,
            self.data_type,
            src=src,
            tgt=tgt,
            src_dir=src_dir,
            sample_rate=self.sample_rate,
            window_size=self.window_size,
            window_stride=self.window_stride,
            window=self.window,
            use_filter_pred=self.use_filter_pred,
            image_channel_size=self.image_channel_size,
            dynamic_dict=self.copy_attn)

        cur_device = "cuda" if self.cuda else "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        builder = onmt.translate.TranslationBuilder(data, self.fields,
                                                    self.n_best,
                                                    self.replace_unk, tgt)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        for batch_count, batch in enumerate(data_iter):
            batch_data = self.translate_batch(batch,
                                              data,
                                              attn_debug,
                                              fast=self.fast)
            translations = builder.from_batch(batch_data)

            for trans_count, trans in enumerate(translations):
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                all_predictions += [n_best_preds]
                if (self.out_file is not None):
                    self.out_file.write('\n'.join(n_best_preds) + '\n')
                    self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *srcs) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))

                if self.attn_output is not None:
                    fig, ax = plt.subplots(figsize=(20, 20))
                    for n_best_index in range(self.n_best):
                        attn_matrix = trans.attns[n_best_index].data.numpy().T
                        (row, col) = attn_matrix.shape

                        below_threshold = True
                        for row_index in range(row):
                            for value in attn_matrix[row_index, :]:
                                if (value > self.attn_min_threshold):
                                    below_threshold = False
                                    start_index = row_index
                                    break
                            if (not below_threshold):
                                break

                        below_threshold = True
                        for row_index in reversed(range(row)):
                            for value in attn_matrix[row_index, :]:
                                if (value > self.attn_min_threshold):
                                    below_threshold = False
                                    end_index = row_index
                            if (not below_threshold):
                                break

                        if (start_index is None or end_index is None):
                            os.write(
                                1, "Cannot find attention values more than " +
                                str(self.attn_min_threshold))
                            os.write(1, "Please lower attn_min_threshold")
                        else:
                            attn_matrix = attn_matrix[start_index:end_index +
                                                      1, :]
                            src_raw = trans.src_raw[start_index:end_index + 1]
                            if (len(src_raw) > self.attn_max_src_length):
                                continue
                            else:
                                draw(attn_matrix,
                                     trans.pred_sents[n_best_index], src_raw,
                                     ax)
                                plt.savefig(
                                    os.path.join(
                                        self.attn_output,
                                        str(batch_count * batch_size +
                                            trans_count + 1) + "_" +
                                        str(n_best_index) + ".png"))
                                plt.cla()
                    plt.close(fig)

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            if self.logger:
                self.logger.info(msg)
            else:
                print(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions
Ejemplo n.º 10
0
def save_hidden_states(opt, args):
    OnmtArgumentParser.update_model_opts(opt)
    OnmtArgumentParser.validate_model_opts(opt)

    # load model
    model_path = os.path.join(opt.save_model,
                              args['onmt']['translate']['model'])
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model_opt = OnmtArgumentParser.ckpt_model_opts(checkpoint["opt"])
    OnmtArgumentParser.update_model_opts(model_opt)
    OnmtArgumentParser.validate_model_opts(model_opt)
    vocab = checkpoint['vocab']
    model = build_model(model_opt, opt, vocab, checkpoint)

    cache = []
    cache_idxs = [
    ]  # stores index into the training data and length of sentence
    for split in ('train', 'valid', 'test'):
        if split == 'test':
            test_file = args['data']['test_path']
            test_data = open(test_file, "r").readlines()
            test_src, test_tgt = zip(*[line.split("\t") for line in test_data])

            src_reader = inputters.str2reader["text"].from_opt(opt)
            tgt_reader = inputters.str2reader["text"].from_opt(opt)
            src_data = {"reader": src_reader, "data": test_src, "dir": None}
            tgt_data = {"reader": tgt_reader, "data": test_tgt, "dir": None}
            _readers, _data, _dir = inputters.Dataset.config([
                ('src', src_data), ('tgt', tgt_data)
            ])

            data = inputters.Dataset(vocab,
                                     readers=_readers,
                                     data=_data,
                                     dirs=_dir,
                                     sort_key=inputters.str2sortkey["text"],
                                     filter_pred=None)

            batch_iter = inputters.OrderedIterator(dataset=data,
                                                   device=args['device'],
                                                   batch_size=64,
                                                   batch_size_fn=None,
                                                   train=False,
                                                   sort=False,
                                                   sort_within_batch=True,
                                                   shuffle=False)
        else:
            train_dataset_paths = get_dataset_paths(opt,
                                                    split,
                                                    eos=args['lm']['use_eos'])

            batch_iter = DatasetLazyIter(
                train_dataset_paths,
                vocab,  # vocab
                64,  # batch size
                None,  # "batch_fn"
                1,  # "batch_size_multiple"
                args['device'],  # device
                True,  # is train
                8192,  # pool factor
                repeat=False,
                num_batches_multiple=1,
                yield_raw_example=False)

        tgt_field = vocab["tgt"].base_field
        tgt_pad_idx = tgt_field.vocab.stoi[tgt_field.pad_token]
        for batch_i, batch in tqdm(enumerate(batch_iter), desc=f"[{split}]"):
            # run through model
            if batch_i > 10000:
                break
            src, src_lengths = batch.src
            tgt = batch.tgt
            with torch.no_grad():
                hidden_states, attn = model(src,
                                            tgt,
                                            src_lengths,
                                            bptt=False,
                                            with_align=False)

            # save src idxs and hidden states
            pad_masks = (tgt[1:] != tgt_pad_idx).squeeze(2)
            cache.extend(hidden_states[pad_masks].cpu().numpy())
            cache_idxs.extend([(
                batch.indices[i].item(),
                pad_masks[:, i].sum().item(),
            ) for i in range(pad_masks.size(1))])

        cache = np.vstack(cache)
        # save the cache and the cache indices
        save_path = args['reporter']['results_path']
        print(save_path)
        print(cache.shape)
        np.save(os.path.join(save_path, f"cache.{split}.npy"), cache)
        with open(os.path.join(save_path, f"cache_idxs.{split}.csv"),
                  "w") as csvfile:
            csvfile.write("\n".join(
                [f"{idx},{length}" for idx, length in cache_idxs]))
        cache_idxs = []
        cache = []
Ejemplo n.º 11
0
    def translate(
        self,
        src,
        tgt=None,
        batch_size=None,
        batch_type="sents",
        attn_debug=False,
        align_debug=False,
        phrase_table="",
    ):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging
            align_debug (bool): enables the word alignment logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """

        if batch_size is None:
            raise ValueError("batch_size must be set")

        if self.tgt_prefix and tgt is None:
            raise ValueError("Prefix should be feed to tgt if -tgt_prefix.")

        src_data = {"reader": self.src_reader, "data": src}
        tgt_data = {"reader": self.tgt_reader, "data": tgt}
        _readers, _data = inputters.Dataset.config([("src", src_data),
                                                    ("tgt", tgt_data)])

        data = inputters.Dataset(
            self.fields,
            readers=_readers,
            data=_data,
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred,
        )

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            batch_size_fn=max_tok_len if batch_type == "tokens" else None,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False,
        )

        xlation_builder = onmt.translate.TranslationBuilder(
            data,
            self.fields,
            self.n_best,
            self.replace_unk,
            tgt,
            self.phrase_table,
        )

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        start_time = time.time()

        for batch in data_iter:
            batch_data = self.translate_batch(batch, data.src_vocabs,
                                              attn_debug)
            translations = xlation_builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                if self.report_align:
                    align_pharaohs = [
                        build_align_pharaoh(align)
                        for align in trans.word_aligns[:self.n_best]
                    ]
                    n_best_preds_align = [
                        " ".join(align) for align in align_pharaohs
                    ]
                    n_best_preds = [
                        pred + DefaultTokens.ALIGNMENT_SEPARATOR + align for
                        pred, align in zip(n_best_preds, n_best_preds_align)
                    ]
                all_predictions += [n_best_preds]
                self.out_file.write("\n".join(n_best_preds) + "\n")
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode("utf-8"))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append(DefaultTokens.EOS)
                    attns = trans.attns[0].tolist()
                    if self.data_type == "text":
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    output = report_matrix(srcs, preds, attns)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode("utf-8"))

                if align_debug:
                    tgts = trans.pred_sents[0]
                    align = trans.word_aligns[0].tolist()
                    if self.data_type == "text":
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(align[0]))]
                    output = report_matrix(srcs, tgts, align)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode("utf-8"))

        end_time = time.time()

        if self.report_score:
            msg = self._report_score("PRED", pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score("GOLD", gold_score_total,
                                         gold_words_total)
                self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" %
                      (total_time / len(all_predictions)))
            self._log("Tokens per second: %f" %
                      (pred_words_total / total_time))

        if self.dump_beam:
            import json

            json.dump(
                self.translator.beam_accum,
                codecs.open(self.dump_beam, "w", "utf-8"),
            )
        return all_scores, all_predictions
Ejemplo n.º 12
0
    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  batch_size=None,
                  attn_debug=False):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            src_data_iter (iterator): an interator generating source data
                e.g. it may be a list or an openned file
            tgt_path (str): filepath of target data
            tgt_data_iter (iterator): an interator generating target data
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src_data_iter is not None or src_path is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters. \
            build_dataset(self.fields,
                          self.data_type,
                          src_path=src_path,
                          src_data_iter=src_data_iter,
                          tgt_path=tgt_path,
                          tgt_data_iter=tgt_data_iter,
                          sample_rate=self.sample_rate,
                          window_size=self.window_size,
                          window_stride=self.window_stride,
                          window=self.window,
                          use_filter_pred=self.use_filter_pred,
                          topk_keywords=self.topk_keywords)

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=cur_device,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        builder = onmt.translate.TranslationBuilder(data, self.fields,
                                                    self.n_best,
                                                    self.replace_unk, tgt_path)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        _recall_1_result = []
        _mrr_1_result = []

        _recall_5_result = []
        _mrr_5_result = []

        _recall_10_result = []
        _mrr_10_result = []

        _recall_15_result = []
        _mrr_15_result = []

        _recall_20_result = []
        _mrr_20_result = []

        pred_sents_list = []
        gold_sents_list = []

        for batch in data_iter:
            batch_data = self.translate_batch(
                batch,
                data,
                fast=self.fast,
                skip_keyphrase=self.skip_keyphrase)

            # for click prediction
            scores_sku = batch_data["score_sku_prob"]
            # [b X 20]
            sorted_index_t = np.argsort(-scores_sku.detach())[:, :20].to(
                scores_sku.device)
            batch_data["pred_sku"] = sorted_index_t
            target_sku = batch_data["batch"].tgt_sku[0].squeeze()

            _preds = sorted_index_t.tolist()
            _golds = target_sku.tolist()

            for (_pred, _gold) in zip(_preds, _golds):
                # Recall@20
                if _gold in _pred:
                    _recall_20_result.append(1)
                    _mrr_20_result.append(1.0 / (_pred.index(_gold) + 1))
                else:
                    _recall_20_result.append(0)
                    _mrr_20_result.append(0)

                # Recall@15
                _tmp_pred = _pred[:15]
                if _gold in _tmp_pred:
                    _recall_15_result.append(1)
                    _mrr_15_result.append(1.0 / (_tmp_pred.index(_gold) + 1))
                else:
                    _recall_15_result.append(0)
                    _mrr_15_result.append(0)

                # Recall@10
                _tmp_pred = _pred[:10]
                if _gold in _tmp_pred:
                    _recall_10_result.append(1)
                    _mrr_10_result.append(1.0 / (_tmp_pred.index(_gold) + 1))
                else:
                    _recall_10_result.append(0)
                    _mrr_10_result.append(0)

                # Recall@5
                _tmp_pred = _pred[:5]
                if _gold in _tmp_pred:
                    _recall_5_result.append(1)
                    _mrr_5_result.append(1.0 / (_tmp_pred.index(_gold) + 1))
                else:
                    _recall_5_result.append(0)
                    _mrr_5_result.append(0)

                # Recall@1
                _tmp_pred = _pred[:1]
                if _gold in _tmp_pred:
                    _recall_1_result.append(1)
                    _mrr_1_result.append(1.0 / (_tmp_pred.index(_gold) + 1))
                else:
                    _recall_1_result.append(0)
                    _mrr_1_result.append(0)

            # if batch_data only contains batch data and score_sku_prob, skip the translation of decoder
            if len(batch_data.keys()) == 3:
                continue

            translations = builder.from_batch(batch_data)
            pred_sents_list.extend(
                [' '.join(_tran.pred_sents[0]) for _tran in translations])
            gold_sents_list.extend(
                [' '.join(_tran.gold_sent) for _tran in translations])

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt_path is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                # Debug attention.
                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *srcs) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))

        # print click results
        self.logger.info('Recall@1: {}\t\tMrr@1: {}'.format(
            np.array(_recall_1_result).mean(),
            np.array(_mrr_1_result).mean()))
        self.logger.info('Recall@5: {}\t\tMrr@5: {}'.format(
            np.array(_recall_5_result).mean(),
            np.array(_mrr_5_result).mean()))
        self.logger.info('Recall@10: {}\t\tMrr@10: {}'.format(
            np.array(_recall_10_result).mean(),
            np.array(_mrr_10_result).mean()))
        self.logger.info('Recall@15: {}\t\tMrr@15: {}'.format(
            np.array(_recall_15_result).mean(),
            np.array(_mrr_15_result).mean()))
        self.logger.info('Recall@20: {}\t\tMrr@20: {}'.format(
            np.array(_recall_20_result).mean(),
            np.array(_mrr_20_result).mean()))

        def save_pred(_results_dirname, _name, _results, _cut_off):
            with open(
                    '{}/pred_{}@{}'.format(_results_dirname, _name, _cut_off),
                    'w') as f:
                f.write('\n'.join(map(str, _results)))

        self.logger.info('Saving prediction results.')
        _date = self.pred_save_path.split('/')[1]
        _model = self.pred_save_path.split('/')[2].split('.')[0]

        if len(gold_sents_list) > 0:
            rouge = Rouge()
            rouge_scores = rouge.get_scores(pred_sents_list,
                                            gold_sents_list,
                                            avg=True)
            self.logger.info('Rouge-1:\tP:{}\t\tR:{}\t\tF1:{}'.format(
                rouge_scores['rouge-1']['p'], rouge_scores['rouge-1']['r'],
                rouge_scores['rouge-1']['f']))
            self.logger.info('Rouge-2:\tP:{}\t\tR:{}\t\tF1:{}'.format(
                rouge_scores['rouge-2']['p'], rouge_scores['rouge-2']['r'],
                rouge_scores['rouge-2']['f']))
            self.logger.info('Rouge-L:\tP:{}\t\tR:{}\t\tF1:{}'.format(
                rouge_scores['rouge-l']['p'], rouge_scores['rouge-l']['r'],
                rouge_scores['rouge-l']['f']))

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            if self.logger:
                self.logger.info(msg)
            else:
                print(msg)
            if tgt_path is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions
Ejemplo n.º 13
0
def translate(opt):
    """ Returns source and baseline embeddings"""

    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    baseline_shards = split_corpus(opt.baseline, opt.shard_size)

    print("\nEmbedding source and baseline...\n")

    # Loop for src_embedding
    for i, src_shard in enumerate(src_shards):
        src_data = {
            "reader": translator.src_reader,
            "data": src_shard,
            "dir": opt.src_dir
        }
        _readers, _data, _dir = inputters.Dataset.config([('src', src_data)])

        data = inputters.Dataset(
            translator.fields,
            readers=_readers,
            data=_data,
            dirs=_dir,
            sort_key=inputters.str2sortkey[translator.data_type],
            filter_pred=translator._filter_pred)

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=translator._dev,
            batch_size=opt.batch_size,
            batch_size_fn=translator.max_tok_len
            if opt.batch_type == "tokens" else None,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False)

        for batch in data_iter:
            src, src_lengths = batch.src
            src_embed = translator.model.encoder.embed(src, src_lengths)

    # Loop for baseline_embedding
    for i, src_shard in enumerate(baseline_shards):
        src_data = {
            "reader": translator.src_reader,
            "data": src_shard,
            "dir": opt.src_dir
        }
        _readers, _data, _dir = inputters.Dataset.config([('src', src_data)])

        data = inputters.Dataset(
            translator.fields,
            readers=_readers,
            data=_data,
            dirs=_dir,
            sort_key=inputters.str2sortkey[translator.data_type],
            filter_pred=translator._filter_pred)

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=translator._dev,
            batch_size=opt.batch_size,
            batch_size_fn=translator.max_tok_len
            if opt.batch_type == "tokens" else None,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False)

        for batch in data_iter:
            src, src_lengths = batch.src
            bline_embed = translator.model.encoder.embed(src, src_lengths)
    return src_embed, bline_embed
Ejemplo n.º 14
0
    def translate_gold_diff(self,
                            src,
                            tgt=None,
                            tgt2=None,
                            src_dir=None,
                            batch_size=None,
                            batch_type="sents",
                            attn_debug=False,
                            align_debug=False,
                            phrase_table="",
                            src_embed=None,
                            hidden_state=None,
                            unlearn=False):
        """Translate content of ``src`` and get gold score difference of ``tgt`` and "tgt2".

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging
            align_debug (bool): enables the word alignment logging
            src_embed: embeddings of the source
            hidden_state (torch tensor) output of the encoder layers
        Returns:
            gold_score1 = gold_score2 (torch tensor)
        """

        if batch_size is None:
            raise ValueError("batch_size must be set")

        src_data = {"reader": self.src_reader, "data": src, "dir": src_dir}
        tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None}
        _readers, _data, _dir = inputters.Dataset.config([('src', src_data),
                                                          ('tgt', tgt_data)])

        data = inputters.Dataset(
            self.fields,
            readers=_readers,
            data=_data,
            dirs=_dir,
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred)

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            batch_size_fn=max_tok_len if batch_type == "tokens" else None,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False)

        if unlearn:
            with torch.no_grad():
                for i, batch in enumerate(data_iter):
                    gold_scores_1, src, enc_states, memory_bank, src_lengths = self.translate_batch(
                        batch,
                        data.src_vocabs,
                        attn_debug,
                        src_embed=src_embed,
                        hidden_state=hidden_state)

        else:
            for i, batch in enumerate(data_iter):
                gold_scores_1, src, enc_states, memory_bank, src_lengths = self.translate_batch(
                    batch,
                    data.src_vocabs,
                    attn_debug,
                    src_embed=src_embed,
                    hidden_state=hidden_state)
        gold_scores_1 = torch.exp(gold_scores_1)

        if tgt2 is not None:
            tgt2_data = {"reader": self.tgt2_reader, "data": tgt2, "dir": None}
            _readers, _data, _dir = inputters.Dataset.config([
                ('src', src_data), ('tgt', tgt2_data)
            ])

            data = inputters.Dataset(
                self.fields,
                readers=_readers,
                data=_data,
                dirs=_dir,
                sort_key=inputters.str2sortkey[self.data_type],
                filter_pred=self._filter_pred)

            data_iter = inputters.OrderedIterator(
                dataset=data,
                device=self._dev,
                batch_size=batch_size,
                batch_size_fn=max_tok_len if batch_type == "tokens" else None,
                train=False,
                sort=False,
                sort_within_batch=True,
                shuffle=False)

            for batch in data_iter:
                gold_scores_2 = self.translate_batch(batch,
                                                     data.src_vocabs,
                                                     attn_debug,
                                                     src,
                                                     enc_states,
                                                     memory_bank,
                                                     src_lengths,
                                                     src_embed,
                                                     tgt2=True,
                                                     hidden_state=hidden_state)
            gold_scores_2 = torch.exp(gold_scores_2)
        else:
            gold_scores_2 = 0

        return gold_scores_1 - gold_scores_2
Ejemplo n.º 15
0
    def translate(self,
                  src,
                  tgt=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False,
                  attn_vis=False):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters.Dataset(
            self.fields,
            readers=([self.src_reader, self.tgt_reader]
                     if tgt else [self.src_reader]),
            data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)],
            dirs=[src_dir, None] if tgt else [src_dir],
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred)
        data_iter = inputters.OrderedIterator(dataset=data,
                                              device=self._dev,
                                              batch_size=batch_size,
                                              train=False,
                                              sort=False,
                                              sort_within_batch=True,
                                              shuffle=False)

        xlation_builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        start_time = time.time()

        for batch in data_iter:
            batch_data = self.translate_batch(batch, data.src_vocabs,
                                              attn_debug)

            translations = xlation_builder.from_batch(batch_data)

            for i, trans in enumerate(translations):
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if attn_vis:
                    src = trans.src_raw
                    pred = trans.pred_sents[0]
                    pred.append("'</s>'")
                    attns = trans.attns[0]
                    attns = attns[:len(pred), :len(src)].data.cpu()

                    def draw(data, x, y, ax, title):
                        ax.set_title(title)
                        seaborn.heatmap(data,
                                        xticklabels=x,
                                        square=True,
                                        yticklabels=y,
                                        vmin=0.0,
                                        vmax=1.0,
                                        cbar=False,
                                        ax=ax)

                    # plots attention over inputs
                    seaborn.heatmap(attns,
                                    xticklabels=src,
                                    square=True,
                                    yticklabels=pred,
                                    vmin=0.0,
                                    vmax=1.0,
                                    cbar=False)

                    plt.show()

                    for layer in range(0, 6, 2):
                        fig, axs = plt.subplots(1, 4, figsize=(12, 6))
                        for h in range(4):
                            data = list(self.model.encoder.children()
                                        )[1][layer].self_attn.attn.data.cpu()
                            data = data[i, h]
                            data = data[:len(src), :len(src)]
                            draw(data,
                                 src,
                                 src if h == 0 else [],
                                 ax=axs[h],
                                 title="Head {}".format(h))
#                         plt.suptitle("Attention Distribution for Layer {}".format(layer), y=.8)
                        plt.show()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *srcs) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))
        end_time = time.time()

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                self._log(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt)
                    self._log(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt)
                    self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" %
                      (total_time / len(all_predictions)))
            self._log("Tokens per second: %f" %
                      (pred_words_total / total_time))

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions