示例#1
0
def test_transducer_beam_search(rnn_type, search_params):
    token_list = ["<blank>", "a", "b", "c", "<sos>"]
    vocab_size = len(token_list)
    beam_size = 1 if search_params["search_type"] == "greedy" else 2

    encoder_output_size = 4
    decoder_output_size = 4

    decoder = TransducerDecoder(vocab_size,
                                hidden_size=decoder_output_size,
                                rnn_type=rnn_type)
    joint_net = JointNetwork(vocab_size,
                             encoder_output_size,
                             decoder_output_size,
                             joint_space_size=2)

    lm = search_params.pop("lm", SequentialRNNLM(vocab_size, rnn_type="lstm"))
    if isinstance(lm, str) and lm == "TransformerLM":
        lm = TransformerLM(vocab_size, pos_enc=None, unit=10, layer=2)

    beam = BeamSearchTransducer(
        decoder,
        joint_net,
        beam_size=beam_size,
        lm=lm,
        token_list=token_list,
        **search_params,
    )

    enc_out = torch.randn(30, encoder_output_size)

    with torch.no_grad():
        _ = beam(enc_out)
def test_transducer_error_calculator(report_opts):
    token_list = ["<blank>", "a", "b", "c", "<space>"]
    vocab_size = len(token_list)

    encoder_output_size = 4
    decoder_output_size = 4

    decoder = TransducerDecoder(
        vocab_size,
        hidden_size=decoder_output_size,
    )
    joint_net = JointNetwork(vocab_size,
                             encoder_output_size,
                             decoder_output_size,
                             joint_space_size=2)

    error_calc = ErrorCalculatorTransducer(
        decoder,
        joint_net,
        token_list,
        "<space>",
        "<blank>",
        **report_opts,
    )

    enc_out = torch.randn(4, 30, encoder_output_size)
    target = torch.randint(0, vocab_size, [4, 20], dtype=torch.int32)

    with torch.no_grad():
        _, _ = error_calc(enc_out, target)
示例#3
0
def test_integer_parameters_limits(search_opts):
    vocab_size = 4
    encoder_size = 4

    decoder = StatelessDecoder(vocab_size, embed_size=4)
    joint_net = JointNetwork(vocab_size, encoder_size, 4, joint_space_size=2)

    with pytest.raises(AssertionError):
        _ = BeamSearchTransducer(
            decoder,
            joint_net,
            **search_opts,
        )
示例#4
0
def test_recombine_hyps():
    decoder = StatelessDecoder(4, embed_size=4)
    joint_net = JointNetwork(4, 4, 4, joint_space_size=2)
    beam_search = BeamSearchTransducer(decoder, joint_net, 2)

    test_hyp = [
        Hypothesis(score=0.0, yseq=[0, 1, 2], dec_state=None),
        Hypothesis(score=12.0, yseq=[0, 1, 2], dec_state=None),
    ]

    final = beam_search.recombine_hyps(test_hyp)

    assert len(final) == 1
    assert final[0].score == np.logaddexp(0.0, 12.0)
示例#5
0
def test_model_training(enc_params, enc_gen_params, dec_params,
                        joint_net_params, main_params, stats_file):
    batch_size = 2
    input_size = 10

    token_list = ["<blank>", "a", "b", "c", "<space>"]
    vocab_size = len(token_list)

    encoder = Encoder(input_size, enc_params, main_conf=enc_gen_params)
    decoder = get_decoder(vocab_size, dec_params)

    joint_network = JointNetwork(vocab_size, encoder.output_size,
                                 decoder.output_size, **joint_net_params)

    specaug = get_specaug() if main_params.pop("specaug", False) else None
    #    normalize = get_normalize() if main_params.pop("normalize", False) else None

    normalize = main_params.pop("normalize", None)
    if normalize is not None:
        if normalize == "utterance":
            normalize = UtteranceMVN(norm_means=True,
                                     norm_vars=True,
                                     eps=1e-13)
        else:
            normalize = GlobalMVN(stats_file, norm_means=True, norm_vars=True)

    model = ESPnetASRTransducerModel(
        vocab_size,
        token_list,
        frontend=None,
        specaug=specaug,
        normalize=normalize,
        encoder=encoder,
        decoder=decoder,
        joint_network=joint_network,
        **main_params,
    )

    feats, labels, feat_len, label_len = prepare(model, input_size, vocab_size,
                                                 batch_size)

    _ = model(feats, feat_len, labels, label_len)

    if main_params.get("report_cer") or main_params.get("report_wer"):
        model.training = False

        _ = model(feats, feat_len, labels, label_len)
