def inference( output_dir: str, maxlenratio: float, minlenratio: float, batch_size: int, dtype: str, beam_size: int, ngpu: int, seed: int, ctc_weight: float, lm_weight: float, penalty: float, nbest: int, num_workers: int, log_level: Union[int, str], data_path_and_name_and_type: Sequence[Tuple[str, str, str]], key_file: Optional[str], asr_train_config: str, asr_model_file: str, lm_train_config: Optional[str], lm_file: Optional[str], word_lm_train_config: Optional[str], word_lm_file: Optional[str], blank_symbol: str, token_type: Optional[str], bpemodel: Optional[str], allow_variable_data_keys: bool, ): assert check_argument_types() if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if word_lm_train_config is not None: raise NotImplementedError("Word LM is not implemented") if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if ngpu >= 1: device = "cuda" else: device = "cpu" # 1. Set random-seed set_all_random_seed(seed) # 2. Build ASR model scorers = {} asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, device) asr_model.eval() decoder = asr_model.decoder ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) token_list = asr_model.token_list scorers.update( decoder=decoder, ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) # 3. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device) scorers["lm"] = lm.lm # 4. Build BeamSearch object weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, lm=lm_weight, length_bonus=penalty, ) beam_search = BeamSearch( beam_size=beam_size, weights=weights, scorers=scorers, sos=asr_model.sos, eos=asr_model.eos, vocab_size=len(token_list), token_list=token_list, ) beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() for scorer in scorers.values(): if isinstance(scorer, torch.nn.Module): scorer.to(device=device, dtype=getattr(torch, dtype)).eval() logging.info(f"Beam_search: {beam_search}") logging.info(f"Decoding device={device}, dtype={dtype}") # 5. Build data-iterator loader = ASRTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, preprocess_fn=ASRTask.build_preprocess_fn(asr_train_args, False), collate_fn=ASRTask.build_collate_fn(asr_train_args), allow_variable_data_keys=allow_variable_data_keys, inference=True, ) # 6. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel if token_type is None: tokenizer = None elif token_type == "bpe": if bpemodel is not None: tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: tokenizer = None else: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") # 7 .Start for-loop # FIXME(kamo): The output format should be discussed about with DatadirWriter(output_dir) as writer: for keys, batch in loader: assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" with torch.no_grad(): # a. To device batch = to_device(batch, device) # b. Forward Encoder enc, _ = asr_model.encode(**batch) assert len(enc) == batch_size, len(enc) # c. Passed the encoder result and the beam search nbest_hyps = beam_search(x=enc[0], maxlenratio=maxlenratio, minlenratio=minlenratio) nbest_hyps = nbest_hyps[:nbest] # Only supporting batch_size==1 key = keys[0] for n in range(1, nbest + 1): hyp = nbest_hyps[n - 1] assert isinstance(hyp, Hypothesis), type(hyp) # remove sos/eos and get results token_int = hyp.yseq[1:-1].tolist() # remove blank symbol id, which is assumed to be 0 token_int = list(filter(lambda x: x != 0, token_int)) # Change integer-ids to tokens token = converter.ids2tokens(token_int) # Create a directory: outdir/{n}best_recog ibest_writer = writer[f"{n}best_recog"] # Write the result to each files ibest_writer["token"][key] = " ".join(token) ibest_writer["token_int"][key] = " ".join(map(str, token_int)) ibest_writer["score"][key] = str(hyp.score) if tokenizer is not None: text = tokenizer.tokens2text(token) ibest_writer["text"][key] = text
class MaskCTCInference(torch.nn.Module): """Mask-CTC-based non-autoregressive inference""" def __init__( self, asr_model: MaskCTCModel, n_iterations: int, threshold_probability: float, ): """Initialize Mask-CTC inference""" super().__init__() self.ctc = asr_model.ctc self.mlm = asr_model.decoder self.mask_token = asr_model.mask_token self.n_iterations = n_iterations self.threshold_probability = threshold_probability self.converter = TokenIDConverter(token_list=asr_model.token_list) def ids2text(self, ids: List[int]): text = "".join(self.converter.ids2tokens(ids)) return text.replace("<mask>", "_").replace("<space>", " ") def forward(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Perform Mask-CTC inference""" # greedy ctc outputs enc_out = enc_out.unsqueeze(0) ctc_probs, ctc_ids = torch.exp( self.ctc.log_softmax(enc_out)).max(dim=-1) y_hat = torch.stack([x[0] for x in groupby(ctc_ids[0])]) y_idx = torch.nonzero(y_hat != 0).squeeze(-1) logging.info("ctc:{}".format(self.ids2text(y_hat[y_idx].tolist()))) # calculate token-level ctc probabilities by taking # the maximum probability of consecutive frames with # the same ctc symbols probs_hat = [] cnt = 0 for i, y in enumerate(y_hat.tolist()): probs_hat.append(-1) while cnt < ctc_ids.shape[1] and y == ctc_ids[0][cnt]: if probs_hat[i] < ctc_probs[0][cnt]: probs_hat[i] = ctc_probs[0][cnt].item() cnt += 1 probs_hat = torch.from_numpy(numpy.array(probs_hat)) # mask ctc outputs based on ctc probabilities p_thres = self.threshold_probability mask_idx = torch.nonzero(probs_hat[y_idx] < p_thres).squeeze(-1) confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1) mask_num = len(mask_idx) y_in = torch.zeros(1, len(y_idx), dtype=torch.long) + self.mask_token y_in[0][confident_idx] = y_hat[y_idx][confident_idx] logging.info("msk:{}".format(self.ids2text(y_in[0].tolist()))) # iterative decoding if not mask_num == 0: K = self.n_iterations num_iter = K if mask_num >= K and K > 0 else mask_num for t in range(num_iter - 1): pred, _ = self.mlm(enc_out, [enc_out.size(1)], y_in, [y_in.size(1)]) pred_score, pred_id = pred[0][mask_idx].max(dim=-1) cand = torch.topk(pred_score, mask_num // num_iter, -1)[1] y_in[0][mask_idx[cand]] = pred_id[cand] mask_idx = torch.nonzero( y_in[0] == self.mask_token).squeeze(-1) logging.info("msk:{}".format(self.ids2text(y_in[0].tolist()))) # predict leftover masks (|masks| < mask_num // num_iter) pred, _ = self.mlm(enc_out, [enc_out.size(1)], y_in, [y_in.size(1)]) y_in[0][mask_idx] = pred[0][mask_idx].argmax(dim=-1) logging.info("msk:{}".format(self.ids2text(y_in[0].tolist()))) # pad with mask tokens to ensure compatibility with sos/eos tokens yseq = torch.tensor([self.mask_token] + y_in.tolist()[0] + [self.mask_token], device=y_in.device) return Hypothesis(yseq=yseq)
def test_idstokens(): converter = TokenIDConverter(["a", "b", "c", "<unk>"]) assert converter.ids2tokens([0, 1, 2]) == ["a", "b", "c"]
def test_input_2dim_array(): converter = TokenIDConverter(["a", "b", "c", "<unk>"]) with pytest.raises(ValueError): converter.ids2tokens(np.random.randn(2, 2))
class ASR(object): def __init__( self, zip_model_file: Union[Path, str], ) -> None: self.zip_model_file = abspath(zip_model_file) self.device = 'cpu' self.model = None self.beam_search = None self.tokenizer = None self.converter = None self.global_cmvn = None self.extract_zip_model_file(self.zip_model_file) def extract_zip_model_file(self, zip_model_file: str) -> Dict[str, Any]: """Extrai os dados de um zip contendo o arquivo com o estado do modelo e configurações Args: zip_model_file (str): ZipFile do modelo gerado dos scripts de treinamento Raises: ValueError: Se o arquivo não for correto FileNotFoundError: Se o arquivo zip não contiver os arquivos necessários Returns: Dict[str, Any]: Dicionário do arquivo .yaml utilizado durante o treinamento para carregar o modelo corretamente """ print("Unzipping model") if not zipfile.is_zipfile(zip_model_file): raise ValueError(f"File {zip_model_file} is not a zipfile") else: zipfile.ZipFile(zip_model_file).extractall(dirname(zip_model_file)) check = ['exp', 'meta.yaml'] if not all([x for x in check]): raise FileNotFoundError print("Load yaml file") with open('meta.yaml') as f: meta = yaml.load(f, Loader=yaml.FullLoader) model_stats_file = meta['files']['asr_model_file'] asr_model_config_file = meta['yaml_files']['asr_train_config'] self.model_config = {} with open(asr_model_config_file) as f: self.model_config = yaml.load(f, Loader=yaml.FullLoader) try: self.global_cmvn = self.model_config['normalize_conf'][ 'stats_file'] except KeyError: self.global_cmvn = None print(f'Loading model config from {asr_model_config_file}') print(f'Loading model state from {model_stats_file}') #Build Model print('Building model') self.model, _ = ASRTask.build_model_from_file(asr_model_config_file, model_stats_file, self.device) self.model.to(dtype=getattr(torch, 'float32')).eval() #print("Loading extra modules") self.build_beam_search() self.build_tokenizer() def build_beam_search(self, ctc_weight: float = 0.4, beam_size: int = 1): """Constroi o objeto de decodificação beam_search. Esse objeto faz a decodificação do vetor de embeddings da saída da parte encoder do modelo passando pelos decoders da rede que são o módulo CTC e Transformer ou RNN. Como: Loss = (1-λ)*DecoderLoss + λ*CTCLoss Se ctc_weight=1 apenas o módulo CTC será usado na decodificação Args: ctc_weight (float, optional): Peso dado ao módulo CTC da rede. Defaults to 0.4. beam_size (int, optional): Tamanho do feixe de busca durante a codificação. Defaults to 1. """ scorers = {} ctc = CTCPrefixScorer(ctc=self.model.ctc, eos=self.model.eos) token_list = self.model.token_list scorers.update( decoder=self.model.decoder, ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) #Variáveis com os pesos para cada parte da decodificação #lm referente à modelos de linguagem não são utilizados aqui mas são necessários no objeto weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, lm=1.0, length_bonus=0.0, ) #Cria o objeto beam_search self.beam_search = BeamSearch( beam_size=beam_size, weights=weights, scorers=scorers, sos=self.model.sos, eos=self.model.eos, vocab_size=len(token_list), token_list=token_list, pre_beam_score_key=None if ctc_weight == 1.0 else "full", ) self.beam_search.to(device=self.device, dtype=getattr(torch, 'float32')).eval() for scorer in scorers.values(): if isinstance(scorer, torch.nn.Module): scorer.to(device=self.device, dtype=getattr(torch, 'float32')).eval() def build_tokenizer(self): """Cria um objeto tokenizer para conversão dos tokens inteiros para o dicionário de caracteres correspondente. Caso o modelo possua um modelo BPE de tokenização, ele é utilizado. Se não, apenas a lista de caracteres no arquivo de configuração é usada. """ token_type = self.model_config['token_type'] if token_type == 'bpe': bpemodel = self.model_config['bpemodel'] self.tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: self.tokenizer = build_tokenizer(token_type=token_type) self.converter = TokenIDConverter(token_list=self.model.token_list) def get_layers(self) -> Dict[str, Dict[str, torch.Size]]: """Retorna as camadas nomeadas e os respectivos shapes para todos os módulos da rede. Os módulos são: Encoder: RNN, VGGRNN, TransformerEncoder Decoder: RNN, TransformerDecoder CTC Returns: Dict[str, Dict[str, torch.Size]]: Dicionário de cada módulo com seus respectivos layers e shape """ r = {} r['frontend'] = { x: self.model.frontend.state_dict()[x].shape for x in self.model.frontend.state_dict().keys() } r['specaug'] = { x: self.model.specaug.state_dict()[x].shape for x in self.model.specaug.state_dict().keys() } r['normalize'] = { x: self.model.normalize.state_dict()[x].shape for x in self.model.normalize.state_dict().keys() } r['encoder'] = { x: self.model.encoder.state_dict()[x].shape for x in self.model.encoder.state_dict().keys() } r['decoder'] = { x: self.model.decoder.state_dict()[x].shape for x in self.model.decoder.state_dict().keys() } r['ctc'] = { x: self.model.ctc.state_dict()[x].shape for x in self.model.ctc.state_dict().keys() } return r def frontend(self, audiofile: Union[Path, str, bytes], normalize: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: """Executa o frontend do modelo, transformando as amostras de áudio em parâmetros log mel spectrogram Args: audiofile (Union[Path, str]): arquivo de áudio Returns: Tuple[torch.Tensor, torch.Tensor]: Parâmetros, Tamanho do vetor de parâmetros """ if isinstance(audiofile, str): audio_samples, rate = librosa.load(audiofile, sr=16000) elif isinstance(audiofile, bytes): audio_samples, rate = librosa.core.load(io.BytesIO(audiofile), sr=16000) else: raise ValueError("Failed to load audio file") if isinstance(audio_samples, np.ndarray): audio_samples = torch.tensor(audio_samples) audio_samples = audio_samples.unsqueeze(0).to(getattr( torch, 'float32')) lengths = audio_samples.new_full([1], dtype=torch.long, fill_value=audio_samples.size(1)) features, features_length = self.model.frontend(audio_samples, lengths) if normalize: features, features_length = self.model.normalize( features, features_length) return features, features_length def specaug( self, features: torch.Tensor, features_length: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Executa o módulo specaug, da parte de 'data augmentation'. Útil para visualização apenas. Não é utilizado na inferência, apenas no treinamento. Args: features (torch.Tensor): Parâmetros features_length (torch.Tensor): tamanho do vetor de parâmetros Returns: Tuple[torch.Tensor, torch.Tensor]: Parâmetros com máscaras temporais, em frequência e distoção. Tamanho dos vetores """ return self.model.specaug(features, features_length) def __del__(self) -> None: """Remove os arquivos temporários """ for f in ['exp', 'meta.yaml']: print(f"Removing {f}") ff = join(dirname(self.zip_model_file), f) if exists(ff): if isdir(ff): shutil.rmtree(ff) elif isfile(ff): os.remove(ff) else: raise ValueError("Error ao remover arquivos temporários") @torch.no_grad() def recognize(self, audiofile: Union[Path, str, bytes]) -> Result: result = Result() if isinstance(audiofile, str): audio_samples, rate = librosa.load(audiofile, sr=16000) elif isinstance(audiofile, bytes): audio_samples, rate = librosa.core.load(io.BytesIO(audiofile), sr=16000) else: raise ValueError("Failed to load audio file") result.audio_samples = copy.deepcopy(audio_samples) #a entrada do modelo é torch.tensor if isinstance(audio_samples, np.ndarray): audio_samples = torch.tensor(audio_samples) audio_samples = audio_samples.unsqueeze(0).to(getattr( torch, 'float32')) lengths = audio_samples.new_full([1], dtype=torch.long, fill_value=audio_samples.size(1)) batch = {"speech": audio_samples, "speech_lengths": lengths} batch = to_device(batch, device=self.device) #model encoder enc, _ = self.model.encode(**batch) #model decoder nbest_hyps = self.beam_search(x=enc[0]) #Apenas a melhor hipótese best_hyps = nbest_hyps[0] #Conversão de tokenids do treinamento para texto token_int = best_hyps.yseq[1:-1].tolist() token_int = list(filter(lambda x: x != 0, token_int)) token = self.converter.ids2tokens(token_int) text = self.tokenizer.tokens2text(token) #Preenche o objeto result result.text = text result.encoded_vector = enc[0] #[0] remove dimensão de batch #calcula todas as matrizes de atenção # text_tensor = torch.Tensor(token_int).unsqueeze(0).to( getattr(torch, 'long')) batch["text"] = text_tensor batch["text_lengths"] = text_tensor.new_full( [1], dtype=torch.long, fill_value=text_tensor.size(1)) result.attention_weights = calculate_all_attentions(self.model, batch) result.tokens_txt = token #CTC posteriors logp = self.model.ctc.log_softmax(enc.unsqueeze(0))[0] result.ctc_posteriors = logp.exp_().numpy() result.tokens_int = best_hyps.yseq result.mel_features, _ = self.frontend(audiofile, normalize=False) return result def __call__(self, input: Union[Path, str, bytes]) -> Result: return self.recognize(input)