def _define_lm_name(dir_name, args): if 'gated_conv' in args.lm_type: from neural_sp.models.lm.gated_convlm import GatedConvLM as module elif args.lm_type == 'transformer': from neural_sp.models.lm.transformerlm import TransformerLM as module elif args.lm_type == 'transformer_xl': from neural_sp.models.lm.transformer_xl import TransformerXL as module else: from neural_sp.models.lm.rnnlm import RNNLM as module if hasattr(module, 'define_name'): dir_name = module.define_name(dir_name, args) else: raise NotImplementedError(module) return dir_name
def _define_encoder_name(dir_name, args): if args.enc_type == 'tds': from neural_sp.models.seq2seq.encoders.tds import TDSEncoder as module elif args.enc_type == 'gated_conv': from neural_sp.models.seq2seq.encoders.gated_conv import GatedConvEncoder as module elif 'transformer' in args.enc_type: from neural_sp.models.seq2seq.encoders.transformer import TransformerEncoder as module elif 'conformer' in args.enc_type: from neural_sp.models.seq2seq.encoders.conformer import ConformerEncoder as module else: from neural_sp.models.seq2seq.encoders.rnn import RNNEncoder as module if hasattr(module, 'define_name'): dir_name = module.define_name(dir_name, args) else: raise NotImplementedError(module) return dir_name
def _define_decoder_name(dir_name, args): if args.dec_type in ['transformer', 'transformer_xl']: from neural_sp.models.seq2seq.decoders.transformer import TransformerDecoder as module elif args.dec_type in [ 'transformer_transducer', 'transformer_transducer_xl' ]: from neural_sp.models.seq2seq.decoders.transformer_transducer import TransformerTransducer as module elif args.dec_type in ['lstm_transducer', 'gru_transducer']: from neural_sp.models.seq2seq.decoders.rnn_transducer import RNNTransducer as module elif args.dec_type == 'asg': from neural_sp.models.seq2seq.decoders.asg import ASGDecoder as module else: from neural_sp.models.seq2seq.decoders.las import RNNDecoder as module if hasattr(module, 'define_name'): dir_name = module.define_name(dir_name, args) else: raise NotImplementedError(module) return dir_name