def run(args): print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) aligner = CtcAligner(args.am, cpt_tag=args.am_tag, device_id=args.device_id) if aligner.accept_raw: src_reader = AudioReader(args.feats_or_wav_scp, sr=args.sr, channel=args.channel) else: src_reader = ScriptReader(args.feats_or_wav_scp) if args.word_boundary: raise RuntimeError( "Now can't generate word boundary when using Kaldi's feature") txt_reader = Reader(args.text, num_tokens=-1, restrict=False) processor = TextPreProcessor(args.dict, space=args.space, spm=args.spm) ali_stdout, ali_fd = io_wrapper(args.alignment, "w") wdb_stdout, wdb_fd = False, None if args.word_boundary: wdb_stdout, wdb_fd = io_wrapper(args.word_boundary, "w") done = 0 tot_utts = len(src_reader) timer = SimpleTimer() for key, str_seq in txt_reader: done += 1 logger.info( f"Generate alignment for utterance {key} ({done}/{tot_utts}) ...") int_seq = processor.run(str_seq) wav_or_feats = src_reader[key] ali = aligner.run(wav_or_feats, int_seq) header = f"{ali['score']:.3f}, {len(ali['align_seq'])}" ali_fd.write(f"{key} {ali['align_str']}\n") logger.info(f"{key} ({header}) {ali['align_str']}") if wdb_fd: dur = wav_or_feats.shape[-1] * 1.0 / args.sr wdb = gen_word_boundary(key, dur, ali["align_str"]) wdb_fd.write("\n".join(wdb) + "\n") if not ali_stdout: ali_fd.close() if wdb_fd and not wdb_stdout: wdb_fd.close() cost = timer.elapsed() logger.info(f"Generate alignments for {tot_utts} utterance done, " + f"time cost = {cost:.2f}m")
def run(args): nbest, nbest_hypos = read_nbest(args.nbest) ngram = kenlm.LanguageModel(args.lm) stdout, top1 = io_wrapper(args.top1, "w") for key, nbest_dict in nbest_hypos.items(): rescore = [] for hyp in nbest_dict: am_score, num_tokens, trans = hyp lm_score = ngram.score(trans, bos=True, eos=True) if args.len_norm: am_score /= num_tokens score = am_score + args.lm_weight * lm_score rescore.append((score, trans)) rescore = sorted(rescore, key=lambda n: n[0], reverse=True) top1.write(f"{key}\t{rescore[0][1]}\n") if not stdout: top1.close() logger.info(f"Rescore {len(nbest_hypos)} utterances on {nbest} hypos")
def run(args): print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) if args.batch_size == 1: warnings.warn("can use decode.py instead as batch_size == 1") decoder = BatchDecoder(args.am, device_id=args.device_id, cpt_tag=args.am_tag) if decoder.accept_raw: src_reader = AudioReader(args.feats_or_wav_scp, sr=args.sr, channel=args.channel) else: src_reader = ScriptReader(args.feats_or_wav_scp) if args.lm: if Path(args.lm).is_file(): from aps.asr.lm.ngram import NgramLM lm = NgramLM(args.lm, args.dict) logger.info( f"Load ngram LM from {args.lm}, weight = {args.lm_weight}") else: lm = NnetEvaluator(args.lm, device_id=args.device_id, cpt_tag=args.lm_tag) logger.info(f"Load RNN LM from {args.lm}: epoch {lm.epoch}, " + f"weight = {args.lm_weight}") lm = lm.nnet else: lm = None processor = TextPostProcessor(args.dict, space=args.space, show_unk=args.show_unk, spm=args.spm) stdout_top1, top1 = io_wrapper(args.best, "w") topn = None if args.dump_nbest: stdout_topn, topn = io_wrapper(args.dump_nbest, "w") nbest = min(args.beam_size, args.nbest) topn.write(f"{nbest}\n") ali_dir = args.dump_align if ali_dir: Path(ali_dir).mkdir(exist_ok=True, parents=True) logger.info(f"Dump alignments to dir: {ali_dir}") done = 0 timer = SimpleTimer() batches = [] dec_args = dict( filter(lambda x: x[0] in beam_search_params, vars(args).items())) dec_args["lm"] = lm tot_utts = len(src_reader) for key, src in src_reader: done += 1 batches.append({ "key": key, "inp": src, "len": src.shape[-1] if decoder.accept_raw else src.shape[0] }) end = (done == len(src_reader) and len(batches)) if len(batches) != args.batch_size and not end: continue # decode batches = sorted(batches, key=lambda b: b["len"], reverse=True) batch_nbest = decoder.run([bz["inp"] for bz in batches], **dec_args) keys = [bz["key"] for bz in batches] for key, nbest in zip(keys, batch_nbest): logger.info(f"Decoding utterance {key} ({done}/{tot_utts}) ...") nbest_hypos = [f"{key}\n"] for idx, hyp in enumerate(nbest): # remove SOS/EOS token = hyp["trans"][1:-1] trans = processor.run(token) score = hyp["score"] nbest_hypos.append(f"{score:.3f}\t{len(token):d}\t{trans}\n") if idx == 0: logger.info(f"{key} ({score:.3f}, {len(token):d}) {trans}") top1.write(f"{key}\t{trans}\n") if ali_dir: if hyp["align"] is None: raise RuntimeError( "Can not dump alignment out as it's None") np.save(f"{ali_dir}/{key}-nbest{idx+1}", hyp["align"].numpy()) if topn: topn.write("".join(nbest_hypos)) top1.flush() if topn: topn.flush() batches.clear() if not stdout_top1: top1.close() if topn and not stdout_topn: topn.close() cost = timer.elapsed() logger.info(f"Decode {tot_utts} utterance done, time cost = {cost:.2f}m")
def run(args): print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) decoder = FasterDecoder(args.am, cpt_tag=args.am_tag, function=args.function, device_id=args.device_id) if decoder.accept_raw: src_reader = AudioReader(args.feats_or_wav_scp, sr=args.sr, channel=args.channel) else: src_reader = ScriptReader(args.feats_or_wav_scp) if args.lm: if Path(args.lm).is_file(): from aps.asr.lm.ngram import NgramLM lm = NgramLM(args.lm, args.dict) logger.info( f"Load ngram LM from {args.lm}, weight = {args.lm_weight}") else: lm = NnetEvaluator(args.lm, device_id=args.device_id, cpt_tag=args.lm_tag) logger.info(f"Load RNN LM from {args.lm}: epoch {lm.epoch}, " + f"weight = {args.lm_weight}") lm = lm.nnet else: lm = None processor = TextPostProcessor(args.dict, space=args.space, show_unk=args.show_unk, spm=args.spm) stdout_top1, top1 = io_wrapper(args.best, "w") topn = None if args.dump_nbest: stdout_topn, topn = io_wrapper(args.dump_nbest, "w") if args.function == "greedy_search": nbest = min(args.beam_size, args.nbest) else: nbest = 1 topn.write(f"{nbest}\n") ali_dir = args.dump_align if ali_dir: Path(ali_dir).mkdir(exist_ok=True, parents=True) logger.info(f"Dump alignments to dir: {ali_dir}") N = 0 timer = SimpleTimer() dec_args = dict( filter(lambda x: x[0] in beam_search_params, vars(args).items())) dec_args["lm"] = lm for key, src in src_reader: logger.info(f"Decoding utterance {key}...") nbest_hypos = decoder.run(src, **dec_args) nbest = [f"{key}\n"] for idx, hyp in enumerate(nbest_hypos): # remove SOS/EOS token = hyp["trans"][1:-1] trans = processor.run(token) score = hyp["score"] nbest.append(f"{score:.3f}\t{len(token):d}\t{trans}\n") if idx == 0: top1.write(f"{key}\t{trans}\n") if ali_dir: if hyp["align"] is None: raise RuntimeError( "Can not dump alignment out as it's None") np.save(f"{ali_dir}/{key}-nbest{idx+1}", hyp["align"].numpy()) if topn: topn.write("".join(nbest)) if not (N + 1) % 10: top1.flush() if topn: topn.flush() N += 1 if not stdout_top1: top1.close() if topn and not stdout_topn: topn.close() cost = timer.elapsed() logger.info( f"Decode {len(src_reader)} utterance done, time cost = {cost:.2f}m")