Esempio n. 1
0
def asr_config_file(tmp_path: Path, token_list):
    enc_body_conf = (
        "{'body_conf': [{'block_type': 'conformer',"
        " 'hidden_size': 4, 'linear_size': 4,"
        " 'conv_mod_kernel_size': 3}]}"
    )
    decoder_conf = "{'hidden_size': 4}"
    joint_net_conf = "{'joint_space_size': 4}"

    ASRTransducerTask.main(
        cmd=[
            "--dry_run",
            "true",
            "--output_dir",
            str(tmp_path / "asr"),
            "--token_list",
            str(token_list),
            "--token_type",
            "char",
            "--encoder_conf",
            enc_body_conf,
            "--decoder",
            "rnn",
            "--decoder_conf",
            decoder_conf,
            "--joint_network_conf",
            joint_net_conf,
        ]
    )
    return tmp_path / "asr" / "config.yaml"
Esempio n. 2
0
def test_build_model():
    args = get_dummy_namespace()

    _ = ASRTransducerTask.build_model(args)

    with pytest.raises(RuntimeError):
        args.token_list = -1

        _ = ASRTransducerTask.build_model(args)
Esempio n. 3
0
def main(cmd=None):
    r"""ASR Transducer training.

    Example:

        % python asr_transducer_train.py asr --print_config \
                --optim adadelta > conf/train_asr.yaml
        % python asr_transducer_train.py \
                --config conf/tuning/transducer/train_rnn_transducer.yaml
    """
    ASRTransducerTask.main(cmd=cmd)
Esempio n. 4
0
def test_required_data_names(inference):
    retval = ASRTransducerTask.required_data_names(True, inference)

    if inference:
        assert retval == ("speech", )
    else:
        assert retval == ("speech", "text")
Esempio n. 5
0
def test_build_preprocess_fn(use_preprocessor):
    args = get_dummy_namespace()

    new_args = {
        "use_preprocessor": use_preprocessor,
        "bpemodel": None,
        "non_linguistic_symbols": args.token_list,
        "cleaner": None,
        "g2p": None,
    }
    args.__dict__.update(new_args)

    _ = ASRTransducerTask.build_preprocess_fn(args, True)
Esempio n. 6
0
def asr_stream_config_file(request, tmp_path: Path, token_list):
    enc_body_conf = (
        "{'body_conf': [{'block_type': 'conformer', 'hidden_size': 4, "
        "'linear_size': 4, 'conv_mod_kernel_size': 3},"
        "{'block_type': 'conv1d', 'kernel_size': 2, 'output_size': 2, "
        "'batch_norm': True, 'relu': True}], "
        "'main_conf': {'dynamic_chunk_training': True',"
        "'short_chunk_size': 1, 'left_chunk_size': 1}}"
    )

    if request.param == "vgg":
        enc_body_conf = enc_body_conf[:-1] + (",'input_conf': {'vgg_like': True}}")

    decoder_conf = "{'hidden_size': 4}"
    joint_net_conf = "{'joint_space_size': 4}"

    ASRTransducerTask.main(
        cmd=[
            "--dry_run",
            "true",
            "--output_dir",
            str(tmp_path / "asr_stream"),
            "--token_list",
            str(token_list),
            "--token_type",
            "char",
            "--encoder_conf",
            enc_body_conf,
            "--decoder",
            "rnn",
            "--decoder_conf",
            decoder_conf,
            "--joint_network_conf",
            joint_net_conf,
        ]
    )
    return tmp_path / "asr_stream" / "config.yaml"
Esempio n. 7
0
def test_build_collate_fn():
    args = get_dummy_namespace()

    _ = ASRTransducerTask.build_collate_fn(args, True)
Esempio n. 8
0
def test_add_arguments():
    ASRTransducerTask.get_parser()
