Ejemplo n.º 1
0
def build_las(
    input_size: int,
    config: DictConfig,
    vocab: Vocabulary,
    device: torch.device,
) -> nn.DataParallel:
    model = ListenAttendSpell(
        input_dim=input_size,
        num_classes=len(vocab),
        encoder_hidden_state_dim=config.model.hidden_dim,
        decoder_hidden_state_dim=config.model.hidden_dim <<
        (1 if config.model.use_bidirectional else 0),
        num_encoder_layers=config.model.num_encoder_layers,
        num_decoder_layers=config.model.num_decoder_layers,
        bidirectional=config.model.use_bidirectional,
        extractor=config.model.extractor,
        activation=config.model.activation,
        rnn_type=config.model.rnn_type,
        max_length=config.model.max_len,
        pad_id=vocab.pad_id,
        sos_id=vocab.sos_id,
        eos_id=vocab.eos_id,
        attn_mechanism=config.model.attn_mechanism,
        num_heads=config.model.num_heads,
        encoder_dropout_p=config.model.dropout,
        decoder_dropout_p=config.model.dropout,
        joint_ctc_attention=config.model.joint_ctc_attention,
    )
    model.flatten_parameters()

    return nn.DataParallel(model).to(device)
def build_las(input_size, opt, vocab, device):
    """ Various Listen, Attend and Spell dispatcher function. """
    listenr = build_listener(input_size=input_size,
                             num_classes=len(vocab),
                             hidden_dim=opt.hidden_dim,
                             dropout_p=opt.dropout,
                             num_layers=opt.num_encoder_layers,
                             bidirectional=opt.use_bidirectional,
                             extractor=opt.extractor,
                             activation=opt.activation,
                             rnn_type=opt.rnn_type,
                             device=device,
                             mask_conv=opt.mask_conv,
                             joint_ctc_attention=opt.joint_ctc_attention)
    speller = build_speller(num_classes=len(vocab),
                            max_len=opt.max_len,
                            pad_id=vocab.pad_id,
                            sos_id=vocab.sos_id,
                            eos_id=vocab.eos_id,
                            hidden_dim=opt.hidden_dim <<
                            (1 if opt.use_bidirectional else 0),
                            num_layers=opt.num_decoder_layers,
                            rnn_type=opt.rnn_type,
                            dropout_p=opt.dropout,
                            num_heads=opt.num_heads,
                            attn_mechanism=opt.attn_mechanism,
                            device=device)

    model = ListenAttendSpell(listenr, speller)
    model.flatten_parameters()

    return nn.DataParallel(model).to(device)
Ejemplo n.º 3
0
 def search(self, model: ListenAttendSpell, queue: Queue, device: str, print_every: int) -> float:
     if isinstance(model, nn.DataParallel):
         topk_decoder = TopKDecoder(model.module.decoder, self.k)
         model.module.set_decoder(topk_decoder)
     else:
         topk_decoder = TopKDecoder(model.decoder, self.k)
         model.set_decoder(topk_decoder)
     return super(BeamSearch, self).search(model, queue, device, print_every)
Ejemplo n.º 4
0
def build_las(input_size: int, config: DictConfig, vocab: Vocabulary,
              device: torch.device) -> nn.DataParallel:
    """ Various Listen, Attend and Spell dispatcher function. """
    listenr = build_listener(
        input_size=input_size,
        num_classes=len(vocab),
        hidden_dim=config.model.hidden_dim,
        dropout_p=config.model.dropout,
        num_layers=config.model.num_encoder_layers,
        bidirectional=config.model.use_bidirectional,
        extractor=config.model.extractor,
        activation=config.model.activation,
        rnn_type=config.model.rnn_type,
        device=device,
        mask_conv=config.model.mask_conv,
        joint_ctc_attention=config.model.joint_ctc_attention,
    )
    speller = build_speller(
        num_classes=len(vocab),
        max_len=config.model.max_len,
        pad_id=vocab.pad_id,
        sos_id=vocab.sos_id,
        eos_id=vocab.eos_id,
        hidden_dim=config.model.hidden_dim <<
        (1 if config.model.use_bidirectional else 0),
        num_layers=config.model.num_decoder_layers,
        rnn_type=config.model.rnn_type,
        dropout_p=config.model.dropout,
        num_heads=config.model.num_heads,
        attn_mechanism=config.model.attn_mechanism,
        device=device,
    )

    model = ListenAttendSpell(listenr, speller)
    model.flatten_parameters()

    return nn.DataParallel(model).to(device)
Ejemplo n.º 5
0
# limitations under the License.

import torch
import torch.nn as nn

from kospeech.models import ListenAttendSpell

B, T, D, H = 3, 12345, 80, 32

cuda = torch.cuda.is_available()
device = torch.device('cuda' if cuda else 'cpu')

model = ListenAttendSpell(
    input_dim=D,
    num_classes=10,
    encoder_hidden_state_dim=H,
    decoder_hidden_state_dim=H << 1,
    bidirectional=True,
    max_length=10,
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-04)

for i in range(10):
    inputs = torch.rand(B, T, D).to(device)
    input_lengths = torch.IntTensor([12345, 12300, 12000])
    targets = torch.LongTensor([[1, 3, 3, 3, 3, 3, 4, 5, 6, 2],
                                [1, 3, 3, 3, 3, 3, 4, 5, 2, 0],
                                [1, 3, 3, 3, 3, 3, 4, 2, 0, 0]]).to(device)
    outputs, output_lengths, encoder_log_probs = model(
        inputs, input_lengths, targets, teacher_forcing_ratio=1.0)
Ejemplo n.º 6
0
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from kospeech.models import ListenAttendSpell
from kospeech.models.las.encoder import EncoderRNN
from kospeech.models.las.decoder import DecoderRNN

B, T, D, H = 3, 12345, 80, 32

cuda = torch.cuda.is_available()
device = torch.device('cuda' if cuda else 'cpu')

inputs = torch.rand(B, T, D).to(device)
input_lengths = torch.IntTensor([T, T - 100, T - 1000])
targets = torch.LongTensor([[1, 1, 2], [3, 4, 2], [7, 2, 0]])

model = ListenAttendSpell(
    input_dim=D,
    num_classes=10,
    encoder_hidden_state_dim=H,
    decoder_hidden_state_dim=H << 1,
    bidirectional=True,
    max_length=10,
).to(device)

outputs = model.recognize(inputs, input_lengths)
print(outputs.size())
print("LAS Recognize PASS")