Exemplo n.º 1
0
def main():
    opt = parser.parse_args()
    opt.cuda = opt.gpu > -1
    if opt.cuda:
        torch.cuda.set_device(opt.gpu)

    # Always pick n_best
    opt.n_best = opt.beam_size

    if opt.output == "stdout":
        outF = sys.stdout
    else:
        outF = open(opt.output, 'w')

    predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0

    srcBatch, tgtBatch = [], []

    count = 0

    tgtF = open(opt.tgt) if opt.tgt else None

    if opt.dump_beam != "":
        import json
        translator.initBeamAccum()

        # here we are trying to
    inFile = None

    translator = onmt.EnsembleTranslator(opt)

    if opt.src == "stdin":
        inFile = sys.stdin
        opt.batch_size = 1
    elif opt.encoder_type == "audio":
        inFile = h5.File(opt.src, 'r')
    else:
        inFile = open(opt.src)

    if opt.encoder_type == "audio":

        s_prev_context = []
        t_prev_context = []

        for i in range(len(inFile)):
            if opt.stride == 1:
                line = torch.from_numpy(np.array(inFile[str(i)]))
            else:
                line = torch.from_numpy(
                    np.array(inFile[str(i)])[0::opt.stride])
            if opt.concat != 1:
                add = (opt.concat - line.size()[0] % opt.concat) % opt.concat
                z = torch.FloatTensor(add, line.size()[1]).zero_()
                line = torch.cat((line, z), 0)
                line = line.reshape((line.size()[0] // opt.concat,
                                     line.size()[1] * opt.concat))

            #~ srcTokens = line.split()
            if opt.previous_context > 0:
                s_prev_context.append(line)
                for i in range(1, opt.previous_context + 1):
                    if i < len(s_prev_context):
                        line = torch.cat((torch.cat(
                            (s_prev_context[-i - 1],
                             torch.zeros(1,
                                         line.size()[1]))), line))
                if len(s_prev_context) > opt.previous_context:
                    s_prev_context = s_prev_context[-1 * opt.previous_context:]
            srcBatch += [line]

            if tgtF:
                #~ tgtTokens = tgtF.readline().split() if tgtF else None
                tline = tgtF.readline().strip()
                if opt.previous_context > 0:
                    t_prev_context.append(tline)
                    for i in range(1, opt.previous_context + 1):
                        if i < len(s_prev_context):
                            tline = t_prev_context[-i - 1] + " # " + tline
                    if len(t_prev_context) > opt.previous_context:
                        t_prev_context = t_prev_context[-1 *
                                                        opt.previous_context:]

                if opt.input_type == 'word':
                    tgtTokens = tline.split() if tgtF else None
                elif opt.input_type == 'char':
                    tgtTokens = list(tline.strip()) if tgtF else None
                else:
                    raise NotImplementedError("Input type unknown")

                tgtBatch += [tgtTokens]

            if len(srcBatch) < opt.batch_size:
                continue

            print("Batch size:", len(srcBatch), len(tgtBatch))
            predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores = translator.translate_asr(
                srcBatch, tgtBatch)

            print("Result:", len(predBatch))
            count, predScore, predWords, goldScore, goldWords = translateBatch(
                opt, tgtF, count, outF, translator, srcBatch, tgtBatch,
                predBatch, predScore, predLength, goldScore, numGoldWords,
                allGoldScores, opt.input_type)
            predScoreTotal += predScore
            predWordsTotal += predWords
            goldScoreTotal += goldScore
            goldWordsTotal += goldWords
            srcBatch, tgtBatch = [], []

        if len(srcBatch) != 0:
            print("Batch size:", len(srcBatch), len(tgtBatch))
            predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores = translator.translate_asr(
                srcBatch, tgtBatch)
            print("Result:", len(predBatch))
            count, predScore, predWords, goldScore, goldWords = translateBatch(
                opt, tgtF, count, outF, translator, srcBatch, tgtBatch,
                predBatch, predScore, predLength, goldScore, numGoldWords,
                allGoldScores, opt.input_type)
            predScoreTotal += predScore
            predWordsTotal += predWords
            goldScoreTotal += goldScore
            goldWordsTotal += goldWords
            srcBatch, tgtBatch = [], []

    else:

        for line in addone(inFile):
            if line is not None:
                if opt.input_type == 'word':
                    srcTokens = line.split()
                elif opt.input_type == 'char':
                    srcTokens = list(line.strip())
                else:
                    raise NotImplementedError("Input type unknown")
                srcBatch += [srcTokens]
                if tgtF:
                    #~ tgtTokens = tgtF.readline().split() if tgtF else None
                    if opt.input_type == 'word':
                        tgtTokens = tgtF.readline().split() if tgtF else None
                    elif opt.input_type == 'char':
                        tgtTokens = list(
                            tgtF.readline().strip()) if tgtF else None
                    else:
                        raise NotImplementedError("Input type unknown")
                    tgtBatch += [tgtTokens]

                if len(srcBatch) < opt.batch_size:
                    continue
            else:
                # at the end of file, check last batch
                if len(srcBatch) == 0:
                    break

            # actually done beam search from the model
            predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores = translator.translate(
                srcBatch, tgtBatch)

            # convert output tensor to words
            count, predScore, predWords, goldScore, goldWords = translateBatch(
                opt, tgtF, count, outF, translator, srcBatch, tgtBatch,
                predBatch, predScore, predLength, goldScore, numGoldWords,
                allGoldScores, opt.input_type)
            predScoreTotal += predScore
            predWordsTotal += predWords
            goldScoreTotal += goldScore
            goldWordsTotal += goldWords
            srcBatch, tgtBatch = [], []

    if opt.verbose:
        reportScore('PRED', predScoreTotal, predWordsTotal)
        if tgtF: reportScore('GOLD', goldScoreTotal, goldWordsTotal)

    if tgtF:
        tgtF.close()

    if opt.dump_beam:
        json.dump(translator.beam_accum, open(opt.dump_beam, 'w'))
