コード例 #1
0
ファイル: translation_model.py プロジェクト: lrjxgl/seq2seq-1
    def decode_batch(self, sentence_tuples, batch_size, remove_unk=False, fix_edits=True):
        if batch_size == 1:
            batches = ([sentence_tuple] for sentence_tuple in sentence_tuples)   # lazy
        else:
            batch_count = int(math.ceil(len(sentence_tuples) / batch_size))
            batches = [sentence_tuples[i * batch_size:(i + 1) * batch_size] for i in range(batch_count)]

        def map_to_ids(sentence_tuple):
            token_ids = [
                sentence if vocab is None else
                utils.sentence_to_token_ids(sentence, vocab.vocab, character_level=self.character_level.get(ext))
                for ext, vocab, sentence in zip(self.extensions, self.vocabs, sentence_tuple)
            ]
            return token_ids

        for batch_id, batch in enumerate(batches):
            token_ids = list(map(map_to_ids, batch))
            batch_token_ids = self.seq2seq_model.greedy_decoding(token_ids)
            batch_token_ids = zip(*batch_token_ids)

            for src_tokens, trg_token_ids in zip(batch, batch_token_ids):
                trg_tokens = []

                for trg_token_ids_, vocab in zip(trg_token_ids, self.trg_vocab):
                    trg_token_ids_ = list(trg_token_ids_)   # from np array to list
                    if utils.EOS_ID in trg_token_ids_:
                        trg_token_ids_ = trg_token_ids_[:trg_token_ids_.index(utils.EOS_ID)]

                    trg_tokens_ = [vocab.reverse[i] if i < len(vocab.reverse) else utils._UNK
                                   for i in trg_token_ids_]
                    trg_tokens.append(trg_tokens_)

                if self.pred_edits:
                    # first output is ops, second output is words
                    raw_hypothesis = ' '.join('_'.join(tokens) for tokens in zip(*trg_tokens))
                    trg_tokens = utils.reverse_edits(src_tokens[0].split(), trg_tokens, fix=fix_edits)
                    trg_tokens = [token for token in trg_tokens if token not in utils._START_VOCAB]
                    # FIXME: char-level
                else:
                    trg_tokens = trg_tokens[0]
                    raw_hypothesis = ''.join(trg_tokens) if self.char_output else ' '.join(trg_tokens)

                if remove_unk:
                    trg_tokens = [token for token in trg_tokens if token != utils._UNK]

                if self.char_output:
                    hypothesis = ''.join(trg_tokens)
                else:
                    hypothesis = ' '.join(trg_tokens).replace('@@ ', '')  # merge subwords units

                yield hypothesis, raw_hypothesis
コード例 #2
0
    def evaluate(self,
                 sess,
                 beam_size,
                 score_function,
                 on_dev=True,
                 output=None,
                 remove_unk=False,
                 max_dev_size=None,
                 script_dir='scripts',
                 early_stopping=True,
                 use_edits=False,
                 **kwargs):
        """
        :param score_function: name of the scoring function used to score and rank models
          (typically 'bleu_score')
        :param on_dev: if True, evaluate the dev corpus, otherwise evaluate the test corpus
        :param output: save the hypotheses to this file
        :param remove_unk: remove the UNK symbols from the output
        :param max_dev_size: maximum number of lines to read from dev files
        :param script_dir: parameter of scoring functions
        :return: scores of each corpus to evaluate
        """
        utils.log('starting decoding')
        assert on_dev or len(self.filenames.test) == len(self.extensions)

        filenames = self.filenames.dev if on_dev else [self.filenames.test]

        # convert `output` into a list, for zip
        if isinstance(output, str):
            output = [output]
        elif output is None:
            output = [None] * len(filenames)

        scores = []

        for filenames_, output_ in zip(
                filenames, output):  # evaluation on multiple corpora
            lines = list(
                utils.read_lines(filenames_, self.extensions,
                                 self.binary_input))
            if on_dev and max_dev_size:
                lines = lines[:max_dev_size]

            hypotheses = []
            references = []

            output_file = None

            try:
                if output_ is not None:
                    output_file = open(output_, 'w')

                *src_sentences, trg_sentences = zip(*lines)
                src_sentences = list(zip(*src_sentences))

                hypothesis_iter = self._decode_batch(
                    sess,
                    src_sentences,
                    self.batch_size,
                    beam_size=beam_size,
                    early_stopping=early_stopping,
                    remove_unk=remove_unk,
                    use_edits=use_edits)
                for sources, hypothesis, reference in zip(
                        src_sentences, hypothesis_iter, trg_sentences):
                    if use_edits:
                        reference = utils.reverse_edits(sources[0], reference)

                    hypotheses.append(hypothesis)
                    references.append(reference.strip().replace('@@ ', ''))

                    if output_file is not None:
                        output_file.write(hypothesis + '\n')
                        output_file.flush()

            finally:
                if output_file is not None:
                    output_file.close()

            # default scoring function is utils.bleu_score
            score, score_summary = getattr(evaluation, score_function)(
                hypotheses, references, script_dir=script_dir)

            # print the scoring information
            score_info = []
            if self.name is not None:
                score_info.append(self.name)
            score_info.append('score={:.2f}'.format(score))
            if score_summary:
                score_info.append(score_summary)

            utils.log(' '.join(map(str, score_info)))
            scores.append(score)

        return scores
