Ejemplo n.º 1
0
    def log(self, sent_number):
        """
        Log translation.
        """

        msg = ['\nSENT {}: {}\n'.format(sent_number, self.src_raw)]

        best_pred = self.pred_sents[0]
        best_score = self.pred_scores[0]
        pred_sent = ' '.join(best_pred)
        msg.append('PRED {}: {}\n'.format(sent_number, pred_sent))
        msg.append("PRED SCORE: {:.4f}\n".format(best_score))

        if self.word_aligns is not None:
            pred_align = self.word_aligns[0]
            pred_align_pharaoh = build_align_pharaoh(pred_align)
            pred_align_sent = ' '.join(pred_align_pharaoh)
            msg.append("ALIGN: {}\n".format(pred_align_sent))

        if self.gold_sent is not None:
            tgt_sent = ' '.join(self.gold_sent)
            msg.append('GOLD {}: {}\n'.format(sent_number, tgt_sent))
            msg.append(("GOLD SCORE: {:.4f}\n".format(self.gold_score)))
        if len(self.pred_sents) > 1:
            msg.append('\nBEST HYP:\n')
            for score, sent in zip(self.pred_scores, self.pred_sents):
                msg.append("[{:.4f}] {}\n".format(score, sent))

        return "".join(msg)
Ejemplo n.º 2
0
    def translate(self,
                  src,
                  tgt=None,
                  src_dir=None,
                  batch_size=None,
                  batch_type="sents",
                  attn_debug=False,
                  align_debug=False,
                  phrase_table=""):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging
            align_debug (bool): enables the word alignment logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """

        if batch_size is None:
            raise ValueError("batch_size must be set")

        src_data = {"reader": self.src_reader, "data": src, "dir": src_dir}
        tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None}
        _readers, _data, _dir = inputters.Dataset.config([('src', src_data),
                                                          ('tgt', tgt_data)])

        data = inputters.Dataset(
            self.fields,
            readers=_readers,
            data=_data,
            dirs=_dir,
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred)

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            batch_size_fn=max_tok_len if batch_type == "tokens" else None,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False)

        xlation_builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt,
            self.phrase_table)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        start_time = time.time()

        for batch in data_iter:
            batch_data = self.translate_batch(batch, data.src_vocabs,
                                              attn_debug)
            translations = xlation_builder.from_batch(batch_data)
            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                if self.report_align:
                    align_pharaohs = [
                        build_align_pharaoh(align)
                        for align in trans.word_aligns[:self.n_best]
                    ]
                    n_best_preds_align = [
                        " ".join(align) for align in align_pharaohs
                    ]
                    n_best_preds = [
                        pred + " ||| " + align for pred, align in zip(
                            n_best_preds, n_best_preds_align)
                    ]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    output = report_matrix(srcs, preds, attns)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if align_debug:
                    if trans.gold_sent is not None:
                        tgts = trans.gold_sent
                    else:
                        tgts = trans.pred_sents[0]
                    align = trans.word_aligns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(align[0]))]
                    output = report_matrix(srcs, tgts, align)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

        end_time = time.time()

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" %
                      (total_time / len(all_predictions)))
            self._log("Tokens per second: %f" %
                      (pred_words_total / total_time))

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions
Ejemplo n.º 3
0
    def translate(self,
                  src,
                  tgt=None,
                  src_dir=None,
                  batch_size=None,
                  batch_type="sents",
                  attn_debug=False,
                  align_debug=False,
                  phrase_table=""):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging
            align_debug (bool): enables the word alignment logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """

        if batch_size is None:
            raise ValueError("batch_size must be set")

        src_data = {"reader": self.src_reader, "data": src, "dir": src_dir}
        tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None}
        _readers, _data, _dir = inputters.Dataset.config([('src', src_data),
                                                          ('tgt', tgt_data)])

        data = inputters.Dataset(
            self.fields,
            readers=_readers,
            data=_data,
            dirs=_dir,
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred)

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            batch_size_fn=max_tok_len if batch_type == "tokens" else None,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False)

        xlation_builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt,
            self.phrase_table)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        # compute accuracy like we do during training across the entire test set (or single word)
        total_correct_words, total_num_words = 0, 0

        all_scores = []
        all_predictions = []

        start_time = time.time()

        skipped = 0
        total_num_utts = 0
        batch_index = 0
        for batch in data_iter:
            batch_data = self.translate_batch(batch, data.src_vocabs,
                                              attn_debug)
            translations = xlation_builder.from_batch(batch_data)

            trans_index = 0
            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                    # ##################################################################
                    # # IOHAVOC - compute accuracy
                    # #                                  data_iter.batches[batch][trans_i].tgt[0]
                    # # target_sentence = trans.gold_sent  # data_iter.batches[0][0].tgt[0]   ### <<----- this only works for once
                    # # target_sentence = data_iter.batches[batch_index][trans_index].tgt[0]
                    # target_sentence = trans.gold_sent
                    # total_num_utts += 1
                    #
                    # # if "<unk>" in target_sentence:
                    # #     self._log("<UNK> in target_sentence .. skipping")
                    # #     skipped += 1
                    # #     continue
                    #
                    # # by summing here you count UNKS, when you want to discount unks
                    # # if True:  # style 2
                    # #     stopwords = ['<unk>']
                    # #     target_sentence_for_not_counting_unks = [word for word in trans.gold_sent if word not in stopwords]
                    # #     total_num_words += len(target_sentence_for_not_counting_unks)
                    # # else:  # style 1
                    #
                    # assert(len(trans.pred_sents) == 1)
                    # n_best_preds = trans.pred_sents[0]
                    #
                    # if len(target_sentence) != len(n_best_preds): # make sure we predicted the same num words
                    #     # IOHAVOC FEB9 EVAL
                    #     self._log("ERROR why??? ")
                    #     if "<unk>" in target_sentence:
                    #         self._log("<UNK> in PHRASE, test set vocab mismatch is messing up predictions")
                    #     elif len(target_sentence) > 0 and len(n_best_preds) > 0:
                    #         self._log("UNEVEN LENGTHS => " + str(target_sentence) + " + =>" + str(n_best_preds))
                    #     else:
                    #         self._log("ANOTHER REASON")
                    #     skipped += 1
                    #     continue
                    #
                    # current_trans_correct = 0
                    # for i in range(len(target_sentence)):
                    #     if target_sentence[i] == "<unk>":
                    #         continue
                    #     if target_sentence[i] == n_best_preds[i]:
                    #         total_correct_words += 1
                    #         current_trans_correct += 1
                    #     else:
                    #         if False: # IOHAVOC FEB9 EVAL
                    #             print("error words: " + target_sentence[i] + " " + n_best_preds[i])
                    #
                    # total_num_words += len(target_sentence)
                    #
                    # # IOHAVOC FEB9 EVAL
                    # #
                    # # self._log("")
                    # # self._log("")
                    # # self._log("-----------------------------------")
                    # # self._log("num_correct_words: " + str(current_trans_correct))
                    # # # self._log("num_words: " + str(len(target_sentence_for_not_counting_unks)))
                    # # self._log("num_words: " + str(len(target_sentence)))
                    # # # self._log("num_words: " + str(len(trans.gold_sent)))
                    #
                    # ##################################################################

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                if self.report_align:
                    align_pharaohs = [
                        build_align_pharaoh(align)
                        for align in trans.word_aligns[:self.n_best]
                    ]
                    n_best_preds_align = [
                        " ".join(align) for align in align_pharaohs
                    ]
                    n_best_preds = [
                        pred + " ||| " + align for pred, align in zip(
                            n_best_preds, n_best_preds_align)
                    ]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                # IOHAVOC FEB9 EVAL
                # IOHAVOC UNCOMMENT THIS FOR ERROR ANALYSES & TO SEE EACH UTTERANCES RESULTS
                # if self.verbose:
                #     sent_number = next(counter)
                #     output = trans.log(sent_number)
                #     if self.logger:
                #         self.logger.info(output)
                #     else:
                #         os.write(1, output.encode('utf-8'))
                #     print("TARGET: " + " ".join(target_sentence))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    output = report_matrix(srcs, preds, attns)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if align_debug:
                    if trans.gold_sent is not None:
                        tgts = trans.gold_sent
                    else:
                        tgts = trans.pred_sents[0]
                    align = trans.word_aligns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(align[0]))]
                    output = report_matrix(srcs, tgts, align)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

        end_time = time.time()

        if self.report_score:
            # ##################################################################
            # self._log("-----------------------------------")
            # self._log("Running total correct_words: " + str(total_correct_words))
            # self._log("Running total num_words: " + str(total_num_words))
            # self._log("Accuracy (%): " + str(100 * (total_correct_words / total_num_words)))
            # self._log("-----------------------------------")
            # ##################################################################

            print("\nskipped: " + str(skipped))
            print("total_num_utts: " + str(total_num_utts))
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            self._log(msg)

            if tgt is not None:
                # GOLD SCORE IS ANNOYING
                # msg = self._report_score('GOLD', gold_score_total, gold_words_total)
                # self._log(msg)
                self.report_bleu = True
                if self.report_bleu:
                    msg = self._report_bleu(tgt)
                    self._log(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt)
                    self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" %
                      (total_time / len(all_predictions)))
            self._log("Tokens per second: %f" %
                      (pred_words_total / total_time))

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions
Ejemplo n.º 4
0
    def attention_analysis(self,
                           direction,
                           src,
                           tgt,
                           batch_type="sents",
                           phrase_table=""):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging
            align_debug (bool): enables the word alignment logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert direction in ['x2y', 'y2x']
        self.model.encoder = (self.model.encoder_x2y if direction == 'x2y' else
                              self.model.encoder_y2x)
        self.model.decoder = (self.model.decoder_x2y if direction == 'x2y' else
                              self.model.decoder_y2x)
        self.model.generator = (self.model.generator_x2y if direction == 'x2y'
                                else self.model.generator_y2x)
        self.direction = direction

        batch_size = len(src)

        src_data = {"reader": self.src_reader, "data": src, "dir": None}
        tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None}
        _readers, _data, _dir = inputters.Dataset.config([('src', src_data),
                                                          ('tgt', tgt_data)])

        # corpus_id field is useless here
        if self.fields.get("corpus_id", None) is not None:
            self.fields.pop('corpus_id')
        data = inputters.Dataset(
            self.fields,
            readers=_readers,
            data=_data,
            dirs=_dir,
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred)

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            batch_size_fn=max_tok_len if batch_type == "tokens" else None,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False)

        xlation_builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt,
            self.phrase_table)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_gold_scores = []
        all_predictions = []
        all_attentions = []

        start_time = time.time()

        for batch in tqdm(data_iter):
            batch_data = self.translate_batch(batch,
                                              data.src_vocabs,
                                              attn_debug=True,
                                              only_gold_score=False)
            translations = xlation_builder.from_batch(batch_data)

            for trans in translations:
                n_best_scores = trans.pred_scores[:self.n_best]
                all_scores += [n_best_scores]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])

                n_best_gold_scores = [trans.gold_score]
                all_gold_scores += [n_best_gold_scores]
                gold_score_total += trans.gold_score
                gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                if self.report_align:
                    align_pharaohs = [
                        build_align_pharaoh(align)
                        for align in trans.word_aligns[:self.n_best]
                    ]
                    n_best_preds_align = [
                        " ".join(align) for align in align_pharaohs
                    ]
                    n_best_preds = [
                        pred + " ||| " + align for pred, align in zip(
                            n_best_preds, n_best_preds_align)
                    ]
                all_predictions += [n_best_preds]
                if self.out_file:
                    if self.log_score:
                        # in BWD translation(tgt=product, n_best==1),
                        # we use gold score
                        if self.n_best == 1 and tgt is not None:
                            n_best_scores = n_best_gold_scores
                        n_best_preds_scores = [
                            pred + ',' + str(score.item())
                            for pred, score in zip(n_best_preds, n_best_scores)
                        ]
                        self.out_file.write('\n'.join(n_best_preds_scores) +
                                            '\n')
                        self.out_file.flush()
                    else:
                        self.out_file.write('\n'.join(n_best_preds) + '\n')
                        self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    os.write(1, output.encode('utf-8'))

                for i in range(self.beam_size):
                    preds = trans.pred_sents[i]
                    preds.append('</s>')
                    attns = trans.attns[i].tolist()
                    srcs = trans.src_raw
                    print(srcs, len(srcs), len(attns), len(attns[0]))
                    output = report_matrix(srcs, preds, attns)
                    os.write(1, output.encode('utf-8'))
                    all_attentions.append(trans.attns[i].cpu().numpy())

        end_time = time.time()

        if self.report_score:
            msg = self._report_score('GOLD', gold_score_total,
                                     gold_words_total)
            self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" %
                      (total_time / len(all_predictions)))
            self._log("Tokens per second: %f" %
                      (pred_words_total / total_time))

        return all_scores, all_predictions, all_attentions
