def forward(self, state, x):
        # update state with input label x
        if state is None:  # make initial states and cumlative probability vector
            self.var_word_eos = to_cuda(self, self.var_word_eos)
            self.var_word_unk = to_cuda(self, self.var_word_eos)
            wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
            cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
            new_node = self.lexroot
            xi = self.space
        else:
            wlm_state, cumsum_probs, node = state
            xi = int(x)
            if xi == self.space:  # inter-word transition
                if node is not None and node[1] >= 0:  # check if the node is word end
                    w = to_cuda(self, torch.LongTensor([node[1]]))
                else:  # this node is not a word end, which means <unk>
                    w = self.var_word_unk
                # update wordlm state and cumlative probability vector
                wlm_state, z_wlm = self.wordlm(wlm_state, w)
                cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
                new_node = self.lexroot  # move to the tree root
            elif node is not None and xi in node[0]:  # intra-word transition
                new_node = node[0][xi]
            elif self.open_vocab:  # if no path in the tree, enter open-vocabulary mode
                new_node = None
            else:  # if open_vocab flag is disabled, return 0 probabilities
                log_y = torch.full((1, self.subword_dict_size), self.logzero)
                return (wlm_state, None, None), log_y

        if new_node is not None:
            succ, wid, wids = new_node
            # compute parent node probability
            sum_prob = (cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]]) if wids is not None else 1.0
            if sum_prob < self.zero:
                log_y = torch.full((1, self.subword_dict_size), self.logzero)
                return (wlm_state, cumsum_probs, new_node), log_y
            # set <unk> probability as a default value
            unk_prob = cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1]
            y = torch.full((1, self.subword_dict_size), float(unk_prob) * self.oov_penalty)
            # compute transition probabilities to child nodes
            for cid, nd in succ.items():
                y[:, cid] = (cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]]) / sum_prob
            # apply word-level probabilies for <space> and <eos> labels
            if wid >= 0:
                wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob
                y[:, self.space] = wlm_prob
                y[:, self.eos] = wlm_prob
            elif xi == self.space:
                y[:, self.space] = self.zero
                y[:, self.eos] = self.zero
            log_y = torch.log(torch.max(y, self.zero_tensor))  # clip to avoid log(0)
        else:  # if no path in the tree, transition probability is one
            log_y = torch.zeros(1, self.subword_dict_size)
        return (wlm_state, cumsum_probs, new_node), log_y
    def forward(self, state, x):
        # update state with input label x
        if state is None:  # make initial states and log-prob vectors
            self.var_word_eos = to_cuda(self, self.var_word_eos)
            self.var_word_unk = to_cuda(self, self.var_word_eos)
            wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
            wlm_logprobs = F.log_softmax(z_wlm, dim=1)
            clm_state, z_clm = self.subwordlm(None, x)
            log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight
            new_node = self.lexroot
            clm_logprob = 0.
            xi = self.space
        else:
            clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
            xi = int(x)
            if xi == self.space:  # inter-word transition
                if node is not None and node[1] >= 0:  # check if the node is word end
                    w = to_cuda(self, torch.LongTensor([node[1]]))
                else:  # this node is not a word end, which means <unk>
                    w = self.var_word_unk
                # update wordlm state and log-prob vector
                wlm_state, z_wlm = self.wordlm(wlm_state, w)
                wlm_logprobs = F.log_softmax(z_wlm, dim=1)
                new_node = self.lexroot  # move to the tree root
                clm_logprob = 0.
            elif node is not None and xi in node[0]:  # intra-word transition
                new_node = node[0][xi]
                clm_logprob += log_y[0, xi]
            elif self.open_vocab:  # if no path in the tree, enter open-vocabulary mode
                new_node = None
                clm_logprob += log_y[0, xi]
            else:  # if open_vocab flag is disabled, return 0 probabilities
                log_y = torch.full((1, self.subword_dict_size), self.logzero)
                return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.), log_y

            clm_state, z_clm = self.subwordlm(clm_state, x)
            log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight

        # apply word-level probabilies for <space> and <eos> labels
        if xi != self.space:
            if new_node is not None and new_node[1] >= 0:  # if new node is word end
                wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob
            else:
                wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty
            log_y[:, self.space] = wlm_logprob
            log_y[:, self.eos] = wlm_logprob
        else:
            log_y[:, self.space] = self.logzero
            log_y[:, self.eos] = self.logzero

        return (clm_state, wlm_state, wlm_logprobs, new_node, log_y, float(clm_logprob)), log_y