コード例 #3
0
    def _decode_batch(self,
                      sess,
                      sentence_tuples,
                      batch_size,
                      beam_size=1,
                      remove_unk=False,
                      early_stopping=True,
                      use_edits=False):
        beam_search = beam_size > 1 or isinstance(sess, list)

        if beam_search:
            batch_size = 1

        if batch_size == 1:
            batches = ([sentence_tuple]
                       for sentence_tuple in sentence_tuples)  # lazy
        else:
            batch_count = int(math.ceil(len(sentence_tuples) / batch_size))
            batches = [
                sentence_tuples[i * batch_size:(i + 1) * batch_size]
                for i in range(batch_count)
            ]

        def map_to_ids(sentence_tuple):
            token_ids = [
                utils.sentence_to_token_ids(
                    sentence, vocab.vocab, character_level=char_level)
                if vocab is not None else
                sentence  # when `sentence` is not a sentence but a vector...
                for vocab, sentence, char_level in zip(
                    self.vocabs, sentence_tuple, self.character_level)
            ]
            return token_ids

        for batch in batches:
            token_ids = list(map(map_to_ids, batch))

            if beam_search:
                hypotheses, _ = self.seq2seq_model.beam_search_decoding(
                    sess,
                    token_ids[0],
                    beam_size,
                    ngrams=self.ngrams,
                    early_stopping=early_stopping)
                batch_token_ids = [
                    hypotheses[0]
                ]  # first hypothesis is the highest scoring one

            else:
                batch_token_ids = self.seq2seq_model.greedy_decoding(
                    sess, token_ids)

            for src_tokens, trg_token_ids in zip(batch, batch_token_ids):
                trg_token_ids = list(trg_token_ids)

                if utils.EOS_ID in trg_token_ids:
                    trg_token_ids = trg_token_ids[:trg_token_ids.index(utils.
                                                                       EOS_ID)]

                trg_tokens = [
                    self.trg_vocab.reverse[i]
                    if i < len(self.trg_vocab.reverse) else utils._UNK
                    for i in trg_token_ids
                ]

                if use_edits:
                    trg_tokens = utils.reverse_edits(
                        src_tokens[0], ' '.join(trg_tokens)).split()

                if remove_unk:
                    trg_tokens = [
                        token for token in trg_tokens if token != utils._UNK
                    ]

                if self.character_level[-1]:
                    yield ''.join(trg_tokens)
                else:
                    yield ' '.join(trg_tokens).replace(
                        '@@ ', '')  # merge subword units
コード例 #4
0
ファイル: reverse-edits.py プロジェクト: zxsted/seq2seq-2
#!/usr/bin/env python3

import argparse
from translate import utils

parser = argparse.ArgumentParser()
parser.add_argument('source')
parser.add_argument('edits')

