Ejemplo n.º 1
0
    def forward(self, h, x_seq_lens, y=None, y_seq_lens=None):
        batch_size = h.size(0)
        sos = int2onehot(h.new_full((batch_size, 1), self.sos),
                         num_classes=self.label_vec_size).float()
        eos = int2onehot(h.new_full((batch_size, 1), self.eos),
                         num_classes=self.label_vec_size).float()

        hidden = None
        y_hats = list()
        attentions = list()

        in_mask = self.get_mask(h, x_seq_lens) if self.masked_attend else None
        x = torch.cat([sos, h.narrow(1, 0, 1)], dim=-1)

        y_hats_seq_lens = torch.ones(
            (batch_size, ), dtype=torch.int) * self.max_seq_lens

        bi = torch.zeros((
            self.num_eos,
            batch_size,
        )).byte()
        if x.is_cuda:
            bi = bi.cuda()

        for t in range(self.max_seq_lens):
            s, hidden = self.rnns(x, hidden)
            s = self.norm(s)
            c, a = self.attention(s, h, in_mask)
            y_hat = self.chardist(torch.cat([s, c], dim=-1))
            y_hat = self.softmax(y_hat)

            y_hats.append(y_hat)
            attentions.append(a)

            # check 3 conjecutive eos occurrences
            bi[t % self.num_eos] = onehot2int(y_hat.squeeze()).eq(self.eos)
            ri = y_hats_seq_lens.gt(t)
            if bi.is_cuda:
                ri = ri.cuda()
            y_hats_seq_lens[bi.prod(dim=0, dtype=torch.uint8) * ri] = t + 1

            # early termination
            if y_hats_seq_lens.le(t + 1).all():
                break

            if y is None or not self._is_sample_step():  # non sampling step
                x = torch.cat([y_hat, c], dim=-1)
            elif t < y.size(1):  # scheduled sampling step
                x = torch.cat([y.narrow(1, t, 1), c], dim=-1)
            else:
                x = torch.cat([eos, c], dim=-1)

        y_hats = torch.cat(y_hats, dim=1)
        attentions = torch.cat(attentions, dim=2)

        return y_hats, y_hats_seq_lens, attentions
Ejemplo n.º 2
0
 def print_result(self, filename, ys_hat, words):
     logger.info(f"decoding wav file: {str(Path(filename).resolve())}")
     if self.verbose:
         labels = onehot2int(ys_hat)
         logger.info(f"labels: {' '.join([str(x) for x in labels.tolist()])}")
         symbols = [self.decoder.labeler.idx2phone(x.item()) for x in labels]
         logger.info(f"symbols: {' '.join(symbols)}")
     words = words.squeeze()
     text = ' '.join([self.decoder.labeler.idx2word(i) for i in words]) \
            if words.dim() else '<null output from decoder>'
     logger.info(f"decoded text: {text}")
Ejemplo n.º 3
0
    def forward(self, h, y=None):
        batch_size = h.size(0)
        sos = int2onehot(h.new_full((batch_size, 1), self.sos),
                         num_classes=self.label_vec_size).float()
        x = torch.cat([sos, h.narrow(1, 0, 1)], dim=-1)

        hidden = [None] * self.rnn_num_layers
        y_hats = list()
        attentions = list()
        max_seq_len = self.max_seq_len if y is None else y.size(1)
        unit_len = torch.ones((batch_size, ))

        for i in range(max_seq_len):
            for l, rnn in enumerate(self.rnns):
                x, hidden[l] = rnn(x, unit_len, hidden[l])
            c, a = self.attention(x, h)
            y_hat = self.chardist(torch.cat([x, c], dim=-1))

            y_hats.append(y_hat)
            attentions.append(a)

            # if eos occurs in all batch, stop iteration
            if not onehot2int(y_hat.squeeze()).ne(self.eos).nonzero().numel():
                break

            if y is None:
                x = torch.cat([y_hat, c], dim=-1)
            else:  # teach force
                x = torch.cat([y.narrow(1, i, 1), c], dim=-1)

        y_hats = torch.cat(y_hats, dim=1)
        attentions = torch.cat(attentions, dim=2)

        seq_lens = torch.full((batch_size, ), max_seq_len, dtype=torch.int)
        for b, y_hat in enumerate(y_hats):
            idx = onehot2int(y_hat).eq(self.eos).nonzero()
            if idx.numel():
                seq_lens[b] = idx[0][0]

        return y_hats, seq_lens, attentions
Ejemplo n.º 4
0
 def unit_validate(self, data):
     xs, ys, frame_lens, label_lens, filenames, _ = data
     if self.use_cuda:
         xs = xs.cuda(non_blocking=True)
     ys_hat, frame_lens = self.model(xs, frame_lens)
     if self.fp16:
         ys_hat = ys_hat.float()
     # convert likes to ctc labels
     hyps = [onehot2int(yh[:s]).squeeze() for yh, s in zip(ys_hat, frame_lens)]
     hyps = [remove_duplicates(h, blank=0) for h in hyps]
     # slice the targets
     pos = torch.cat((torch.zeros((1, ), dtype=torch.long), torch.cumsum(label_lens, dim=0)))
     refs = [ys[s:l] for s, l in zip(pos[:-1], pos[1:])]
     return hyps, refs
Ejemplo n.º 5
0
 def unit_validate(self, data):
     xs, ys, frame_lens, label_lens, filenames, _ = data
     if self.use_cuda:
         xs = xs.cuda()
     ys_hat = self.model(xs)
     pos = torch.cat((torch.zeros(
         (1, ), dtype=torch.long), torch.cumsum(frame_lens, dim=0)))
     ys_hat = [
         ys_hat.narrow(0, p, l).clone()
         for p, l in zip(pos[:-1], frame_lens)
     ]
     # convert likes to phn labels
     hyps = [
         onehot2int(yh[:s]).squeeze() for yh, s in zip(ys_hat, frame_lens)
     ]
     # slice the targets
     pos = torch.cat((torch.zeros(
         (1, ), dtype=torch.long), torch.cumsum(label_lens, dim=0)))
     refs = [ys[s:l] for s, l in zip(pos[:-1], pos[1:])]
     return hyps, refs