Exemplo n.º 2
0
def main():
    opt = parser.parse_args()
    print(opt)
    opt.cuda = opt.gpu > -1
    onmt.Constants.cudaActivated = opt.cuda

    if opt.cuda:
        torch.cuda.set_device(opt.gpu)
    
    # Always pick n_best
    opt.n_best = opt.beam_size
        
    
    if opt.output == "stdout":
        outF = sys.stdout
    else:
        outF = open(opt.output, 'w')

    predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0

    srcBatch, tgtBatch = [], []

    count = 0

    tgtF = open(opt.tgt) if opt.tgt else None

    if opt.dump_beam != "":
        import json
        translator.initBeamAccum()
        
        # here we are trying to 
    inFile = None
    if(opt.src == "stdin"):
            inFile = sys.stdin
            opt.batch_size = 1
    else:
      inFile = open(opt.src)
    
    if opt.version == 1.0:
        translator = onmt.EnsembleTranslator(opt)
    elif opt.version == 2.0:
        translator = onmt.Translator(opt)
        
    for line in addone(inFile):
        if line is not None:
            if opt.input_type == 'word':
                srcTokens = line.split()
                srcBatch += [srcTokens]
                if tgtF:
                    tgtTokens = tgtF.readline().split() if tgtF else None
                    tgtBatch += [tgtTokens]
            elif opt.input_type == 'char':
                srcTokens = list(line.strip())
                srcBatch += [srcTokens]
                if tgtF:
                    #~ tgtTokens = tgtF.readline().split() if tgtF else None
                    tgtTokens = list(tgtF.readline().strip()) if tgtF else None
                    tgtBatch += [tgtTokens]
            else:
                raise NotImplementedError("Input type unknown")

            #if len(srcBatch) < opt.batch_size:
            #    print('srcBatch < opt.batch_size')
            #    continue
        else:
            # at the end of file, check last batch
            if len(srcBatch) == 0:
                break

        predBatch, predScore, predLength, goldScore, numGoldWords  = translator.translate(srcBatch,
                                                                                    tgtBatch)
        if opt.normalize and opt.version == 1.0:
            predBatch_ = []
            predScore_ = []
            for bb, ss, ll in zip(predBatch, predScore, predLength):
                #~ ss_ = [s_/numpy.maximum(1.,len(b_)) for b_,s_,l_ in zip(bb,ss,ll)]
                ss_ = [lenPenalty(s_, l_, opt.alpha) for b_,s_,l_ in zip(bb,ss,ll)]
                ss_origin = [(s_, len(b_)) for b_,s_,l_ in zip(bb,ss,ll)]
                sidx = numpy.argsort(ss_)[::-1]
                #~ print(ss_, sidx, ss_origin)
                predBatch_.append([bb[s] for s in sidx])
                predScore_.append([ss_[s] for s in sidx])
            predBatch = predBatch_
            predScore = predScore_

            if opt.preferLongestOutputs:
                sortedPredictions = []
                for index, prediction in enumerate(predBatch[0]):
                    sortedPredictions.append((index, len(prediction)))
                sortedPredictions.sort(key=lambda x: x[1], reverse=True)

                predBatchCopy = predBatch
                predScoreCopy = predScore

                for index, sortedPrediction in enumerate(sortedPredictions):
                    predBatch[0][index] = predBatchCopy[0][sortedPredictions[index][0]]
                    predScore[0][index] = predScoreCopy[0][sortedPredictions[index][0]]


        predScoreTotal += sum(score[0] for score in predScore)
        predWordsTotal += sum(len(x[0]) for x in predBatch)
        if tgtF is not None:
            goldScoreTotal += sum(goldScore).item()
            goldWordsTotal += numGoldWords
            
        for b in range(len(predBatch)):
                        
            count += 1
            
            bestHyp =  getSentenceFromTokens(predBatch[b][0], opt.input_type)            
            if not opt.print_nbest:
                #~ print(predBatch[b][0])
                outF.write(bestHyp + '\n')
                outF.flush()

            if opt.verbose:
                srcSent = getSentenceFromTokens(srcBatch[b], opt.input_type)
                if translator.tgt_dict.lower:
                    srcSent = srcSent.lower()
                print('SENT %d: %s' % (count, srcSent))
                
                
                print('PRED %d: %s' % (count, bestHyp))
                print("PRED SCORE: %.4f" %  predScore[b][0])

                if tgtF is not None:
                    #~ if opt.input_type == 'word':
                        #~ tgtSent = ' '.join(tgtBatch[b]) 
                    #~ elif opt.input_type == 'char':
                        #~ tgtSent = ''.join(tgtBatch[b])
                    tgtSent = getSentenceFromTokens(tgtBatch[b], opt.input_type)
                    if translator.tgt_dict.lower:
                        tgtSent = tgtSent.lower()
                    print('GOLD %d: %s ' % (count, tgtSent))
                    print("GOLD SCORE: %.4f" % goldScore[b])

                if opt.print_nbest:
                    print('\nBEST HYP:')
                    for n in range(opt.n_best):
                        idx = n
                        sent = getSentenceFromTokens(predBatch[b][idx], opt.input_type)
                        print("[%.4f] %s" % (predScore[b][idx], sent))
                            

                print('')

        srcBatch, tgtBatch = [], []
        
    if opt.verbose:
        reportScore('PRED', predScoreTotal, predWordsTotal)
        if tgtF: reportScore('GOLD', goldScoreTotal, goldWordsTotal)
                

    if tgtF:
        tgtF.close()

    if opt.dump_beam:
        json.dump(translator.beam_accum, open(opt.dump_beam, 'w'))
