def test_transducer_beam_search(rnn_type, search_params):
    token_list = ["<blank>", "a", "b", "c"]
    vocab_size = len(token_list)
    beam_size = 1 if search_params["search_type"] == "greedy" else 2

    encoder_output_size = 4
    decoder_output_size = 4

    decoder = TransducerDecoder(vocab_size,
                                hidden_size=decoder_output_size,
                                rnn_type=rnn_type)
    joint_net = JointNetwork(vocab_size,
                             encoder_output_size,
                             decoder_output_size,
                             joint_space_size=2)

    lm = search_params.pop("lm", SequentialRNNLM(vocab_size, rnn_type="lstm"))

    beam = BeamSearchTransducer(
        decoder,
        joint_net,
        beam_size=beam_size,
        lm=lm,
        **search_params,
    )

    enc_out = torch.randn(30, encoder_output_size)

    with torch.no_grad():
        _ = beam(enc_out)
Exemplo n.º 2
0
def test_transducer_beam_search(rnn_type, search_params):
    token_list = ["<blank>", "a", "b", "c", "<sos>"]
    vocab_size = len(token_list)
    beam_size = 1 if search_params["search_type"] == "greedy" else 2

    encoder_output_size = 4
    decoder_output_size = 4

    decoder = TransducerDecoder(vocab_size,
                                hidden_size=decoder_output_size,
                                rnn_type=rnn_type)
    joint_net = JointNetwork(vocab_size,
                             encoder_output_size,
                             decoder_output_size,
                             joint_space_size=2)

    lm = search_params.pop("lm", SequentialRNNLM(vocab_size, rnn_type="lstm"))
    if isinstance(lm, str) and lm == "TransformerLM":
        lm = TransformerLM(vocab_size, pos_enc=None, unit=10, layer=2)

    beam = BeamSearchTransducer(
        decoder,
        joint_net,
        beam_size=beam_size,
        lm=lm,
        token_list=token_list,
        **search_params,
    )

    enc_out = torch.randn(30, encoder_output_size)

    with torch.no_grad():
        _ = beam(enc_out)
Exemplo n.º 3
0
    def __init__(
        self,
        decoder: AbsDecoder,
        joint_network: torch.nn.Module,
        token_list: List[int],
        sym_space: str,
        sym_blank: str,
        report_cer: bool = False,
        report_wer: bool = False,
    ):
        """Construct an ErrorCalculatorTransducer."""
        super().__init__()

        self.beam_search = BeamSearchTransducer(
            decoder=decoder,
            joint_network=joint_network,
            beam_size=2,
            search_type="default",
            score_norm=False,
        )

        self.decoder = decoder

        self.token_list = token_list
        self.space = sym_space
        self.blank = sym_blank

        self.report_cer = report_cer
        self.report_wer = report_wer
