Example #1
0
    def align(self, output=None, align_encoder_id=0, reverse=False, **kwargs):
        if len(self.filenames.test) != len(self.extensions):
            raise Exception('wrong number of input files')

        binary = self.binary and any(self.binary)

        paths = self.filenames.test or [None]
        lines = utils.read_lines(paths, binary=self.binary)

        for line_id, lines in enumerate(lines):
            token_ids = [
                sentence if vocab is None else
                utils.sentence_to_token_ids(sentence, vocab.vocab, character_level=self.character_level.get(ext))
                for ext, vocab, sentence in zip(self.extensions, self.vocabs, lines)
            ]

            _, weights = self.seq2seq_model.step(data=[token_ids], align=True, update_model=False)

            trg_vocab = self.trg_vocab[0]
            trg_token_ids = token_ids[len(self.src_ext)]
            trg_tokens = [trg_vocab.reverse[i] if i < len(trg_vocab.reverse) else utils._UNK for i in trg_token_ids]

            weights = weights.squeeze()
            max_len = weights.shape[1]

            if binary:
                src_tokens = None
            else:
                src_tokens = lines[align_encoder_id].split()[:max_len - 1] + [utils._EOS]
            trg_tokens = trg_tokens[:weights.shape[0] - 1] + [utils._EOS]

            output_file = output and '{}.{}.pdf'.format(output, line_id + 1)

            utils.heatmap(src_tokens, trg_tokens, weights, output_file=output_file, reverse=reverse)
Example #2
0
    def align(self, sess, output=None, wav_files=None, **kwargs):
        if len(self.src_ext) != 1:
            raise NotImplementedError

        if len(self.filenames.test) != len(self.extensions):
            raise Exception('wrong number of input files')

        for line_id, lines in enumerate(
                utils.read_lines(self.filenames.test, self.extensions,
                                 self.binary_input)):
            token_ids = [
                utils.sentence_to_token_ids(
                    sentence, vocab.vocab, character_level=char_level)
                if vocab is not None else sentence for vocab, sentence,
                char_level in zip(self.vocabs, lines, self.character_level)
            ]

            _, weights = self.seq2seq_model.step(sess,
                                                 data=[token_ids],
                                                 forward_only=True,
                                                 align=True,
                                                 update_model=False)
            trg_tokens = [
                self.trg_vocab.reverse[i]
                if i < len(self.trg_vocab.reverse) else utils._UNK
                for i in token_ids[-1]
            ]

            weights = weights.squeeze()[:len(trg_tokens), :len(token_ids[0])].T
            max_len = weights.shape[0]

            if self.binary_input[0]:
                src_tokens = None
            else:
                src_tokens = lines[0].split()[:max_len]

            if wav_files is not None:
                wav_file = wav_files[line_id]
            else:
                wav_file = None

            output_file = '{}.{}.svg'.format(output, line_id +
                                             1) if output is not None else None
            utils.heatmap(src_tokens,
                          trg_tokens,
                          weights.T,
                          wav_file=wav_file,
                          output_file=output_file)
