Пример #1
0
 def forward(self, state, x):
     if state is None:
         state = {
             'c1': to_cuda(self, self.zero_state(x.size(0))),
             'h1': to_cuda(self, self.zero_state(x.size(0))),
             'c2': to_cuda(self, self.zero_state(x.size(0))),
             'h2': to_cuda(self, self.zero_state(x.size(0)))
         }
     h0 = self.embed(x)
     h1, c1 = self.l1(self.d0(h0), (state['h1'], state['c1']))
     h2, c2 = self.l2(self.d1(h1), (state['h2'], state['c2']))
     y = self.lo(self.d2(h2))
     state = {'c1': c1, 'h1': h1, 'c2': c2, 'h2': h2}
     return state, y
Пример #2
0
    def recognize(self, x, recog_args, char_list):
        '''E2E greedy/beam search

        :param x:
        :param recog_args:
        :param char_list:
        :return:
        '''
        prev = self.training
        self.eval()
        # subsample frame
        x = x[::self.subsample[0], :]
        xlen = [x.shape[0]]
        xpad = base.to_cuda(
            self,
            Variable(torch.from_numpy(np.array(x, dtype=np.float32)),
                     volatile=True))

        # 1. encoder
        # make a utt list (1) to use the same interface for encoder
        h, hlen = self.enc(xpad.unsqueeze(0), xlen)
        # h, hlen = self.forward_common(h, hlen)
        lpz = None

        # 2. decoder
        # decode the first utterance
        if recog_args.beam_size == 1:
            y = self.dec.recognize(h[0], recog_args)
        else:
            y = self.dec.recognize_beam(h[0], lpz, recog_args, char_list)

        if prev:
            self.train()
        return y
Пример #3
0
    def forward(self, xs, ilens, ys, labels, olens=None, spembs=None):
        """TACOTRON2 LOSS FORWARD CALCULATION

        :param torch.Tensor xs: batch of padded character ids (B, Tmax)
        :param list ilens: list of lengths of each input batch (B)
        :param torch.Tensor ys: batch of padded target features (B, Lmax, odim)
        :param torch.Tensor labels: batch of the sequences of stop token labels (B, Lmax)
        :param list olens: batch of the lengths of each target (B)
        :param torch.Tensor spembs: batch of speaker embedding vector (B, spk_embed_dim)
        :return: loss value
        :rtype: torch.Tensor
        """
        after_outs, before_outs, logits = self.model(xs, ilens, ys, spembs)
        if self.use_masking and olens is not None:
            # weight positive samples
            if self.bce_pos_weight != 1.0:
                # TODO(kan-bayashi): need to be fixed in pytorch v4
                weights = ys.data.new(*labels.size()).fill_(1)
                if torch_is_old:
                    weights = Variable(weights, volatile=ys.volatile)
                weights.masked_fill_(labels.eq(1), self.bce_pos_weight)
            else:
                weights = None
            # masking padded values
            mask = to_cuda(self, make_mask(olens, ys.size(2)))
            ys = ys.masked_select(mask)
            after_outs = after_outs.masked_select(mask)
            before_outs = before_outs.masked_select(mask)
            labels = labels.masked_select(mask[:, :, 0])
            logits = logits.masked_select(mask[:, :, 0])
            weights = weights.masked_select(
                mask[:, :, 0]) if weights is not None else None
            # calculate loss
            l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys)
            mse_loss = F.mse_loss(after_outs, ys) + F.mse_loss(before_outs, ys)
            bce_loss = F.binary_cross_entropy_with_logits(
                logits, labels, weights)
            loss = l1_loss + mse_loss + bce_loss
        else:
            # calculate loss
            l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys)
            mse_loss = F.mse_loss(after_outs, ys) + F.mse_loss(before_outs, ys)
            bce_loss = F.binary_cross_entropy_with_logits(logits, labels)
            loss = l1_loss + mse_loss + bce_loss

        # report loss values for logging
        loss_data = loss.data[0] if torch_is_old else loss.item()
        l1_loss_data = l1_loss.data[0] if torch_is_old else l1_loss.item()
        bce_loss_data = bce_loss.data[0] if torch_is_old else bce_loss.item()
        mse_loss_data = mse_loss.data[0] if torch_is_old else mse_loss.item()
        logging.debug("loss = %.3e (bce: %.3e, l1: %.3e, mse: %.3e)" %
                      (loss_data, bce_loss_data, l1_loss_data, mse_loss_data))
        self.reporter.report(l1_loss_data, mse_loss_data, bce_loss_data,
                             loss_data)

        return loss
Пример #4
0
 def sort_variables(self, xs, sorted_index):
     xs = [xs[i] for i in sorted_index]
     xs = [base.to_cuda(self, Variable(torch.from_numpy(xx))) for xx in xs]
     xlens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
     return xs, xlens