def test_enh_asr_model( enh_encoder, enh_decoder, enh_separator, training, loss_wrappers, frontend, s2t_encoder, s2t_decoder, s2t_ctc, ): inputs = torch.randn(2, 300) ilens = torch.LongTensor([300, 200]) speech_ref = torch.randn(2, 300).float() text = torch.LongTensor([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]]) text_lengths = torch.LongTensor([5, 5]) enh_model = ESPnetEnhancementModel( encoder=enh_encoder, separator=enh_separator, decoder=enh_decoder, mask_module=None, loss_wrappers=loss_wrappers, ) s2t_model = ESPnetASRModel( vocab_size=len(token_list), token_list=token_list, frontend=frontend, encoder=s2t_encoder, decoder=s2t_decoder, ctc=s2t_ctc, specaug=None, normalize=None, preencoder=None, postencoder=None, joint_network=None, ) enh_s2t_model = ESPnetEnhS2TModel( enh_model=enh_model, s2t_model=s2t_model, ) if training: enh_s2t_model.train() else: enh_s2t_model.eval() kwargs = { "speech": inputs, "speech_lengths": ilens, "speech_ref1": speech_ref, "text": text, "text_lengths": text_lengths, } loss, stats, weight = enh_s2t_model(**kwargs)
def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel: assert check_argument_types() if isinstance(args.token_list, str): with open(args.token_list, encoding="utf-8") as f: token_list = [line.rstrip() for line in f] # Overwriting token_list to keep it as "portable". args.token_list = list(token_list) elif isinstance(args.token_list, (tuple, list)): token_list = list(args.token_list) else: raise RuntimeError("token_list must be str or list") vocab_size = len(token_list) logging.info(f"Vocabulary size: {vocab_size }") # 1. frontend if args.input_size is None: # Extract features in the model frontend_class = frontend_choices.get_class(args.frontend) frontend = frontend_class(**args.frontend_conf) input_size = frontend.output_size() else: # Give features from data-loader args.frontend = None args.frontend_conf = {} frontend = None input_size = args.input_size # 2. Data augmentation for spectrogram if args.specaug is not None: specaug_class = specaug_choices.get_class(args.specaug) specaug = specaug_class(**args.specaug_conf) else: specaug = None # 3. Normalization layer if args.normalize is not None: normalize_class = normalize_choices.get_class(args.normalize) normalize = normalize_class(**args.normalize_conf) else: normalize = None # 4. Encoder encoder_class = encoder_choices.get_class(args.encoder) encoder = encoder_class(input_size=input_size, **args.encoder_conf) # 5. Decoder decoder_class = decoder_choices.get_class(args.decoder) decoder = decoder_class( vocab_size=vocab_size, encoder_output_size=encoder.output_size(), **args.decoder_conf, ) # 6. CTC ctc = CTC(odim=vocab_size, encoder_output_sizse=encoder.output_size(), **args.ctc_conf) # 7. RNN-T Decoder (Not implemented) rnnt_decoder = None # 8. Build model model = ESPnetASRModel( vocab_size=vocab_size, frontend=frontend, specaug=specaug, normalize=normalize, encoder=encoder, decoder=decoder, ctc=ctc, rnnt_decoder=rnnt_decoder, token_list=token_list, **args.model_conf, ) # FIXME(kamo): Should be done in model? # 9. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model
def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel: assert check_argument_types() if isinstance(args.token_list, str): with open(args.token_list, encoding="utf-8") as f: token_list = [line.rstrip() for line in f] # Overwriting token_list to keep it as "portable". args.token_list = list(token_list) elif isinstance(args.token_list, (tuple, list)): token_list = list(args.token_list) else: raise RuntimeError("token_list must be str or list") vocab_size = len(token_list) logging.info(f"Vocabulary size: {vocab_size }") # 1. frontend if args.input_size is None: # Extract features in the model frontend_class = frontend_choices.get_class(args.frontend) frontend = frontend_class(**args.frontend_conf) input_size = frontend.output_size() else: # Give features from data-loader args.frontend = None args.frontend_conf = {} frontend = None input_size = args.input_size # 2. Data augmentation for spectrogram if args.specaug is not None: specaug_class = specaug_choices.get_class(args.specaug) specaug = specaug_class(**args.specaug_conf) else: specaug = None # 3. Normalization layer if args.normalize is not None: normalize_class = normalize_choices.get_class(args.normalize) normalize = normalize_class(**args.normalize_conf) else: normalize = None # 4. Pre-encoder input block # NOTE(kan-bayashi): Use getattr to keep the compatibility if getattr(args, "preencoder", None) is not None: preencoder_class = preencoder_choices.get_class(args.preencoder) preencoder = preencoder_class(**args.preencoder_conf) input_size = preencoder.output_size() else: preencoder = None # 4. Encoder encoder_class = encoder_choices.get_class(args.encoder) encoder = encoder_class(input_size=input_size, **args.encoder_conf) # 5. Post-encoder block # NOTE(kan-bayashi): Use getattr to keep the compatibility encoder_output_size = encoder.output_size() if getattr(args, "postencoder", None) is not None: postencoder_class = postencoder_choices.get_class(args.postencoder) postencoder = postencoder_class(input_size=encoder_output_size, **args.postencoder_conf) encoder_output_size = postencoder.output_size() else: postencoder = None # 5. Decoder decoder_class = decoder_choices.get_class(args.decoder) if args.decoder == "transducer": decoder = decoder_class( vocab_size, embed_pad=0, **args.decoder_conf, ) joint_network = JointNetwork( vocab_size, encoder.output_size(), decoder.dunits, **args.joint_net_conf, ) else: decoder = decoder_class( vocab_size=vocab_size, encoder_output_size=encoder_output_size, **args.decoder_conf, ) joint_network = None # 6. CTC ctc = CTC(odim=vocab_size, encoder_output_sizse=encoder_output_size, **args.ctc_conf) # 8. Build model model = ESPnetASRModel( vocab_size=vocab_size, frontend=frontend, specaug=specaug, normalize=normalize, preencoder=preencoder, encoder=encoder, postencoder=postencoder, decoder=decoder, ctc=ctc, joint_network=joint_network, token_list=token_list, **args.model_conf, ) # FIXME(kamo): Should be done in model? # 9. Initialize if args.init is not None: initialize(model, args.init) assert check_return_type(model) return model