def test_TransformerDecoder_batch_beam_search_online(input_layer,
                                                     normalize_before,
                                                     use_output_layer, dtype,
                                                     decoder_class, tmp_path):
    token_list = ["<blank>", "a", "b", "c", "unk", "<eos>"]
    vocab_size = len(token_list)
    encoder_output_size = 8

    decoder = decoder_class(
        vocab_size=vocab_size,
        encoder_output_size=encoder_output_size,
        input_layer=input_layer,
        normalize_before=normalize_before,
        use_output_layer=use_output_layer,
        linear_units=10,
    )
    ctc = CTC(odim=vocab_size, encoder_output_sizse=encoder_output_size)
    ctc.to(dtype)
    ctc_scorer = CTCPrefixScorer(ctc=ctc, eos=vocab_size - 1)
    beam = BatchBeamSearchOnlineSim(
        beam_size=3,
        vocab_size=vocab_size,
        weights={
            "test": 0.7,
            "ctc": 0.3
        },
        scorers={
            "test": decoder,
            "ctc": ctc_scorer
        },
        token_list=token_list,
        sos=vocab_size - 1,
        eos=vocab_size - 1,
        pre_beam_score_key=None,
    )
    cp = tmp_path / "config.yaml"
    yp = tmp_path / "dummy.yaml"
    with cp.open("w") as f:
        f.write("config: " + str(yp) + "\n")
    with yp.open("w") as f:
        f.write("encoder_conf:\n")
        f.write("    block_size: 4\n")
        f.write("    hop_size: 2\n")
        f.write("    look_ahead: 1\n")
    beam.set_streaming_config(cp)
    beam.set_block_size(4)
    beam.set_hop_size(2)
    beam.set_look_ahead(1)
    beam.to(dtype=dtype)

    enc = torch.randn(10, encoder_output_size).type(dtype)
    with torch.no_grad():
        beam(
            x=enc,
            maxlenratio=0.0,
            minlenratio=0.0,
        )
def test_Encoder_forward_backward(
    input_layer,
    positionwise_layer_type,
    interctc_layer_idx,
    interctc_use_conditioning,
):
    encoder = TransformerEncoder(
        20,
        output_size=40,
        input_layer=input_layer,
        positionwise_layer_type=positionwise_layer_type,
        interctc_layer_idx=interctc_layer_idx,
        interctc_use_conditioning=interctc_use_conditioning,
    )
    if input_layer == "embed":
        x = torch.randint(0, 10, [2, 10])
    else:
        x = torch.randn(2, 10, 20, requires_grad=True)
    x_lens = torch.LongTensor([10, 8])
    if len(interctc_layer_idx) > 0:
        ctc = None
        if interctc_use_conditioning:
            vocab_size = 5
            output_size = encoder.output_size()
            ctc = CTC(odim=vocab_size, encoder_output_size=output_size)
            encoder.conditioning_layer = torch.nn.Linear(
                vocab_size, output_size)
        y, _, _ = encoder(x, x_lens, ctc=ctc)
        y = y[0]
    else:
        y, _, _ = encoder(x, x_lens)
    y.sum().backward()
Beispiel #3
0
def test_maskctc(encoder_arch, interctc_layer_idx, interctc_use_conditioning,
                 interctc_weight):
    vocab_size = 5
    enc_out = 4
    encoder = encoder_arch(
        20,
        output_size=enc_out,
        linear_units=4,
        num_blocks=2,
        interctc_layer_idx=interctc_layer_idx,
        interctc_use_conditioning=interctc_use_conditioning,
    )
    decoder = MLMDecoder(
        vocab_size,
        enc_out,
        linear_units=4,
        num_blocks=2,
    )
    ctc = CTC(odim=vocab_size, encoder_output_size=enc_out)

    model = MaskCTCModel(
        vocab_size,
        token_list=["<blank>", "<unk>", "a", "i", "<eos>"],
        frontend=None,
        specaug=None,
        normalize=None,
        preencoder=None,
        encoder=encoder,
        postencoder=None,
        decoder=decoder,
        ctc=ctc,
        interctc_weight=interctc_weight,
    )

    inputs = dict(
        speech=torch.randn(2, 10, 20, requires_grad=True),
        speech_lengths=torch.tensor([10, 8], dtype=torch.long),
        text=torch.randint(2, 4, [2, 4], dtype=torch.long),
        text_lengths=torch.tensor([4, 3], dtype=torch.long),
    )
    loss, *_ = model(**inputs)
    loss.backward()

    with torch.no_grad():
        model.eval()

        s2t = MaskCTCInference(
            asr_model=model,
            n_iterations=2,
            threshold_probability=0.5,
        )

        # free running
        inputs = dict(enc_out=torch.randn(2, 4), )
        s2t(**inputs)