Exemple #3
0
 def forward(self, state, x):
     if state is None:
         state = {
             'c1': to_cuda(self, self.zero_state(x.size(0))),
             'h1': to_cuda(self, self.zero_state(x.size(0))),
             'c2': to_cuda(self, self.zero_state(x.size(0))),
             'h2': to_cuda(self, self.zero_state(x.size(0)))
         }
     h0 = self.embed(x)
     h1, c1 = self.l1(self.d0(h0), (state['h1'], state['c1']))
     h2, c2 = self.l2(self.d1(h1), (state['h2'], state['c2']))
     y = self.lo(self.d2(h2))
     state = {'c1': c1, 'h1': h1, 'c2': c2, 'h2': h2}
     return state, y
Exemple #4
0
def get_asr_data(self, data, sort_by):
    # utt list of frame x dim
    xs = [d[1]['feat_asr'] for d in data]
    # remove 0-output-length utterances
    tids = [d[1]['output'][0]['tokenid'].split() for d in data]
    filtered_index = filter(lambda i: len(tids[i]) > 0, range(len(xs)))
    if sort_by == 'feat':
        sorted_index = sorted(filtered_index, key=lambda i: -len(xs[i]))
    elif sort_by == 'text':
        sorted_index = sorted(filtered_index, key=lambda i: -len(tids[i]))
    else:
        logging.error("Error: specify 'text' or 'feat' to sort")
        sys.exit()
    if len(sorted_index) != len(xs):
        logging.warning(
            'Target sequences include empty tokenid (batch %d -> %d).' %
            (len(xs), len(sorted_index)))
    xs = [xs[i] for i in sorted_index]
    # utt list of olen
    texts = [
        np.fromiter(map(int, tids[i]), dtype=np.int64) for i in sorted_index
    ]
    if torch_is_old:
        texts = [
            to_cuda(self,
                    Variable(torch.from_numpy(y), volatile=not self.training))
            for y in texts
        ]
    else:
        texts = [to_cuda(self, torch.from_numpy(y)) for y in texts]

    # subsample frame
    xs = [xx[::self.subsample[0], :] for xx in xs]
    featlens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
    if torch_is_old:
        hs = [
            to_cuda(self,
                    Variable(torch.from_numpy(xx), volatile=not self.training))
            for xx in xs
        ]
    else:
        hs = [to_cuda(self, torch.from_numpy(xx)) for xx in xs]

    # 1. encoder
    feats = pad_list(hs, 0.0)

    return texts, feats, featlens
Exemple #5
0
    def forward(self, state, x):
        if state is None:
            c = [
                to_cuda(self, self.zero_state(x.size(0)))
                for n in six.moves.range(self.n_layers)
            ]
            h = [
                to_cuda(self, self.zero_state(x.size(0)))
                for n in six.moves.range(self.n_layers)
            ]
            state = {'c': c, 'h': h}

        h = [None] * self.n_layers
        c = [None] * self.n_layers
        emb = self.embed(x)
        h[0], c[0] = self.lstm[0](self.dropout[0](emb),
                                  (state['h'][0], state['c'][0]))
        for n in six.moves.range(1, self.n_layers):
            h[n], c[n] = self.lstm[n](self.dropout[n](h[n - 1]),
                                      (state['h'][n], state['c'][n]))
        y = self.lo(self.dropout[-1](h[-1]))
        state = {'c': c, 'h': h}
        return state, y
