def forward(self, state, x): # update state with input label x if state is None: # make initial states and cumlative probability vector self.var_word_eos = to_cuda(self, self.var_word_eos) self.var_word_unk = to_cuda(self, self.var_word_eos) wlm_state, z_wlm = self.wordlm(None, self.var_word_eos) cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1) new_node = self.lexroot xi = self.space else: wlm_state, cumsum_probs, node = state xi = int(x) if xi == self.space: # inter-word transition if node is not None and node[1] >= 0: # check if the node is word end w = to_cuda(self, torch.LongTensor([node[1]])) else: # this node is not a word end, which means <unk> w = self.var_word_unk # update wordlm state and cumlative probability vector wlm_state, z_wlm = self.wordlm(wlm_state, w) cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1) new_node = self.lexroot # move to the tree root elif node is not None and xi in node[0]: # intra-word transition new_node = node[0][xi] elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode new_node = None else: # if open_vocab flag is disabled, return 0 probabilities log_y = torch.full((1, self.subword_dict_size), self.logzero) return (wlm_state, None, None), log_y if new_node is not None: succ, wid, wids = new_node # compute parent node probability sum_prob = (cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]]) if wids is not None else 1.0 if sum_prob < self.zero: log_y = torch.full((1, self.subword_dict_size), self.logzero) return (wlm_state, cumsum_probs, new_node), log_y # set <unk> probability as a default value unk_prob = cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1] y = torch.full((1, self.subword_dict_size), float(unk_prob) * self.oov_penalty) # compute transition probabilities to child nodes for cid, nd in succ.items(): y[:, cid] = (cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]]) / sum_prob # apply word-level probabilies for <space> and <eos> labels if wid >= 0: wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob y[:, self.space] = wlm_prob y[:, self.eos] = wlm_prob elif xi == self.space: y[:, self.space] = self.zero y[:, self.eos] = self.zero log_y = torch.log(torch.max(y, self.zero_tensor)) # clip to avoid log(0) else: # if no path in the tree, transition probability is one log_y = torch.zeros(1, self.subword_dict_size) return (wlm_state, cumsum_probs, new_node), log_y
def forward(self, state, x): # update state with input label x if state is None: # make initial states and log-prob vectors self.var_word_eos = to_cuda(self, self.var_word_eos) self.var_word_unk = to_cuda(self, self.var_word_eos) wlm_state, z_wlm = self.wordlm(None, self.var_word_eos) wlm_logprobs = F.log_softmax(z_wlm, dim=1) clm_state, z_clm = self.subwordlm(None, x) log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight new_node = self.lexroot clm_logprob = 0. xi = self.space else: clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state xi = int(x) if xi == self.space: # inter-word transition if node is not None and node[1] >= 0: # check if the node is word end w = to_cuda(self, torch.LongTensor([node[1]])) else: # this node is not a word end, which means <unk> w = self.var_word_unk # update wordlm state and log-prob vector wlm_state, z_wlm = self.wordlm(wlm_state, w) wlm_logprobs = F.log_softmax(z_wlm, dim=1) new_node = self.lexroot # move to the tree root clm_logprob = 0. elif node is not None and xi in node[0]: # intra-word transition new_node = node[0][xi] clm_logprob += log_y[0, xi] elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode new_node = None clm_logprob += log_y[0, xi] else: # if open_vocab flag is disabled, return 0 probabilities log_y = torch.full((1, self.subword_dict_size), self.logzero) return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.), log_y clm_state, z_clm = self.subwordlm(clm_state, x) log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight # apply word-level probabilies for <space> and <eos> labels if xi != self.space: if new_node is not None and new_node[1] >= 0: # if new node is word end wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob else: wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty log_y[:, self.space] = wlm_logprob log_y[:, self.eos] = wlm_logprob else: log_y[:, self.space] = self.logzero log_y[:, self.eos] = self.logzero return (clm_state, wlm_state, wlm_logprobs, new_node, log_y, float(clm_logprob)), log_y
def forward(self, state, x): if state is None: state = { 'c1': to_cuda(self, self.zero_state(x.size(0))), 'h1': to_cuda(self, self.zero_state(x.size(0))), 'c2': to_cuda(self, self.zero_state(x.size(0))), 'h2': to_cuda(self, self.zero_state(x.size(0))) } h0 = self.embed(x) h1, c1 = self.l1(self.d0(h0), (state['h1'], state['c1'])) h2, c2 = self.l2(self.d1(h1), (state['h2'], state['c2'])) y = self.lo(self.d2(h2)) state = {'c1': c1, 'h1': h1, 'c2': c2, 'h2': h2} return state, y
def get_asr_data(self, data, sort_by): # utt list of frame x dim xs = [d[1]['feat_asr'] for d in data] # remove 0-output-length utterances tids = [d[1]['output'][0]['tokenid'].split() for d in data] filtered_index = filter(lambda i: len(tids[i]) > 0, range(len(xs))) if sort_by == 'feat': sorted_index = sorted(filtered_index, key=lambda i: -len(xs[i])) elif sort_by == 'text': sorted_index = sorted(filtered_index, key=lambda i: -len(tids[i])) else: logging.error("Error: specify 'text' or 'feat' to sort") sys.exit() if len(sorted_index) != len(xs): logging.warning( 'Target sequences include empty tokenid (batch %d -> %d).' % (len(xs), len(sorted_index))) xs = [xs[i] for i in sorted_index] # utt list of olen texts = [ np.fromiter(map(int, tids[i]), dtype=np.int64) for i in sorted_index ] if torch_is_old: texts = [ to_cuda(self, Variable(torch.from_numpy(y), volatile=not self.training)) for y in texts ] else: texts = [to_cuda(self, torch.from_numpy(y)) for y in texts] # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] featlens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) if torch_is_old: hs = [ to_cuda(self, Variable(torch.from_numpy(xx), volatile=not self.training)) for xx in xs ] else: hs = [to_cuda(self, torch.from_numpy(xx)) for xx in xs] # 1. encoder feats = pad_list(hs, 0.0) return texts, feats, featlens
def forward(self, state, x): if state is None: c = [ to_cuda(self, self.zero_state(x.size(0))) for n in six.moves.range(self.n_layers) ] h = [ to_cuda(self, self.zero_state(x.size(0))) for n in six.moves.range(self.n_layers) ] state = {'c': c, 'h': h} h = [None] * self.n_layers c = [None] * self.n_layers emb = self.embed(x) h[0], c[0] = self.lstm[0](self.dropout[0](emb), (state['h'][0], state['c'][0])) for n in six.moves.range(1, self.n_layers): h[n], c[n] = self.lstm[n](self.dropout[n](h[n - 1]), (state['h'][n], state['c'][n])) y = self.lo(self.dropout[-1](h[-1])) state = {'c': c, 'h': h} return state, y
def forward(self, xs, ilens, ys, labels, olens=None, spembs=None): """TACOTRON2 LOSS FORWARD CALCULATION :param torch.Tensor xs: batch of padded character ids (B, Tmax) :param list ilens: list of lengths of each input batch (B) :param torch.Tensor ys: batch of padded target features (B, Lmax, odim) :param torch.Tensor labels: batch of the sequences of stop token labels (B, Lmax) :param list olens: batch of the lengths of each target (B) :param torch.Tensor spembs: batch of speaker embedding vector (B, spk_embed_dim) :return: loss value :rtype: torch.Tensor """ after_outs, before_outs, logits = self.model(xs, ilens, ys, spembs) if self.use_masking and olens is not None: # weight positive samples if self.bce_pos_weight != 1.0: weights = ys.new(*labels.size()).fill_(1) weights.masked_fill_(labels.eq(1), self.bce_pos_weight) else: weights = None # masking padded values mask = to_cuda(self, make_mask(olens, ys.size(2))) ys = ys.masked_select(mask) after_outs = after_outs.masked_select(mask) before_outs = before_outs.masked_select(mask) labels = labels.masked_select(mask[:, :, 0]) logits = logits.masked_select(mask[:, :, 0]) weights = weights.masked_select( mask[:, :, 0]) if weights is not None else None # calculate loss l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys) mse_loss = F.mse_loss(after_outs, ys) + F.mse_loss(before_outs, ys) bce_loss = F.binary_cross_entropy_with_logits( logits, labels, weights) loss = l1_loss + mse_loss + bce_loss else: # calculate loss l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys) mse_loss = F.mse_loss(after_outs, ys) + F.mse_loss(before_outs, ys) bce_loss = F.binary_cross_entropy_with_logits(logits, labels) loss = l1_loss + mse_loss + bce_loss # report loss values for logging logging.debug( "loss = %.3e (bce: %.3e, l1: %.3e, mse: %.3e)" % (loss.item(), bce_loss.item(), l1_loss.item(), mse_loss.item())) self.reporter.report(l1_loss.item(), mse_loss.item(), bce_loss.item(), loss.item()) return loss
def forward(self, xs, ilens, ys, labels, olens, spembs=None, spcs=None): """TACOTRON2 LOSS FORWARD CALCULATION :param torch.Tensor xs: batch of padded character ids (B, Tmax) :param list ilens: list of lengths of each input batch (B) :param torch.Tensor ys: batch of padded target features (B, Lmax, odim) :param torch.Tensor labels: batch of the sequences of stop token labels (B, Lmax) :param list olens: batch of the lengths of each target (B) :param torch.Tensor spembs: batch of speaker embedding vector (B, spk_embed_dim) :param torch.Tensor spcs: batch of padded target features (B, Lmax, spc_dim) :return: loss value :rtype: torch.Tensor """ # calcuate outputs if self.use_cbhg: cbhg_outs, after_outs, before_outs, logits = self.model(xs, ilens, ys, olens, spembs) else: after_outs, before_outs, logits = self.model(xs, ilens, ys, olens, spembs) # remove mod part if self.reduction_factor > 1: olens = [olen - olen % self.reduction_factor for olen in olens] ys = ys[:, :max(olens)] labels = labels[:, :max(olens)] spcs = spcs[:, :max(olens)] if spcs is not None else None # prepare weight of positive samples in cross entorpy if self.bce_pos_weight != 1.0: weights = ys.new(*labels.size()).fill_(1) weights.masked_fill_(labels.eq(1), self.bce_pos_weight) else: weights = None # perform masking for padded values if self.use_masking: mask = to_cuda(self, make_non_pad_mask(olens).unsqueeze(-1)) ys = ys.masked_select(mask) after_outs = after_outs.masked_select(mask) before_outs = before_outs.masked_select(mask) labels = labels.masked_select(mask[:, :, 0]) logits = logits.masked_select(mask[:, :, 0]) weights = weights.masked_select(mask[:, :, 0]) if weights is not None else None if self.use_cbhg: spcs = spcs.masked_select(mask) cbhg_outs = cbhg_outs.masked_select(mask) # calculate loss l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys) mse_loss = F.mse_loss(after_outs, ys) + F.mse_loss(before_outs, ys) bce_loss = F.binary_cross_entropy_with_logits(logits, labels, weights) if self.use_cbhg: # calculate chbg loss and then itegrate them cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs) cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs) loss = l1_loss + mse_loss + bce_loss + cbhg_l1_loss + cbhg_mse_loss # report loss values for logging self.reporter.report([ {'l1_loss': l1_loss.item()}, {'mse_loss': mse_loss.item()}, {'bce_loss': bce_loss.item()}, {'cbhg_l1_loss': cbhg_l1_loss.item()}, {'cbhg_mse_loss': cbhg_mse_loss.item()}, {'loss': loss.item()}]) else: # integrate loss loss = l1_loss + mse_loss + bce_loss # report loss values for logging self.reporter.report([ {'l1_loss': l1_loss.item()}, {'mse_loss': mse_loss.item()}, {'bce_loss': bce_loss.item()}, {'loss': loss.item()}]) return loss
def get_tts_data(self, data, sort_by, use_speaker_embedding=None): # get eos eos = str(int(data[0][1]['output'][0]['shape'][1]) - 1) # get target features and input character sequence texts = [b[1]['output'][0]['tokenid'].split() + [eos] for b in data] feats = [b[1]['feat_tts'] for b in data] # remove empty sequence and get sort along with length filtered_idx = filter(lambda i: len(texts[i]) > 0, range(len(feats))) if sort_by == 'feat': sorted_idx = sorted(filtered_idx, key=lambda i: -len(feats[i])) elif sort_by == 'text': sorted_idx = sorted(filtered_idx, key=lambda i: -len(texts[i])) else: logging.error("Error: specify 'text' or 'feat' to sort") sys.exit() texts = [ np.fromiter(map(int, texts[i]), dtype=np.int64) for i in sorted_idx ] feats = [feats[i] for i in sorted_idx] # get list of lengths (must be tensor for DataParallel) textlens = torch.from_numpy( np.fromiter((x.shape[0] for x in texts), dtype=np.int64)) featlens = torch.from_numpy( np.fromiter((y.shape[0] for y in feats), dtype=np.int64)) # perform padding and convert to tensor texts = torch.from_numpy(pad_ndarray_list(texts, 0)).long() feats = torch.from_numpy(pad_ndarray_list(feats, 0)).float() # make labels for stop prediction labels = feats.new(feats.size(0), feats.size(1)).zero_() for i, l in enumerate(featlens): labels[i, l - 1:] = 1 if torch_is_old: texts = to_cuda(self, texts, volatile=not self.training) feats = to_cuda(self, feats, volatile=not self.training) labels = to_cuda(self, labels, volatile=not self.training) else: texts = to_cuda(self, texts) feats = to_cuda(self, feats) labels = to_cuda(self, labels) # load speaker embedding if use_speaker_embedding is not None: spembs = [b[1]['feat_spembs'] for b in data] spembs = [spembs[i] for i in sorted_idx] spembs = torch.from_numpy(np.array(spembs)).float() if torch_is_old: spembs = to_cuda(self, spembs, volatile=not self.training) else: spembs = to_cuda(self, spembs) else: spembs = None if self.return_targets: return texts, textlens, feats, labels, featlens, spembs else: return texts, textlens, feats, spembs
def forward(self, data, return_hidden=False, return_inout=False): asr_texts, asr_feats, asr_featlens = get_asr_data(self, data, 'feat') tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs = \ get_tts_data(self, data, 'feat', self.use_speaker_embedding) # encoder hpad_pre_spk, hlens, feat_input, feat_len = self.asr_enc( asr_feats, asr_featlens, True) if self.use_speaker_embedding is not None: spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hpad_pre_spk.size(1), -1) hpad = torch.cat([hpad_pre_spk, spembs], dim=-1) else: hpad = hpad_pre_spk after_outs, before_outs_, logits, att_ws = self.tts_dec( hpad, hlens.tolist(), tts_feats) # copied from e2e_tts_th.py if self.use_masking and tts_featlens is not None: # weight positive samples if self.bce_pos_weight != 1.0: # TODO(kan-bayashi): need to be fixed in pytorch v4 weights = tts_feats.data.new(*tts_labels.size()).fill_(1) if torch_is_old: weights = Variable(weights, volatile=tts_feats.volatile) weights.masked_fill_(tts_labels.eq(1), self.bce_pos_weight) else: weights = None # masking padded values mask = to_cuda(self, make_mask(tts_featlens, tts_feats.size(2))) feats = tts_feats.masked_select(mask) after_outs = after_outs.masked_select(mask) before_outs = before_outs_.masked_select(mask) labels = tts_labels.masked_select(mask[:, :, 0]) logits = logits.masked_select(mask[:, :, 0]) weights = weights.masked_select( mask[:, :, 0]) if weights is not None else None # calculate loss l1_loss = F.l1_loss(after_outs, feats) + F.l1_loss( before_outs, feats) mse_loss = F.mse_loss(after_outs, feats) + F.mse_loss( before_outs, feats) bce_loss = F.binary_cross_entropy_with_logits( logits, labels, weights) loss = l1_loss + mse_loss + bce_loss else: # calculate loss l1_loss = F.l1_loss(after_outs, tts_feats) + F.l1_loss( before_outs, tts_feats) mse_loss = F.mse_loss(after_outs, tts_feats) + F.mse_loss( before_outs, tts_feats) bce_loss = F.binary_cross_entropy_with_logits(logits, tts_labels) loss = l1_loss + mse_loss + bce_loss # report loss values for logging loss_data = loss.data[0] if torch_is_old else loss.item() l1_loss_data = l1_loss.data[0] if torch_is_old else l1_loss.item() bce_loss_data = bce_loss.data[0] if torch_is_old else bce_loss.item() mse_loss_data = mse_loss.data[0] if torch_is_old else mse_loss.item() logging.debug("loss = %.3e (bce: %.3e, l1: %.3e, mse: %.3e)" % (loss_data, bce_loss_data, l1_loss_data, mse_loss_data)) if return_inout: return loss, feat_input, feat_len, before_outs_, tts_featlens if return_hidden: return loss, hpad_pre_spk, hlens return loss