Exemplo n.º 1
0
def inference(
    output_dir: str,
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    dtype: str,
    beam_size: int,
    ngpu: int,
    seed: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    nbest: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    asr_train_config: str,
    asr_model_file: str,
    lm_train_config: Optional[str],
    lm_file: Optional[str],
    word_lm_train_config: Optional[str],
    word_lm_file: Optional[str],
    blank_symbol: str,
    token_type: Optional[str],
    bpemodel: Optional[str],
    allow_variable_data_keys: bool,
):
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if word_lm_train_config is not None:
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build ASR model
    scorers = {}
    asr_model, asr_train_args = ASRTask.build_model_from_file(
        asr_train_config, asr_model_file, device)
    asr_model.eval()

    decoder = asr_model.decoder
    ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
    token_list = asr_model.token_list
    scorers.update(
        decoder=decoder,
        ctc=ctc,
        length_bonus=LengthBonus(len(token_list)),
    )

    # 3. Build Language model
    if lm_train_config is not None:
        lm, lm_train_args = LMTask.build_model_from_file(
            lm_train_config, lm_file, device)
        scorers["lm"] = lm.lm

    # 4. Build BeamSearch object
    weights = dict(
        decoder=1.0 - ctc_weight,
        ctc=ctc_weight,
        lm=lm_weight,
        length_bonus=penalty,
    )
    beam_search = BeamSearch(
        beam_size=beam_size,
        weights=weights,
        scorers=scorers,
        sos=asr_model.sos,
        eos=asr_model.eos,
        vocab_size=len(token_list),
        token_list=token_list,
    )
    beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
    for scorer in scorers.values():
        if isinstance(scorer, torch.nn.Module):
            scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
    logging.info(f"Beam_search: {beam_search}")
    logging.info(f"Decoding device={device}, dtype={dtype}")

    # 5. Build data-iterator
    loader = ASRTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=ASRTask.build_preprocess_fn(asr_train_args, False),
        collate_fn=ASRTask.build_collate_fn(asr_train_args),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    # 6. [Optional] Build Text converter: e.g. bpe-sym -> Text
    if token_type is None:
        token_type = asr_train_args.token_type
    if bpemodel is None:
        bpemodel = asr_train_args.bpemodel

    if token_type is None:
        tokenizer = None
    elif token_type == "bpe":
        if bpemodel is not None:
            tokenizer = build_tokenizer(token_type=token_type,
                                        bpemodel=bpemodel)
        else:
            tokenizer = None
    else:
        tokenizer = build_tokenizer(token_type=token_type)
    converter = TokenIDConverter(token_list=token_list)
    logging.info(f"Text tokenizer: {tokenizer}")

    # 7 .Start for-loop
    # FIXME(kamo): The output format should be discussed about
    with DatadirWriter(output_dir) as writer:
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"

            with torch.no_grad():
                # a. To device
                batch = to_device(batch, device)

                # b. Forward Encoder
                enc, _ = asr_model.encode(**batch)
                assert len(enc) == batch_size, len(enc)

                # c. Passed the encoder result and the beam search
                nbest_hyps = beam_search(x=enc[0],
                                         maxlenratio=maxlenratio,
                                         minlenratio=minlenratio)
                nbest_hyps = nbest_hyps[:nbest]

            # Only supporting batch_size==1
            key = keys[0]
            for n in range(1, nbest + 1):
                hyp = nbest_hyps[n - 1]
                assert isinstance(hyp, Hypothesis), type(hyp)

                # remove sos/eos and get results
                token_int = hyp.yseq[1:-1].tolist()

                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != 0, token_int))

                # Change integer-ids to tokens
                token = converter.ids2tokens(token_int)

                # Create a directory: outdir/{n}best_recog
                ibest_writer = writer[f"{n}best_recog"]

                # Write the result to each files
                ibest_writer["token"][key] = " ".join(token)
                ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                ibest_writer["score"][key] = str(hyp.score)

                if tokenizer is not None:
                    text = tokenizer.tokens2text(token)
                    ibest_writer["text"][key] = text