Exemple #6
0
    def forward(self, xs, ilens, ys, labels, olens=None, spembs=None):
        """TACOTRON2 LOSS FORWARD CALCULATION

        :param torch.Tensor xs: batch of padded character ids (B, Tmax)
        :param list ilens: list of lengths of each input batch (B)
        :param torch.Tensor ys: batch of padded target features (B, Lmax, odim)
        :param torch.Tensor labels: batch of the sequences of stop token labels (B, Lmax)
        :param list olens: batch of the lengths of each target (B)
        :param torch.Tensor spembs: batch of speaker embedding vector (B, spk_embed_dim)
        :return: loss value
        :rtype: torch.Tensor
        """
        after_outs, before_outs, logits = self.model(xs, ilens, ys, spembs)
        if self.use_masking and olens is not None:
            # weight positive samples
            if self.bce_pos_weight != 1.0:
                weights = ys.new(*labels.size()).fill_(1)
                weights.masked_fill_(labels.eq(1), self.bce_pos_weight)
            else:
                weights = None
            # masking padded values
            mask = to_cuda(self, make_mask(olens, ys.size(2)))
            ys = ys.masked_select(mask)
            after_outs = after_outs.masked_select(mask)
            before_outs = before_outs.masked_select(mask)
            labels = labels.masked_select(mask[:, :, 0])
            logits = logits.masked_select(mask[:, :, 0])
            weights = weights.masked_select(
                mask[:, :, 0]) if weights is not None else None
            # calculate loss
            l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys)
            mse_loss = F.mse_loss(after_outs, ys) + F.mse_loss(before_outs, ys)
            bce_loss = F.binary_cross_entropy_with_logits(
                logits, labels, weights)
            loss = l1_loss + mse_loss + bce_loss
        else:
            # calculate loss
            l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys)
            mse_loss = F.mse_loss(after_outs, ys) + F.mse_loss(before_outs, ys)
            bce_loss = F.binary_cross_entropy_with_logits(logits, labels)
            loss = l1_loss + mse_loss + bce_loss

        # report loss values for logging
        logging.debug(
            "loss = %.3e (bce: %.3e, l1: %.3e, mse: %.3e)" %
            (loss.item(), bce_loss.item(), l1_loss.item(), mse_loss.item()))
        self.reporter.report(l1_loss.item(), mse_loss.item(), bce_loss.item(),
                             loss.item())

        return loss
