def forward(self, h, x_seq_lens, y=None, y_seq_lens=None): batch_size = h.size(0) sos = int2onehot(h.new_full((batch_size, 1), self.sos), num_classes=self.label_vec_size).float() eos = int2onehot(h.new_full((batch_size, 1), self.eos), num_classes=self.label_vec_size).float() hidden = None y_hats = list() attentions = list() in_mask = self.get_mask(h, x_seq_lens) if self.masked_attend else None x = torch.cat([sos, h.narrow(1, 0, 1)], dim=-1) y_hats_seq_lens = torch.ones( (batch_size, ), dtype=torch.int) * self.max_seq_lens bi = torch.zeros(( self.num_eos, batch_size, )).byte() if x.is_cuda: bi = bi.cuda() for t in range(self.max_seq_lens): s, hidden = self.rnns(x, hidden) s = self.norm(s) c, a = self.attention(s, h, in_mask) y_hat = self.chardist(torch.cat([s, c], dim=-1)) y_hat = self.softmax(y_hat) y_hats.append(y_hat) attentions.append(a) # check 3 conjecutive eos occurrences bi[t % self.num_eos] = onehot2int(y_hat.squeeze()).eq(self.eos) ri = y_hats_seq_lens.gt(t) if bi.is_cuda: ri = ri.cuda() y_hats_seq_lens[bi.prod(dim=0, dtype=torch.uint8) * ri] = t + 1 # early termination if y_hats_seq_lens.le(t + 1).all(): break if y is None or not self._is_sample_step(): # non sampling step x = torch.cat([y_hat, c], dim=-1) elif t < y.size(1): # scheduled sampling step x = torch.cat([y.narrow(1, t, 1), c], dim=-1) else: x = torch.cat([eos, c], dim=-1) y_hats = torch.cat(y_hats, dim=1) attentions = torch.cat(attentions, dim=2) return y_hats, y_hats_seq_lens, attentions
def print_result(self, filename, ys_hat, words): logger.info(f"decoding wav file: {str(Path(filename).resolve())}") if self.verbose: labels = onehot2int(ys_hat) logger.info(f"labels: {' '.join([str(x) for x in labels.tolist()])}") symbols = [self.decoder.labeler.idx2phone(x.item()) for x in labels] logger.info(f"symbols: {' '.join(symbols)}") words = words.squeeze() text = ' '.join([self.decoder.labeler.idx2word(i) for i in words]) \ if words.dim() else '<null output from decoder>' logger.info(f"decoded text: {text}")
def forward(self, h, y=None): batch_size = h.size(0) sos = int2onehot(h.new_full((batch_size, 1), self.sos), num_classes=self.label_vec_size).float() x = torch.cat([sos, h.narrow(1, 0, 1)], dim=-1) hidden = [None] * self.rnn_num_layers y_hats = list() attentions = list() max_seq_len = self.max_seq_len if y is None else y.size(1) unit_len = torch.ones((batch_size, )) for i in range(max_seq_len): for l, rnn in enumerate(self.rnns): x, hidden[l] = rnn(x, unit_len, hidden[l]) c, a = self.attention(x, h) y_hat = self.chardist(torch.cat([x, c], dim=-1)) y_hats.append(y_hat) attentions.append(a) # if eos occurs in all batch, stop iteration if not onehot2int(y_hat.squeeze()).ne(self.eos).nonzero().numel(): break if y is None: x = torch.cat([y_hat, c], dim=-1) else: # teach force x = torch.cat([y.narrow(1, i, 1), c], dim=-1) y_hats = torch.cat(y_hats, dim=1) attentions = torch.cat(attentions, dim=2) seq_lens = torch.full((batch_size, ), max_seq_len, dtype=torch.int) for b, y_hat in enumerate(y_hats): idx = onehot2int(y_hat).eq(self.eos).nonzero() if idx.numel(): seq_lens[b] = idx[0][0] return y_hats, seq_lens, attentions
def unit_validate(self, data): xs, ys, frame_lens, label_lens, filenames, _ = data if self.use_cuda: xs = xs.cuda(non_blocking=True) ys_hat, frame_lens = self.model(xs, frame_lens) if self.fp16: ys_hat = ys_hat.float() # convert likes to ctc labels hyps = [onehot2int(yh[:s]).squeeze() for yh, s in zip(ys_hat, frame_lens)] hyps = [remove_duplicates(h, blank=0) for h in hyps] # slice the targets pos = torch.cat((torch.zeros((1, ), dtype=torch.long), torch.cumsum(label_lens, dim=0))) refs = [ys[s:l] for s, l in zip(pos[:-1], pos[1:])] return hyps, refs
def unit_validate(self, data): xs, ys, frame_lens, label_lens, filenames, _ = data if self.use_cuda: xs = xs.cuda() ys_hat = self.model(xs) pos = torch.cat((torch.zeros( (1, ), dtype=torch.long), torch.cumsum(frame_lens, dim=0))) ys_hat = [ ys_hat.narrow(0, p, l).clone() for p, l in zip(pos[:-1], frame_lens) ] # convert likes to phn labels hyps = [ onehot2int(yh[:s]).squeeze() for yh, s in zip(ys_hat, frame_lens) ] # slice the targets pos = torch.cat((torch.zeros( (1, ), dtype=torch.long), torch.cumsum(label_lens, dim=0))) refs = [ys[s:l] for s, l in zip(pos[:-1], pos[1:])] return hyps, refs