Beispiel #4
0
def test_encoder_forward_backward(
    input_layer,
    positionwise_layer_type,
    rel_pos_type,
    pos_enc_layer_type,
    selfattention_layer_type,
    interctc_layer_idx,
    interctc_use_conditioning,
    stochastic_depth_rate,
):
    encoder = ConformerEncoder(
        20,
        output_size=2,
        attention_heads=2,
        linear_units=4,
        num_blocks=2,
        input_layer=input_layer,
        macaron_style=False,
        rel_pos_type=rel_pos_type,
        pos_enc_layer_type=pos_enc_layer_type,
        selfattention_layer_type=selfattention_layer_type,
        activation_type="swish",
        use_cnn_module=True,
        cnn_module_kernel=3,
        positionwise_layer_type=positionwise_layer_type,
        interctc_layer_idx=interctc_layer_idx,
        interctc_use_conditioning=interctc_use_conditioning,
        stochastic_depth_rate=stochastic_depth_rate,
    )
    if input_layer == "embed":
        x = torch.randint(0, 10, [2, 32])
    else:
        x = torch.randn(2, 32, 20, requires_grad=True)
    x_lens = torch.LongTensor([32, 28])
    if len(interctc_layer_idx) > 0:
        ctc = None
        if interctc_use_conditioning:
            vocab_size = 5
            output_size = encoder.output_size()
            ctc = CTC(odim=vocab_size, encoder_output_size=output_size)
            encoder.conditioning_layer = torch.nn.Linear(
                vocab_size, output_size)
        y, _, _ = encoder(x, x_lens, ctc=ctc)
        y = y[0]
    else:
        y, _, _ = encoder(x, x_lens)
    y.sum().backward()
Beispiel #5
0
    def build_model(cls, args: argparse.Namespace) -> ESPnetEnhASRModel:
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
                token_list = [line.rstrip() for line in f]

            # Overwriting token_list to keep it as "portable".
            args.token_list = list(token_list)
        elif isinstance(args.token_list, (tuple, list)):
            token_list = list(args.token_list)
        else:
            raise RuntimeError("token_list must be str or list")
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size }")

        # 0. Build pre enhancement model
        enh_model = enh_choices.get_class(args.enh)(**args.enh_conf)

        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size

        # 2. Data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
            specaug = specaug_class(**args.specaug_conf)
        else:
            specaug = None

        # 3. Normalization layer
        if args.normalize is not None:
            normalize_class = normalize_choices.get_class(args.normalize)
            normalize = normalize_class(**args.normalize_conf)
        else:
            normalize = None

        # 4. Encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(input_size=input_size, **args.encoder_conf)

        # 5. Decoder
        decoder_class = decoder_choices.get_class(args.decoder)

        decoder = decoder_class(
            vocab_size=vocab_size,
            encoder_output_size=encoder.output_size(),
            **args.decoder_conf,
        )

        # 6. CTC
        ctc = CTC(
            odim=vocab_size, encoder_output_sizse=encoder.output_size(), **args.ctc_conf
        )

        # 7. RNN-T Decoder (Not implemented)
        rnnt_decoder = None

        # 8. Build model
        model = ESPnetEnhASRModel(
            vocab_size=vocab_size,
            enh=enh_model,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            encoder=encoder,
            decoder=decoder,
            ctc=ctc,
            rnnt_decoder=rnnt_decoder,
            token_list=token_list,
            **args.asr_model_conf,
        )

        # FIXME(kamo): Should be done in model?
        # 9. Initialize
        if args.init is not None:
            initialize(model, args.init)

        assert check_return_type(model)
        return model