Exemplo n.º 4
0
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        transducer_conf: dict = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        ngram_scorer: str = "full",
        ngram_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
        batch_size: int = 1,
        dtype: str = "float32",
        beam_size: int = 20,
        ctc_weight: float = 0.5,
        lm_weight: float = 1.0,
        ngram_weight: float = 0.9,
        penalty: float = 0.0,
        nbest: int = 1,
        streaming: bool = False,
        enh_s2t_task: bool = False,
        quantize_asr_model: bool = False,
        quantize_lm: bool = False,
        quantize_modules: List[str] = ["Linear"],
        quantize_dtype: str = "qint8",
    ):
        assert check_argument_types()

        task = ASRTask if not enh_s2t_task else EnhS2TTask

        if quantize_asr_model or quantize_lm:
            if quantize_dtype == "float16" and torch.__version__ < LooseVersion(
                    "1.5.0"):
                raise ValueError(
                    "float16 dtype for dynamic quantization is not supported with "
                    "torch version < 1.5.0. Switch to qint8 dtype instead.")

        quantize_modules = set(
            [getattr(torch.nn, q) for q in quantize_modules])
        quantize_dtype = getattr(torch, quantize_dtype)

        # 1. Build ASR model
        scorers = {}
        asr_model, asr_train_args = task.build_model_from_file(
            asr_train_config, asr_model_file, device)
        if enh_s2t_task:
            asr_model.inherite_attributes(inherite_s2t_attrs=[
                "ctc",
                "decoder",
                "eos",
                "joint_network",
                "sos",
                "token_list",
                "use_transducer_decoder",
            ])
        asr_model.to(dtype=getattr(torch, dtype)).eval()

        if quantize_asr_model:
            logging.info("Use quantized asr model for decoding.")

            asr_model = torch.quantization.quantize_dynamic(
                asr_model, qconfig_spec=quantize_modules, dtype=quantize_dtype)

        decoder = asr_model.decoder

        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        token_list = asr_model.token_list
        scorers.update(
            decoder=decoder,
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )

        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device)

            if quantize_lm:
                logging.info("Use quantized lm for decoding.")

                lm = torch.quantization.quantize_dynamic(
                    lm, qconfig_spec=quantize_modules, dtype=quantize_dtype)

            scorers["lm"] = lm.lm

        # 3. Build ngram model
        if ngram_file is not None:
            if ngram_scorer == "full":
                from espnet.nets.scorers.ngram import NgramFullScorer

                ngram = NgramFullScorer(ngram_file, token_list)
            else:
                from espnet.nets.scorers.ngram import NgramPartScorer

                ngram = NgramPartScorer(ngram_file, token_list)
        else:
            ngram = None
        scorers["ngram"] = ngram

        # 4. Build BeamSearch object
        if asr_model.use_transducer_decoder:
            beam_search_transducer = BeamSearchTransducer(
                decoder=asr_model.decoder,
                joint_network=asr_model.joint_network,
                beam_size=beam_size,
                lm=scorers["lm"] if "lm" in scorers else None,
                lm_weight=lm_weight,
                **transducer_conf,
            )
            beam_search = None
        else:
            beam_search_transducer = None

            weights = dict(
                decoder=1.0 - ctc_weight,
                ctc=ctc_weight,
                lm=lm_weight,
                ngram=ngram_weight,
                length_bonus=penalty,
            )
            beam_search = BeamSearch(
                beam_size=beam_size,
                weights=weights,
                scorers=scorers,
                sos=asr_model.sos,
                eos=asr_model.eos,
                vocab_size=len(token_list),
                token_list=token_list,
                pre_beam_score_key=None if ctc_weight == 1.0 else "full",
            )

            # TODO(karita): make all scorers batchfied
            if batch_size == 1:
                non_batch = [
                    k for k, v in beam_search.full_scorers.items()
                    if not isinstance(v, BatchScorerInterface)
                ]
                if len(non_batch) == 0:
                    if streaming:
                        beam_search.__class__ = BatchBeamSearchOnlineSim
                        beam_search.set_streaming_config(asr_train_config)
                        logging.info(
                            "BatchBeamSearchOnlineSim implementation is selected."
                        )
                    else:
                        beam_search.__class__ = BatchBeamSearch
                        logging.info(
                            "BatchBeamSearch implementation is selected.")
                else:
                    logging.warning(
                        f"As non-batch scorers {non_batch} are found, "
                        f"fall back to non-batch implementation.")

            beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
            for scorer in scorers.values():
                if isinstance(scorer, torch.nn.Module):
                    scorer.to(device=device, dtype=getattr(torch,
                                                           dtype)).eval()
            logging.info(f"Beam_search: {beam_search}")
            logging.info(f"Decoding device={device}, dtype={dtype}")

        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel

        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
            if bpemodel is not None:
                tokenizer = build_tokenizer(token_type=token_type,
                                            bpemodel=bpemodel)
            else:
                tokenizer = None
        else:
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")

        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
        self.tokenizer = tokenizer
        self.beam_search = beam_search
        self.beam_search_transducer = beam_search_transducer
        self.maxlenratio = maxlenratio
        self.minlenratio = minlenratio
        self.device = device
        self.dtype = dtype
        self.nbest = nbest