示例#6
0
def test_activation(act_type, act_params):
    batch_size = 2
    input_size = 10

    token_list = ["<blank>", "a", "b", "c", "<space>"]
    vocab_size = len(token_list)

    encoder = Encoder(
        input_size,
        [
            {
                "block_type": "conformer",
                "hidden_size": 8,
                "linear_size": 4,
                "conv_mod_kernel_size": 3,
            }
        ],
        main_conf=act_params,
    )
    decoder = StatelessDecoder(vocab_size, embed_size=4)

    joint_network = JointNetwork(
        vocab_size,
        encoder.output_size,
        decoder.output_size,
        joint_activation_type=act_type,
        **act_params,
    )

    model = ESPnetASRTransducerModel(
        vocab_size,
        token_list,
        frontend=None,
        specaug=None,
        normalize=None,
        encoder=encoder,
        decoder=decoder,
        joint_network=joint_network,
    )

    feats, labels, feat_len, label_len = prepare(
        model, input_size, vocab_size, batch_size
    )

    _ = model(feats, feat_len, labels, label_len)
示例#7
0
def test_collect_feats(extract_feats):
    token_list = ["<blank>", "a", "b", "c", "<space>"]
    vocab_size = len(token_list)

    encoder = Encoder(
        20,
        [{
            "block_type": "conformer",
            "hidden_size": 4,
            "linear_size": 4,
            "conv_mod_kernel_size": 3,
        }],
    )
    decoder = StatelessDecoder(vocab_size, embed_size=4)

    joint_network = JointNetwork(vocab_size, encoder.output_size,
                                 decoder.output_size, 8)

    model = ESPnetASRTransducerModel(
        vocab_size,
        token_list,
        frontend=None,
        specaug=None,
        normalize=None,
        encoder=encoder,
        decoder=decoder,
        joint_network=joint_network,
    )
    model.extract_feats_in_collect_stats = extract_feats

    feats_dict = model.collect_feats(
        torch.randn(2, 12),
        torch.tensor([12, 11]),
        torch.randn(2, 8),
        torch.tensor([8, 8]),
    )

    assert set(("feats", "feats_lengths")) == feats_dict.keys()
示例#8
0
def test_transducer_beam_search(decoder_class, decoder_opts, search_opts):
    token_list = ["<blank>", "a", "b", "c"]
    vocab_size = len(token_list)

    encoder_size = 4

    decoder = decoder_class(vocab_size, embed_size=4, **decoder_opts)
    joint_net = JointNetwork(vocab_size, encoder_size, 4, joint_space_size=2)

    lm = search_opts.pop(
        "lm", SequentialRNNLM(vocab_size, unit=8, nlayers=1, rnn_type="lstm"))

    beam = BeamSearchTransducer(
        decoder,
        joint_net,
        beam_size=2,
        lm=lm,
        **search_opts,
    )

    enc_out = torch.randn(30, encoder_size)

    with torch.no_grad():
        _ = beam(enc_out)
