示例#1
0
 def search(self, model: SpeechSeq2seq, queue: Queue, device: str,
            print_every: int) -> float:
     if isinstance(model, nn.DataParallel):
         topk_decoder = SpeechTopKDecoder(model.module.decoder, self.k)
         model.module.set_decoder(topk_decoder)
     else:
         topk_decoder = SpeechTopKDecoder(model.decoder, self.k)
         model.set_decoder(topk_decoder)
     return super(BeamSearch, self).search(model, queue, device,
                                           print_every)
示例#2
0
def build_seq2seq(input_size, opt, device):
    """ Various Listen, Attend and Spell dispatcher function. """
    encoder = build_seq2seq_encoder(input_size=input_size,
                                    hidden_dim=opt.hidden_dim,
                                    dropout_p=opt.dropout,
                                    num_layers=opt.num_encoder_layers,
                                    bidirectional=opt.use_bidirectional,
                                    extractor=opt.extractor,
                                    activation=opt.activation,
                                    rnn_type=opt.rnn_type,
                                    device=device,
                                    mask_conv=opt.mask_conv)
    decoder = build_seq2seq_decoder(num_classes=len(char2id),
                                    max_len=opt.max_len,
                                    sos_id=SOS_token,
                                    eos_id=EOS_token,
                                    hidden_dim=opt.hidden_dim <<
                                    (1 if opt.use_bidirectional else 0),
                                    num_layers=opt.num_decoder_layers,
                                    rnn_type=opt.rnn_type,
                                    dropout_p=opt.dropout,
                                    num_heads=opt.num_heads,
                                    attn_mechanism=opt.attn_mechanism,
                                    device=device)

    model = SpeechSeq2seq(encoder, decoder)
    model.flatten_parameters()
    model = nn.DataParallel(model).to(device)

    return model
示例#3
0
from kospeech.models.acoustic.seq2seq.encoder import SpeechEncoderRNN
from kospeech.models.acoustic.seq2seq.decoder import SpeechDecoderRNN
from kospeech.models.acoustic.seq2seq.seq2seq import SpeechSeq2seq

encoder = SpeechEncoderRNN(80, 512, 'cpu')
decoder = SpeechDecoderRNN(2038, 151, 1024, 1, 2)
model = SpeechSeq2seq(encoder, decoder)

print(model)