Exemple #7
0
    def forward(self, xs, ilens, ys, labels, olens, spembs=None, spcs=None):
        """TACOTRON2 LOSS FORWARD CALCULATION

        :param torch.Tensor xs: batch of padded character ids (B, Tmax)
        :param list ilens: list of lengths of each input batch (B)
        :param torch.Tensor ys: batch of padded target features (B, Lmax, odim)
        :param torch.Tensor labels: batch of the sequences of stop token labels (B, Lmax)
        :param list olens: batch of the lengths of each target (B)
        :param torch.Tensor spembs: batch of speaker embedding vector (B, spk_embed_dim)
        :param torch.Tensor spcs: batch of padded target features (B, Lmax, spc_dim)
        :return: loss value
        :rtype: torch.Tensor
        """
        # calcuate outputs
        if self.use_cbhg:
            cbhg_outs, after_outs, before_outs, logits = self.model(xs, ilens, ys, olens, spembs)
        else:
            after_outs, before_outs, logits = self.model(xs, ilens, ys, olens, spembs)

        # remove mod part
        if self.reduction_factor > 1:
            olens = [olen - olen % self.reduction_factor for olen in olens]
            ys = ys[:, :max(olens)]
            labels = labels[:, :max(olens)]
            spcs = spcs[:, :max(olens)] if spcs is not None else None

        # prepare weight of positive samples in cross entorpy
        if self.bce_pos_weight != 1.0:
            weights = ys.new(*labels.size()).fill_(1)
            weights.masked_fill_(labels.eq(1), self.bce_pos_weight)
        else:
            weights = None

        # perform masking for padded values
        if self.use_masking:
            mask = to_cuda(self, make_non_pad_mask(olens).unsqueeze(-1))
            ys = ys.masked_select(mask)
            after_outs = after_outs.masked_select(mask)
            before_outs = before_outs.masked_select(mask)
            labels = labels.masked_select(mask[:, :, 0])
            logits = logits.masked_select(mask[:, :, 0])
            weights = weights.masked_select(mask[:, :, 0]) if weights is not None else None
            if self.use_cbhg:
                spcs = spcs.masked_select(mask)
                cbhg_outs = cbhg_outs.masked_select(mask)

        # calculate loss
        l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys)
        mse_loss = F.mse_loss(after_outs, ys) + F.mse_loss(before_outs, ys)
        bce_loss = F.binary_cross_entropy_with_logits(logits, labels, weights)
        if self.use_cbhg:
            # calculate chbg loss and then itegrate them
            cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs)
            cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs)
            loss = l1_loss + mse_loss + bce_loss + cbhg_l1_loss + cbhg_mse_loss
            # report loss values for logging
            self.reporter.report([
                {'l1_loss': l1_loss.item()},
                {'mse_loss': mse_loss.item()},
                {'bce_loss': bce_loss.item()},
                {'cbhg_l1_loss': cbhg_l1_loss.item()},
                {'cbhg_mse_loss': cbhg_mse_loss.item()},
                {'loss': loss.item()}])
        else:
            # integrate loss
            loss = l1_loss + mse_loss + bce_loss
            # report loss values for logging
            self.reporter.report([
                {'l1_loss': l1_loss.item()},
                {'mse_loss': mse_loss.item()},
                {'bce_loss': bce_loss.item()},
                {'loss': loss.item()}])

        return loss
Exemple #8
0
def get_tts_data(self, data, sort_by, use_speaker_embedding=None):
    # get eos
    eos = str(int(data[0][1]['output'][0]['shape'][1]) - 1)

    # get target features and input character sequence
    texts = [b[1]['output'][0]['tokenid'].split() + [eos] for b in data]
    feats = [b[1]['feat_tts'] for b in data]

    # remove empty sequence and get sort along with length
    filtered_idx = filter(lambda i: len(texts[i]) > 0, range(len(feats)))
    if sort_by == 'feat':
        sorted_idx = sorted(filtered_idx, key=lambda i: -len(feats[i]))
    elif sort_by == 'text':
        sorted_idx = sorted(filtered_idx, key=lambda i: -len(texts[i]))
    else:
        logging.error("Error: specify 'text' or 'feat' to sort")
        sys.exit()
    texts = [
        np.fromiter(map(int, texts[i]), dtype=np.int64) for i in sorted_idx
    ]
    feats = [feats[i] for i in sorted_idx]

    # get list of lengths (must be tensor for DataParallel)
    textlens = torch.from_numpy(
        np.fromiter((x.shape[0] for x in texts), dtype=np.int64))
    featlens = torch.from_numpy(
        np.fromiter((y.shape[0] for y in feats), dtype=np.int64))

    # perform padding and convert to tensor
    texts = torch.from_numpy(pad_ndarray_list(texts, 0)).long()
    feats = torch.from_numpy(pad_ndarray_list(feats, 0)).float()

    # make labels for stop prediction
    labels = feats.new(feats.size(0), feats.size(1)).zero_()
    for i, l in enumerate(featlens):
        labels[i, l - 1:] = 1

    if torch_is_old:
        texts = to_cuda(self, texts, volatile=not self.training)
        feats = to_cuda(self, feats, volatile=not self.training)
        labels = to_cuda(self, labels, volatile=not self.training)
    else:
        texts = to_cuda(self, texts)
        feats = to_cuda(self, feats)
        labels = to_cuda(self, labels)

    # load speaker embedding
    if use_speaker_embedding is not None:
        spembs = [b[1]['feat_spembs'] for b in data]
        spembs = [spembs[i] for i in sorted_idx]
        spembs = torch.from_numpy(np.array(spembs)).float()

        if torch_is_old:
            spembs = to_cuda(self, spembs, volatile=not self.training)
        else:
            spembs = to_cuda(self, spembs)
    else:
        spembs = None

    if self.return_targets:
        return texts, textlens, feats, labels, featlens, spembs
    else:
        return texts, textlens, feats, spembs
