def init_asr_model(configs): if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) global_cmvn = GlobalCMVN( torch.from_numpy(mean).float(), torch.from_numpy(istd).float()) else: global_cmvn = None input_dim = configs['input_dim'] vocab_size = configs['output_dim'] encoder_type = configs.get('encoder', 'conformer') if encoder_type == 'conformer': encoder = ConformerEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) else: encoder = TransformerEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) ctc = CTC(vocab_size, encoder.output_size()) model = ASRModel( vocab_size=vocab_size, encoder=encoder, decoder=decoder, ctc=ctc, **configs['model_conf'], ) return model
def init_model(configs): if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) global_cmvn = GlobalCMVN( torch.from_numpy(mean).float(), torch.from_numpy(istd).float()) else: global_cmvn = None input_dim = configs['input_dim'] vocab_size = configs['output_dim'] encoder_type = configs.get('encoder', 'conformer') decoder_type = configs.get('decoder', 'bitransformer') if encoder_type == 'conformer': encoder = ConformerEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) else: encoder = TransformerEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) if decoder_type == 'transformer': decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) else: assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0 assert configs['decoder_conf']['r_num_blocks'] > 0 decoder = BiTransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) ctc = CTC(vocab_size, encoder.output_size()) # Init joint CTC/Attention or Transducer model if 'predictor' in configs: predictor_type = configs.get('predictor', 'rnn') if predictor_type == 'rnn': predictor = RNNPredictor(vocab_size, **configs['predictor_conf']) else: raise NotImplementedError("only rnn type support now") configs['joint_conf']['enc_output_size'] = configs['encoder_conf'][ 'output_size'] configs['joint_conf']['pred_output_size'] = configs['predictor_conf'][ 'output_size'] joint = TransducerJoint(vocab_size, **configs['joint_conf']) model = Transducer(vocab_size=vocab_size, blank=0, predictor=predictor, encoder=encoder, attention_decoder=decoder, joint=joint, ctc=ctc, **configs['model_conf']) else: model = ASRModel(vocab_size=vocab_size, encoder=encoder, decoder=decoder, ctc=ctc, **configs['model_conf']) return model