Exemplo n.º 5
0
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        transducer_conf: dict = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        ngram_scorer: str = "full",
        ngram_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
        batch_size: int = 1,
        dtype: str = "float32",
        beam_size: int = 20,
        ctc_weight: float = 0.5,
        lm_weight: float = 1.0,
        ngram_weight: float = 0.9,
        penalty: float = 0.0,
        nbest: int = 1,
        streaming: bool = False,
    ):
        assert check_argument_types()

        # 1. Build ASR model
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, device)
        asr_model.to(dtype=getattr(torch, dtype)).eval()

        decoder = asr_model.decoder

        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        token_list = asr_model.token_list
        scorers.update(
            decoder=decoder,
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )

        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device)
            scorers["lm"] = lm.lm

        # 3. Build ngram model
        if ngram_file is not None:
            if ngram_scorer == "full":
                from espnet.nets.scorers.ngram import NgramFullScorer

                ngram = NgramFullScorer(ngram_file, token_list)
            else:
                from espnet.nets.scorers.ngram import NgramPartScorer

                ngram = NgramPartScorer(ngram_file, token_list)
        else:
            ngram = None
        scorers["ngram"] = ngram

        # 4. Build BeamSearch object
        if asr_model.use_transducer_decoder:
            beam_search_transducer = BeamSearchTransducer(
                decoder=asr_model.decoder,
                joint_network=asr_model.joint_network,
                beam_size=beam_size,
                lm=scorers["lm"] if "lm" in scorers else None,
                lm_weight=lm_weight,
                **transducer_conf,
            )
            beam_search = None
        else:
            beam_search_transducer = None

            weights = dict(
                decoder=1.0 - ctc_weight,
                ctc=ctc_weight,
                lm=lm_weight,
                length_bonus=penalty,
            )
            beam_search = BeamSearch(
                beam_size=beam_size,
                weights=weights,
                scorers=scorers,
                sos=asr_model.sos,
                eos=asr_model.eos,
                vocab_size=len(token_list),
                token_list=token_list,
                pre_beam_score_key=None if ctc_weight == 1.0 else "full",
            )

            # TODO(karita): make all scorers batchfied
            if batch_size == 1:
                non_batch = [
                    k for k, v in beam_search.full_scorers.items()
                    if not isinstance(v, BatchScorerInterface)
                ]
                if len(non_batch) == 0:
                    if streaming:
                        beam_search.__class__ = BatchBeamSearchOnlineSim
                        beam_search.set_streaming_config(asr_train_config)
                        logging.info(
                            "BatchBeamSearchOnlineSim implementation is selected."
                        )
                    else:
                        beam_search.__class__ = BatchBeamSearch
                        logging.info(
                            "BatchBeamSearch implementation is selected.")
                else:
                    logging.warning(
                        f"As non-batch scorers {non_batch} are found, "
                        f"fall back to non-batch implementation.")

            beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
            for scorer in scorers.values():
                if isinstance(scorer, torch.nn.Module):
                    scorer.to(device=device, dtype=getattr(torch,
                                                           dtype)).eval()
            logging.info(f"Beam_search: {beam_search}")
            logging.info(f"Decoding device={device}, dtype={dtype}")

        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel

        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
            if bpemodel is not None:
                tokenizer = build_tokenizer(token_type=token_type,
                                            bpemodel=bpemodel)
            else:
                tokenizer = None
        else:
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")

        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
        self.tokenizer = tokenizer
        self.beam_search = beam_search
        self.beam_search_transducer = beam_search_transducer
        self.maxlenratio = maxlenratio
        self.minlenratio = minlenratio
        self.device = device
        self.dtype = dtype
        self.nbest = nbest