Exemplo n.º 3
0
def main():
    opt = parser.parse_args()
    opt.cuda = opt.gpu > -1
    if opt.cuda:
        torch.cuda.set_device(opt.gpu)
    
    # Always pick n_best
    opt.n_best = opt.beam_size
        
    
    if opt.output == "stdout":
        outF = sys.stdout
    else:
        outF = open(opt.output, 'w')

    predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0

    srcBatch, tgtBatch = [], []

    count = 0

    tgtF = open(opt.tgt) if opt.tgt else None

    if opt.dump_beam != "":
        import json
        translator.initBeamAccum()
        
        # here we are trying to 
    inFile = None
    if(opt.src == "stdin"):
            inFile = sys.stdin
            opt.batch_size = 1
    else:
      inFile = open(opt.src)
        
    translator = onmt.EnsembleTranslator(opt)
        
    for line in addone(inFile):
        if line is not None:
            srcTokens = line.split()
            srcBatch += [srcTokens]
            if tgtF:
                tgtTokens = tgtF.readline().split() if tgtF else None
                tgtBatch += [tgtTokens]

            if len(srcBatch) < opt.batch_size:
                continue
        else:
            # at the end of file, check last batch
            if len(srcBatch) == 0:
                break

        predBatch, predScore, predLength, goldScore, numGoldWords  = translator.translate(srcBatch,
                                                                                    tgtBatch)
        if opt.normalize:
            predBatch_ = []
            predScore_ = []
            for bb, ss, ll in zip(predBatch, predScore, predLength):
                #~ ss_ = [s_/numpy.maximum(1.,len(b_)) for b_,s_,l_ in zip(bb,ss,ll)]
                ss_ = [len_penalty(s_, l_, opt.alpha) for b_,s_,l_ in zip(bb,ss,ll)]
                ss_origin = [(s_, len(b_)) for b_,s_,l_ in zip(bb,ss,ll)]
                sidx = numpy.argsort(ss_)[::-1]
                #~ print(ss_, sidx, ss_origin)
                predBatch_.append([bb[s] for s in sidx])
                predScore_.append([ss_[s] for s in sidx])
            predBatch = predBatch_
            predScore = predScore_    
                                                              
        predScoreTotal += sum(score[0].item() for score in predScore)
        predWordsTotal += sum(len(x[0]) for x in predBatch)
        if tgtF is not None:
            goldScoreTotal += sum(goldScore).item()
            goldWordsTotal += numGoldWords
            
        for b in range(len(predBatch)):
                        
            count += 1
                        
            if not opt.print_nbest:
                #~ print(predBatch[b][0])
                outF.write(" ".join(predBatch[b][0]) + '\n')
                outF.flush()

            if opt.verbose:
                srcSent = ' '.join(srcBatch[b])
                if translator.tgt_dict.lower:
                    srcSent = srcSent.lower()
                print('SENT %d: %s' % (count, srcSent))
                print('PRED %d: %s' % (count, " ".join(predBatch[b][0])))
                print("PRED SCORE: %.4f" %  predScore[b][0])

                if tgtF is not None:
                    tgtSent = ' '.join(tgtBatch[b])
                    if translator.tgt_dict.lower:
                        tgtSent = tgtSent.lower()
                    print('GOLD %d: %s ' % (count, tgtSent))
                    print("GOLD SCORE: %.4f" % goldScore[b])

                if opt.print_nbest:
                    print('\nBEST HYP:')
                    for n in range(opt.n_best):
                        idx = n
                        print("[%.4f] %s" % (predScore[b][idx],
                            " ".join(predBatch[b][idx])))

                print('')

        srcBatch, tgtBatch = [], []
        
    if opt.verbose:
        reportScore('PRED', predScoreTotal, predWordsTotal)
        if tgtF: reportScore('GOLD', goldScoreTotal, goldWordsTotal)
                

    if tgtF:
        tgtF.close()

    if opt.dump_beam:
        json.dump(translator.beam_accum, open(opt.dump_beam, 'w'))