Ejemplo n.º 5
0
    def translate(self,
                  src,
                  tgt=None,
                  src_dir=None,
                  batch_size=None,
                  batch_type="sents",
                  attn_debug=False,
                  align_debug=False,
                  phrase_table="",
                  partial=None,
                  dymax_len=None):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging
            align_debug (bool): enables the word alignment logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
            * attns is a list of attention scores for translation having highest cumilative log likelihood
        """
        self.dymax_len = dymax_len
        self.partialf = None

        # To check with partial words
        partialfcheck = True

        # To check with editdistance, put True. To check with just startswith which will be prone to errors due to spelling mistakes, put False.
        partialfedit = False

        # partialopt = True

        # Logic for partial and partialf
        if partial and partial != '':
            partials = partial.split()
            print(partials, '~~~~partials~~~')
            vocabdict = dict(self.fields)["tgt"].base_field.vocab
            # if vocabdict.stoi[partials[-1]] == 0:
            if partialfcheck:
                # if partialfedit:
                #     parlen = len(partials[-1])
                #     f = lambda x: 1 + editdistance.eval(x[:parlen], partials[-1]) * 20
                # else:
                #     f = lambda x: float('inf') if not x.startswith(partials[-1]) else float('1.0')

                # editarr = [(f(k) , v) for k, v in vocabdict.stoi.items() if v]
                # self.partialf = [20.0] + [i[0] for i in sorted(editarr, key=lambda x: x[1])]

                self.partial = [vocabdict.stoi[x] for x in partials[:-1]]
                print("#########vocabdict.stoi########")
                print(self.partial)
                print("##################################")

                self.partialf = [
                    v for k, v in vocabdict.stoi.items()
                    if k.startswith(partials[-1]) and v
                ]
            else:
                self.partial = [vocabdict.stoi[x] for x in partials]
            # else:
            #     self.partialf = None
            #     self.partial = [vocabdict.stoi[x] for x in partials]
        else:
            self.partial = None
            # self.partialf = None

        if batch_size is None:
            raise ValueError("batch_size must be set")

        src_data = {"reader": self.src_reader, "data": src, "dir": src_dir}
        tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None}
        _readers, _data, _dir = inputters.Dataset.config([('src', src_data),
                                                          ('tgt', tgt_data)])

        # corpus_id field is useless here
        if self.fields.get("corpus_id", None) is not None:
            self.fields.pop('corpus_id')
        data = inputters.Dataset(
            self.fields,
            readers=_readers,
            data=_data,
            dirs=_dir,
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred)

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            batch_size_fn=max_tok_len if batch_type == "tokens" else None,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False)

        xlation_builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt,
            self.phrase_table)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = [
        ]  # I guess this is the cumilative log likelihood score of each sentence
        all_predictions = []

        start_time = time.time()

        for batch in data_iter:
            batch_data = self.translate_batch(batch, data.src_vocabs,
                                              attn_debug)
            translations = xlation_builder.from_batch(batch_data)

            for trans in translations:
                print("Loop")
                print(trans, trans.pred_sents)
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]

                print("############n_best_preds###############")
                print(n_best_preds)
                print("############n_best_preds###############")

                if self.report_align:
                    align_pharaohs = [
                        build_align_pharaoh(align)
                        for align in trans.word_aligns[:self.n_best]
                    ]
                    n_best_preds_align = [
                        " ".join(align) for align in align_pharaohs
                    ]
                    n_best_preds = [
                        pred + " ||| " + align for pred, align in zip(
                            n_best_preds, n_best_preds_align)
                    ]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    output = report_matrix(
                        srcs, preds, attns
                    )  # This prints attentions in output for the sentence having highest cumilative log likelihood score

                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if align_debug:
                    if trans.gold_sent is not None:
                        tgts = trans.gold_sent
                    else:
                        tgts = trans.pred_sents[0]
                    align = trans.word_aligns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(align[0]))]
                    output = report_matrix(srcs, tgts, align)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

        end_time = time.time()

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" %
                      (total_time / len(all_predictions)))
            self._log("Tokens per second: %f" %
                      (pred_words_total / total_time))

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))

        if attn_debug:
            return all_scores, all_predictions, attns, pred_score_total, pred_words_total
        else:
            return all_scores, all_predictions, pred_score_total, pred_words_total
Ejemplo n.º 6
0
    def translate(self,
                  src,
                  tgt=None,
                  src_dir=None,
                  batch_size=None,
                  batch_type="sents",
                  attn_debug=False,
                  align_debug=False,
                  phrase_table="",
                  opt=None):
        """Translate content of ``src`` and get gold scores from ``tgt``.

        Args:
            src: See :func:`self.src_reader.read()`.
            tgt: See :func:`self.tgt_reader.read()`.
            src_dir: See :func:`self.src_reader.read()` (only relevant
                for certain types of data).
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging
            align_debug (bool): enables the word alignment logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """

        if batch_size is None:
            raise ValueError("batch_size must be set")

        # modified by @memray to accommodate keyphrase
        src_data = {"reader": self.src_reader, "data": src, "dir": src_dir}
        tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None}
        _readers, _data, _dir = inputters.Dataset.config([('src', src_data),
                                                          ('tgt', tgt_data)])

        data = inputters.str2dataset[self.data_type](
            self.fields,
            readers=_readers,
            data=_data,
            dirs=_dir,
            sort_key=inputters.str2sortkey[self.data_type],
            filter_pred=self._filter_pred)

        # @memray, as Dataset is only instantiated here, having to use this plugin setter
        if isinstance(data, KeyphraseDataset):
            data.tgt_type = self.tgt_type

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            batch_size_fn=max_tok_len if batch_type == "tokens" else None,
            train=False,
            sort=False,
            # sort_within_batch=True,
            sort_within_batch=False,  #@memray: to keep the original order
            shuffle=False)

        xlation_builder = onmt.translate.TranslationBuilder(
            data, self.fields, self.n_best, self.replace_unk, tgt,
            self.phrase_table)
        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        start_time = time.time()

        num_examples = 0
        for batch in data_iter:
            num_examples += batch_size
            print("Translating %d/%d" % (num_examples, len(src)))

            batch_data = self.translate_batch(batch, data.src_vocabs,
                                              attn_debug)
            translations = xlation_builder.from_batch(batch_data)

            # @memray
            if self.data_type == "keyphrase":
                # post-process for one2seq outputs, split seq into individual phrases
                if self.model_tgt_type != 'one2one':
                    translations = self.segment_one2seq_trans(translations)
                # add statistics of kps(pred_num, beamstep_num etc.)
                translations = self.add_trans_stats(translations,
                                                    self.model_tgt_type)

                # add copied flag
                vocab_size = len(self.fields['src'].base_field.vocab.itos)
                for t in translations:
                    t.add_copied_flags(vocab_size)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                if self.report_align:
                    align_pharaohs = [
                        build_align_pharaoh(align)
                        for align in trans.word_aligns[:self.n_best]
                    ]
                    n_best_preds_align = [
                        " ".join(align) for align in align_pharaohs
                    ]
                    n_best_preds = [
                        pred + " ||| " + align for pred, align in zip(
                            n_best_preds, n_best_preds_align)
                    ]
                all_predictions += [n_best_preds]

                if self.out_file:
                    import json
                    if self.data_type == "keyphrase":
                        self.out_file.write(
                            json.dumps(trans.__dict__()) + '\n')
                        self.out_file.flush()
                    else:
                        self.out_file.write('\n'.join(n_best_preds) + '\n')
                        self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    if self.data_type == "keyphrase":
                        output = trans.log_kp(sent_number)
                    else:
                        output = trans.log(sent_number)

                    if self.verbose:
                        if self.logger:
                            self.logger.info(output)
                        else:
                            os.write(1, output.encode('utf-8'))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    output = report_matrix(srcs, preds, attns)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                if align_debug:
                    if trans.gold_sent is not None:
                        tgts = trans.gold_sent
                    else:
                        tgts = trans.pred_sents[0]
                    align = trans.word_aligns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(align[0]))]
                    output = report_matrix(srcs, tgts, align)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

        end_time = time.time()

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                self._log(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt)
                    self._log(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt)
                    self._log(msg)
                if self.report_kpeval:
                    # don't run eval here. because in opt.tgt rare words are replaced by <unk>
                    pass
                    # msg = self._report_kpeval(opt.src, opt.tgt, opt.output)
                    # self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" %
                      (total_time / len(all_predictions)))
            self._log("Tokens per second: %f" %
                      (pred_words_total / total_time))

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))

        return all_scores, all_predictions
Ejemplo n.º 7
0
    def _translate(self,
                   data,
                   tgt=None,
                   batch_size=None,
                   batch_type="sents",
                   attn_debug=False,
                   align_debug=False,
                   phrase_table="",
                   dynamic=False,
                   transform=None):

        data_iter = inputters.OrderedIterator(
            dataset=data,
            device=self._dev,
            batch_size=batch_size,
            batch_size_fn=max_tok_len if batch_type == "tokens" else None,
            train=False,
            sort=False,
            sort_within_batch=True,
            shuffle=False,
        )

        xlation_builder = onmt.translate.TranslationBuilder(
            data,
            self.fields,
            self.n_best,
            self.replace_unk,
            tgt,
            self.phrase_table,
        )

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        start_time = time.time()

        for batch in data_iter:
            batch_data = self.translate_batch(batch, data.src_vocabs,
                                              attn_debug)
            translations = xlation_builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [
                    " ".join(pred) for pred in trans.pred_sents[:self.n_best]
                ]
                if self.report_align:
                    align_pharaohs = [
                        build_align_pharaoh(align)
                        for align in trans.word_aligns[:self.n_best]
                    ]
                    n_best_preds_align = [
                        " ".join(align) for align in align_pharaohs
                    ]
                    n_best_preds = [
                        pred + DefaultTokens.ALIGNMENT_SEPARATOR + align for
                        pred, align in zip(n_best_preds, n_best_preds_align)
                    ]

                if dynamic:
                    n_best_preds = [
                        transform.apply_reverse(x) for x in n_best_preds
                    ]
                all_predictions += [n_best_preds]
                self.out_file.write("\n".join(n_best_preds) + "\n")
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode("utf-8"))

                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append(DefaultTokens.EOS)
                    attns = trans.attns[0].tolist()
                    if self.data_type == "text":
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    output = report_matrix(srcs, preds, attns)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode("utf-8"))

                if align_debug:
                    tgts = trans.pred_sents[0]
                    align = trans.word_aligns[0].tolist()
                    if self.data_type == "text":
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(align[0]))]
                    output = report_matrix(srcs, tgts, align)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode("utf-8"))

        end_time = time.time()

        if self.report_score:
            msg = self._report_score("PRED", pred_score_total,
                                     pred_words_total)
            self._log(msg)
            if tgt is not None:
                msg = self._report_score("GOLD", gold_score_total,
                                         gold_words_total)
                self._log(msg)

        if self.report_time:
            total_time = end_time - start_time
            self._log("Total translation time (s): %f" % total_time)
            self._log("Average translation time (s): %f" %
                      (total_time / len(all_predictions)))
            self._log("Tokens per second: %f" %
                      (pred_words_total / total_time))

        if self.dump_beam:
            import json

            json.dump(
                self.translator.beam_accum,
                codecs.open(self.dump_beam, "w", "utf-8"),
            )
        return all_scores, all_predictions