def test_Encoder_forward_backward( input_layer, positionwise_layer_type, interctc_layer_idx, interctc_use_conditioning, ): encoder = TransformerEncoder( 20, output_size=40, input_layer=input_layer, positionwise_layer_type=positionwise_layer_type, interctc_layer_idx=interctc_layer_idx, interctc_use_conditioning=interctc_use_conditioning, ) if input_layer == "embed": x = torch.randint(0, 10, [2, 10]) else: x = torch.randn(2, 10, 20, requires_grad=True) x_lens = torch.LongTensor([10, 8]) if len(interctc_layer_idx) > 0: ctc = None if interctc_use_conditioning: vocab_size = 5 output_size = encoder.output_size() ctc = CTC(odim=vocab_size, encoder_output_size=output_size) encoder.conditioning_layer = torch.nn.Linear( vocab_size, output_size) y, _, _ = encoder(x, x_lens, ctc=ctc) y = y[0] else: y, _, _ = encoder(x, x_lens) y.sum().backward()
def test_TransformerDecoder_batch_beam_search_online(input_layer, normalize_before, use_output_layer, dtype, decoder_class, tmp_path): token_list = ["<blank>", "a", "b", "c", "unk", "<eos>"] vocab_size = len(token_list) encoder_output_size = 8 decoder = decoder_class( vocab_size=vocab_size, encoder_output_size=encoder_output_size, input_layer=input_layer, normalize_before=normalize_before, use_output_layer=use_output_layer, linear_units=10, ) ctc = CTC(odim=vocab_size, encoder_output_sizse=encoder_output_size) ctc.to(dtype) ctc_scorer = CTCPrefixScorer(ctc=ctc, eos=vocab_size - 1) beam = BatchBeamSearchOnlineSim( beam_size=3, vocab_size=vocab_size, weights={ "test": 0.7, "ctc": 0.3 }, scorers={ "test": decoder, "ctc": ctc_scorer }, token_list=token_list, sos=vocab_size - 1, eos=vocab_size - 1, pre_beam_score_key=None, ) cp = tmp_path / "config.yaml" yp = tmp_path / "dummy.yaml" with cp.open("w") as f: f.write("config: " + str(yp) + "\n") with yp.open("w") as f: f.write("encoder_conf:\n") f.write(" block_size: 4\n") f.write(" hop_size: 2\n") f.write(" look_ahead: 1\n") beam.set_streaming_config(cp) beam.set_block_size(4) beam.set_hop_size(2) beam.set_look_ahead(1) beam.to(dtype=dtype) enc = torch.randn(10, encoder_output_size).type(dtype) with torch.no_grad(): beam( x=enc, maxlenratio=0.0, minlenratio=0.0, )
def test_maskctc(encoder_arch, interctc_layer_idx, interctc_use_conditioning, interctc_weight): vocab_size = 5 enc_out = 4 encoder = encoder_arch( 20, output_size=enc_out, linear_units=4, num_blocks=2, interctc_layer_idx=interctc_layer_idx, interctc_use_conditioning=interctc_use_conditioning, ) decoder = MLMDecoder( vocab_size, enc_out, linear_units=4, num_blocks=2, ) ctc = CTC(odim=vocab_size, encoder_output_size=enc_out) model = MaskCTCModel( vocab_size, token_list=["<blank>", "<unk>", "a", "i", "<eos>"], frontend=None, specaug=None, normalize=None, preencoder=None, encoder=encoder, postencoder=None, decoder=decoder, ctc=ctc, interctc_weight=interctc_weight, ) inputs = dict( speech=torch.randn(2, 10, 20, requires_grad=True), speech_lengths=torch.tensor([10, 8], dtype=torch.long), text=torch.randint(2, 4, [2, 4], dtype=torch.long), text_lengths=torch.tensor([4, 3], dtype=torch.long), ) loss, *_ = model(**inputs) loss.backward() with torch.no_grad(): model.eval() s2t = MaskCTCInference( asr_model=model, n_iterations=2, threshold_probability=0.5, ) # free running inputs = dict(enc_out=torch.randn(2, 4), ) s2t(**inputs)
def test_encoder_forward_backward( input_layer, positionwise_layer_type, rel_pos_type, pos_enc_layer_type, selfattention_layer_type, interctc_layer_idx, interctc_use_conditioning, stochastic_depth_rate, ): encoder = ConformerEncoder( 20, output_size=2, attention_heads=2, linear_units=4, num_blocks=2, input_layer=input_layer, macaron_style=False, rel_pos_type=rel_pos_type, pos_enc_layer_type=pos_enc_layer_type, selfattention_layer_type=selfattention_layer_type, activation_type="swish", use_cnn_module=True, cnn_module_kernel=3, positionwise_layer_type=positionwise_layer_type, interctc_layer_idx=interctc_layer_idx, interctc_use_conditioning=interctc_use_conditioning, stochastic_depth_rate=stochastic_depth_rate, ) if input_layer == "embed": x = torch.randint(0, 10, [2, 32]) else: x = torch.randn(2, 32, 20, requires_grad=True) x_lens = torch.LongTensor([32, 28]) if len(interctc_layer_idx) > 0: ctc = None if interctc_use_conditioning: vocab_size = 5 output_size = encoder.output_size() ctc = CTC(odim=vocab_size, encoder_output_size=output_size) encoder.conditioning_layer = torch.nn.Linear( vocab_size, output_size) y, _, _ = encoder(x, x_lens, ctc=ctc) y = y[0] else: y, _, _ = encoder(x, x_lens) y.sum().backward()
def build_model(cls, args: argparse.Namespace) -> ESPnetEnhASRModel: assert check_argument_types() if isinstance(args.token_list, str): with open(args.token_list, encoding="utf-8") as f: token_list = [line.rstrip() for line in f] # Overwriting token_list to keep it as "portable". args.token_list = list(token_list) elif isinstance(args.token_list, (tuple, list)): token_list = list(args.token_list) else: raise RuntimeError("token_list must be str or list") vocab_size = len(token_list) logging.info(f"Vocabulary size: {vocab_size }") # 0. Build pre enhancement model enh_model = enh_choices.get_class(args.enh)(**args.enh_conf) # 1. frontend if args.input_size is None: # Extract features in the model frontend_class = frontend_choices.get_class(args.frontend) frontend = frontend_class(**args.frontend_conf) input_size = frontend.output_size() else: # Give features from data-loader args.frontend = None args.frontend_conf = {} frontend = None input_size = args.input_size # 2. Data augmentation for spectrogram if args.specaug is not None: specaug_class = specaug_choices.get_class(args.specaug) specaug = specaug_class(**args.specaug_conf) else: specaug = None # 3. Normalization layer if args.normalize is not None: normalize_class = normalize_choices.get_class(args.normalize) normalize = normalize_class(**args.normalize_conf) else: normalize = None # 4. Encoder encoder_class = encoder_choices.get_class(args.encoder) encoder = encoder_class(input_size=input_size, **args.encoder_conf) # 5. Decoder decoder_class = decoder_choices.get_class(args.decoder) decoder = decoder_class( vocab_size=vocab_size, encoder_output_size=encoder.output_size(), **args.decoder_conf, ) # 6. CTC ctc = CTC( odim=vocab_size, encoder_output_sizse=encoder.output_size(), **args.ctc_conf ) # 7. RNN-T Decoder (Not implemented) rnnt_decoder = None # 8. Build model model = ESPnetEnhASRModel( vocab_size=vocab_size, enh=enh_model, frontend=frontend, specaug=specaug, normalize=normalize, encoder=encoder, decoder=decoder, ctc=ctc, rnnt_decoder=rnnt_decoder, token_list=token_list, **args.asr_model_conf, ) # FIXME(kamo): Should be done in model? # 9. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
def build_model(cls, args: argparse.Namespace) -> ESPnetSTModel: assert check_argument_types() if isinstance(args.token_list, str): with open(args.token_list, encoding="utf-8") as f: token_list = [line.rstrip() for line in f] # Overwriting token_list to keep it as "portable". args.token_list = list(token_list) elif isinstance(args.token_list, (tuple, list)): token_list = list(args.token_list) else: raise RuntimeError("token_list must be str or list") vocab_size = len(token_list) logging.info(f"Vocabulary size: {vocab_size }") if args.src_token_list is not None: if isinstance(args.src_token_list, str): with open(args.src_token_list, encoding="utf-8") as f: src_token_list = [line.rstrip() for line in f] # Overwriting src_token_list to keep it as "portable". args.src_token_list = list(src_token_list) elif isinstance(args.src_token_list, (tuple, list)): src_token_list = list(args.src_token_list) else: raise RuntimeError("token_list must be str or list") src_vocab_size = len(src_token_list) logging.info(f"Source vocabulary size: {src_vocab_size }") else: src_token_list, src_vocab_size = None, None # 1. frontend if args.input_size is None: # Extract features in the model frontend_class = frontend_choices.get_class(args.frontend) frontend = frontend_class(**args.frontend_conf) input_size = frontend.output_size() else: # Give features from data-loader args.frontend = None args.frontend_conf = {} frontend = None input_size = args.input_size # 2. Data augmentation for spectrogram if args.specaug is not None: specaug_class = specaug_choices.get_class(args.specaug) specaug = specaug_class(**args.specaug_conf) else: specaug = None # 3. Normalization layer if args.normalize is not None: normalize_class = normalize_choices.get_class(args.normalize) normalize = normalize_class(**args.normalize_conf) else: normalize = None # 4. Pre-encoder input block # NOTE(kan-bayashi): Use getattr to keep the compatibility if getattr(args, "preencoder", None) is not None: preencoder_class = preencoder_choices.get_class(args.preencoder) preencoder = preencoder_class(**args.preencoder_conf) input_size = preencoder.output_size() else: preencoder = None # 4. Encoder encoder_class = encoder_choices.get_class(args.encoder) encoder = encoder_class(input_size=input_size, **args.encoder_conf) # 5. Post-encoder block # NOTE(kan-bayashi): Use getattr to keep the compatibility encoder_output_size = encoder.output_size() if getattr(args, "postencoder", None) is not None: postencoder_class = postencoder_choices.get_class(args.postencoder) postencoder = postencoder_class(input_size=encoder_output_size, **args.postencoder_conf) encoder_output_size = postencoder.output_size() else: postencoder = None # 5. Decoder decoder_class = decoder_choices.get_class(args.decoder) decoder = decoder_class( vocab_size=vocab_size, encoder_output_size=encoder_output_size, **args.decoder_conf, ) # 6. CTC if src_token_list is not None: ctc = CTC( odim=src_vocab_size, encoder_output_sizse=encoder_output_size, **args.ctc_conf, ) else: ctc = None # 7. ASR extra decoder if (getattr(args, "extra_asr_decoder", None) is not None and src_token_list is not None): extra_asr_decoder_class = extra_asr_decoder_choices.get_class( args.extra_asr_decoder) extra_asr_decoder = extra_asr_decoder_class( vocab_size=src_vocab_size, encoder_output_size=encoder_output_size, **args.extra_asr_decoder_conf, ) else: extra_asr_decoder = None # 8. MT extra decoder if getattr(args, "extra_mt_decoder", None) is not None: extra_mt_decoder_class = extra_mt_decoder_choices.get_class( args.extra_mt_decoder) extra_mt_decoder = extra_mt_decoder_class( vocab_size=vocab_size, encoder_output_size=encoder_output_size, **args.extra_mt_decoder_conf, ) else: extra_asr_decoder = None # 8. Build model model = ESPnetSTModel( vocab_size=vocab_size, src_vocab_size=src_vocab_size, frontend=frontend, specaug=specaug, normalize=normalize, preencoder=preencoder, encoder=encoder, postencoder=postencoder, decoder=decoder, ctc=ctc, extra_asr_decoder=extra_asr_decoder, extra_mt_decoder=extra_mt_decoder, token_list=token_list, src_token_list=src_token_list, **args.model_conf, ) # FIXME(kamo): Should be done in model? # 9. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel: assert check_argument_types() if isinstance(args.token_list, str): with open(args.token_list, encoding="utf-8") as f: token_list = [line.rstrip() for line in f] # Overwriting token_list to keep it as "portable". args.token_list = list(token_list) elif isinstance(args.token_list, (tuple, list)): token_list = list(args.token_list) else: raise RuntimeError("token_list must be str or list") vocab_size = len(token_list) logging.info(f"Vocabulary size: {vocab_size }") # 1. frontend if args.input_size is None: # Extract features in the model frontend_class = frontend_choices.get_class(args.frontend) frontend = frontend_class(**args.frontend_conf) input_size = frontend.output_size() else: # Give features from data-loader args.frontend = None args.frontend_conf = {} frontend = None input_size = args.input_size # 2. Data augmentation for spectrogram if args.specaug is not None: specaug_class = specaug_choices.get_class(args.specaug) specaug = specaug_class(**args.specaug_conf) else: specaug = None # 3. Normalization layer if args.normalize is not None: normalize_class = normalize_choices.get_class(args.normalize) normalize = normalize_class(**args.normalize_conf) else: normalize = None # 4. Pre-encoder input block # NOTE(kan-bayashi): Use getattr to keep the compatibility if getattr(args, "preencoder", None) is not None: preencoder_class = preencoder_choices.get_class(args.preencoder) preencoder = preencoder_class(**args.preencoder_conf) input_size = preencoder.output_size() else: preencoder = None # 4. Encoder encoder_class = encoder_choices.get_class(args.encoder) encoder = encoder_class(input_size=input_size, **args.encoder_conf) # 5. Post-encoder block # NOTE(kan-bayashi): Use getattr to keep the compatibility encoder_output_size = encoder.output_size() if getattr(args, "postencoder", None) is not None: postencoder_class = postencoder_choices.get_class(args.postencoder) postencoder = postencoder_class(input_size=encoder_output_size, **args.postencoder_conf) encoder_output_size = postencoder.output_size() else: postencoder = None # 5. Decoder decoder_class = decoder_choices.get_class(args.decoder) if args.decoder == "transducer": decoder = decoder_class( vocab_size, embed_pad=0, **args.decoder_conf, ) joint_network = JointNetwork( vocab_size, encoder.output_size(), decoder.dunits, **args.joint_net_conf, ) else: decoder = decoder_class( vocab_size=vocab_size, encoder_output_size=encoder_output_size, **args.decoder_conf, ) joint_network = None # 6. CTC ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf) # 7. Build model try: model_class = model_choices.get_class(args.model) except AttributeError: model_class = model_choices.get_class("espnet") model = model_class( vocab_size=vocab_size, frontend=frontend, specaug=specaug, normalize=normalize, preencoder=preencoder, encoder=encoder, postencoder=postencoder, decoder=decoder, ctc=ctc, joint_network=joint_network, token_list=token_list, **args.model_conf, ) # FIXME(kamo): Should be done in model? # 8. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
asr_transformer_encoder = TransformerEncoder( 32, output_size=16, linear_units=16, num_blocks=2, ) asr_transformer_decoder = TransformerDecoder( len(token_list), 16, linear_units=16, num_blocks=2, ) asr_ctc = CTC(odim=len(token_list), encoder_output_size=16) @pytest.mark.parametrize( "enh_encoder, enh_decoder", [(enh_stft_encoder, enh_stft_decoder)], ) @pytest.mark.parametrize("enh_separator", [enh_rnn_separator]) @pytest.mark.parametrize("training", [True, False]) @pytest.mark.parametrize("loss_wrappers", [[fix_order_solver]]) @pytest.mark.parametrize("frontend", [default_frontend]) @pytest.mark.parametrize("s2t_encoder", [asr_transformer_encoder]) @pytest.mark.parametrize("s2t_decoder", [asr_transformer_decoder]) @pytest.mark.parametrize("s2t_ctc", [asr_ctc]) def test_enh_asr_model( enh_encoder,
def test_ctc_argmax(ctc_type, ctc_args): if ctc_type == "warpctc": pytest.importorskip("warpctc_pytorch") ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type) ctc.argmax(ctc_args[0])
def test_ctc_forward_backward(ctc_type, ctc_args): if ctc_type == "warpctc": pytest.importorskip("warpctc_pytorch") ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type) ctc(*ctc_args).sum().backward()