Exemple #9
0
    def forward(self, data, return_hidden=False, return_inout=False):
        asr_texts, asr_feats, asr_featlens = get_asr_data(self, data, 'feat')
        tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs = \
            get_tts_data(self, data, 'feat', self.use_speaker_embedding)

        # encoder
        hpad_pre_spk, hlens, feat_input, feat_len = self.asr_enc(
            asr_feats, asr_featlens, True)
        if self.use_speaker_embedding is not None:
            spembs = F.normalize(spembs).unsqueeze(1).expand(
                -1, hpad_pre_spk.size(1), -1)
            hpad = torch.cat([hpad_pre_spk, spembs], dim=-1)
        else:
            hpad = hpad_pre_spk

        after_outs, before_outs_, logits, att_ws = self.tts_dec(
            hpad, hlens.tolist(), tts_feats)
        # copied from e2e_tts_th.py
        if self.use_masking and tts_featlens is not None:
            # weight positive samples
            if self.bce_pos_weight != 1.0:
                # TODO(kan-bayashi): need to be fixed in pytorch v4
                weights = tts_feats.data.new(*tts_labels.size()).fill_(1)
                if torch_is_old:
                    weights = Variable(weights, volatile=tts_feats.volatile)
                weights.masked_fill_(tts_labels.eq(1), self.bce_pos_weight)
            else:
                weights = None
            # masking padded values
            mask = to_cuda(self, make_mask(tts_featlens, tts_feats.size(2)))
            feats = tts_feats.masked_select(mask)
            after_outs = after_outs.masked_select(mask)
            before_outs = before_outs_.masked_select(mask)
            labels = tts_labels.masked_select(mask[:, :, 0])
            logits = logits.masked_select(mask[:, :, 0])
            weights = weights.masked_select(
                mask[:, :, 0]) if weights is not None else None
            # calculate loss
            l1_loss = F.l1_loss(after_outs, feats) + F.l1_loss(
                before_outs, feats)
            mse_loss = F.mse_loss(after_outs, feats) + F.mse_loss(
                before_outs, feats)
            bce_loss = F.binary_cross_entropy_with_logits(
                logits, labels, weights)
            loss = l1_loss + mse_loss + bce_loss
        else:
            # calculate loss
            l1_loss = F.l1_loss(after_outs, tts_feats) + F.l1_loss(
                before_outs, tts_feats)
            mse_loss = F.mse_loss(after_outs, tts_feats) + F.mse_loss(
                before_outs, tts_feats)
            bce_loss = F.binary_cross_entropy_with_logits(logits, tts_labels)
            loss = l1_loss + mse_loss + bce_loss

        # report loss values for logging
        loss_data = loss.data[0] if torch_is_old else loss.item()
        l1_loss_data = l1_loss.data[0] if torch_is_old else l1_loss.item()
        bce_loss_data = bce_loss.data[0] if torch_is_old else bce_loss.item()
        mse_loss_data = mse_loss.data[0] if torch_is_old else mse_loss.item()
        logging.debug("loss = %.3e (bce: %.3e, l1: %.3e, mse: %.3e)" %
                      (loss_data, bce_loss_data, l1_loss_data, mse_loss_data))
        if return_inout:
            return loss, feat_input, feat_len, before_outs_, tts_featlens
        if return_hidden:
            return loss, hpad_pre_spk, hlens
        return loss