示例#1
0
                print(sth)

                tgt = torch.LongTensor(hypo).to(device).view(1, -1)
                attn = model.get_attn(enc_out, mask, tgt)
                attn = attn[0]

                if len(sth) > len(stable_hypo):
                    cs = torch.cumsum(attn[head], dim=1)
                    ep = cs.le(1. - padding).sum(dim=1)
                    ep = ep.cpu().numpy()
                    for j in range(len(stable_hypo), len(sth)):
                        if ep[j - 1] * 4 + stable_time > end: break
                    if j > len(stable_hypo):
                        sth = sth[0:j]
                        latency += end * (len(sth) - len(stable_hypo))
                        stable_hypo = sth

            latency += time_len * (len(hypo) - len(stable_hypo))
            latency /= time_len * (len(hypo) - 1)
            count += 1
            total_latency += latency
            print('Latency: %0.3f' % latency)

            write_ctm([hypo[1:]], [score[1:]], fctm, [utt], dic, word_dic,
                      args.space)
    fctm.close()
    print('Final Latency: %0.3f' % (total_latency / count))
    time_elapsed = time.time() - since
    print("  Elapsed Time: %.0fm %.0fs" %
          (time_elapsed // 60, time_elapsed % 60))
示例#2
0
    reader.initialize()

    since = time.time()
    batch_size = args.batch_size
    fout = open(args.output, 'w')
    while True:
        src_seq, src_mask, utts = reader.read_batch_utt(batch_size)
        if len(utts) == 0: break
        with torch.no_grad():
            src_seq, src_mask = src_seq.to(device), src_mask.to(device)
            hypos, scores = beam_search(model,
                                        src_seq,
                                        src_mask,
                                        device,
                                        args.beam_size,
                                        args.max_len,
                                        len_norm=args.len_norm,
                                        coverage=args.coverage,
                                        lm=lm,
                                        lm_scale=args.lm_scale)
            hypos, scores = hypos.tolist(), scores.tolist()
            if args.format == 'ctm':
                write_ctm(hypos, scores, fout, utts, dic, word_dic, args.space)
            else:
                write_text(hypos, scores, fout, utts, dic, args.space)

    fout.close()
    time_elapsed = time.time() - since
    print("  Elapsed Time: %.0fm %.0fs" %
          (time_elapsed // 60, time_elapsed % 60))