Esempio n. 9
0
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        beam_search_config: Dict[str, Any] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        beam_size: int = 5,
        dtype: str = "float32",
        lm_weight: float = 1.0,
        quantize_asr_model: bool = False,
        quantize_modules: List[str] = None,
        quantize_dtype: str = "qint8",
        nbest: int = 1,
        streaming: bool = False,
        chunk_size: int = 16,
        left_context: int = 32,
        right_context: int = 0,
        display_partial_hypotheses: bool = False,
    ) -> None:
        assert check_argument_types()

        asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
            asr_train_config, asr_model_file, device)

        if quantize_asr_model:
            if quantize_modules is not None:
                if not all([q in ["LSTM", "Linear"]
                            for q in quantize_modules]):
                    raise ValueError(
                        "Only 'Linear' and 'LSTM' modules are currently supported"
                        " by PyTorch and in --quantize_modules")

                q_config = set(
                    [getattr(torch.nn, q) for q in quantize_modules])
            else:
                q_config = {torch.nn.Linear}

            if quantize_dtype == "float16" and (V(torch.__version__) <
                                                V("1.5.0")):
                raise ValueError(
                    "float16 dtype for dynamic quantization is not supported with torch"
                    " version < 1.5.0. Switching to qint8 dtype instead.")
            q_dtype = getattr(torch, quantize_dtype)

            asr_model = torch.quantization.quantize_dynamic(
                asr_model, q_config, dtype=q_dtype).eval()
        else:
            asr_model.to(dtype=getattr(torch, dtype)).eval()

        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device)
            lm_scorer = lm.lm
        else:
            lm_scorer = None

        # 4. Build BeamSearch object
        if beam_search_config is None:
            beam_search_config = {}

        beam_search = BeamSearchTransducer(
            asr_model.decoder,
            asr_model.joint_network,
            beam_size,
            lm=lm_scorer,
            lm_weight=lm_weight,
            nbest=nbest,
            **beam_search_config,
        )

        token_list = asr_model.token_list

        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel

        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
            if bpemodel is not None:
                tokenizer = build_tokenizer(token_type=token_type,
                                            bpemodel=bpemodel)
            else:
                tokenizer = None
        else:
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")

        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.device = device
        self.dtype = dtype
        self.nbest = nbest

        self.converter = converter
        self.tokenizer = tokenizer

        self.beam_search = beam_search
        self.streaming = streaming
        self.chunk_size = max(chunk_size, 0)
        self.left_context = max(left_context, 0)
        self.right_context = max(right_context, 0)

        if not streaming or chunk_size == 0:
            self.streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False

        self.n_fft = asr_train_args.frontend_conf.get("n_fft", 512)
        self.hop_length = asr_train_args.frontend_conf.get("hop_length", 128)

        if asr_train_args.frontend_conf.get("win_length", None) is not None:
            self.frontend_window_size = asr_train_args.frontend_conf[
                "win_length"]
        else:
            self.frontend_window_size = self.n_fft

        self.window_size = self.chunk_size + self.right_context
        self._raw_ctx = self.asr_model.encoder.get_encoder_input_raw_size(
            self.window_size, self.hop_length)

        self.last_chunk_length = (self.asr_model.encoder.embed.min_frame_length
                                  + self.right_context + 1) * self.hop_length

        self.reset_inference_cache()
Esempio n. 10
0
def test_main_with_no_args():
    with pytest.raises(SystemExit):
        ASRTransducerTask.main(cmd=[])
Esempio n. 11
0
def test_main_print_config():
    with pytest.raises(SystemExit):
        ASRTransducerTask.main(cmd=["--print_config"])
Esempio n. 12
0
def test_main_help():
    with pytest.raises(SystemExit):
        ASRTransducerTask.main(cmd=["--help"])
Esempio n. 13
0
def test_optional_data_names():
    retval = ASRTransducerTask.optional_data_names(True, True)

    assert not retval
Esempio n. 14
0
def test_add_arguments_help():
    parser = ASRTransducerTask.get_parser()
    with pytest.raises(SystemExit):
        parser.parse_args(["--help"])
Esempio n. 15
0
def test_print_config_and_load_it(tmp_path):
    config_file = tmp_path / "config.yaml"
    with config_file.open("w") as f:
        ASRTransducerTask.print_config(f)
    parser = ASRTransducerTask.get_parser()
    parser.parse_args(["--config", str(config_file)])
Esempio n. 16
0
def get_parser():
    """Get parser for ASR Transducer task."""
    parser = ASRTransducerTask.get_parser()
    return parser
