def main(): args = parser.parse_args() sample = dict() net_input = dict() feature = get_feature(args.wav_path) target_dict = Dictionary.load(args.target_dict_path) model = load_model(args.w2v_path, target_dict) model[0].eval() generator = W2lViterbiDecoder(target_dict) net_input["source"] = feature.unsqueeze(0) padding_mask = torch.BoolTensor( net_input["source"].size(1)).fill_(False).unsqueeze(0) net_input["padding_mask"] = padding_mask sample["net_input"] = net_input with torch.no_grad(): hypo = generator.generate(model, sample, prefix_tokens=None) hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu()) print(post_process(hyp_pieces, 'letter'))
def predict_file(self, file_path): generator = W2lViterbiDecoder(self.target_dict) sample = dict() net_input = dict() feature = get_feature(file_path) net_input["source"] = feature.unsqueeze(0).to(device) padding_mask = torch.BoolTensor( net_input["source"].size(1)).fill_(False).unsqueeze(0).to(device) net_input["padding_mask"] = padding_mask sample["net_input"] = net_input with torch.no_grad(): hypo = generator.generate(self.model, sample, prefix_tokens=None) # print(hypo[0][0]["tokens"].size()) hyp_pieces = self.target_dict.string(hypo[0][0]["tokens"].int().cpu(), bpe_symbol='none') asr_result = post_process(hyp_pieces, 'none') # print(asr_result) audio_embedding = self.model(**net_input) audio_embedding = audio_embedding['encoder_out_no_proj'].squeeze( 1).cpu().numpy() # audio_embedding = audio_embedding['encoder_out'].squeeze(1).detach().cpu().numpy() # print(np.shape(audio_embedding)) return asr_result, audio_embedding
def build_generator(cfg: UnsupGenerateConfig): w2l_decoder = cfg.w2l_decoder if w2l_decoder == DecoderType.VITERBI: from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder return W2lViterbiDecoder(cfg, task.target_dictionary) elif w2l_decoder == DecoderType.KENLM: from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder return W2lKenLMDecoder(cfg, task.target_dictionary) elif w2l_decoder == DecoderType.FAIRSEQ: from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder return W2lFairseqLMDecoder(cfg, task.target_dictionary) elif w2l_decoder == DecoderType.KALDI: from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoder assert cfg.kaldi_decoder_config is not None return KaldiDecoder( cfg.kaldi_decoder_config, cfg.beam, ) else: raise NotImplementedError( "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment but found " + str(w2l_decoder) )
def build_generator( self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None ): if self.args.decoder_type == "ctc": from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder return W2lViterbiDecoder(args, self.tgt_dict) else: raise NotImplementedError("only ctc decoder is supported at the moment")
def build_generator(self, task, models, cfg): try: from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder except Exception: raise Exception( "Cannot run this test without flashlight dependency") with open_dict(cfg): cfg.nbest = 1 return W2lViterbiDecoder(cfg, task.target_dictionary)
def build_generator(self, args): w2l_decoder = getattr(args, "w2l_decoder", None) if w2l_decoder == "viterbi": from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder return W2lViterbiDecoder(args, self.target_dictionary) elif w2l_decoder == "kenlm": from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder return W2lKenLMDecoder(args, self.target_dictionary) else: return super().build_generator(args)
def _init_model(self, dict_file): parser = self._create_parser(dict_file) model_dir = os.path.dirname(self.model_path) target_dict_path = os.path.join(model_dir, dict_file) args = parser.parse_args([ '--target_dict_path', target_dict_path, '--w2v_path', self.model_path ]) target_dict = Dictionary.load(args.target_dict_path) self.model = self._load_model(args.w2v_path, target_dict)[0] self.model.eval() self.generator = W2lViterbiDecoder(args, target_dict) self.args = args self.target_dict = target_dict
def get_decoder(decoder_args_dict, dictionary): decoder_args = Namespace(**decoder_args_dict) if decoder_args.decoder_type == "viterbi": from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder return W2lViterbiDecoder(decoder_args, dictionary) elif decoder_args.decoder_type == "kenlm": from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder decoder_args.beam_size_token = len(dictionary) if isinstance(decoder_args.unk_weight, str): decoder_args.unk_weight = eval(decoder_args.unk_weight) return W2lKenLMDecoder(decoder_args, dictionary) return None
def build_generator(args): w2l_decoder = getattr(args, "w2l_decoder", None) if w2l_decoder == "viterbi": from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder return W2lViterbiDecoder(args, task.target_dictionary) elif w2l_decoder == "kenlm": from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder print(task.target_dictionary.symbols) return W2lKenLMDecoder(args, task.target_dictionary) elif w2l_decoder == "fairseqlm": from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder print(task.target_dictionary.symbols) return W2lFairseqLMDecoder(args, task.target_dictionary) else: return super().build_generator(args)
def build_generator(args): w2l_decoder = getattr(args, "w2l_decoder", None) if w2l_decoder == "viterbi": from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder return W2lViterbiDecoder(args, task.target_dictionary) elif w2l_decoder == "kenlm": from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder return W2lKenLMDecoder(args, task.target_dictionary) elif w2l_decoder == "fairseqlm": from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder return W2lFairseqLMDecoder(args, task.target_dictionary) else: print('only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment')
def main(): parser = argparse.ArgumentParser(description='Wav2vec-2.0 Recognize') parser.add_argument('--wav_path', type=str, default='~/xxx.wav', help='path of wave file') parser.add_argument('--w2v_path', type=str, default='pre_train_weights/wav2vec_vox_960h_pl.pt', help='path of pre-trained wav2vec-2.0 model') parser.add_argument('--target_dict_path', type=str, default='pre_train_weights/dict.ltr.txt', help='path of target dict (dict.ltr.txt)') args = parser.parse_args() sample = dict() net_input = dict() feature = get_feature(args.wav_path) target_dict = Dictionary.load(args.target_dict_path) model = load_model(args.w2v_path, target_dict) model.eval() generator = W2lViterbiDecoder(target_dict) net_input["source"] = feature.unsqueeze(0).to(device) padding_mask = torch.BoolTensor( net_input["source"].size(1)).fill_(False).unsqueeze(0).to(device) net_input["padding_mask"] = padding_mask sample["net_input"] = net_input with torch.no_grad(): hypo = generator.generate(model, sample, prefix_tokens=None) print(hypo[0][0]["tokens"].size()) hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu(), bpe_symbol='none') print(len(hyp_pieces.split())) print(post_process(hyp_pieces, 'letter_b'))
def build_generator(args): w2l_decoder = getattr(args, "w2l_decoder", None) if w2l_decoder == "viterbi": from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder return W2lViterbiDecoder(args, task.target_dictionary) elif w2l_decoder == "kenlm": from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder return W2lKenLMDecoder(args, task.target_dictionary) elif w2l_decoder == "fairseqlm": from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder return W2lFairseqLMDecoder(args, task.target_dictionary) elif w2l_decoder == "ctc_decoder": from examples.speech_recognition.ctc_decoder import CTCDecoder return CTCDecoder(args, task.target_dictionary) elif w2l_decoder == "cif_decoder": from examples.speech_recognition.cif_decoder import CIFDecoder return CIFDecoder(args, task.target_dictionary, {}) elif w2l_decoder == "cif_lm_decoder": from examples.speech_recognition.cif_decoder import CIFDecoder return CIFDecoder(args, task.target_dictionary, ({}, {})) elif w2l_decoder == "cif_bert_decoder": from examples.speech_recognition.cif_bert_decoder import CIF_BERT_Decoder return CIF_BERT_Decoder(args, task.target_dictionary) elif w2l_decoder == "seq2seq_decoder": from examples.speech_recognition.seq2seq_decoder import Seq2seqDecoder return Seq2seqDecoder(args, task.target_dictionary, {}) elif w2l_decoder == "seq2seq_lm_decoder": from examples.speech_recognition.seq2seq_decoder import Seq2seqDecoder return Seq2seqDecoder(args, task.target_dictionary, ({}, {})) else: return super().build_generator(args)
def __init__(self, model_weight, target_dict): self.target_dict = Dictionary.load(target_dict) self.model = load_model(model_weight, self.target_dict) self.model.eval() self.generator = W2lViterbiDecoder(self.target_dict)
def _main(cfg: DictConfig, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=output_file, ) if 'label_dir' in cfg.task: manifest_dir, _ = os.path.split(cfg.dataset.gen_subset) with read_write(cfg): cfg.task.label_dir = os.path.join(cfg.task.data, manifest_dir) print('cfg.task.data', cfg.task.label_dir) logger = logging.getLogger("fairseq_cli.generate") utils.import_user_module(cfg.common) if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: cfg.dataset.max_tokens = 12000 logger.info(cfg) # Fix seed for stochastic decoding if cfg.common.seed is not None and not cfg.generation.no_seed_provided: np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset splits task = tasks.setup_task(cfg.task) # Set dictionaries try: src_dict = getattr(task, "source_dictionary", None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary overrides = ast.literal_eval(cfg.common_eval.model_overrides) # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, saved_cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, ) token_type = None if type(models[0]) == Wav2Bart or type(models[0]) == WavTransBart or type(models[0]) == WavLinearBart or type(models[0]) == WavBart2Bart: token_type = 'bart' elif type(models[0]) == Wav2BartChr: token_type = 'chr' elif type(models[0]) == Wav2VecCtc or type(models[0]) == Wav2BertChr or type(models[0]) == Wav2BertMixChr: token_type = 'chrctc' elif type(models[0]) == Wav2Bert: token_type = 'bert' else: raise ValueError(f'token_type not defined for {type(models[0])}') print(f'token_type is {token_type}') # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) if cfg.generation.lm_path is not None: overrides["data"] = cfg.task.data try: lms, _ = checkpoint_utils.load_model_ensemble( [cfg.generation.lm_path], arg_overrides=overrides, task=None ) except: logger.warning( f"Failed to load language model! Please make sure that the language model dict is the same " f"as target dict and is located in the data dir ({cfg.task.data})" ) raise assert len(lms) == 1 else: lms = [None] # Optimize ensemble for generation for model in chain(models, lms): if model is None: continue if cfg.common.fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(cfg.generation.replace_unk) itr = task.get_batch_iterator( dataset=task.dataset(cfg.dataset.gen_subset), max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models] ), ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, seed=cfg.common.seed, num_shards=cfg.distributed_training.distributed_world_size, shard_id=cfg.distributed_training.distributed_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator gen_timer = StopwatchMeter() extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} print('cfg.generation', cfg.generation) # print(cfg.task._name == 'audio_pretraining') if cfg.task._name != 'audio_pretraining' and cfg.task._name != 'audio_pretraining_bertbpe': generator = task.build_generator( models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) else: print('use W2lViterbiDecoder') from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder from easydict import EasyDict as edict args = edict({ 'criterion': 'ctc', 'nbest': 1, }) generator = W2lViterbiDecoder(args, task.target_dictionary) # Handle tokenization and BPE tokenizer = task.build_tokenizer(cfg.tokenizer) bpe = task.build_bpe(cfg.bpe) def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x scorer = scoring.build_scorer(cfg.scoring, tgt_dict) num_sentences = 0 has_target = True wps_meter = TimeMeter() for si, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue prefix_tokens = None if cfg.generation.prefix_size > 0: prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] constraints = None if "constraints" in sample: constraints = sample["constraints"] gen_timer.start() hypos = task.inference_step( generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints, ) # print('hypos', hypos) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample["id"].tolist()): has_target = sample["target"] is not None # Remove padding if "src_tokens" in sample["net_input"]: src_tokens = utils.strip_pad( sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() ) else: src_tokens = None target_tokens = None if has_target: target_tokens = ( utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() ) # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text( sample_id ) target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text( sample_id ) else: if src_dict is not None: src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) else: src_str = "" if has_target: if token_type == 'chr': target_str = tgt_dict.string( target_tokens, cfg.common_eval.post_process, escape_unk=True, extra_symbols_to_ignore=get_symbols_to_strip_from_output( generator ), ) elif token_type == 'bart': target_str = task.bart.decode(target_tokens.int().cpu()) elif token_type == 'bert': target_str = task.bert.decode(target_tokens.int().cpu()) elif token_type == 'chrctc': target_str = tgt_dict.string( target_tokens, cfg.common_eval.post_process, escape_unk=True, ) else: raise ValueError(f'token_type not defined for {type(models[0])}') src_str = decode_fn(src_str) if has_target and token_type == 'chr': target_str = decode_fn(target_str) elif has_target and token_type == 'chrctc': target_str = ''.join(target_str.split()).replace('|', ' ') if not cfg.common_eval.quiet: if src_dict is not None: print("S-{}\t{}".format(sample_id, src_str), file=output_file) if has_target: print("T-{}\t{}".format(sample_id, target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): # print('align', hypo["alignment"]) if token_type == 'bart': hypo_tokens = hypo["tokens"].int().cpu() hypo_str = task.bart.decode(hypo["tokens"].int().cpu()) alignment = hypo["alignment"] elif token_type == 'chr': hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=cfg.common_eval.post_process, # extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) elif token_type == 'chrctc': hypo_tokens = hypo["tokens"].int().cpu() hypo_str = task.target_dictionary.string(hypo_tokens) hypo["positional_scores"] = torch.FloatTensor([0.]) elif token_type == 'bert': hypo_tokens = hypo["tokens"].int().cpu() hypo_str = task.bert.decode(hypo["tokens"].int().cpu()) alignment = hypo["alignment"] else: raise ValueError(f'token_type not defined for {type(models[0])}') detok_hypo_str = decode_fn(hypo_str) if token_type == 'chr' or token_type == 'chrctc': print('target_str', ''.join(target_str.split()).replace('|', ' ')) print('typo_str', ''.join(detok_hypo_str.split()).replace('|', ' ')) detok_hypo_str = ''.join(detok_hypo_str.split()).replace('|', ' ') # target_str = ''.join(target_str.split()).replace('|', ' ') elif token_type == 'bart': print('target_str', target_str) print('typo_str', detok_hypo_str) #elif token_type == 'chrctc': # print('target_str', ''.join(target_str.split()).replace('|', ' ')) # print('typo_str', ''.join(detok_hypo_str.split()).replace('|', ' ')) if not cfg.common_eval.quiet: score = hypo["score"] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print( "H-{}\t{}\t{}".format(sample_id, score, hypo_str), file=output_file, ) # detokenized hypothesis print( "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), file=output_file, ) print( "P-{}\t{}".format( sample_id, " ".join( map( lambda x: "{:.4f}".format(x), # convert from base e to base 2 hypo["positional_scores"] .div_(math.log(2)) .tolist(), ) ), ), file=output_file, ) if cfg.generation.print_alignment == "hard": print( "A-{}\t{}".format( sample_id, " ".join( [ "{}-{}".format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment ] ), ), file=output_file, ) if cfg.generation.print_alignment == "soft": print( "A-{}\t{}".format( sample_id, " ".join( [ ",".join(src_probs) for src_probs in alignment ] ), ), file=output_file, ) if cfg.generation.print_step: print( "I-{}\t{}".format(sample_id, hypo["steps"]), file=output_file, ) if cfg.generation.retain_iter_history: for step, h in enumerate(hypo["history"]): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h["tokens"].int().cpu(), src_str=src_str, alignment=None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=None, ) print( "E-{}_{}\t{}".format(sample_id, step, h_str), file=output_file, ) # Score only the top hypothesis if has_target and j == 0: if align_dict is not None or cfg.common_eval.post_process is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True ) hypo_tokens = tgt_dict.encode_line( detok_hypo_str, add_if_not_exist=True ) if hasattr(scorer, "add_string"): # print('add_string 1', target_str, '2', detok_hypo_str) # if si > 2: # raise print('2', target_str, detok_hypo_str) scorer.add_string(target_str, detok_hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) progress.log({"wps": round(wps_meter.avg)}) num_sentences += ( sample["nsentences"] if "nsentences" in sample else sample["id"].numel() ) logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info( "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1.0 / gen_timer.avg, ) ) if has_target: if cfg.bpe and not cfg.generation.sacrebleu: if cfg.common_eval.post_process: logger.warning( "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" ) else: logger.warning( "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" ) # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( "Generate {} with beam={}: {}".format( cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() ), file=output_file, ) return scorer