Exemple #1
0
def _define_lm_name(dir_name, args):
    if 'gated_conv' in args.lm_type:
        from neural_sp.models.lm.gated_convlm import GatedConvLM as module
    elif args.lm_type == 'transformer':
        from neural_sp.models.lm.transformerlm import TransformerLM as module
    elif args.lm_type == 'transformer_xl':
        from neural_sp.models.lm.transformer_xl import TransformerXL as module
    else:
        from neural_sp.models.lm.rnnlm import RNNLM as module
    if hasattr(module, 'define_name'):
        dir_name = module.define_name(dir_name, args)
    else:
        raise NotImplementedError(module)
    return dir_name
Exemple #2
0
def _define_encoder_name(dir_name, args):
    if args.enc_type == 'tds':
        from neural_sp.models.seq2seq.encoders.tds import TDSEncoder as module
    elif args.enc_type == 'gated_conv':
        from neural_sp.models.seq2seq.encoders.gated_conv import GatedConvEncoder as module
    elif 'transformer' in args.enc_type:
        from neural_sp.models.seq2seq.encoders.transformer import TransformerEncoder as module
    elif 'conformer' in args.enc_type:
        from neural_sp.models.seq2seq.encoders.conformer import ConformerEncoder as module
    else:
        from neural_sp.models.seq2seq.encoders.rnn import RNNEncoder as module
    if hasattr(module, 'define_name'):
        dir_name = module.define_name(dir_name, args)
    else:
        raise NotImplementedError(module)
    return dir_name
Exemple #3
0
def _define_decoder_name(dir_name, args):
    if args.dec_type in ['transformer', 'transformer_xl']:
        from neural_sp.models.seq2seq.decoders.transformer import TransformerDecoder as module
    elif args.dec_type in [
            'transformer_transducer', 'transformer_transducer_xl'
    ]:
        from neural_sp.models.seq2seq.decoders.transformer_transducer import TransformerTransducer as module
    elif args.dec_type in ['lstm_transducer', 'gru_transducer']:
        from neural_sp.models.seq2seq.decoders.rnn_transducer import RNNTransducer as module
    elif args.dec_type == 'asg':
        from neural_sp.models.seq2seq.decoders.asg import ASGDecoder as module
    else:
        from neural_sp.models.seq2seq.decoders.las import RNNDecoder as module
    if hasattr(module, 'define_name'):
        dir_name = module.define_name(dir_name, args)
    else:
        raise NotImplementedError(module)
    return dir_name