Пример #1
0
class OnlineTranslator(object):
    def __init__(self, model):
        opt = TranslatorParameter(model)
        from onmt.inference.fast_translator import FastTranslator
        self.translator = FastTranslator(opt)
        # self.translator = onmt.EnsembleTranslator(opt)

    def translate(self,input):
        predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores = self.translator.translate([input.split()],[])

        return " ".join(predBatch[0][0])
Пример #2
0
 def __init__(self, model):
     opt = TranslatorParameter(model)
     from onmt.inference.fast_translator import FastTranslator
     self.translator = FastTranslator(opt)
Пример #3
0
def main():
    opt = parser.parse_args()
    opt.cuda = opt.gpu > -1
    if opt.cuda:
        torch.cuda.set_device(opt.gpu)

    # Always pick n_best
    opt.n_best = opt.beam_size

    if opt.output == "stdout":
        outF = sys.stdout
    else:
        outF = open(opt.output, 'w')

    pred_score_total, pred_words_total, gold_score_total, gold_words_total = 0, 0, 0, 0

    src_batches = []
    src_batch, tgt_batch = [], []

    count = 0

    tgtF = open(opt.tgt) if opt.tgt else None
    #
    # if opt.dump_beam != "":
    #     import json
    #     translator.initBeamAccum()

    in_file = None

    if opt.src == "stdin":
        in_file = sys.stdin
        opt.batch_size = 1
    elif opt.encoder_type == "audio" and opt.asr_format == "h5":
        in_file = h5.File(opt.src, 'r')
    elif opt.encoder_type == "audio" and opt.asr_format == "scp":
        import kaldiio
        from kaldiio import ReadHelper
        audio_data = iter(ReadHelper('scp:' + opt.src))
    else:
        in_file = open(opt.src)

    if opt.streaming:
        if opt.batch_size != 1:
            opt.batch_size = 1
            print("Warning: Streaming only works with batch size 1")

        if opt.global_search:
            print(" Using global search algorithm ")
            from onmt.inference.global_translator import GlobalStreamTranslator
            translator = GlobalStreamTranslator(opt)
        else:
            translator = StreamTranslator(opt)
    else:
        if opt.fast_translate:
            translator = FastTranslator(opt)
        else:
            translator = onmt.Translator(opt)

    # Audio processing for the source batch
    if opt.encoder_type == "audio":

        s_prev_context = []
        t_prev_context = []

        i = 0

        concats = opt.concat.split("|")

        n_models = len(opt.model.split("|"))
        if len(concats) == 1:
            concats = concats * n_models

        assert len(
            concats
        ) == n_models, "The number of models must match the number of concat configs"
        for j, _ in enumerate(concats):
            src_batches.append(list())  #

        while True:
            if opt.asr_format == "h5":
                if i == len(in_file):
                    break
                line = np.array(in_file[str(i)])
                i += 1
            elif opt.asr_format == "scp":
                try:
                    _, line = next(audio_data)
                except StopIteration:
                    break

            if opt.stride != 1:
                line = line[0::opt.stride]
            line = torch.from_numpy(line)

            original_line = line

            for j, concat_ in enumerate(concats):
                concat = int(concat_)
                line = original_line

                if concat != 1:
                    add = (concat - line.size()[0] % concat) % concat
                    z = torch.FloatTensor(add, line.size()[1]).zero_()
                    line = torch.cat((line, z), 0)
                    line = line.reshape(
                        (line.size()[0] // concat, line.size()[1] * concat))

                if opt.previous_context > 0:
                    s_prev_context.append(line)
                    for i in range(1, opt.previous_context + 1):
                        if i < len(s_prev_context):
                            line = torch.cat((torch.cat(
                                (s_prev_context[-i - 1],
                                 torch.zeros(1,
                                             line.size()[1]))), line))
                    if len(s_prev_context) > opt.previous_context:
                        s_prev_context = s_prev_context[-1 *
                                                        opt.previous_context:]

                src_batches[j] += [line]

            if tgtF:
                # ~ tgt_tokens = tgtF.readline().split() if tgtF else None
                tline = tgtF.readline().strip()
                if opt.previous_context > 0:
                    t_prev_context.append(tline)
                    for i in range(1, opt.previous_context + 1):
                        if i < len(s_prev_context):
                            tline = t_prev_context[-i - 1] + " # " + tline
                    if len(t_prev_context) > opt.previous_context:
                        t_prev_context = t_prev_context[-1 *
                                                        opt.previous_context:]

                if opt.input_type == 'word':
                    tgt_tokens = tline.split() if tgtF else None
                elif opt.input_type == 'char':
                    tgt_tokens = list(tline.strip()) if tgtF else None
                else:
                    raise NotImplementedError("Input type unknown")

                tgt_batch += [tgt_tokens]

            if len(src_batches[0]) < opt.batch_size:
                continue

            # TODO: if opt.concat is a list
            print("Batch size:", len(src_batches[0]), len(tgt_batch))
            pred_batch, pred_score, pred_length, gold_score, num_gold_words, all_gold_scores = translator.translate(
                src_batches, tgt_batch, type='asr')

            print("Result:", len(pred_batch))
            count, pred_score, pred_words, gold_score, goldWords = \
                translate_batch(opt, tgtF, count, outF, translator,
                                src_batches[0], tgt_batch, pred_batch,
                                pred_score,
                                pred_length, gold_score,
                                num_gold_words,
                                all_gold_scores, opt.input_type)

            pred_score_total += pred_score
            pred_words_total += pred_words
            gold_score_total += gold_score
            gold_words_total += goldWords
            src_batch, tgt_batch = [], []
            for j, _ in enumerate(src_batches):
                src_batches[j] = []

        # catch the last batch
        if len(src_batches[0]) != 0:
            print("Batch size:", len(src_batches[0]), len(tgt_batch))
            pred_batch, pred_score, pred_length, gold_score, num_gold_words, all_gold_scores = translator.translate(
                src_batches, tgt_batch, type='asr')
            print("Result:", len(pred_batch))
            count, pred_score, pred_words, gold_score, goldWords \
                = translate_batch(opt, tgtF, count, outF, translator,
                                  src_batches[0], tgt_batch, pred_batch,
                                  pred_score,
                                  pred_length, gold_score,
                                  num_gold_words,
                                  all_gold_scores, opt.input_type)

            pred_score_total += pred_score
            pred_words_total += pred_words
            gold_score_total += gold_score
            gold_words_total += goldWords
            src_batch, tgt_batch = [], []
            for j, _ in enumerate(src_batches):
                src_batches[j] = []
    # Text processing
    else:
        for line in addone(in_file):
            if line is not None:
                if opt.input_type == 'word':
                    src_tokens = line.split()
                elif opt.input_type == 'char':
                    src_tokens = list(line.strip())
                else:
                    raise NotImplementedError("Input type unknown")
                if line.strip() == "":
                    if opt.streaming:
                        print("Found a document break")
                        translator.reset_stream()
                        continue

                src_batch += [src_tokens]
                if tgtF:
                    # ~ tgt_tokens = tgtF.readline().split() if tgtF else None
                    if opt.input_type == 'word':
                        tgt_tokens = tgtF.readline().split() if tgtF else None
                    elif opt.input_type == 'char':
                        tgt_tokens = list(
                            tgtF.readline().strip()) if tgtF else None
                    else:
                        raise NotImplementedError("Input type unknown")
                    tgt_batch += [tgt_tokens]

                if len(src_batch) < opt.batch_size:
                    continue
            else:
                # at the end of file, check last batch
                if len(src_batch) == 0:
                    break

            # actually done beam search from the model
            pred_batch, pred_score, pred_length, gold_score, num_gold_words, all_gold_scores = translator.translate(
                src_batch, tgt_batch)

            # convert output tensor to words
            count, pred_score, pred_words, gold_score, goldWords = translate_batch(
                opt, tgtF, count, outF, translator, src_batch, tgt_batch,
                pred_batch, pred_score, pred_length, gold_score,
                num_gold_words, all_gold_scores, opt.input_type)
            pred_score_total += pred_score
            pred_words_total += pred_words
            gold_score_total += gold_score
            gold_words_total += goldWords
            src_batch, tgt_batch = [], []

    if opt.verbose:
        reportScore('PRED', pred_score_total, pred_words_total)
        if tgtF: reportScore('GOLD', gold_score_total, gold_words_total)

    if tgtF:
        tgtF.close()

    if opt.dump_beam:
        json.dump(translator.beam_accum, open(opt.dump_beam, 'w'))