class OnlineTranslator(object): def __init__(self, model): opt = TranslatorParameter(model) from onmt.inference.fast_translator import FastTranslator self.translator = FastTranslator(opt) # self.translator = onmt.EnsembleTranslator(opt) def translate(self,input): predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores = self.translator.translate([input.split()],[]) return " ".join(predBatch[0][0])
def __init__(self, model): opt = TranslatorParameter(model) from onmt.inference.fast_translator import FastTranslator self.translator = FastTranslator(opt)
def main(): opt = parser.parse_args() opt.cuda = opt.gpu > -1 if opt.cuda: torch.cuda.set_device(opt.gpu) # Always pick n_best opt.n_best = opt.beam_size if opt.output == "stdout": outF = sys.stdout else: outF = open(opt.output, 'w') pred_score_total, pred_words_total, gold_score_total, gold_words_total = 0, 0, 0, 0 src_batches = [] src_batch, tgt_batch = [], [] count = 0 tgtF = open(opt.tgt) if opt.tgt else None # # if opt.dump_beam != "": # import json # translator.initBeamAccum() in_file = None if opt.src == "stdin": in_file = sys.stdin opt.batch_size = 1 elif opt.encoder_type == "audio" and opt.asr_format == "h5": in_file = h5.File(opt.src, 'r') elif opt.encoder_type == "audio" and opt.asr_format == "scp": import kaldiio from kaldiio import ReadHelper audio_data = iter(ReadHelper('scp:' + opt.src)) else: in_file = open(opt.src) if opt.streaming: if opt.batch_size != 1: opt.batch_size = 1 print("Warning: Streaming only works with batch size 1") if opt.global_search: print(" Using global search algorithm ") from onmt.inference.global_translator import GlobalStreamTranslator translator = GlobalStreamTranslator(opt) else: translator = StreamTranslator(opt) else: if opt.fast_translate: translator = FastTranslator(opt) else: translator = onmt.Translator(opt) # Audio processing for the source batch if opt.encoder_type == "audio": s_prev_context = [] t_prev_context = [] i = 0 concats = opt.concat.split("|") n_models = len(opt.model.split("|")) if len(concats) == 1: concats = concats * n_models assert len( concats ) == n_models, "The number of models must match the number of concat configs" for j, _ in enumerate(concats): src_batches.append(list()) # while True: if opt.asr_format == "h5": if i == len(in_file): break line = np.array(in_file[str(i)]) i += 1 elif opt.asr_format == "scp": try: _, line = next(audio_data) except StopIteration: break if opt.stride != 1: line = line[0::opt.stride] line = torch.from_numpy(line) original_line = line for j, concat_ in enumerate(concats): concat = int(concat_) line = original_line if concat != 1: add = (concat - line.size()[0] % concat) % concat z = torch.FloatTensor(add, line.size()[1]).zero_() line = torch.cat((line, z), 0) line = line.reshape( (line.size()[0] // concat, line.size()[1] * concat)) if opt.previous_context > 0: s_prev_context.append(line) for i in range(1, opt.previous_context + 1): if i < len(s_prev_context): line = torch.cat((torch.cat( (s_prev_context[-i - 1], torch.zeros(1, line.size()[1]))), line)) if len(s_prev_context) > opt.previous_context: s_prev_context = s_prev_context[-1 * opt.previous_context:] src_batches[j] += [line] if tgtF: # ~ tgt_tokens = tgtF.readline().split() if tgtF else None tline = tgtF.readline().strip() if opt.previous_context > 0: t_prev_context.append(tline) for i in range(1, opt.previous_context + 1): if i < len(s_prev_context): tline = t_prev_context[-i - 1] + " # " + tline if len(t_prev_context) > opt.previous_context: t_prev_context = t_prev_context[-1 * opt.previous_context:] if opt.input_type == 'word': tgt_tokens = tline.split() if tgtF else None elif opt.input_type == 'char': tgt_tokens = list(tline.strip()) if tgtF else None else: raise NotImplementedError("Input type unknown") tgt_batch += [tgt_tokens] if len(src_batches[0]) < opt.batch_size: continue # TODO: if opt.concat is a list print("Batch size:", len(src_batches[0]), len(tgt_batch)) pred_batch, pred_score, pred_length, gold_score, num_gold_words, all_gold_scores = translator.translate( src_batches, tgt_batch, type='asr') print("Result:", len(pred_batch)) count, pred_score, pred_words, gold_score, goldWords = \ translate_batch(opt, tgtF, count, outF, translator, src_batches[0], tgt_batch, pred_batch, pred_score, pred_length, gold_score, num_gold_words, all_gold_scores, opt.input_type) pred_score_total += pred_score pred_words_total += pred_words gold_score_total += gold_score gold_words_total += goldWords src_batch, tgt_batch = [], [] for j, _ in enumerate(src_batches): src_batches[j] = [] # catch the last batch if len(src_batches[0]) != 0: print("Batch size:", len(src_batches[0]), len(tgt_batch)) pred_batch, pred_score, pred_length, gold_score, num_gold_words, all_gold_scores = translator.translate( src_batches, tgt_batch, type='asr') print("Result:", len(pred_batch)) count, pred_score, pred_words, gold_score, goldWords \ = translate_batch(opt, tgtF, count, outF, translator, src_batches[0], tgt_batch, pred_batch, pred_score, pred_length, gold_score, num_gold_words, all_gold_scores, opt.input_type) pred_score_total += pred_score pred_words_total += pred_words gold_score_total += gold_score gold_words_total += goldWords src_batch, tgt_batch = [], [] for j, _ in enumerate(src_batches): src_batches[j] = [] # Text processing else: for line in addone(in_file): if line is not None: if opt.input_type == 'word': src_tokens = line.split() elif opt.input_type == 'char': src_tokens = list(line.strip()) else: raise NotImplementedError("Input type unknown") if line.strip() == "": if opt.streaming: print("Found a document break") translator.reset_stream() continue src_batch += [src_tokens] if tgtF: # ~ tgt_tokens = tgtF.readline().split() if tgtF else None if opt.input_type == 'word': tgt_tokens = tgtF.readline().split() if tgtF else None elif opt.input_type == 'char': tgt_tokens = list( tgtF.readline().strip()) if tgtF else None else: raise NotImplementedError("Input type unknown") tgt_batch += [tgt_tokens] if len(src_batch) < opt.batch_size: continue else: # at the end of file, check last batch if len(src_batch) == 0: break # actually done beam search from the model pred_batch, pred_score, pred_length, gold_score, num_gold_words, all_gold_scores = translator.translate( src_batch, tgt_batch) # convert output tensor to words count, pred_score, pred_words, gold_score, goldWords = translate_batch( opt, tgtF, count, outF, translator, src_batch, tgt_batch, pred_batch, pred_score, pred_length, gold_score, num_gold_words, all_gold_scores, opt.input_type) pred_score_total += pred_score pred_words_total += pred_words gold_score_total += gold_score gold_words_total += goldWords src_batch, tgt_batch = [], [] if opt.verbose: reportScore('PRED', pred_score_total, pred_words_total) if tgtF: reportScore('GOLD', gold_score_total, gold_words_total) if tgtF: tgtF.close() if opt.dump_beam: json.dump(translator.beam_accum, open(opt.dump_beam, 'w'))