def _target_mask(self, olens): """Make masks for masked self-attention. Args: olens (LongTensor or List): Batch of lengths (B,). Returns: Tensor: Mask tensor for masked self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> olens = [5, 3] >>> self._target_mask(olens) tensor([[[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]], [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) return y_masks.unsqueeze(-2) & s_masks
def inference(self, x, inference_args, *args, **kwargs): """Generates the sequence of features from given a sequences of characters :param torch.Tensor x: the sequence of character ids (T) :param Namespace inference_args: argments containing following attributes (float) threshold: threshold in inference (float) minlenratio: minimum length ratio in inference (float) maxlenratio: maximum length ratio in inference :rtype: torch.Tensor :return: the sequence of stop probabilities (L) :rtype: torch.Tensor """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio # forward encoder xs = x.unsqueeze(0) hs, _ = self.encoder(xs, None) # set limits of length maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) minlen = int(hs.size(1) * minlenratio / self.reduction_factor) # initialize idx = 0 ys = hs.new_zeros(1, 1, self.odim) outs, probs = [], [] # forward decoder step-by-step while True: # update index idx += 1 # calculate output and stop prob at idx-th step y_masks = subsequent_mask(idx).unsqueeze(0) z = self.decoder.recognize(ys, y_masks, hs) # (B, adim) outs += [self.feat_out(z).view(self.reduction_factor, self.odim)] # [(r, odim), ...] probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...] # update next inputs ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)), dim=1) # (1, idx + 1, odim) # check whether to finish generation if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue outs = torch.cat(outs, dim=0).unsqueeze(0).transpose( 1, 2) # (L, odim) -> (1, L, odim) -> (1, odim, L) if self.postnet is not None: outs = outs + self.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) break return outs, probs
def _target_mask(self, olens): """Make mask for MaskedMultiHeadedAttention using padded sequences >>> olens = [5, 3] >>> self._target_mask(olens) tensor([[[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]], [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) return y_masks.unsqueeze(-2) & s_masks & y_masks.unsqueeze(-1)
def inference(self, x, inference_args, spemb=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). inference_args (Namespace): - threshold (float): Threshold in inference. - minlenratio (float): Minimum length ratio in inference. - maxlenratio (float): Maximum length ratio in inference. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T). """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio use_att_constraint = getattr(inference_args, "use_att_constraint", False) # keep compatibility print("==================") print("==================") print("TESTING TRANSFORMER") if use_att_constraint: logging.warning( "Attention constraint is not yet supported in Transformer. Not enabled." ) # forward encoder xs = x.unsqueeze(0) hs, _ = self.encoder(xs, None) # integrate speaker embedding if self.spk_embed_dim is not None: spembs = spemb.unsqueeze(0) hs = self._integrate_with_spk_embed(hs, spembs) # set limits of length maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) minlen = int(hs.size(1) * minlenratio / self.reduction_factor) # initialize idx = 0 ys = hs.new_zeros(1, 1, self.odim) outs, probs = [], [] # forward decoder step-by-step z_cache = self.decoder.init_state(x) while True: # update index idx += 1 # calculate output and stop prob at idx-th step y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device) z, z_cache = self.decoder.forward_one_step( ys, y_masks, hs, cache=z_cache) # (B, adim) outs += [self.feat_out(z).view(self.reduction_factor, self.odim)] # [(r, odim), ...] probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...] # update next inputs ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)), dim=1) # (1, idx + 1, odim) # get attention weights att_ws_ = [] for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention) and "src" in name: att_ws_ += [m.attn[0, :, -1].unsqueeze(1) ] # [(#heads, 1, T),...] if idx == 1: att_ws = att_ws_ else: # [(#heads, l, T), ...] att_ws = [ torch.cat([att_w, att_w_], dim=1) for att_w, att_w_ in zip(att_ws, att_ws_) ] # check whether to finish generation if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue outs = torch.cat(outs, dim=0).unsqueeze(0).transpose( 1, 2) # (L, odim) -> (1, L, odim) -> (1, odim, L) if self.postnet is not None: outs = outs + self.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) break # concatenate attention weights -> (#layers, #heads, L, T) att_ws = torch.stack(att_ws, dim=0) return outs, probs, att_ws
def inference(self, x, inference_args, spemb=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). inference_args (Namespace): - threshold (float): Threshold in inference. - minlenratio (float): Minimum length ratio in inference. - maxlenratio (float): Maximum length ratio in inference. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T). """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio # forward encoder xs = x.unsqueeze(0) hs, _ = self.encoder(xs, None) # integrate speaker embedding if self.spk_embed_dim is not None: spembs = spemb.unsqueeze(0) hs = self._integrate_with_spk_embed(hs, spembs) # set limits of length maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) minlen = int(hs.size(1) * minlenratio / self.reduction_factor) # initialize idx = 0 ys = hs.new_zeros(1, 1, self.odim) outs, probs = [], [] # forward decoder step-by-step while True: # update index idx += 1 # calculate output and stop prob at idx-th step y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device) z = self.decoder.recognize(ys, y_masks, hs) # (B, adim) outs += [self.feat_out(z).view(self.reduction_factor, self.odim)] # [(r, odim), ...] probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...] # update next inputs ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)), dim=1) # (1, idx + 1, odim) # check whether to finish generation if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue outs = torch.cat(outs, dim=0).unsqueeze(0).transpose( 1, 2) # (L, odim) -> (1, L, odim) -> (1, odim, L) if self.postnet is not None: outs = outs + self.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) break # get attention weights att_ws = [] for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention) and "src" in name: att_ws += [m.attn] att_ws = torch.cat(att_ws, dim=0) return outs, probs, att_ws