Example #3
0
    def align(self, sess, output=None, align_encoder_id=0, **kwargs):
        if self.binary and any(self.binary):
            raise NotImplementedError

        if len(self.filenames.test) != len(self.extensions):
            raise Exception('wrong number of input files')

        for line_id, lines in enumerate(utils.read_lines(self.filenames.test)):
            token_ids = [
                sentence if vocab is None else utils.sentence_to_token_ids(
                    sentence,
                    vocab.vocab,
                    character_level=self.character_level.get(ext)) for ext,
                vocab, sentence in zip(self.extensions, self.vocabs, lines)
            ]

            _, weights = self.seq2seq_model.step(sess,
                                                 data=[token_ids],
                                                 forward_only=True,
                                                 align=True,
                                                 update_model=False)

            trg_vocab = self.trg_vocab[0]  # FIXME
            trg_token_ids = token_ids[len(self.src_ext)]
            trg_tokens = [
                trg_vocab.reverse[i]
                if i < len(trg_vocab.reverse) else utils._UNK
                for i in trg_token_ids
            ]

            weights = weights.squeeze()
            max_len = weights.shape[1]

            utils.debug(weights)

            trg_tokens.append(utils._EOS)
            src_tokens = lines[align_encoder_id].split()[:max_len -
                                                         1] + [utils._EOS]

            output_file = '{}.{}.svg'.format(output, line_id +
                                             1) if output is not None else None

            utils.heatmap(src_tokens,
                          trg_tokens,
                          weights,
                          output_file=output_file)
    def decode_batch(self,
                     sentence_tuples,
                     batch_size,
                     remove_unk=False,
                     fix_edits=True,
                     unk_replace=False,
                     align=False,
                     reverse=False,
                     output=None):
        if batch_size == 1:
            batches = ([sentence_tuple]
                       for sentence_tuple in sentence_tuples)  # lazy
        else:
            batch_count = int(math.ceil(len(sentence_tuples) / batch_size))
            batches = [
                sentence_tuples[i * batch_size:(i + 1) * batch_size]
                for i in range(batch_count)
            ]

        def map_to_ids(sentence_tuple):
            token_ids = [
                sentence if vocab is None else utils.sentence_to_token_ids(
                    sentence,
                    vocab.vocab,
                    character_level=self.character_level.get(ext))
                for ext, vocab, sentence in zip(self.extensions, self.vocabs,
                                                sentence_tuple)
            ]
            return token_ids

        line_id = 0
        for batch_id, batch in enumerate(batches):
            token_ids = list(map(map_to_ids, batch))
            batch_token_ids, batch_weights = self.seq2seq_model.greedy_decoding(
                token_ids,
                beam_size=self.beam_size,
                align=unk_replace or align or self.debug)
            batch_token_ids = zip(*batch_token_ids)

            for sentence_id, (src_tokens, trg_token_ids) in enumerate(
                    zip(batch, batch_token_ids)):
                line_id += 1
                trg_tokens = []

                for trg_token_ids_, vocab in zip(trg_token_ids,
                                                 self.trg_vocab):
                    trg_token_ids_ = list(
                        trg_token_ids_)  # from np array to list
                    if utils.EOS_ID in trg_token_ids_:
                        trg_token_ids_ = trg_token_ids_[:trg_token_ids_.
                                                        index(utils.EOS_ID)]

                    trg_tokens_ = [
                        vocab.reverse[i]
                        if i < len(vocab.reverse) else utils._UNK
                        for i in trg_token_ids_
                    ]
                    trg_tokens.append(trg_tokens_)

                if align:
                    weights_ = batch_weights[sentence_id].squeeze()
                    max_len_ = weights_.shape[1]

                    if self.binary[0]:
                        src_tokens_ = None
                    else:
                        src_tokens_ = src_tokens[0].split()[:max_len_ -
                                                            1] + [utils._EOS]
                        src_tokens_ = [
                            token
                            if token in self.src_vocab[0].vocab else utils._UNK
                            for token in src_tokens_
                        ]
                        weights_ = weights_[:, :len(src_tokens_)]

                    trg_tokens_ = trg_tokens[0][:weights_.shape[0] -
                                                1] + [utils._EOS]
                    weights_ = weights_[:len(trg_tokens_)]
                    output_file = output and '{}.{}.pdf'.format(
                        output, line_id)
                    utils.heatmap(src_tokens_,
                                  trg_tokens_,
                                  weights_,
                                  reverse=reverse,
                                  output_file=output_file)

                if self.debug or unk_replace:
                    weights = batch_weights[sentence_id]
                    src_words = src_tokens[0].split()
                    align_ids = np.argmax(weights[:, :len(src_words)], axis=1)
                else:
                    align_ids = [0] * len(trg_tokens[0])

                def replace(token, align_id):
                    if self.debug and (not unk_replace or token == utils._UNK):
                        suffix = '({})'.format(align_id)
                    else:
                        suffix = ''
                    if token == utils._UNK and unk_replace:
                        token = src_words[align_id]
                        if not token[0].isupper(
                        ) and self.lexicon is not None and token in self.lexicon:
                            token = self.lexicon[token]
                    return token + suffix

                trg_tokens[0] = [
                    replace(token, align_id)
                    for align_id, token in zip(align_ids, trg_tokens[0])
                ]

                if self.pred_edits:
                    # first output is ops, second output is words
                    raw_hypothesis = ' '.join('_'.join(tokens)
                                              for tokens in zip(*trg_tokens))
                    src_words = src_tokens[0].split()
                    trg_tokens = utils.reverse_edits(src_words,
                                                     trg_tokens,
                                                     fix=fix_edits)
                    trg_tokens = [
                        token for token in trg_tokens
                        if token not in utils._START_VOCAB
                    ]
                    # FIXME: char-level
                else:
                    trg_tokens = trg_tokens[0]
                    raw_hypothesis = ''.join(
                        trg_tokens) if self.char_output else ' '.join(
                            trg_tokens)

                if remove_unk:
                    trg_tokens = [
                        token for token in trg_tokens if token != utils._UNK
                    ]

                if self.char_output:
                    hypothesis = ''.join(trg_tokens)
                else:
                    hypothesis = ' '.join(trg_tokens).replace(
                        '@@ ', '')  # merge subwords units

                yield hypothesis, raw_hypothesis
