def main(): opt = parser.parse_args() opt.cuda = opt.gpu > -1 if opt.cuda: torch.cuda.set_device(opt.gpu) # Always pick n_best opt.n_best = opt.beam_size if opt.output == "stdout": outF = sys.stdout else: outF = open(opt.output, 'w') predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0 srcBatch, tgtBatch = [], [] count = 0 tgtF = open(opt.tgt) if opt.tgt else None if opt.dump_beam != "": import json translator.initBeamAccum() # here we are trying to inFile = None translator = onmt.EnsembleTranslator(opt) if opt.src == "stdin": inFile = sys.stdin opt.batch_size = 1 elif opt.encoder_type == "audio": inFile = h5.File(opt.src, 'r') else: inFile = open(opt.src) if opt.encoder_type == "audio": s_prev_context = [] t_prev_context = [] for i in range(len(inFile)): if opt.stride == 1: line = torch.from_numpy(np.array(inFile[str(i)])) else: line = torch.from_numpy( np.array(inFile[str(i)])[0::opt.stride]) if opt.concat != 1: add = (opt.concat - line.size()[0] % opt.concat) % opt.concat z = torch.FloatTensor(add, line.size()[1]).zero_() line = torch.cat((line, z), 0) line = line.reshape((line.size()[0] // opt.concat, line.size()[1] * opt.concat)) #~ srcTokens = line.split() if opt.previous_context > 0: s_prev_context.append(line) for i in range(1, opt.previous_context + 1): if i < len(s_prev_context): line = torch.cat((torch.cat( (s_prev_context[-i - 1], torch.zeros(1, line.size()[1]))), line)) if len(s_prev_context) > opt.previous_context: s_prev_context = s_prev_context[-1 * opt.previous_context:] srcBatch += [line] if tgtF: #~ tgtTokens = tgtF.readline().split() if tgtF else None tline = tgtF.readline().strip() if opt.previous_context > 0: t_prev_context.append(tline) for i in range(1, opt.previous_context + 1): if i < len(s_prev_context): tline = t_prev_context[-i - 1] + " # " + tline if len(t_prev_context) > opt.previous_context: t_prev_context = t_prev_context[-1 * opt.previous_context:] if opt.input_type == 'word': tgtTokens = tline.split() if tgtF else None elif opt.input_type == 'char': tgtTokens = list(tline.strip()) if tgtF else None else: raise NotImplementedError("Input type unknown") tgtBatch += [tgtTokens] if len(srcBatch) < opt.batch_size: continue print("Batch size:", len(srcBatch), len(tgtBatch)) predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores = translator.translate_asr( srcBatch, tgtBatch) print("Result:", len(predBatch)) count, predScore, predWords, goldScore, goldWords = translateBatch( opt, tgtF, count, outF, translator, srcBatch, tgtBatch, predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores, opt.input_type) predScoreTotal += predScore predWordsTotal += predWords goldScoreTotal += goldScore goldWordsTotal += goldWords srcBatch, tgtBatch = [], [] if len(srcBatch) != 0: print("Batch size:", len(srcBatch), len(tgtBatch)) predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores = translator.translate_asr( srcBatch, tgtBatch) print("Result:", len(predBatch)) count, predScore, predWords, goldScore, goldWords = translateBatch( opt, tgtF, count, outF, translator, srcBatch, tgtBatch, predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores, opt.input_type) predScoreTotal += predScore predWordsTotal += predWords goldScoreTotal += goldScore goldWordsTotal += goldWords srcBatch, tgtBatch = [], [] else: for line in addone(inFile): if line is not None: if opt.input_type == 'word': srcTokens = line.split() elif opt.input_type == 'char': srcTokens = list(line.strip()) else: raise NotImplementedError("Input type unknown") srcBatch += [srcTokens] if tgtF: #~ tgtTokens = tgtF.readline().split() if tgtF else None if opt.input_type == 'word': tgtTokens = tgtF.readline().split() if tgtF else None elif opt.input_type == 'char': tgtTokens = list( tgtF.readline().strip()) if tgtF else None else: raise NotImplementedError("Input type unknown") tgtBatch += [tgtTokens] if len(srcBatch) < opt.batch_size: continue else: # at the end of file, check last batch if len(srcBatch) == 0: break # actually done beam search from the model predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores = translator.translate( srcBatch, tgtBatch) # convert output tensor to words count, predScore, predWords, goldScore, goldWords = translateBatch( opt, tgtF, count, outF, translator, srcBatch, tgtBatch, predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores, opt.input_type) predScoreTotal += predScore predWordsTotal += predWords goldScoreTotal += goldScore goldWordsTotal += goldWords srcBatch, tgtBatch = [], [] if opt.verbose: reportScore('PRED', predScoreTotal, predWordsTotal) if tgtF: reportScore('GOLD', goldScoreTotal, goldWordsTotal) if tgtF: tgtF.close() if opt.dump_beam: json.dump(translator.beam_accum, open(opt.dump_beam, 'w'))
def main(): opt = parser.parse_args() print(opt) opt.cuda = opt.gpu > -1 onmt.Constants.cudaActivated = opt.cuda if opt.cuda: torch.cuda.set_device(opt.gpu) # Always pick n_best opt.n_best = opt.beam_size if opt.output == "stdout": outF = sys.stdout else: outF = open(opt.output, 'w') predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0 srcBatch, tgtBatch = [], [] count = 0 tgtF = open(opt.tgt) if opt.tgt else None if opt.dump_beam != "": import json translator.initBeamAccum() # here we are trying to inFile = None if(opt.src == "stdin"): inFile = sys.stdin opt.batch_size = 1 else: inFile = open(opt.src) if opt.version == 1.0: translator = onmt.EnsembleTranslator(opt) elif opt.version == 2.0: translator = onmt.Translator(opt) for line in addone(inFile): if line is not None: if opt.input_type == 'word': srcTokens = line.split() srcBatch += [srcTokens] if tgtF: tgtTokens = tgtF.readline().split() if tgtF else None tgtBatch += [tgtTokens] elif opt.input_type == 'char': srcTokens = list(line.strip()) srcBatch += [srcTokens] if tgtF: #~ tgtTokens = tgtF.readline().split() if tgtF else None tgtTokens = list(tgtF.readline().strip()) if tgtF else None tgtBatch += [tgtTokens] else: raise NotImplementedError("Input type unknown") #if len(srcBatch) < opt.batch_size: # print('srcBatch < opt.batch_size') # continue else: # at the end of file, check last batch if len(srcBatch) == 0: break predBatch, predScore, predLength, goldScore, numGoldWords = translator.translate(srcBatch, tgtBatch) if opt.normalize and opt.version == 1.0: predBatch_ = [] predScore_ = [] for bb, ss, ll in zip(predBatch, predScore, predLength): #~ ss_ = [s_/numpy.maximum(1.,len(b_)) for b_,s_,l_ in zip(bb,ss,ll)] ss_ = [lenPenalty(s_, l_, opt.alpha) for b_,s_,l_ in zip(bb,ss,ll)] ss_origin = [(s_, len(b_)) for b_,s_,l_ in zip(bb,ss,ll)] sidx = numpy.argsort(ss_)[::-1] #~ print(ss_, sidx, ss_origin) predBatch_.append([bb[s] for s in sidx]) predScore_.append([ss_[s] for s in sidx]) predBatch = predBatch_ predScore = predScore_ if opt.preferLongestOutputs: sortedPredictions = [] for index, prediction in enumerate(predBatch[0]): sortedPredictions.append((index, len(prediction))) sortedPredictions.sort(key=lambda x: x[1], reverse=True) predBatchCopy = predBatch predScoreCopy = predScore for index, sortedPrediction in enumerate(sortedPredictions): predBatch[0][index] = predBatchCopy[0][sortedPredictions[index][0]] predScore[0][index] = predScoreCopy[0][sortedPredictions[index][0]] predScoreTotal += sum(score[0] for score in predScore) predWordsTotal += sum(len(x[0]) for x in predBatch) if tgtF is not None: goldScoreTotal += sum(goldScore).item() goldWordsTotal += numGoldWords for b in range(len(predBatch)): count += 1 bestHyp = getSentenceFromTokens(predBatch[b][0], opt.input_type) if not opt.print_nbest: #~ print(predBatch[b][0]) outF.write(bestHyp + '\n') outF.flush() if opt.verbose: srcSent = getSentenceFromTokens(srcBatch[b], opt.input_type) if translator.tgt_dict.lower: srcSent = srcSent.lower() print('SENT %d: %s' % (count, srcSent)) print('PRED %d: %s' % (count, bestHyp)) print("PRED SCORE: %.4f" % predScore[b][0]) if tgtF is not None: #~ if opt.input_type == 'word': #~ tgtSent = ' '.join(tgtBatch[b]) #~ elif opt.input_type == 'char': #~ tgtSent = ''.join(tgtBatch[b]) tgtSent = getSentenceFromTokens(tgtBatch[b], opt.input_type) if translator.tgt_dict.lower: tgtSent = tgtSent.lower() print('GOLD %d: %s ' % (count, tgtSent)) print("GOLD SCORE: %.4f" % goldScore[b]) if opt.print_nbest: print('\nBEST HYP:') for n in range(opt.n_best): idx = n sent = getSentenceFromTokens(predBatch[b][idx], opt.input_type) print("[%.4f] %s" % (predScore[b][idx], sent)) print('') srcBatch, tgtBatch = [], [] if opt.verbose: reportScore('PRED', predScoreTotal, predWordsTotal) if tgtF: reportScore('GOLD', goldScoreTotal, goldWordsTotal) if tgtF: tgtF.close() if opt.dump_beam: json.dump(translator.beam_accum, open(opt.dump_beam, 'w'))
def main(): opt = parser.parse_args() opt.cuda = opt.gpu > -1 if opt.cuda: torch.cuda.set_device(opt.gpu) # Always pick n_best opt.n_best = opt.beam_size if opt.output == "stdout": outF = sys.stdout else: outF = open(opt.output, 'w') predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0 srcBatch, tgtBatch = [], [] count = 0 tgtF = open(opt.tgt) if opt.tgt else None if opt.dump_beam != "": import json translator.initBeamAccum() # here we are trying to inFile = None if(opt.src == "stdin"): inFile = sys.stdin opt.batch_size = 1 else: inFile = open(opt.src) translator = onmt.EnsembleTranslator(opt) for line in addone(inFile): if line is not None: srcTokens = line.split() srcBatch += [srcTokens] if tgtF: tgtTokens = tgtF.readline().split() if tgtF else None tgtBatch += [tgtTokens] if len(srcBatch) < opt.batch_size: continue else: # at the end of file, check last batch if len(srcBatch) == 0: break predBatch, predScore, predLength, goldScore, numGoldWords = translator.translate(srcBatch, tgtBatch) if opt.normalize: predBatch_ = [] predScore_ = [] for bb, ss, ll in zip(predBatch, predScore, predLength): #~ ss_ = [s_/numpy.maximum(1.,len(b_)) for b_,s_,l_ in zip(bb,ss,ll)] ss_ = [len_penalty(s_, l_, opt.alpha) for b_,s_,l_ in zip(bb,ss,ll)] ss_origin = [(s_, len(b_)) for b_,s_,l_ in zip(bb,ss,ll)] sidx = numpy.argsort(ss_)[::-1] #~ print(ss_, sidx, ss_origin) predBatch_.append([bb[s] for s in sidx]) predScore_.append([ss_[s] for s in sidx]) predBatch = predBatch_ predScore = predScore_ predScoreTotal += sum(score[0].item() for score in predScore) predWordsTotal += sum(len(x[0]) for x in predBatch) if tgtF is not None: goldScoreTotal += sum(goldScore).item() goldWordsTotal += numGoldWords for b in range(len(predBatch)): count += 1 if not opt.print_nbest: #~ print(predBatch[b][0]) outF.write(" ".join(predBatch[b][0]) + '\n') outF.flush() if opt.verbose: srcSent = ' '.join(srcBatch[b]) if translator.tgt_dict.lower: srcSent = srcSent.lower() print('SENT %d: %s' % (count, srcSent)) print('PRED %d: %s' % (count, " ".join(predBatch[b][0]))) print("PRED SCORE: %.4f" % predScore[b][0]) if tgtF is not None: tgtSent = ' '.join(tgtBatch[b]) if translator.tgt_dict.lower: tgtSent = tgtSent.lower() print('GOLD %d: %s ' % (count, tgtSent)) print("GOLD SCORE: %.4f" % goldScore[b]) if opt.print_nbest: print('\nBEST HYP:') for n in range(opt.n_best): idx = n print("[%.4f] %s" % (predScore[b][idx], " ".join(predBatch[b][idx]))) print('') srcBatch, tgtBatch = [], [] if opt.verbose: reportScore('PRED', predScoreTotal, predWordsTotal) if tgtF: reportScore('GOLD', goldScoreTotal, goldWordsTotal) if tgtF: tgtF.close() if opt.dump_beam: json.dump(translator.beam_accum, open(opt.dump_beam, 'w'))
def main(): opt = parser.parse_args() opt.cuda = opt.gpu > -1 if opt.cuda: torch.cuda.set_device(opt.gpu) # Always pick n_best opt.n_best = opt.beam_size if opt.output == "stdout": outF = sys.stdout # out file else: outF = open(opt.output, 'w') if opt.start_prefixing_from: try: # data_dir = os.path.dirname(opt.src) with open('.'.join(opt.src.split('.')[:-1]) + '.num.partial.seqs.pickle', 'rb') as f: num_partial_seqs = pickle.load(f) except: raise Exception('Failed to open partial sequence counter file.') # when using prefix, also write out latency report out_latency_f = open(opt.output_latency, 'w') out_confidence_f = open(opt.output_confidence, 'w') predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0 srcBatch, tgtBatch, tgt_length_batch = [], [], [] count = 0 tgtF = open(opt.tgt) if opt.tgt else None tgt_lengthF = open(opt.tgt_length) if opt.tgt_length else None # here we are trying to inFile = None if opt.start_prefixing_from: translator = onmt.EnsembleTranslatorOnlineSim(opt) else: translator = onmt.EnsembleTranslator(opt) warnings.warn('Not doing online decoding!') if opt.dump_beam != "": import json translator.init_beam_accum() if opt.src == "stdin": inFile = sys.stdin opt.batch_size = 1 elif opt.encoder_type == "audio" and opt.asr_format == "h5": inFile = h5.File(opt.src,'r') elif opt.encoder_type == "audio" and opt.asr_format == "scp": import kaldiio from kaldiio import ReadHelper audio_data = iter(ReadHelper('scp:' + opt.src)) else: inFile = open(opt.src) if opt.encoder_type == "audio": s_prev_context = [] t_prev_context = [] i = 0 latency = [] latency_no_punct = [] partial_hyps = [] partial_scores = [] # this will hold likelihood of each token while True: # keep reading from scp for new utterances if opt.asr_format == "h5": if i == len(inFile): break line = np.array(inFile[str(i)]) i += 1 elif opt.asr_format == "scp": try: utt_id, line = next(audio_data) except StopIteration: break if opt.stride != 1: line = line[0::opt.stride] # line = line[:, :40] line = line[:, :40] line = torch.from_numpy(line) if opt.concat != 1: add = (opt.concat-line.size()[0] % opt.concat) % opt.concat z = torch.FloatTensor(add, line.size()[1]).zero_() line = torch.cat((line,z),0) line = line.reshape((line.size()[0]//opt.concat,line.size()[1]*opt.concat)) #~ srcTokens = line.split() if opt.previous_context > 0: s_prev_context.append(line) for i in range(1, opt.previous_context+1): if i < len(s_prev_context): line = torch.cat((torch.cat((s_prev_context[-i-1], torch.zeros(1, line.size()[1]))), line)) if len(s_prev_context) > opt.previous_context: s_prev_context = s_prev_context[-1 * opt.previous_context:] srcBatch += [line] # make batch if tgtF: #~ tgtTokens = tgtF.readline().split() if tgtF else None tline = tgtF.readline().strip() if opt.previous_context > 0: t_prev_context.append(tline) for i in range(1,opt.previous_context+1): if i < len(s_prev_context): tline = t_prev_context[-i-1]+" # "+tline if len(t_prev_context) > opt.previous_context: t_prev_context = t_prev_context[-1*opt.previous_context:] if opt.input_type == 'word': tgtTokens = tline.split() if tgtF else None elif opt.input_type == 'char': tgtTokens = list(tline.strip()) if tgtF else None else: raise NotImplementedError("Input type unknown") tgtBatch += [tgtTokens] if tgt_lengthF: tgt_length_int = int(tgt_lengthF.readline()) tgt_length_batch = [tgt_length_int] if len(srcBatch) < opt.batch_size: continue # fetch next instance if opt.start_prefixing_from: # online decoding # if I should use prefix partial_utt_idx = int(utt_id.split('_')[-1]) orig_utt_idx = '_'.join(utt_id.split('_')[:-1]) is_full_utt = (partial_utt_idx == (num_partial_seqs[orig_utt_idx] - 1)) else: is_full_utt = True # always write out # modify tgtBatch, if needed if opt.start_prefixing_from: # modify prefix based on agreement if opt.require_prefix_agreements: # require prefix agree if len(partial_hyps) >= 2: idx_agree = 0 for idx_agree in range(min(len(partial_hyps[-2]), len(partial_hyps[-1]))): if partial_hyps[-2][idx_agree] != partial_hyps[-1][idx_agree]: # print('tokens disagreement, stopped at ', idx_agree) idx_agree -= 1 break if partial_utt_idx >= opt.start_prefixing_from: tok_removed = len(partial_hyps[-1]) - (idx_agree + 1) tgtBatch += [partial_hyps[-1][:idx_agree+1]] partial_scores = partial_scores[:len(partial_scores) - tok_removed] latency.append(len(tgtBatch[0])) tok_no_punct = [tok for tok in tgtBatch[0] if not all(c in string.punctuation for c in tok)] latency_no_punct.append(len(tok_no_punct)) else: latency.append(0) latency_no_punct.append(0) else: # normal prefix if partial_utt_idx >= opt.start_prefixing_from: # prev_len = latency[-1] # should not cut further than previous prefix # actual_prefix_len = max(prev_len, len(partial_hyps[-1]) - opt.remove_last_n) if not opt.wait_if_worse or partial_scores[-1] > partial_scores[-2]: tgtBatch += [partial_hyps[-1]] else: print('waiting, {0} < {1}'.format(partial_scores[-1], partial_scores[-2])) del partial_scores[-1] #= partial_scores[-2] # overwrite score del partial_hyps[-1] #= partial_hyps[-2] tgtBatch += [partial_hyps[-1]] # find index of closet partial_hyp that is les than me # latency.append(latency[-1]) # no output latency.append(len(tgtBatch[0])) tok_no_punct = [tok for tok in tgtBatch[0] if not all(c in string.punctuation for c in tok)] latency_no_punct.append(len(tok_no_punct)) else: latency.append(0) latency_no_punct.append(0) print("Batch size:", len(srcBatch), len(tgtBatch)) predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores, all_lk = translator.translate_asr(srcBatch, tgtBatch, tgt_length_batch) print("Result:", len(predBatch)) count,predScore,predWords,goldScore,goldWords,reordered_pred_words, best_conf = translateBatch(opt,tgtF,count,outF,translator,srcBatch,tgtBatch,predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores,opt.input_type, all_lk, write_out=is_full_utt) _partial_hyp = reordered_pred_words[0][0] # best_lk = [] # for my_lk in reordered_all_lk: # best_lk += my_lk[0] # print(my_lk[0]) if is_full_utt: # full seq latency.append(len(_partial_hyp)) latency_no_punct.append(len([tok for tok in _partial_hyp if not all(c in string.punctuation for c in tok)])) else: # strip until last token of partial seq is not only punctuation while _partial_hyp and all(c in string.punctuation for c in _partial_hyp[-1]): # print(_partial_hyp) # print('---------------removing', _partial_hyp[-1]) del _partial_hyp[-1] # print(_partial_hyp) if best_conf: # when continuously deleting punctuations del best_conf[-1] else: del partial_scores[-1] if opt.max_out_per_segment: allowed_len = latency[-1] + opt.max_out_per_segment # print('-----------------', allowed_len) _partial_hyp = _partial_hyp[:allowed_len] best_conf = best_conf[:opt.max_out_per_segment] if opt.remove_last_n: prev_len = latency[-1] actual_removed = min(opt.remove_last_n, len(_partial_hyp) - prev_len) # actual_prefix_len = max(prev_len, len(_partial_hyp) - opt.remove_last_n) # _partial_hyp = _partial_hyp[:actual_prefix_len] _partial_hyp = _partial_hyp[:(len(_partial_hyp)-actual_removed)] # print(_partial_hyp) best_conf = best_conf[:(len(best_conf)-actual_removed)] if opt.confidence_mode and best_conf: # new output non empty if opt.confidence_mode == 4: i = -1 for i in range(len(best_conf) - 1, -1, -1): if best_conf[i] > opt.min_confidence: print(best_conf[i], '>', opt.min_confidence, 'breaking at', i) i += 1 break _partial_hyp = _partial_hyp[:latency[-1] + i] best_conf = best_conf[:i] # elif opt.confidence_mode == 1: # conf_val = sum(best_conf) / len(best_conf) # elif opt.confidence_mode == 2: # conf_val = min(best_conf) # elif opt.confidence_mode == 3: # conf_val = best_conf[-1] else: raise ValueError('Invalid confidence_mode {0}'.format(opt.confidence_mode)) # if conf_val < opt.min_confidence: # _partial_hyp = _partial_hyp[:latency[-1]] # best_conf = [] partial_hyps.append(_partial_hyp) if is_full_utt or partial_utt_idx+1 >= opt.start_prefixing_from: partial_scores.extend(best_conf) #print('==============extended', len(best_conf)) print('best conf', best_conf) predScoreTotal += predScore predWordsTotal += predWords goldScoreTotal += goldScore goldWordsTotal += goldWords srcBatch, tgtBatch, tgt_length_int = [], [], [] if is_full_utt: # don't apply prefix anymore. Clear previous ones if latency[-1] != len(partial_scores): raise ValueError('{0} tokens vs {1} confidence scores!!'.format(latency[-1], len(partial_scores))) out_latency_f.write(','.join([str(x) for x in latency]) + '\n') out_latency_f.write(','.join([str(x) for x in latency_no_punct]) + '\n') out_latency_f.flush() out_confidence_f.write(','.join(['{0:.2f}'.format(x) for x in partial_scores]) + '\n') out_confidence_f.flush() partial_hyps = [] partial_scores = [] latency = [] latency_no_punct = [] # after all utterances are done. if len(srcBatch) != 0: print("Batch size:", len(srcBatch), len(tgtBatch)) predBatch, predScore, predLength, goldScore, numGoldWords,allGoldScores, all_lk = translator.translate_asr(srcBatch, tgtBatch) print("Result:", len(predBatch)) count,predScore,predWords,goldScore,goldWords,reordered_pred_words, best_conf = translateBatch(opt,tgtF,count,outF,translator,srcBatch,tgtBatch,predBatch, predScore, predLength, goldScore, numGoldWords,allGoldScores,opt.input_type, all_lk, write_out=is_full_utt) predScoreTotal += predScore predWordsTotal += predWords goldScoreTotal += goldScore goldWordsTotal += goldWords srcBatch, tgtBatch = [], [] else: for line in addone(inFile): if line is not None: if opt.input_type == 'word': srcTokens = line.split() elif opt.input_type == 'char': srcTokens = list(line.strip()) else: raise NotImplementedError("Input type unknown") srcBatch += [srcTokens] if tgtF: #~ tgtTokens = tgtF.readline().split() if tgtF else None if opt.input_type == 'word': tgtTokens = tgtF.readline().split() if tgtF else None elif opt.input_type == 'char': tgtTokens = list(tgtF.readline().strip()) if tgtF else None else: raise NotImplementedError("Input type unknown") tgtBatch += [tgtTokens] if len(srcBatch) < opt.batch_size: continue else: # at the end of file, check last batch if len(srcBatch) == 0: break # actually done beam search from the model predBatch, predScore, predLength, goldScore, numGoldWords,allGoldScores, all_lk = translator.translate(srcBatch, tgtBatch) # convert output tensor to words count,predScore,predWords,goldScore,goldWords = translateBatch(opt,tgtF,count,outF,translator, srcBatch,tgtBatch, predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores,opt.input_type, all_lk) predScoreTotal += predScore predWordsTotal += predWords goldScoreTotal += goldScore goldWordsTotal += goldWords srcBatch, tgtBatch = [], [] if opt.verbose: reportScore('PRED', predScoreTotal, predWordsTotal) if tgtF: reportScore('GOLD', goldScoreTotal, goldWordsTotal) if tgtF: tgtF.close() if out_latency_f: out_latency_f.close() if out_confidence_f: out_confidence_f.close() if opt.dump_beam: json.dump(translator.beam_accum, open(opt.dump_beam, 'w'))
def __init__(self, model): opt = TranslatorParameter(model) self.translator = onmt.EnsembleTranslator(opt)