Esempio n. 17
0
def inference(
    output_dir: str,
    batch_size: int,
    dtype: str,
    beam_size: int,
    ngpu: int,
    seed: int,
    lm_weight: float,
    nbest: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    beam_search_config: Optional[dict],
    lm_train_config: Optional[str],
    lm_file: Optional[str],
    model_tag: Optional[str],
    token_type: Optional[str],
    bpemodel: Optional[str],
    key_file: Optional[str],
    allow_variable_data_keys: bool,
    quantize_asr_model: Optional[bool],
    quantize_modules: Optional[List[str]],
    quantize_dtype: Optional[str],
    streaming: Optional[bool],
    chunk_size: Optional[int],
    left_context: Optional[int],
    right_context: Optional[int],
    display_partial_hypotheses: bool,
) -> None:
    """Transducer model inference.

    Args:
        output_dir: Output directory path.
        batch_size: Batch decoding size.
        dtype: Data type.
        beam_size: Beam size.
        ngpu: Number of GPUs.
        seed: Random number generator seed.
        lm_weight: Weight of language model.
        nbest: Number of final hypothesis.
        num_workers: Number of workers.
        log_level: Level of verbose for logs.
        data_path_and_name_and_type:
        asr_train_config: ASR model training config path.
        asr_model_file: ASR model path.
        beam_search_config: Beam search config path.
        lm_train_config: Language Model training config path.
        lm_file: Language Model path.
        model_tag: Model tag.
        token_type: Type of token units.
        bpemodel: BPE model path.
        key_file: File key.
        allow_variable_data_keys: Whether to allow variable data keys.
        quantize_asr_model: Whether to apply dynamic quantization to ASR model.
        quantize_modules: List of module names to apply dynamic quantization on.
        quantize_dtype: Dynamic quantization data type.
        streaming: Whether to perform chunk-by-chunk inference.
        chunk_size: Number of frames in chunk AFTER subsampling.
        left_context: Number of frames in left context AFTER subsampling.
        right_context: Number of frames in right context AFTER subsampling.
        display_partial_hypotheses: Whether to display partial hypotheses.

    """
    assert check_argument_types()

    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
        asr_model_file=asr_model_file,
        beam_search_config=beam_search_config,
        lm_train_config=lm_train_config,
        lm_file=lm_file,
        token_type=token_type,
        bpemodel=bpemodel,
        device=device,
        dtype=dtype,
        beam_size=beam_size,
        lm_weight=lm_weight,
        nbest=nbest,
        quantize_asr_model=quantize_asr_model,
        quantize_modules=quantize_modules,
        quantize_dtype=quantize_dtype,
        streaming=streaming,
        chunk_size=chunk_size,
        left_context=left_context,
        right_context=right_context,
    )
    speech2text = Speech2Text.from_pretrained(
        model_tag=model_tag,
        **speech2text_kwargs,
    )

    # 3. Build data-iterator
    loader = ASRTransducerTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=ASRTransducerTask.build_preprocess_fn(
            speech2text.asr_train_args, False),
        collate_fn=ASRTransducerTask.build_collate_fn(
            speech2text.asr_train_args, False),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    # 4 .Start for-loop
    with DatadirWriter(output_dir) as writer:
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys

            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            batch = {
                k: v[0]
                for k, v in batch.items() if not k.endswith("_lengths")
            }
            assert len(batch.keys()) == 1

            try:
                if speech2text.streaming:
                    speech = batch["speech"]

                    _steps = len(speech) // speech2text._raw_ctx
                    _end = 0

                    for i in range(_steps):
                        _end = (i + 1) * speech2text._raw_ctx

                        speech2text.streaming_decode(
                            speech[i * speech2text._raw_ctx:_end],
                            is_final=False)

                    final_hyps = speech2text.streaming_decode(
                        speech[_end:len(speech)], is_final=True)
                else:
                    final_hyps = speech2text(**batch)

                results = speech2text.hypotheses_to_results(final_hyps)
            except TooShortUttError as e:
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
                results = [[" ", ["<space>"], [2], hyp]] * nbest

            key = keys[0]
            for n, (text, token, token_int,
                    hyp) in zip(range(1, nbest + 1), results):
                ibest_writer = writer[f"{n}best_recog"]

                ibest_writer["token"][key] = " ".join(token)
                ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                ibest_writer["score"][key] = str(hyp.score)

                if text is not None:
                    ibest_writer["text"][key] = text