Example #5
0
    def decode_batch(self, sentence_tuples, batch_size, remove_unk=False, fix_edits=True, unk_replace=False,
                     align=False, reverse=False, output=None):
        utils.log("start decode batch")
        if batch_size == 1:
            batches = ([sentence_tuple] for sentence_tuple in sentence_tuples)  # lazy
        else:
            batch_count = int(math.ceil(len(sentence_tuples) / batch_size))
            batches = [sentence_tuples[i * batch_size:(i + 1) * batch_size] for i in range(batch_count)]

        def map_to_ids(sentence_tuple):
            token_ids = [
                sentence if vocab is None else
                utils.sentence_to_token_ids(sentence, vocab.vocab, character_level=self.character_level.get(ext))
                for ext, vocab, sentence in zip(self.extensions, self.vocabs, sentence_tuple)
            ]
            return token_ids

        line_id = 0
        for batch_id, batch in enumerate(batches):
            token_ids = list(map(map_to_ids, batch))
            batch_token_ids, batch_weights = self.seq2seq_model.greedy_decoding(token_ids, align=unk_replace or align)
            utils.log("batch_token_ids")
            utils.log(batch_token_ids)
            utils.log(len(batch_token_ids))
            utils.log(len(batch_token_ids[0]))
            utils.log(len(batch_token_ids[0][0]))
            utils.log(len(batch_token_ids[0][0][0]))
            batch_token_ids = zip(*batch_token_ids)

            for sentence_id, (src_tokens, trg_token_ids) in enumerate(zip(batch, batch_token_ids)):
                # trg_token_ids, shape(64,10,50), [[[....50num....],[....50num....],[....50num....],....,[....50num....]]]
                line_id += 1

                trg_tokens = []

                # for single_trg_token_id in trg_token_ids:
                # single_trg_token_id, shape(50), [....50num....]
                for trg_token_ids_, vocab in zip(trg_token_ids, self.trg_vocab):
                    # trg_token_ids_, shape(10,50)
                    top_10_trg_tokens = []
                    for single_trg_token_ids in trg_token_ids_:
                        # single_trg_token_ids, [,,,,,,,] 50 nums
                        single_trg_token_ids = list(single_trg_token_ids)
                        if utils.EOS_ID in single_trg_token_ids:
                            single_trg_token_ids = single_trg_token_ids[:single_trg_token_ids.index(utils.EOS_ID)]
                        single_trg_token_ids = [vocab.reverse[i] if i < len(vocab.reverse) else utils._UNK
                                                for i in single_trg_token_ids]
                        top_10_trg_tokens.append(single_trg_token_ids)

                    # trg_token_ids_ = list(trg_token_ids_)  # from np array to list
                    # if utils.EOS_ID in trg_token_ids_:
                    #     trg_token_ids_ = trg_token_ids_[:trg_token_ids_.index(utils.EOS_ID)]
                    #
                    # trg_tokens_ = [vocab.reverse[i] if i < len(vocab.reverse) else utils._UNK
                    #            for i in trg_token_ids_]
                    # trg_tokens.append(trg_tokens_)
                    trg_tokens.append(top_10_trg_tokens)
                    # trg_tokens, shape(64, 10, ?)
                #   beam_trg_tokens.append(trg_tokens)
                #   trg_tokens = []

                if align:
                    weights_ = batch_weights[sentence_id].squeeze()
                    max_len_ = weights_.shape[1]
                    src_tokens_ = src_tokens[0].split()[:max_len_ - 1] + [utils._EOS]
                    src_tokens_ = [token if token in self.src_vocab[0].vocab else utils._UNK for token in src_tokens_]
                    trg_tokens_ = trg_tokens[0][0][:weights_.shape[0] - 1] + [utils._EOS]

                    weights_ = weights_[:len(trg_tokens_), :len(src_tokens_)]
                    output_file = output and '{}.{}.pdf'.format(output, line_id)
                    utils.heatmap(src_tokens_, trg_tokens_, weights_, reverse=reverse, output_file=output_file)

                if unk_replace:
                    weights = batch_weights[sentence_id]
                    src_words = src_tokens[0].split()
                    align_ids = np.argmax(weights[:, :len(src_words)], axis=1)

                    def replace(token, align_id):
                        if token == utils._UNK:
                            token = src_words[align_id]
                            if not token[0].isupper() and self.lexicon is not None and token in self.lexicon:
                                token = self.lexicon[token]
                        return token

                    for i in range(len(trg_tokens[0])):
                        trg_tokens[0][i] = [replace(token, align_id) for align_id, token in
                                            zip(align_ids, trg_tokens[0][i])]

                #########################################################################
                if self.pred_edits:
                    # first output is ops, second output is words
                    raw_hypothesis = ' '.join('_'.join(tokens) for tokens in zip(*trg_tokens))
                    src_words = src_tokens[0].split()
                    trg_tokens = utils.reverse_edits(src_words, trg_tokens, fix=fix_edits)
                    trg_tokens = [token for token in trg_tokens if token not in utils._START_VOCAB]
                    # FIXME: char-level
                else:
                    trg_tokens = trg_tokens[0]
                    raw_hypothesis = []
                    for single_trg_tokens in trg_tokens:
                        single_raw_hypothesis = ''.join(single_trg_tokens) if self.char_output else ' '.join(
                            single_trg_tokens)
                        raw_hypothesis.append(single_raw_hypothesis)
                    # raw_hypothesis = ''.join(trg_tokens) if self.char_output else ' '.join(trg_tokens)

                if remove_unk:
                    for i in range(len(trg_tokens)):
                        trg_tokens[i] = [token for token in trg_tokens[i] if token != utils._UNK]

                if self.char_output:
                    hypothesis = []
                    for i in range(len(trg_tokens)):
                        hypothesis.append(''.join(trg_tokens[i]))
                    # hypothesis = ''.join(trg_tokens)
                else:
                    hypothesis = []
                    for i in range(len(trg_tokens)):
                        hypothesis.append(' '.join(trg_tokens[i]).replace('@@ ', ''))
                    # hypothesis = ' '.join(trg_tokens).replace('@@ ', '')  # merge subwords units

                yield hypothesis, raw_hypothesis