示例#9
0
文件: asr.py 项目: espnet/espnet
    def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel:
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
                token_list = [line.rstrip() for line in f]

            # Overwriting token_list to keep it as "portable".
            args.token_list = list(token_list)
        elif isinstance(args.token_list, (tuple, list)):
            token_list = list(args.token_list)
        else:
            raise RuntimeError("token_list must be str or list")
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size }")

        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size

        # 2. Data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
            specaug = specaug_class(**args.specaug_conf)
        else:
            specaug = None

        # 3. Normalization layer
        if args.normalize is not None:
            normalize_class = normalize_choices.get_class(args.normalize)
            normalize = normalize_class(**args.normalize_conf)
        else:
            normalize = None

        # 4. Pre-encoder input block
        # NOTE(kan-bayashi): Use getattr to keep the compatibility
        if getattr(args, "preencoder", None) is not None:
            preencoder_class = preencoder_choices.get_class(args.preencoder)
            preencoder = preencoder_class(**args.preencoder_conf)
            input_size = preencoder.output_size()
        else:
            preencoder = None

        # 4. Encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(input_size=input_size, **args.encoder_conf)

        # 5. Post-encoder block
        # NOTE(kan-bayashi): Use getattr to keep the compatibility
        encoder_output_size = encoder.output_size()
        if getattr(args, "postencoder", None) is not None:
            postencoder_class = postencoder_choices.get_class(args.postencoder)
            postencoder = postencoder_class(
                input_size=encoder_output_size, **args.postencoder_conf
            )
            encoder_output_size = postencoder.output_size()
        else:
            postencoder = None

        # 5. Decoder
        decoder_class = decoder_choices.get_class(args.decoder)

        if args.decoder == "transducer":
            decoder = decoder_class(
                vocab_size,
                embed_pad=0,
                **args.decoder_conf,
            )

            joint_network = JointNetwork(
                vocab_size,
                encoder.output_size(),
                decoder.dunits,
                **args.joint_net_conf,
            )
        else:
            decoder = decoder_class(
                vocab_size=vocab_size,
                encoder_output_size=encoder_output_size,
                **args.decoder_conf,
            )

            joint_network = None

        # 6. CTC
        ctc = CTC(
            odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
        )

        # 7. Build model
        try:
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("espnet")
        model = model_class(
            vocab_size=vocab_size,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            preencoder=preencoder,
            encoder=encoder,
            postencoder=postencoder,
            decoder=decoder,
            ctc=ctc,
            joint_network=joint_network,
            token_list=token_list,
            **args.model_conf,
        )

        # FIXME(kamo): Should be done in model?
        # 8. Initialize
        if args.init is not None:
            initialize(model, args.init)

        assert check_return_type(model)
        return model
示例#10
0
    def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
        """Required data depending on task mode.

        Args:
            cls: ASRTransducerTask object.
            args: Task arguments.

        Return:
            model: ASR Transducer model.

        """
        assert check_argument_types()

        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
                token_list = [line.rstrip() for line in f]

            # Overwriting token_list to keep it as "portable".
            args.token_list = list(token_list)
        elif isinstance(args.token_list, (tuple, list)):
            token_list = list(args.token_list)
        else:
            raise RuntimeError("token_list must be str or list")
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size }")

        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            frontend = None
            input_size = args.input_size

        # 2. Data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
            specaug = specaug_class(**args.specaug_conf)
        else:
            specaug = None

        # 3. Normalization layer
        if args.normalize is not None:
            normalize_class = normalize_choices.get_class(args.normalize)
            normalize = normalize_class(**args.normalize_conf)
        else:
            normalize = None

        # 4. Encoder
        encoder = Encoder(input_size, **args.encoder_conf)
        encoder_output_size = encoder.output_size

        # 5. Decoder
        decoder_class = decoder_choices.get_class(args.decoder)
        decoder = decoder_class(
            vocab_size,
            **args.decoder_conf,
        )
        decoder_output_size = decoder.output_size

        # 6. Joint Network
        joint_network = JointNetwork(
            vocab_size,
            encoder_output_size,
            decoder_output_size,
            **args.joint_network_conf,
        )

        # 7. Build model
        model = ESPnetASRTransducerModel(
            vocab_size=vocab_size,
            token_list=token_list,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            encoder=encoder,
            decoder=decoder,
            joint_network=joint_network,
            **args.model_conf,
        )

        # 8. Initialize
        if args.init is not None:
            raise NotImplementedError(
                "Currently not supported.",
                "Initialization will be reworked in a short future.",
            )

        assert check_return_type(model)

        return model