Beispiel #1
0
    def _decode(self, source_lang, target_lang, segments):
        prefix_lang = target_lang if self._checkpoint.multilingual_target else None
        batch, input_indexes, sentence_len = self._make_decode_batch(segments, prefix_lang=prefix_lang)

        # Compute translation
        self._translator.max_len_b = self._checkpoint.decode_length(source_lang, target_lang, sentence_len)
        translations = self._translator.generate([self._model], batch)

        # Decode translation
        sub_dict = self._checkpoint.subword_dictionary

        results = []
        for i, hypo in enumerate(translations):
            hypo = hypo[0]  # (top-1 best nbest)
            hypo_score = math.exp(hypo['score'])
            hypo_tokens = hypo['tokens']
            hypo_indexes = sub_dict.indexes_of(hypo_tokens)
            hypo_str = sub_dict.string(hypo_tokens)
            hypo_attention = np.asarray(hypo['attention'].data.cpu())

            # Make alignment
            if len(hypo_indexes) > 0:
                hypo_alignment = make_alignment(input_indexes[i], hypo_indexes, hypo_attention,
                                                prefix_lang=prefix_lang is not None)
                hypo_alignment = clean_alignment(hypo_alignment, segments[i], hypo_str)
            else:
                hypo_alignment = []

            results.append(Translation(hypo_str, alignment=hypo_alignment, score=hypo_score))

        return results
Beispiel #2
0
    def _force_decode(self, target_lang, segments, translations):
        prefix_lang = target_lang if self._checkpoint.multilingual_target else None

        batch = self._make_force_decode_batch(segments,
                                              translations,
                                              prefix_lang=prefix_lang)

        src_tokens = batch['src_tokens']
        tgt_tokens = batch['trg_tokens']
        src_indexes = batch['src_indexes']
        tgt_indexes = batch['trg_indexes']
        src_lengths = batch['src_lengths']

        if self._device is not None:
            src_tokens = src_tokens.cuda(self._device)
            src_lengths = src_lengths.cuda(self._device)
            tgt_tokens = tgt_tokens.cuda(self._device)

        self._model.eval()
        _, attn = self._model(src_tokens, src_lengths, tgt_tokens)
        if type(attn) is dict:
            attn = attn['attn']

        results = []
        for i, hypo_attention in enumerate(
                attn):  # for each entry of the original batch
            hypo_attention = hypo_attention.transpose(0, 1).cpu()
            hypo_attention = hypo_attention[hypo_attention.size(0) -
                                            (len(src_indexes[i]) + 1):,
                                            hypo_attention.size(1) -
                                            (len(tgt_indexes[i]) + 1):]

            # Make alignment
            hypo_alignment = make_alignment(src_indexes[i],
                                            tgt_indexes[i],
                                            hypo_attention.data.numpy(),
                                            prefix_lang=prefix_lang
                                            is not None)

            hypo_alignment = clean_alignment(hypo_alignment, segments[i],
                                             translations[i])

            results.append(
                Translation(translations[i], alignment=hypo_alignment))

        return results