Exemplo n.º 4
0
def main():
    opt = parser.parse_args()
    opt.cuda = opt.gpu > -1
    if opt.cuda:
        torch.cuda.set_device(opt.gpu)
    
    # Always pick n_best
    opt.n_best = opt.beam_size

    if opt.output == "stdout":
        outF = sys.stdout  # out file
    else:
        outF = open(opt.output, 'w')

    if opt.start_prefixing_from:
        try:
            # data_dir = os.path.dirname(opt.src)
            with open('.'.join(opt.src.split('.')[:-1]) + '.num.partial.seqs.pickle', 'rb') as f:
                num_partial_seqs = pickle.load(f)
        except:
            raise Exception('Failed to open partial sequence counter file.')

    # when using prefix, also write out latency report
    out_latency_f = open(opt.output_latency, 'w')
    out_confidence_f = open(opt.output_confidence, 'w')

    predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0

    srcBatch, tgtBatch, tgt_length_batch = [], [], []

    count = 0

    tgtF = open(opt.tgt) if opt.tgt else None
    tgt_lengthF = open(opt.tgt_length) if opt.tgt_length else None

    # here we are trying to
    inFile = None

    if opt.start_prefixing_from:
        translator = onmt.EnsembleTranslatorOnlineSim(opt)
    else:
        translator = onmt.EnsembleTranslator(opt)
        warnings.warn('Not doing online decoding!')

    if opt.dump_beam != "":
        import json
        translator.init_beam_accum()

    if opt.src == "stdin":
        inFile = sys.stdin
        opt.batch_size = 1
    elif opt.encoder_type == "audio" and opt.asr_format == "h5":
        inFile = h5.File(opt.src,'r')
    elif opt.encoder_type == "audio" and opt.asr_format == "scp":
        import kaldiio
        from kaldiio import ReadHelper
        audio_data = iter(ReadHelper('scp:' + opt.src))
    else:
      inFile = open(opt.src)

    if opt.encoder_type == "audio":

        s_prev_context = []
        t_prev_context = []

        i = 0

        latency = []
        latency_no_punct = []
        partial_hyps = []
        partial_scores = []  # this will hold likelihood of each token
        while True:  # keep reading from scp for new utterances
            if opt.asr_format == "h5":
                if i == len(inFile):
                    break
                line = np.array(inFile[str(i)])
                i += 1
            elif opt.asr_format == "scp":
                try:
                    utt_id, line = next(audio_data)
                except StopIteration:
                    break

            if opt.stride != 1:
                line = line[0::opt.stride]
            
            # line = line[:, :40]
            line = line[:, :40]
            
            line = torch.from_numpy(line)

            if opt.concat != 1:
                add = (opt.concat-line.size()[0] % opt.concat) % opt.concat
                z = torch.FloatTensor(add, line.size()[1]).zero_()
                line = torch.cat((line,z),0)
                line = line.reshape((line.size()[0]//opt.concat,line.size()[1]*opt.concat))

            #~ srcTokens = line.split()
            if opt.previous_context > 0:
                s_prev_context.append(line)
                for i in range(1, opt.previous_context+1):
                    if i < len(s_prev_context):
                        line = torch.cat((torch.cat((s_prev_context[-i-1], torch.zeros(1, line.size()[1]))), line))
                if len(s_prev_context) > opt.previous_context:
                    s_prev_context = s_prev_context[-1 * opt.previous_context:]
            srcBatch += [line]  # make batch

            if tgtF:
                #~ tgtTokens = tgtF.readline().split() if tgtF else None
                tline = tgtF.readline().strip()
                if opt.previous_context > 0:
                    t_prev_context.append(tline)
                    for i in range(1,opt.previous_context+1):
                        if i < len(s_prev_context):
                            tline = t_prev_context[-i-1]+" # "+tline
                    if len(t_prev_context) > opt.previous_context:
                        t_prev_context = t_prev_context[-1*opt.previous_context:]

                if opt.input_type == 'word':
                    tgtTokens = tline.split() if tgtF else None
                elif opt.input_type == 'char':
                    tgtTokens = list(tline.strip()) if tgtF else None
                else:
                    raise NotImplementedError("Input type unknown")

                tgtBatch += [tgtTokens]

            if tgt_lengthF:
                tgt_length_int = int(tgt_lengthF.readline())
                tgt_length_batch = [tgt_length_int]

            if len(srcBatch) < opt.batch_size:
                continue  # fetch next instance

            if opt.start_prefixing_from:  # online decoding
                # if I should use prefix
                partial_utt_idx = int(utt_id.split('_')[-1])
                orig_utt_idx = '_'.join(utt_id.split('_')[:-1])
                is_full_utt = (partial_utt_idx == (num_partial_seqs[orig_utt_idx] - 1))
            else:
                is_full_utt = True  # always write out

            # modify tgtBatch, if needed
            if opt.start_prefixing_from:
                # modify prefix based on agreement
                if opt.require_prefix_agreements:  # require prefix agree
                    if len(partial_hyps) >= 2:
                        idx_agree = 0
                        for idx_agree in range(min(len(partial_hyps[-2]), len(partial_hyps[-1]))):
                            if partial_hyps[-2][idx_agree] != partial_hyps[-1][idx_agree]:
                                # print('tokens disagreement, stopped at ', idx_agree)
                                idx_agree -= 1
                                break
                    if partial_utt_idx >= opt.start_prefixing_from:
                        tok_removed = len(partial_hyps[-1]) - (idx_agree + 1)
                        tgtBatch += [partial_hyps[-1][:idx_agree+1]]
                        partial_scores = partial_scores[:len(partial_scores) - tok_removed]
                        latency.append(len(tgtBatch[0]))
                        tok_no_punct = [tok for tok in tgtBatch[0] if not all(c in string.punctuation for c in tok)]
                        latency_no_punct.append(len(tok_no_punct))
                    else:
                        latency.append(0)
                        latency_no_punct.append(0)

                else:  # normal prefix
                    if partial_utt_idx >= opt.start_prefixing_from:
                        # prev_len = latency[-1]  # should not cut further than previous prefix
                        # actual_prefix_len = max(prev_len, len(partial_hyps[-1]) - opt.remove_last_n)

                        if not opt.wait_if_worse or partial_scores[-1] > partial_scores[-2]:
                            tgtBatch += [partial_hyps[-1]]

                        else:
                            print('waiting, {0} < {1}'.format(partial_scores[-1], partial_scores[-2]))
                            del partial_scores[-1]  #= partial_scores[-2]  # overwrite score
                            del partial_hyps[-1]  #= partial_hyps[-2]
                            tgtBatch += [partial_hyps[-1]]  # find index of closet partial_hyp that is les than me

                            # latency.append(latency[-1])  # no output
                        latency.append(len(tgtBatch[0]))
                        tok_no_punct = [tok for tok in tgtBatch[0] if not all(c in string.punctuation for c in tok)]
                        latency_no_punct.append(len(tok_no_punct))
                    else:
                        latency.append(0)
                        latency_no_punct.append(0)

            print("Batch size:", len(srcBatch), len(tgtBatch))
            predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores, all_lk = translator.translate_asr(srcBatch, tgtBatch, tgt_length_batch)
            print("Result:", len(predBatch))

            count,predScore,predWords,goldScore,goldWords,reordered_pred_words, best_conf = translateBatch(opt,tgtF,count,outF,translator,srcBatch,tgtBatch,predBatch, predScore, predLength, goldScore, numGoldWords, allGoldScores,opt.input_type,
                                                                                            all_lk, write_out=is_full_utt)
            _partial_hyp = reordered_pred_words[0][0]

            # best_lk = []
            # for my_lk in reordered_all_lk:
            #     best_lk += my_lk[0]
            #     print(my_lk[0])

            if is_full_utt:  # full seq
                latency.append(len(_partial_hyp))
                latency_no_punct.append(len([tok for tok in _partial_hyp
                                             if not all(c in string.punctuation for c in tok)]))

            else:
                # strip until last token of partial seq is not only punctuation
                while _partial_hyp and all(c in string.punctuation for c in _partial_hyp[-1]):
                    # print(_partial_hyp)
                    # print('---------------removing', _partial_hyp[-1])
                    del _partial_hyp[-1]
                    # print(_partial_hyp)
                    if best_conf:  # when continuously deleting punctuations
                        del best_conf[-1]
                    else:
                        del partial_scores[-1]

                if opt.max_out_per_segment:
                    allowed_len = latency[-1] + opt.max_out_per_segment
                    # print('-----------------', allowed_len)
                    _partial_hyp = _partial_hyp[:allowed_len]
                    best_conf = best_conf[:opt.max_out_per_segment]

                if opt.remove_last_n:
                    prev_len = latency[-1]
                    actual_removed = min(opt.remove_last_n, len(_partial_hyp) - prev_len)
                    # actual_prefix_len = max(prev_len, len(_partial_hyp) - opt.remove_last_n)
                    # _partial_hyp = _partial_hyp[:actual_prefix_len]
                    _partial_hyp = _partial_hyp[:(len(_partial_hyp)-actual_removed)]
                    # print(_partial_hyp)
                    best_conf = best_conf[:(len(best_conf)-actual_removed)]

                if opt.confidence_mode and best_conf:  # new output non empty
                    if opt.confidence_mode == 4:
                        i = -1
                        for i in range(len(best_conf) - 1, -1, -1):
                            if best_conf[i] > opt.min_confidence:
                                print(best_conf[i], '>', opt.min_confidence, 'breaking at', i)
                                i += 1
                                break

                        _partial_hyp = _partial_hyp[:latency[-1] + i]
                        best_conf = best_conf[:i]

                    # elif opt.confidence_mode == 1:
                    #     conf_val = sum(best_conf) / len(best_conf)
                    # elif opt.confidence_mode == 2:
                    #     conf_val = min(best_conf)
                    # elif opt.confidence_mode == 3:
                    #     conf_val = best_conf[-1]
                    else:
                        raise ValueError('Invalid confidence_mode {0}'.format(opt.confidence_mode))

                    # if conf_val < opt.min_confidence:
                    #     _partial_hyp = _partial_hyp[:latency[-1]]
                    #     best_conf = []

            partial_hyps.append(_partial_hyp)
            if is_full_utt or partial_utt_idx+1 >= opt.start_prefixing_from:
                partial_scores.extend(best_conf)
                #print('==============extended', len(best_conf))
                print('best conf', best_conf)

            predScoreTotal += predScore
            predWordsTotal += predWords
            goldScoreTotal += goldScore
            goldWordsTotal += goldWords
            srcBatch, tgtBatch, tgt_length_int = [], [], []

            if is_full_utt:  # don't apply prefix anymore. Clear previous ones
                if latency[-1] != len(partial_scores):
                    raise ValueError('{0} tokens vs {1} confidence scores!!'.format(latency[-1], len(partial_scores)))
                out_latency_f.write(','.join([str(x) for x in latency]) + '\n')
                out_latency_f.write(','.join([str(x) for x in latency_no_punct]) + '\n')
                out_latency_f.flush()
                out_confidence_f.write(','.join(['{0:.2f}'.format(x) for x in partial_scores]) + '\n')
                out_confidence_f.flush()
                partial_hyps = []
                partial_scores = []
                latency = []
                latency_no_punct = []

        # after all utterances are done.

        if len(srcBatch) != 0:
            print("Batch size:", len(srcBatch), len(tgtBatch))
            predBatch, predScore, predLength, goldScore, numGoldWords,allGoldScores, all_lk = translator.translate_asr(srcBatch,
                                                                                    tgtBatch)
            print("Result:", len(predBatch))

            count,predScore,predWords,goldScore,goldWords,reordered_pred_words, best_conf = translateBatch(opt,tgtF,count,outF,translator,srcBatch,tgtBatch,predBatch, predScore, predLength, goldScore, numGoldWords,allGoldScores,opt.input_type,
                                                                           all_lk, write_out=is_full_utt)
            predScoreTotal += predScore
            predWordsTotal += predWords
            goldScoreTotal += goldScore
            goldWordsTotal += goldWords
            srcBatch, tgtBatch = [], []


    else:
        
        for line in addone(inFile):
            if line is not None:
                if opt.input_type == 'word':
                    srcTokens = line.split()
                elif opt.input_type == 'char':
                    srcTokens = list(line.strip())
                else:
                    raise NotImplementedError("Input type unknown")
                srcBatch += [srcTokens]
                if tgtF:
                    #~ tgtTokens = tgtF.readline().split() if tgtF else None
                    if opt.input_type == 'word':
                        tgtTokens = tgtF.readline().split() if tgtF else None
                    elif opt.input_type == 'char':
                        tgtTokens = list(tgtF.readline().strip()) if tgtF else None
                    else:
                        raise NotImplementedError("Input type unknown")
                    tgtBatch += [tgtTokens]

                if len(srcBatch) < opt.batch_size:
                    continue
            else:
                # at the end of file, check last batch
                if len(srcBatch) == 0:
                    break

            # actually done beam search from the model
            predBatch, predScore, predLength, goldScore, numGoldWords,allGoldScores, all_lk = translator.translate(srcBatch,
                                                                                    tgtBatch)

            # convert output tensor to words
            count,predScore,predWords,goldScore,goldWords = translateBatch(opt,tgtF,count,outF,translator,
                                                                           srcBatch,tgtBatch,
                                                                           predBatch, predScore, predLength,
                                                                           goldScore, numGoldWords,
                                                                           allGoldScores,opt.input_type, all_lk)
            predScoreTotal += predScore
            predWordsTotal += predWords
            goldScoreTotal += goldScore
            goldWordsTotal += goldWords
            srcBatch, tgtBatch = [], []

    if opt.verbose:
        reportScore('PRED', predScoreTotal, predWordsTotal)
        if tgtF: reportScore('GOLD', goldScoreTotal, goldWordsTotal)


    if tgtF:
        tgtF.close()

    if out_latency_f:
        out_latency_f.close()

    if out_confidence_f:
        out_confidence_f.close()

    if opt.dump_beam:
        json.dump(translator.beam_accum, open(opt.dump_beam, 'w'))
 def __init__(self, model):
     opt = TranslatorParameter(model)
     self.translator = onmt.EnsembleTranslator(opt)