Beispiel #6
0
    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        prev_states: torch.Tensor = None,
        ctc: CTC = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Calculate forward propagation.

        Args:
            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
            ilens (torch.Tensor): Input length (#batch).
            prev_states (torch.Tensor): Not to be used now.

        Returns:
            torch.Tensor: Output tensor (#batch, L, output_size).
            torch.Tensor: Output length (#batch).
            torch.Tensor: Not to be used now.

        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)

        if (
            isinstance(self.embed, Conv2dSubsampling)
            or isinstance(self.embed, Conv2dSubsampling2)
            or isinstance(self.embed, Conv2dSubsampling6)
            or isinstance(self.embed, Conv2dSubsampling8)
        ):
            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
            if short_status:
                raise TooShortUttError(
                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
                    + f"(it needs more than {limit_size} frames), return empty results",
                    xs_pad.size(1),
                    limit_size,
                )
            xs_pad, masks = self.embed(xs_pad, masks)
        else:
            xs_pad = self.embed(xs_pad)

        intermediate_outs = []
        if len(self.interctc_layer_idx) == 0:
            xs_pad, masks = self.encoders(xs_pad, masks)
        else:
            for layer_idx, encoder_layer in enumerate(self.encoders):
                xs_pad, masks = encoder_layer(xs_pad, masks)

                if layer_idx + 1 in self.interctc_layer_idx:
                    encoder_out = xs_pad
                    if isinstance(encoder_out, tuple):
                        encoder_out = encoder_out[0]

                    # intermediate outputs are also normalized
                    if self.normalize_before:
                        encoder_out = self.after_norm(encoder_out)

                    intermediate_outs.append((layer_idx + 1, encoder_out))

                    if self.interctc_use_conditioning:
                        ctc_out = ctc.softmax(encoder_out)

                        if isinstance(xs_pad, tuple):
                            x, pos_emb = xs_pad
                            x = x + self.conditioning_layer(ctc_out)
                            xs_pad = (x, pos_emb)
                        else:
                            xs_pad = xs_pad + self.conditioning_layer(ctc_out)

        if isinstance(xs_pad, tuple):
            xs_pad = xs_pad[0]
        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)

        olens = masks.squeeze(1).sum(1)
        if len(intermediate_outs) > 0:
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None
Beispiel #7
0
    def build_model(cls, args: argparse.Namespace) -> ESPnetSTModel:
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
                token_list = [line.rstrip() for line in f]

            # Overwriting token_list to keep it as "portable".
            args.token_list = list(token_list)
        elif isinstance(args.token_list, (tuple, list)):
            token_list = list(args.token_list)
        else:
            raise RuntimeError("token_list must be str or list")
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size }")

        if args.src_token_list is not None:
            if isinstance(args.src_token_list, str):
                with open(args.src_token_list, encoding="utf-8") as f:
                    src_token_list = [line.rstrip() for line in f]

                # Overwriting src_token_list to keep it as "portable".
                args.src_token_list = list(src_token_list)
            elif isinstance(args.src_token_list, (tuple, list)):
                src_token_list = list(args.src_token_list)
            else:
                raise RuntimeError("token_list must be str or list")
            src_vocab_size = len(src_token_list)
            logging.info(f"Source vocabulary size: {src_vocab_size }")
        else:
            src_token_list, src_vocab_size = None, None

        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size

        # 2. Data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
            specaug = specaug_class(**args.specaug_conf)
        else:
            specaug = None

        # 3. Normalization layer
        if args.normalize is not None:
            normalize_class = normalize_choices.get_class(args.normalize)
            normalize = normalize_class(**args.normalize_conf)
        else:
            normalize = None

        # 4. Pre-encoder input block
        # NOTE(kan-bayashi): Use getattr to keep the compatibility
        if getattr(args, "preencoder", None) is not None:
            preencoder_class = preencoder_choices.get_class(args.preencoder)
            preencoder = preencoder_class(**args.preencoder_conf)
            input_size = preencoder.output_size()
        else:
            preencoder = None

        # 4. Encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(input_size=input_size, **args.encoder_conf)

        # 5. Post-encoder block
        # NOTE(kan-bayashi): Use getattr to keep the compatibility
        encoder_output_size = encoder.output_size()
        if getattr(args, "postencoder", None) is not None:
            postencoder_class = postencoder_choices.get_class(args.postencoder)
            postencoder = postencoder_class(input_size=encoder_output_size,
                                            **args.postencoder_conf)
            encoder_output_size = postencoder.output_size()
        else:
            postencoder = None

        # 5. Decoder
        decoder_class = decoder_choices.get_class(args.decoder)

        decoder = decoder_class(
            vocab_size=vocab_size,
            encoder_output_size=encoder_output_size,
            **args.decoder_conf,
        )

        # 6. CTC
        if src_token_list is not None:
            ctc = CTC(
                odim=src_vocab_size,
                encoder_output_sizse=encoder_output_size,
                **args.ctc_conf,
            )
        else:
            ctc = None

        # 7. ASR extra decoder
        if (getattr(args, "extra_asr_decoder", None) is not None
                and src_token_list is not None):
            extra_asr_decoder_class = extra_asr_decoder_choices.get_class(
                args.extra_asr_decoder)
            extra_asr_decoder = extra_asr_decoder_class(
                vocab_size=src_vocab_size,
                encoder_output_size=encoder_output_size,
                **args.extra_asr_decoder_conf,
            )
        else:
            extra_asr_decoder = None

        # 8. MT extra decoder
        if getattr(args, "extra_mt_decoder", None) is not None:
            extra_mt_decoder_class = extra_mt_decoder_choices.get_class(
                args.extra_mt_decoder)
            extra_mt_decoder = extra_mt_decoder_class(
                vocab_size=vocab_size,
                encoder_output_size=encoder_output_size,
                **args.extra_mt_decoder_conf,
            )
        else:
            extra_asr_decoder = None

        # 8. Build model
        model = ESPnetSTModel(
            vocab_size=vocab_size,
            src_vocab_size=src_vocab_size,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            preencoder=preencoder,
            encoder=encoder,
            postencoder=postencoder,
            decoder=decoder,
            ctc=ctc,
            extra_asr_decoder=extra_asr_decoder,
            extra_mt_decoder=extra_mt_decoder,
            token_list=token_list,
            src_token_list=src_token_list,
            **args.model_conf,
        )

        # FIXME(kamo): Should be done in model?
        # 9. Initialize
        if args.init is not None:
            initialize(model, args.init)

        assert check_return_type(model)
        return model