Exemplo n.º 2
0
class MaskCTCInference(torch.nn.Module):
    """Mask-CTC-based non-autoregressive inference"""
    def __init__(
        self,
        asr_model: MaskCTCModel,
        n_iterations: int,
        threshold_probability: float,
    ):
        """Initialize Mask-CTC inference"""
        super().__init__()
        self.ctc = asr_model.ctc
        self.mlm = asr_model.decoder
        self.mask_token = asr_model.mask_token
        self.n_iterations = n_iterations
        self.threshold_probability = threshold_probability
        self.converter = TokenIDConverter(token_list=asr_model.token_list)

    def ids2text(self, ids: List[int]):
        text = "".join(self.converter.ids2tokens(ids))
        return text.replace("<mask>", "_").replace("<space>", " ")

    def forward(self, enc_out: torch.Tensor) -> List[Hypothesis]:
        """Perform Mask-CTC inference"""
        # greedy ctc outputs
        enc_out = enc_out.unsqueeze(0)
        ctc_probs, ctc_ids = torch.exp(
            self.ctc.log_softmax(enc_out)).max(dim=-1)
        y_hat = torch.stack([x[0] for x in groupby(ctc_ids[0])])
        y_idx = torch.nonzero(y_hat != 0).squeeze(-1)

        logging.info("ctc:{}".format(self.ids2text(y_hat[y_idx].tolist())))

        # calculate token-level ctc probabilities by taking
        # the maximum probability of consecutive frames with
        # the same ctc symbols
        probs_hat = []
        cnt = 0
        for i, y in enumerate(y_hat.tolist()):
            probs_hat.append(-1)
            while cnt < ctc_ids.shape[1] and y == ctc_ids[0][cnt]:
                if probs_hat[i] < ctc_probs[0][cnt]:
                    probs_hat[i] = ctc_probs[0][cnt].item()
                cnt += 1
        probs_hat = torch.from_numpy(numpy.array(probs_hat))

        # mask ctc outputs based on ctc probabilities
        p_thres = self.threshold_probability
        mask_idx = torch.nonzero(probs_hat[y_idx] < p_thres).squeeze(-1)
        confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1)
        mask_num = len(mask_idx)

        y_in = torch.zeros(1, len(y_idx), dtype=torch.long) + self.mask_token
        y_in[0][confident_idx] = y_hat[y_idx][confident_idx]

        logging.info("msk:{}".format(self.ids2text(y_in[0].tolist())))

        # iterative decoding
        if not mask_num == 0:
            K = self.n_iterations
            num_iter = K if mask_num >= K and K > 0 else mask_num

            for t in range(num_iter - 1):
                pred, _ = self.mlm(enc_out, [enc_out.size(1)], y_in,
                                   [y_in.size(1)])
                pred_score, pred_id = pred[0][mask_idx].max(dim=-1)
                cand = torch.topk(pred_score, mask_num // num_iter, -1)[1]
                y_in[0][mask_idx[cand]] = pred_id[cand]
                mask_idx = torch.nonzero(
                    y_in[0] == self.mask_token).squeeze(-1)

                logging.info("msk:{}".format(self.ids2text(y_in[0].tolist())))

            # predict leftover masks (|masks| < mask_num // num_iter)
            pred, _ = self.mlm(enc_out, [enc_out.size(1)], y_in,
                               [y_in.size(1)])
            y_in[0][mask_idx] = pred[0][mask_idx].argmax(dim=-1)

            logging.info("msk:{}".format(self.ids2text(y_in[0].tolist())))

        # pad with mask tokens to ensure compatibility with sos/eos tokens
        yseq = torch.tensor([self.mask_token] + y_in.tolist()[0] +
                            [self.mask_token],
                            device=y_in.device)

        return Hypothesis(yseq=yseq)
Exemplo n.º 3
0
def test_idstokens():
    converter = TokenIDConverter(["a", "b", "c", "<unk>"])
    assert converter.ids2tokens([0, 1, 2]) == ["a", "b", "c"]
Exemplo n.º 4
0
def test_input_2dim_array():
    converter = TokenIDConverter(["a", "b", "c", "<unk>"])
    with pytest.raises(ValueError):
        converter.ids2tokens(np.random.randn(2, 2))
Exemplo n.º 5
0
class ASR(object):
    def __init__(
        self,
        zip_model_file: Union[Path, str],
    ) -> None:

        self.zip_model_file = abspath(zip_model_file)
        self.device = 'cpu'
        self.model = None
        self.beam_search = None
        self.tokenizer = None
        self.converter = None
        self.global_cmvn = None
        self.extract_zip_model_file(self.zip_model_file)

    def extract_zip_model_file(self, zip_model_file: str) -> Dict[str, Any]:
        """Extrai os dados de um zip contendo o arquivo com o estado do modelo e configurações

      Args:
          zip_model_file (str): ZipFile do modelo gerado dos scripts de treinamento

      Raises:
          ValueError: Se o arquivo não for correto
          FileNotFoundError: Se o arquivo zip não contiver os arquivos necessários

      Returns:
          Dict[str, Any]: Dicionário do arquivo .yaml utilizado durante o treinamento para carregar o modelo corretamente
      """
        print("Unzipping model")
        if not zipfile.is_zipfile(zip_model_file):
            raise ValueError(f"File {zip_model_file} is not a zipfile")
        else:
            zipfile.ZipFile(zip_model_file).extractall(dirname(zip_model_file))

        check = ['exp', 'meta.yaml']

        if not all([x for x in check]):
            raise FileNotFoundError

        print("Load yaml file")
        with open('meta.yaml') as f:
            meta = yaml.load(f, Loader=yaml.FullLoader)

        model_stats_file = meta['files']['asr_model_file']
        asr_model_config_file = meta['yaml_files']['asr_train_config']

        self.model_config = {}
        with open(asr_model_config_file) as f:
            self.model_config = yaml.load(f, Loader=yaml.FullLoader)
            try:
                self.global_cmvn = self.model_config['normalize_conf'][
                    'stats_file']
            except KeyError:
                self.global_cmvn = None

        print(f'Loading model config from {asr_model_config_file}')
        print(f'Loading model state from {model_stats_file}')

        #Build Model
        print('Building model')
        self.model, _ = ASRTask.build_model_from_file(asr_model_config_file,
                                                      model_stats_file,
                                                      self.device)
        self.model.to(dtype=getattr(torch, 'float32')).eval()

        #print("Loading extra modules")
        self.build_beam_search()
        self.build_tokenizer()

    def build_beam_search(self, ctc_weight: float = 0.4, beam_size: int = 1):
        """Constroi o objeto de decodificação beam_search.

        Esse objeto faz a decodificação do vetor de embeddings da saída da parte encoder
        do modelo passando pelos decoders da rede que são o módulo CTC e Transformer ou RNN.

        Como:
        Loss = (1-λ)*DecoderLoss + λ*CTCLoss 
        Se ctc_weight=1 apenas o módulo CTC será usado na decodificação

        Args:
            ctc_weight (float, optional): Peso dado ao módulo CTC da rede. Defaults to 0.4.
            beam_size (int, optional): Tamanho do feixe de busca durante a codificação. Defaults to 1.
        """
        scorers = {}
        ctc = CTCPrefixScorer(ctc=self.model.ctc, eos=self.model.eos)
        token_list = self.model.token_list
        scorers.update(
            decoder=self.model.decoder,
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )

        #Variáveis com os pesos para cada parte da decodificação
        #lm referente à modelos de linguagem não são utilizados aqui mas são necessários no objeto
        weights = dict(
            decoder=1.0 - ctc_weight,
            ctc=ctc_weight,
            lm=1.0,
            length_bonus=0.0,
        )

        #Cria o objeto beam_search
        self.beam_search = BeamSearch(
            beam_size=beam_size,
            weights=weights,
            scorers=scorers,
            sos=self.model.sos,
            eos=self.model.eos,
            vocab_size=len(token_list),
            token_list=token_list,
            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
        )

        self.beam_search.to(device=self.device,
                            dtype=getattr(torch, 'float32')).eval()
        for scorer in scorers.values():
            if isinstance(scorer, torch.nn.Module):
                scorer.to(device=self.device, dtype=getattr(torch,
                                                            'float32')).eval()

    def build_tokenizer(self):
        """Cria um objeto tokenizer para conversão dos tokens inteiros para o dicionário
        de caracteres correspondente.

        Caso o modelo possua um modelo BPE de tokenização, ele é utilizado. Se não, apenas a lista
        de caracteres no arquivo de configuração é usada.
        """
        token_type = self.model_config['token_type']
        if token_type == 'bpe':
            bpemodel = self.model_config['bpemodel']
            self.tokenizer = build_tokenizer(token_type=token_type,
                                             bpemodel=bpemodel)
        else:
            self.tokenizer = build_tokenizer(token_type=token_type)

        self.converter = TokenIDConverter(token_list=self.model.token_list)

    def get_layers(self) -> Dict[str, Dict[str, torch.Size]]:
        """Retorna as camadas nomeadas e os respectivos shapes para todos os módulos da rede.

        Os módulos são:
            Encoder: RNN, VGGRNN, TransformerEncoder
            Decoder: RNN, TransformerDecoder
            CTC

        Returns:
            Dict[str, Dict[str, torch.Size]]: Dicionário de cada módulo com seus respectivos layers e shape
        """
        r = {}

        r['frontend'] = {
            x: self.model.frontend.state_dict()[x].shape
            for x in self.model.frontend.state_dict().keys()
        }
        r['specaug'] = {
            x: self.model.specaug.state_dict()[x].shape
            for x in self.model.specaug.state_dict().keys()
        }
        r['normalize'] = {
            x: self.model.normalize.state_dict()[x].shape
            for x in self.model.normalize.state_dict().keys()
        }
        r['encoder'] = {
            x: self.model.encoder.state_dict()[x].shape
            for x in self.model.encoder.state_dict().keys()
        }
        r['decoder'] = {
            x: self.model.decoder.state_dict()[x].shape
            for x in self.model.decoder.state_dict().keys()
        }
        r['ctc'] = {
            x: self.model.ctc.state_dict()[x].shape
            for x in self.model.ctc.state_dict().keys()
        }
        return r

    def frontend(self,
                 audiofile: Union[Path, str, bytes],
                 normalize: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
        """Executa o frontend do modelo, transformando as amostras de áudio em parâmetros log mel spectrogram

        Args:
            audiofile (Union[Path, str]): arquivo de áudio

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Parâmetros, Tamanho do vetor de parâmetros
        """
        if isinstance(audiofile, str):
            audio_samples, rate = librosa.load(audiofile, sr=16000)
        elif isinstance(audiofile, bytes):
            audio_samples, rate = librosa.core.load(io.BytesIO(audiofile),
                                                    sr=16000)
        else:
            raise ValueError("Failed to load audio file")

        if isinstance(audio_samples, np.ndarray):
            audio_samples = torch.tensor(audio_samples)
        audio_samples = audio_samples.unsqueeze(0).to(getattr(
            torch, 'float32'))
        lengths = audio_samples.new_full([1],
                                         dtype=torch.long,
                                         fill_value=audio_samples.size(1))
        features, features_length = self.model.frontend(audio_samples, lengths)

        if normalize:
            features, features_length = self.model.normalize(
                features, features_length)

        return features, features_length

    def specaug(
            self, features: torch.Tensor, features_length: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Executa o módulo specaug, da parte de 'data augmentation'.
        Útil para visualização apenas. 
        Não é utilizado na inferência, apenas no treinamento.

        Args:
            features (torch.Tensor): Parâmetros
            features_length (torch.Tensor): tamanho do vetor de parâmetros

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Parâmetros com máscaras temporais, em frequência e distoção. Tamanho dos vetores
        """
        return self.model.specaug(features, features_length)

    def __del__(self) -> None:
        """Remove os arquivos temporários
        """
        for f in ['exp', 'meta.yaml']:
            print(f"Removing {f}")
            ff = join(dirname(self.zip_model_file), f)
            if exists(ff):
                if isdir(ff):
                    shutil.rmtree(ff)
                elif isfile(ff):
                    os.remove(ff)
                else:
                    raise ValueError("Error ao remover arquivos temporários")

    @torch.no_grad()
    def recognize(self, audiofile: Union[Path, str, bytes]) -> Result:

        result = Result()

        if isinstance(audiofile, str):
            audio_samples, rate = librosa.load(audiofile, sr=16000)
        elif isinstance(audiofile, bytes):
            audio_samples, rate = librosa.core.load(io.BytesIO(audiofile),
                                                    sr=16000)
        else:
            raise ValueError("Failed to load audio file")

        result.audio_samples = copy.deepcopy(audio_samples)

        #a entrada do modelo é torch.tensor
        if isinstance(audio_samples, np.ndarray):
            audio_samples = torch.tensor(audio_samples)
        audio_samples = audio_samples.unsqueeze(0).to(getattr(
            torch, 'float32'))

        lengths = audio_samples.new_full([1],
                                         dtype=torch.long,
                                         fill_value=audio_samples.size(1))
        batch = {"speech": audio_samples, "speech_lengths": lengths}
        batch = to_device(batch, device=self.device)

        #model encoder
        enc, _ = self.model.encode(**batch)

        #model decoder
        nbest_hyps = self.beam_search(x=enc[0])

        #Apenas a melhor hipótese
        best_hyps = nbest_hyps[0]

        #Conversão de tokenids do treinamento para texto
        token_int = best_hyps.yseq[1:-1].tolist()
        token_int = list(filter(lambda x: x != 0, token_int))
        token = self.converter.ids2tokens(token_int)
        text = self.tokenizer.tokens2text(token)

        #Preenche o objeto result
        result.text = text
        result.encoded_vector = enc[0]  #[0] remove dimensão de batch

        #calcula todas as matrizes de atenção
        #
        text_tensor = torch.Tensor(token_int).unsqueeze(0).to(
            getattr(torch, 'long'))
        batch["text"] = text_tensor
        batch["text_lengths"] = text_tensor.new_full(
            [1], dtype=torch.long, fill_value=text_tensor.size(1))

        result.attention_weights = calculate_all_attentions(self.model, batch)
        result.tokens_txt = token

        #CTC posteriors
        logp = self.model.ctc.log_softmax(enc.unsqueeze(0))[0]
        result.ctc_posteriors = logp.exp_().numpy()
        result.tokens_int = best_hyps.yseq
        result.mel_features, _ = self.frontend(audiofile, normalize=False)
        return result

    def __call__(self, input: Union[Path, str, bytes]) -> Result:
        return self.recognize(input)