Ejemplo n.º 1
0
def test_encoder(input_conf, body_conf, main_conf):
    input_size = 8

    encoder = Encoder(input_size, body_conf, input_conf=input_conf, main_conf=main_conf)

    sequence = torch.randn(2, 30, input_size, requires_grad=True)
    sequence_len = torch.tensor([30, 18], dtype=torch.long)

    _ = encoder(sequence, sequence_len)
Ejemplo n.º 2
0
def test_wrong_subsampling_factor():
    input_conf = {"block_type": "conv2d", "subsampling_factor": 8}
    body_conf = [
        {
            "block_type": "conformer",
            "hidden_size": 4,
            "linear_size": 2,
            "conv_mod_kernel_size": 1,
        }
    ]

    with pytest.raises(ValueError):
        _ = Encoder(8, body_conf, input_conf=input_conf)
Ejemplo n.º 3
0
def test_model_training(enc_params, enc_gen_params, dec_params,
                        joint_net_params, main_params, stats_file):
    batch_size = 2
    input_size = 10

    token_list = ["<blank>", "a", "b", "c", "<space>"]
    vocab_size = len(token_list)

    encoder = Encoder(input_size, enc_params, main_conf=enc_gen_params)
    decoder = get_decoder(vocab_size, dec_params)

    joint_network = JointNetwork(vocab_size, encoder.output_size,
                                 decoder.output_size, **joint_net_params)

    specaug = get_specaug() if main_params.pop("specaug", False) else None
    #    normalize = get_normalize() if main_params.pop("normalize", False) else None

    normalize = main_params.pop("normalize", None)
    if normalize is not None:
        if normalize == "utterance":
            normalize = UtteranceMVN(norm_means=True,
                                     norm_vars=True,
                                     eps=1e-13)
        else:
            normalize = GlobalMVN(stats_file, norm_means=True, norm_vars=True)

    model = ESPnetASRTransducerModel(
        vocab_size,
        token_list,
        frontend=None,
        specaug=specaug,
        normalize=normalize,
        encoder=encoder,
        decoder=decoder,
        joint_network=joint_network,
        **main_params,
    )

    feats, labels, feat_len, label_len = prepare(model, input_size, vocab_size,
                                                 batch_size)

    _ = model(feats, feat_len, labels, label_len)

    if main_params.get("report_cer") or main_params.get("report_wer"):
        model.training = False

        _ = model(feats, feat_len, labels, label_len)
Ejemplo n.º 4
0
def test_activation(act_type, act_params):
    batch_size = 2
    input_size = 10

    token_list = ["<blank>", "a", "b", "c", "<space>"]
    vocab_size = len(token_list)

    encoder = Encoder(
        input_size,
        [
            {
                "block_type": "conformer",
                "hidden_size": 8,
                "linear_size": 4,
                "conv_mod_kernel_size": 3,
            }
        ],
        main_conf=act_params,
    )
    decoder = StatelessDecoder(vocab_size, embed_size=4)

    joint_network = JointNetwork(
        vocab_size,
        encoder.output_size,
        decoder.output_size,
        joint_activation_type=act_type,
        **act_params,
    )

    model = ESPnetASRTransducerModel(
        vocab_size,
        token_list,
        frontend=None,
        specaug=None,
        normalize=None,
        encoder=encoder,
        decoder=decoder,
        joint_network=joint_network,
    )

    feats, labels, feat_len, label_len = prepare(
        model, input_size, vocab_size, batch_size
    )

    _ = model(feats, feat_len, labels, label_len)
Ejemplo n.º 5
0
def test_too_short_utterance(input_conf, inputs):
    input_size = 20

    body_conf = [
        {
            "block_type": "conformer",
            "hidden_size": 4,
            "linear_size": 2,
            "conv_mod_kernel_size": 3,
        }
    ]

    encoder = Encoder(input_size, body_conf, input_conf=input_conf)

    sequence = torch.randn(len(inputs), inputs[0], input_size, requires_grad=True)
    sequence_len = torch.tensor(inputs, dtype=torch.long)

    with pytest.raises(TooShortUttError):
        _ = encoder(sequence, sequence_len)
Ejemplo n.º 6
0
def test_collect_feats(extract_feats):
    token_list = ["<blank>", "a", "b", "c", "<space>"]
    vocab_size = len(token_list)

    encoder = Encoder(
        20,
        [{
            "block_type": "conformer",
            "hidden_size": 4,
            "linear_size": 4,
            "conv_mod_kernel_size": 3,
        }],
    )
    decoder = StatelessDecoder(vocab_size, embed_size=4)

    joint_network = JointNetwork(vocab_size, encoder.output_size,
                                 decoder.output_size, 8)

    model = ESPnetASRTransducerModel(
        vocab_size,
        token_list,
        frontend=None,
        specaug=None,
        normalize=None,
        encoder=encoder,
        decoder=decoder,
        joint_network=joint_network,
    )
    model.extract_feats_in_collect_stats = extract_feats

    feats_dict = model.collect_feats(
        torch.randn(2, 12),
        torch.tensor([12, 11]),
        torch.randn(2, 8),
        torch.tensor([8, 8]),
    )

    assert set(("feats", "feats_lengths")) == feats_dict.keys()
Ejemplo n.º 7
0
    def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
        """Required data depending on task mode.

        Args:
            cls: ASRTransducerTask object.
            args: Task arguments.

        Return:
            model: ASR Transducer model.

        """
        assert check_argument_types()

        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
                token_list = [line.rstrip() for line in f]

            # Overwriting token_list to keep it as "portable".
            args.token_list = list(token_list)
        elif isinstance(args.token_list, (tuple, list)):
            token_list = list(args.token_list)
        else:
            raise RuntimeError("token_list must be str or list")
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size }")

        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            frontend = None
            input_size = args.input_size

        # 2. Data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
            specaug = specaug_class(**args.specaug_conf)
        else:
            specaug = None

        # 3. Normalization layer
        if args.normalize is not None:
            normalize_class = normalize_choices.get_class(args.normalize)
            normalize = normalize_class(**args.normalize_conf)
        else:
            normalize = None

        # 4. Encoder
        encoder = Encoder(input_size, **args.encoder_conf)
        encoder_output_size = encoder.output_size

        # 5. Decoder
        decoder_class = decoder_choices.get_class(args.decoder)
        decoder = decoder_class(
            vocab_size,
            **args.decoder_conf,
        )
        decoder_output_size = decoder.output_size

        # 6. Joint Network
        joint_network = JointNetwork(
            vocab_size,
            encoder_output_size,
            decoder_output_size,
            **args.joint_network_conf,
        )

        # 7. Build model
        model = ESPnetASRTransducerModel(
            vocab_size=vocab_size,
            token_list=token_list,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            encoder=encoder,
            decoder=decoder,
            joint_network=joint_network,
            **args.model_conf,
        )

        # 8. Initialize
        if args.init is not None:
            raise NotImplementedError(
                "Currently not supported.",
                "Initialization will be reworked in a short future.",
            )

        assert check_return_type(model)

        return model
Ejemplo n.º 8
0
def test_wrong_block_io(body_conf):
    with pytest.raises(ValueError):
        _ = Encoder(8, body_conf)
Ejemplo n.º 9
0
def test_block_type(input_conf, body_conf):
    with pytest.raises(ValueError):
        _ = Encoder(8, body_conf, input_conf=input_conf)