Beispiel #8
0
    def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel:
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
                token_list = [line.rstrip() for line in f]

            # Overwriting token_list to keep it as "portable".
            args.token_list = list(token_list)
        elif isinstance(args.token_list, (tuple, list)):
            token_list = list(args.token_list)
        else:
            raise RuntimeError("token_list must be str or list")
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size }")

        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size

        # 2. Data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
            specaug = specaug_class(**args.specaug_conf)
        else:
            specaug = None

        # 3. Normalization layer
        if args.normalize is not None:
            normalize_class = normalize_choices.get_class(args.normalize)
            normalize = normalize_class(**args.normalize_conf)
        else:
            normalize = None

        # 4. Pre-encoder input block
        # NOTE(kan-bayashi): Use getattr to keep the compatibility
        if getattr(args, "preencoder", None) is not None:
            preencoder_class = preencoder_choices.get_class(args.preencoder)
            preencoder = preencoder_class(**args.preencoder_conf)
            input_size = preencoder.output_size()
        else:
            preencoder = None

        # 4. Encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(input_size=input_size, **args.encoder_conf)

        # 5. Post-encoder block
        # NOTE(kan-bayashi): Use getattr to keep the compatibility
        encoder_output_size = encoder.output_size()
        if getattr(args, "postencoder", None) is not None:
            postencoder_class = postencoder_choices.get_class(args.postencoder)
            postencoder = postencoder_class(input_size=encoder_output_size,
                                            **args.postencoder_conf)
            encoder_output_size = postencoder.output_size()
        else:
            postencoder = None

        # 5. Decoder
        decoder_class = decoder_choices.get_class(args.decoder)

        if args.decoder == "transducer":
            decoder = decoder_class(
                vocab_size,
                embed_pad=0,
                **args.decoder_conf,
            )

            joint_network = JointNetwork(
                vocab_size,
                encoder.output_size(),
                decoder.dunits,
                **args.joint_net_conf,
            )
        else:
            decoder = decoder_class(
                vocab_size=vocab_size,
                encoder_output_size=encoder_output_size,
                **args.decoder_conf,
            )

            joint_network = None

        # 6. CTC
        ctc = CTC(odim=vocab_size,
                  encoder_output_size=encoder_output_size,
                  **args.ctc_conf)

        # 7. Build model
        try:
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("espnet")
        model = model_class(
            vocab_size=vocab_size,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            preencoder=preencoder,
            encoder=encoder,
            postencoder=postencoder,
            decoder=decoder,
            ctc=ctc,
            joint_network=joint_network,
            token_list=token_list,
            **args.model_conf,
        )

        # FIXME(kamo): Should be done in model?
        # 8. Initialize
        if args.init is not None:
            initialize(model, args.init)

        assert check_return_type(model)
        return model
Beispiel #9
0
asr_transformer_encoder = TransformerEncoder(
    32,
    output_size=16,
    linear_units=16,
    num_blocks=2,
)

asr_transformer_decoder = TransformerDecoder(
    len(token_list),
    16,
    linear_units=16,
    num_blocks=2,
)

asr_ctc = CTC(odim=len(token_list), encoder_output_size=16)


@pytest.mark.parametrize(
    "enh_encoder, enh_decoder",
    [(enh_stft_encoder, enh_stft_decoder)],
)
@pytest.mark.parametrize("enh_separator", [enh_rnn_separator])
@pytest.mark.parametrize("training", [True, False])
@pytest.mark.parametrize("loss_wrappers", [[fix_order_solver]])
@pytest.mark.parametrize("frontend", [default_frontend])
@pytest.mark.parametrize("s2t_encoder", [asr_transformer_encoder])
@pytest.mark.parametrize("s2t_decoder", [asr_transformer_decoder])
@pytest.mark.parametrize("s2t_ctc", [asr_ctc])
def test_enh_asr_model(
    enh_encoder,
Beispiel #10
0
def test_ctc_argmax(ctc_type, ctc_args):
    if ctc_type == "warpctc":
        pytest.importorskip("warpctc_pytorch")
    ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type)
    ctc.argmax(ctc_args[0])
Beispiel #11
0
def test_ctc_forward_backward(ctc_type, ctc_args):
    if ctc_type == "warpctc":
        pytest.importorskip("warpctc_pytorch")
    ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type)
    ctc(*ctc_args).sum().backward()