if __name__ == '__main__':
    args = parser.parse_args()
    with open(args.source) as src_file, open(args.edits) as edit_file:
        for src_line, edits in zip(src_file, edit_file):
            trg_line = utils.reverse_edits(src_line, edits)
            print(trg_line)
コード例 #5
0
#!/usr/bin/env python3

import argparse
from translate import utils

parser = argparse.ArgumentParser()
parser.add_argument('source')
parser.add_argument('edits')

if __name__ == '__main__':
    args = parser.parse_args()
    with open(args.source) as src_file, open(args.edits) as edit_file:
        for src_line, edits in zip(src_file, edit_file):
            trg_line = utils.reverse_edits(src_line.split(), [edits.split()])
            print(' '.join(trg_line))
コード例 #6
0
    def decode_batch(self,
                     sentence_tuples,
                     batch_size,
                     remove_unk=False,
                     fix_edits=True,
                     unk_replace=False,
                     align=False,
                     reverse=False,
                     output=None):
        if batch_size == 1:
            batches = ([sentence_tuple]
                       for sentence_tuple in sentence_tuples)  # lazy
        else:
            batch_count = int(math.ceil(len(sentence_tuples) / batch_size))
            batches = [
                sentence_tuples[i * batch_size:(i + 1) * batch_size]
                for i in range(batch_count)
            ]

        def map_to_ids(sentence_tuple):
            token_ids = [
                sentence if vocab is None else utils.sentence_to_token_ids(
                    sentence,
                    vocab.vocab,
                    character_level=self.character_level.get(ext))
                for ext, vocab, sentence in zip(self.extensions, self.vocabs,
                                                sentence_tuple)
            ]
            return token_ids

        line_id = 0
        for batch_id, batch in enumerate(batches):
            token_ids = list(map(map_to_ids, batch))
            batch_token_ids, batch_weights = self.seq2seq_model.greedy_decoding(
                token_ids,
                beam_size=self.beam_size,
                align=unk_replace or align or self.debug)
            batch_token_ids = zip(*batch_token_ids)

            for sentence_id, (src_tokens, trg_token_ids) in enumerate(
                    zip(batch, batch_token_ids)):
                line_id += 1
                trg_tokens = []

                for trg_token_ids_, vocab in zip(trg_token_ids,
                                                 self.trg_vocab):
                    trg_token_ids_ = list(
                        trg_token_ids_)  # from np array to list
                    if utils.EOS_ID in trg_token_ids_:
                        trg_token_ids_ = trg_token_ids_[:trg_token_ids_.
                                                        index(utils.EOS_ID)]

                    trg_tokens_ = [
                        vocab.reverse[i]
                        if i < len(vocab.reverse) else utils._UNK
                        for i in trg_token_ids_
                    ]
                    trg_tokens.append(trg_tokens_)

                if align:
                    weights_ = batch_weights[sentence_id].squeeze()
                    max_len_ = weights_.shape[1]

                    if self.binary[0]:
                        src_tokens_ = None
                    else:
                        src_tokens_ = src_tokens[0].split()[:max_len_ -
                                                            1] + [utils._EOS]
                        src_tokens_ = [
                            token
                            if token in self.src_vocab[0].vocab else utils._UNK
                            for token in src_tokens_
                        ]
                        weights_ = weights_[:, :len(src_tokens_)]

                    trg_tokens_ = trg_tokens[0][:weights_.shape[0] -
                                                1] + [utils._EOS]
                    weights_ = weights_[:len(trg_tokens_)]
                    output_file = output and '{}.{}.pdf'.format(
                        output, line_id)
                    utils.heatmap(src_tokens_,
                                  trg_tokens_,
                                  weights_,
                                  reverse=reverse,
                                  output_file=output_file)

                if self.debug or unk_replace:
                    weights = batch_weights[sentence_id]
                    src_words = src_tokens[0].split()
                    align_ids = np.argmax(weights[:, :len(src_words)], axis=1)
                else:
                    align_ids = [0] * len(trg_tokens[0])

                def replace(token, align_id):
                    if self.debug and (not unk_replace or token == utils._UNK):
                        suffix = '({})'.format(align_id)
                    else:
                        suffix = ''
                    if token == utils._UNK and unk_replace:
                        token = src_words[align_id]
                        if not token[0].isupper(
                        ) and self.lexicon is not None and token in self.lexicon:
                            token = self.lexicon[token]
                    return token + suffix

                trg_tokens[0] = [
                    replace(token, align_id)
                    for align_id, token in zip(align_ids, trg_tokens[0])
                ]

                if self.pred_edits:
                    # first output is ops, second output is words
                    raw_hypothesis = ' '.join('_'.join(tokens)
                                              for tokens in zip(*trg_tokens))
                    src_words = src_tokens[0].split()
                    trg_tokens = utils.reverse_edits(src_words,
                                                     trg_tokens,
                                                     fix=fix_edits)
                    trg_tokens = [
                        token for token in trg_tokens
                        if token not in utils._START_VOCAB
                    ]
                    # FIXME: char-level
                else:
                    trg_tokens = trg_tokens[0]
                    raw_hypothesis = ''.join(
                        trg_tokens) if self.char_output else ' '.join(
                            trg_tokens)

                if remove_unk:
                    trg_tokens = [
                        token for token in trg_tokens if token != utils._UNK
                    ]

                if self.char_output:
                    hypothesis = ''.join(trg_tokens)
                else:
                    hypothesis = ' '.join(trg_tokens).replace(
                        '@@ ', '')  # merge subwords units

                yield hypothesis, raw_hypothesis
コード例 #7
0
    def decode_batch(self, sentence_tuples, batch_size, remove_unk=False, fix_edits=True, unk_replace=False,
                     align=False, reverse=False, output=None):
        utils.log("start decode batch")
        if batch_size == 1:
            batches = ([sentence_tuple] for sentence_tuple in sentence_tuples)  # lazy
        else:
            batch_count = int(math.ceil(len(sentence_tuples) / batch_size))
            batches = [sentence_tuples[i * batch_size:(i + 1) * batch_size] for i in range(batch_count)]

        def map_to_ids(sentence_tuple):
            token_ids = [
                sentence if vocab is None else
                utils.sentence_to_token_ids(sentence, vocab.vocab, character_level=self.character_level.get(ext))
                for ext, vocab, sentence in zip(self.extensions, self.vocabs, sentence_tuple)
            ]
            return token_ids

        line_id = 0
        for batch_id, batch in enumerate(batches):
            token_ids = list(map(map_to_ids, batch))
            batch_token_ids, batch_weights = self.seq2seq_model.greedy_decoding(token_ids, align=unk_replace or align)
            utils.log("batch_token_ids")
            utils.log(batch_token_ids)
            utils.log(len(batch_token_ids))
            utils.log(len(batch_token_ids[0]))
            utils.log(len(batch_token_ids[0][0]))
            utils.log(len(batch_token_ids[0][0][0]))
            batch_token_ids = zip(*batch_token_ids)

            for sentence_id, (src_tokens, trg_token_ids) in enumerate(zip(batch, batch_token_ids)):
                # trg_token_ids, shape(64,10,50), [[[....50num....],[....50num....],[....50num....],....,[....50num....]]]
                line_id += 1

                trg_tokens = []

                # for single_trg_token_id in trg_token_ids:
                # single_trg_token_id, shape(50), [....50num....]
                for trg_token_ids_, vocab in zip(trg_token_ids, self.trg_vocab):
                    # trg_token_ids_, shape(10,50)
                    top_10_trg_tokens = []
                    for single_trg_token_ids in trg_token_ids_:
                        # single_trg_token_ids, [,,,,,,,] 50 nums
                        single_trg_token_ids = list(single_trg_token_ids)
                        if utils.EOS_ID in single_trg_token_ids:
                            single_trg_token_ids = single_trg_token_ids[:single_trg_token_ids.index(utils.EOS_ID)]
                        single_trg_token_ids = [vocab.reverse[i] if i < len(vocab.reverse) else utils._UNK
                                                for i in single_trg_token_ids]
                        top_10_trg_tokens.append(single_trg_token_ids)

                    # trg_token_ids_ = list(trg_token_ids_)  # from np array to list
                    # if utils.EOS_ID in trg_token_ids_:
                    #     trg_token_ids_ = trg_token_ids_[:trg_token_ids_.index(utils.EOS_ID)]
                    #
                    # trg_tokens_ = [vocab.reverse[i] if i < len(vocab.reverse) else utils._UNK
                    #            for i in trg_token_ids_]
                    # trg_tokens.append(trg_tokens_)
                    trg_tokens.append(top_10_trg_tokens)
                    # trg_tokens, shape(64, 10, ?)
                #   beam_trg_tokens.append(trg_tokens)
                #   trg_tokens = []

                if align:
                    weights_ = batch_weights[sentence_id].squeeze()
                    max_len_ = weights_.shape[1]
                    src_tokens_ = src_tokens[0].split()[:max_len_ - 1] + [utils._EOS]
                    src_tokens_ = [token if token in self.src_vocab[0].vocab else utils._UNK for token in src_tokens_]
                    trg_tokens_ = trg_tokens[0][0][:weights_.shape[0] - 1] + [utils._EOS]

                    weights_ = weights_[:len(trg_tokens_), :len(src_tokens_)]
                    output_file = output and '{}.{}.pdf'.format(output, line_id)
                    utils.heatmap(src_tokens_, trg_tokens_, weights_, reverse=reverse, output_file=output_file)

                if unk_replace:
                    weights = batch_weights[sentence_id]
                    src_words = src_tokens[0].split()
                    align_ids = np.argmax(weights[:, :len(src_words)], axis=1)

                    def replace(token, align_id):
                        if token == utils._UNK:
                            token = src_words[align_id]
                            if not token[0].isupper() and self.lexicon is not None and token in self.lexicon:
                                token = self.lexicon[token]
                        return token

                    for i in range(len(trg_tokens[0])):
                        trg_tokens[0][i] = [replace(token, align_id) for align_id, token in
                                            zip(align_ids, trg_tokens[0][i])]

                #########################################################################
                if self.pred_edits:
                    # first output is ops, second output is words
                    raw_hypothesis = ' '.join('_'.join(tokens) for tokens in zip(*trg_tokens))
                    src_words = src_tokens[0].split()
                    trg_tokens = utils.reverse_edits(src_words, trg_tokens, fix=fix_edits)
                    trg_tokens = [token for token in trg_tokens if token not in utils._START_VOCAB]
                    # FIXME: char-level
                else:
                    trg_tokens = trg_tokens[0]
                    raw_hypothesis = []
                    for single_trg_tokens in trg_tokens:
                        single_raw_hypothesis = ''.join(single_trg_tokens) if self.char_output else ' '.join(
                            single_trg_tokens)
                        raw_hypothesis.append(single_raw_hypothesis)
                    # raw_hypothesis = ''.join(trg_tokens) if self.char_output else ' '.join(trg_tokens)

                if remove_unk:
                    for i in range(len(trg_tokens)):
                        trg_tokens[i] = [token for token in trg_tokens[i] if token != utils._UNK]

                if self.char_output:
                    hypothesis = []
                    for i in range(len(trg_tokens)):
                        hypothesis.append(''.join(trg_tokens[i]))
                    # hypothesis = ''.join(trg_tokens)
                else:
                    hypothesis = []
                    for i in range(len(trg_tokens)):
                        hypothesis.append(' '.join(trg_tokens[i]).replace('@@ ', ''))
                    # hypothesis = ' '.join(trg_tokens).replace('@@ ', '')  # merge subwords units

                yield hypothesis, raw_hypothesis
コード例 #8
0
#!/usr/bin/env python3

import argparse
from translate.utils import reverse_edits

parser = argparse.ArgumentParser()
parser.add_argument('source')
parser.add_argument('edits')
parser.add_argument('--not-strict', action='store_false', dest='strict')
parser.add_argument('--no-fix', action='store_false', dest='fix')

if __name__ == '__main__':
    args = parser.parse_args()
    with open(args.source) as src_file, open(args.edits) as edit_file:
        for source, edits in zip(src_file, edit_file):
            target = reverse_edits(source.strip('\n'),
                                   edits.strip('\n'),
                                   strict=args.strict,
                                   fix=args.fix)
            print(target)