class E2E(ASRInterface, torch.nn.Module): @staticmethod def add_arguments(parser): group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=["pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed"], help='transformer input layer type') group.add_argument('--transformer-attn-dropout-rate', default=None, type=float, help='dropout in transformer attention. use --dropout-rate if None is set') group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument("--transformer-encoder-center-chunk-len", type=int, default=8, help='transformer chunk size, only used when transformer-encoder-type is memory') group.add_argument("--transformer-encoder-left-chunk-len", type=int, default=0, help='transformer left chunk size') group.add_argument("--transformer-encoder-hop-len", type=int, default=0, help='transformer encoder hop len, default') group.add_argument("--transformer-encoder-right-chunk-len", type=int, default=0, help='future data of the encoder') group.add_argument("--transformer-encoder-abs-embed", type=int, default=1, help='whether the network use absolute embed') group.add_argument("--transformer-encoder-rel-embed", type=int, default=0, help='whether the network us reality embed') group.add_argument("--transformer-encoder-use-memory", type=int, default=0, help='whether the network us memory to store history') return parser @property def attention_plot_class(self): return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, center_len=args.transformer_encoder_center_chunk_len, left_len=args.transformer_encoder_left_chunk_len, hop_len=args.transformer_encoder_hop_len, right_len=args.transformer_encoder_right_chunk_len, abs_pos=args.transformer_encoder_abs_embed, rel_pos=args.transformer_encoder_rel_embed, use_mem=args.transformer_encoder_use_memory, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate ) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer or args.mtlalpha > 0.0: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None def reset_parameters(self, args): if args.transformer_init == "pytorch": return # weight init for p in self.parameters(): if p.dim() > 1: if args.transformer_init == "xavier_uniform": torch.nn.init.xavier_uniform_(p.data) elif args.transformer_init == "xavier_normal": torch.nn.init.xavier_normal_(p.data) elif args.transformer_init == "kaiming_uniform": torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") elif args.transformer_init == "kaiming_normal": torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") else: raise ValueError("Unknown initialization: " + args.transformer_init) # bias init for p in self.parameters(): if p.dim() == 1: p.data.zero_() # reset some modules with default init for m in self.modules(): if isinstance(m, (torch.nn.Embedding, LayerNorm)): m.reset_parameters() def add_sos_eos(self, ys_pad): from espnet.nets.pytorch_backend.nets_utils import pad_list eos = ys_pad.new([self.eos]) sos = ys_pad.new([self.sos]) ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] return pad_list(ys_in, self.eos), pad_list(ys_out, self.ignore_id) def target_mask(self, ys_in_pad): ys_mask = ys_in_pad != self.ignore_id m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) return ys_mask.unsqueeze(-2) & m def forward(self, xs_pad, ilens, ys_pad): '''E2E forward :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float ''' # forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # forward decoder ys_in_pad, ys_out_pad = self.add_sos_eos(ys_pad) ys_mask = self.target_mask(ys_in_pad) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # compute loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predected text # TODO(karita) calculate these stats if self.mtlalpha == 0.0: loss_ctc = None cer_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss def recognize(self, feat, recog_args, char_list=None, rnnlm=None, use_jit=False): '''recognize feat :param ndnarray x: input acouctic feature (B, T, D) or (T, D) :param namespace recog_args: argment namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list TODO(karita): do not recompute previous attention for faster decoding ''' self.eval() feat = torch.as_tensor(feat).unsqueeze(0) enc_output, _ = self.encoder(feat, None) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace(self.decoder.recognize, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output) else: local_att_scores = self.decoder.recognize(ys, ys_mask, enc_output) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float(local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.') # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(feat, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): '''E2E attention calculation :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray ''' with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): ret[name] = m.attn.cpu().numpy() return ret
class Transformer(TTSInterface, torch.nn.Module): """Text-to-Speech Transformer module. This is a module of text-to-speech Transformer described in `Neural Speech Synthesis with Transformer Network`_, which convert the sequence of characters or phonemes into the sequence of Mel-filterbanks. .. _`Neural Speech Synthesis with Transformer Network`: https://arxiv.org/pdf/1809.08895.pdf """ @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group("transformer model setting") # network structure related group.add_argument( "--embed-dim", default=512, type=int, help="Dimension of character embedding in encoder prenet") group.add_argument("--eprenet-conv-layers", default=3, type=int, help="Number of encoder prenet convolution layers") group.add_argument( "--eprenet-conv-chans", default=256, type=int, help="Number of encoder prenet convolution channels") group.add_argument("--eprenet-conv-filts", default=5, type=int, help="Filter size of encoder prenet convolution") group.add_argument("--dprenet-layers", default=2, type=int, help="Number of decoder prenet layers") group.add_argument("--dprenet-units", default=256, type=int, help="Number of decoder prenet hidden units") group.add_argument("--elayers", default=3, type=int, help="Number of encoder layers") group.add_argument("--eunits", default=1536, type=int, help="Number of encoder hidden units") group.add_argument( "--adim", default=384, type=int, help="Number of attention transformation dimensions") group.add_argument("--aheads", default=4, type=int, help="Number of heads for multi head attention") group.add_argument("--dlayers", default=3, type=int, help="Number of decoder layers") group.add_argument("--dunits", default=1536, type=int, help="Number of decoder hidden units") group.add_argument("--postnet-layers", default=5, type=int, help="Number of postnet layers") group.add_argument("--postnet-chans", default=256, type=int, help="Number of postnet channels") group.add_argument("--postnet-filts", default=5, type=int, help="Filter size of postnet") group.add_argument( "--use-scaled-pos-enc", default=True, type=strtobool, help= "Use trainable scaled positional encoding instead of the fixed scale one." ) group.add_argument("--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization") group.add_argument( "--encoder-normalize-before", default=False, type=strtobool, help="Whether to apply layer norm before encoder block") group.add_argument( "--decoder-normalize-before", default=False, type=strtobool, help="Whether to apply layer norm before decoder block") group.add_argument( "--encoder-concat-after", default=False, type=strtobool, help= "Whether to concatenate attention layer's input and output in encoder" ) group.add_argument( "--decoder-concat-after", default=False, type=strtobool, help= "Whether to concatenate attention layer's input and output in decoder" ) group.add_argument("--reduction-factor", default=1, type=int, help="Reduction factor") group.add_argument("--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions") group.add_argument("--spk-embed-integration-type", type=str, default="add", choices=["add", "concat"], help="How to integrate speaker embedding") # training related group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help="How to initialize transformer parameters") group.add_argument( "--initial-encoder-alpha", type=float, default=1.0, help="Initial alpha value in encoder's ScaledPositionalEncoding") group.add_argument( "--initial-decoder-alpha", type=float, default=1.0, help="Initial alpha value in decoder's ScaledPositionalEncoding") group.add_argument("--transformer-lr", default=1.0, type=float, help="Initial value of learning rate") group.add_argument("--transformer-warmup-steps", default=4000, type=int, help="Optimizer warmup steps") group.add_argument( "--transformer-enc-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder except for attention") group.add_argument( "--transformer-enc-positional-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder positional encoding") group.add_argument( "--transformer-enc-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder self-attention") group.add_argument( "--transformer-dec-dropout-rate", default=0.1, type=float, help= "Dropout rate for transformer decoder except for attention and pos encoding" ) group.add_argument( "--transformer-dec-positional-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder positional encoding") group.add_argument( "--transformer-dec-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder self-attention") group.add_argument( "--transformer-enc-dec-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder-decoder attention") group.add_argument("--eprenet-dropout-rate", default=0.5, type=float, help="Dropout rate in encoder prenet") group.add_argument("--dprenet-dropout-rate", default=0.5, type=float, help="Dropout rate in decoder prenet") group.add_argument("--postnet-dropout-rate", default=0.5, type=float, help="Dropout rate in postnet") # loss related group.add_argument( "--use-masking", default=True, type=strtobool, help="Whether to use masking in calculation of loss") group.add_argument("--loss-type", default="L1", choices=["L1", "L2", "L1+L2"], help="How to calc loss") group.add_argument( "--bce-pos-weight", default=5.0, type=float, help= "Positive sample weight in BCE calculation (only for use-masking=True)" ) group.add_argument("--use-guided-attn-loss", default=False, type=strtobool, help="Whether to use guided attention loss") group.add_argument("--guided-attn-loss-sigma", default=0.4, type=float, help="Sigma in guided attention loss") group.add_argument("--guided-attn-loss-lambda", default=1.0, type=float, help="Lambda in guided attention loss") group.add_argument( "--num-heads-applied-guided-attn", default=2, type=int, help= "Number of heads in each layer to be applied guided attention loss" "if set -1, all of the heads will be applied.") group.add_argument( "--num-layers-applied-guided-attn", default=2, type=int, help="Number of layers to be applied guided attention loss" "if set -1, all of the layers will be applied.") group.add_argument( "--modules-applied-guided-attn", type=str, nargs="+", default=["encoder-decoder"], help="Module name list to be applied guided attention loss") return parser @property def attention_plot_class(self): """Return plot class for attention weight plot.""" return TTSPlot def __init__(self, idim, odim, args=None): """Initialize TTS-Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - embed_dim (int): Dimension of character embedding. - eprenet_conv_layers (int): Number of encoder prenet convolution layers. - eprenet_conv_chans (int): Number of encoder prenet convolution channels. - eprenet_conv_filts (int): Filter size of encoder prenet convolution. - dprenet_layers (int): Number of decoder prenet layers. - dprenet_units (int): Number of decoder prenet hidden units. - elayers (int): Number of encoder layers. - eunits (int): Number of encoder hidden units. - adim (int): Number of attention transformation dimensions. - aheads (int): Number of heads for multi head attention. - dlayers (int): Number of decoder layers. - dunits (int): Number of decoder hidden units. - postnet_layers (int): Number of postnet layers. - postnet_chans (int): Number of postnet channels. - postnet_filts (int): Filter size of postnet. - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding. - use_batch_norm (bool): Whether to use batch normalization in encoder prenet. - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block. - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block. - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spk_embed_integration_type: How to integrate speaker embedding. - transformer_init (float): How to initialize transformer parameters. - transformer_lr (float): Initial value of learning rate. - transformer_warmup_steps (int): Optimizer warmup steps. - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding. - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module. - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module. - eprenet_dropout_rate (float): Dropout rate in encoder prenet. - dprenet_dropout_rate (float): Dropout rate in decoder prenet. - postnet_dropout_rate (float): Dropout rate in postnet. - use_masking (bool): Whether to use masking in calculation of loss. - bce_pos_weight (float): Positive sample weight in bce calculation (only for use_masking=true). - loss_type (str): How to calculate loss. - use_guided_attn_loss (bool): Whether to use guided attention loss. - num_heads_applied_guided_attn (int): Number of heads in each layer to apply guided attention loss. - num_layers_applied_guided_attn (int): Number of layers to apply guided attention loss. - modules_applied_guided_attn (list): List of module names to apply guided attention loss. - guided-attn-loss-sigma (float) Sigma in guided attention loss. - guided-attn-loss-lambda (float): Lambda in guided attention loss. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type self.use_scaled_pos_enc = args.use_scaled_pos_enc self.reduction_factor = args.reduction_factor self.loss_type = args.loss_type self.use_guided_attn_loss = args.use_guided_attn_loss if self.use_guided_attn_loss: if args.num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = args.elayers else: self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn if args.num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = args.aheads else: self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn self.modules_applied_guided_attn = args.modules_applied_guided_attn # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define transformer encoder if args.eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet(idim=idim, embed_dim=args.embed_dim, elayers=0, econv_layers=args.eprenet_conv_layers, econv_chans=args.eprenet_conv_chans, econv_filts=args.eprenet_conv_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.eprenet_dropout_rate, padding_idx=padding_idx), torch.nn.Linear(args.eprenet_conv_chans, args.adim)) else: encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after) # define projection layer if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define transformer decoder if args.dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet(idim=odim, n_layers=args.dprenet_layers, n_units=args.dprenet_units, dropout_rate=args.dprenet_dropout_rate), torch.nn.Linear(args.dprenet_units, args.adim)) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=-1, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, src_attention_dropout_rate=args. transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) # define postnet self.postnet = None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate) # define loss function self.criterion = TransformerLoss(use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) # initialize parameters self._reset_parameters(init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha) def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): # initialize parameters initialize(self, init_type) # initialize alpha in scaled positional encoding if self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) def _add_first_frame_and_remove_last_frame(self, ys): ys_in = torch.cat( [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1) return ys_in def forward(self, xs, ilens, ys, labels, olens, spembs=None, asrtts=False, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) max_ilen = max(ilens) max_olen = max(olens) if max_ilen != xs.shape[1]: xs = xs[:, :max_ilen] if max_olen != ys.shape[1]: ys = ys[:, :max_olen] labels = labels[:, :max_olen] # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) if self.reduction_factor > 1: ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor] olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: ys_in, olens_in = ys, olens # add first zero frame and remove last frame for auto-regressive ys_in = self._add_first_frame_and_remove_last_frame(ys_in) # forward decoder y_masks = self._target_mask(olens_in) xy_masks = self._source_to_target_mask(ilens, olens_in) zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks) # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax//r, r) -> (B, Lmax//r * r) logits = self.prob_out(zs).view(zs.size(0), -1) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) import numpy as np name = np.random.choice(100, 1) import matplotlib.pyplot as plt fig = plt.figure() ax1 = plt.subplot(2, 1, 1) ax2 = plt.subplot(2, 1, 2) ##for n in range(int(after_outs.size(0))): mel_fbank = np.transpose(after_outs[0].detach().cpu().numpy()) gt_mel_fbank = np.transpose(ys[0].detach().cpu().numpy()) ax1.imshow(mel_fbank, aspect='auto', cmap=plt.cm.jet) ax2.imshow(gt_mel_fbank, aspect='auto', cmap=plt.cm.jet) #fig.savefig('images_asr2tts_transformer/asr2tts.%d.%d.png' % (0,name), orientation='landscape') # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] labels = labels[:, :max_olen] labels[:, -1] = 1.0 # make sure at least one frame has 1 # caluculate loss values if asrtts: l1_loss, l2_loss, bce_loss = self.criterion(after_outs, before_outs, logits, ys, labels, olens, reduction=False) if self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = l2_loss + bce_loss elif self.loss_type == "L1+L2": loss = l1_loss + l2_loss + bce_loss else: raise ValueError("unknown --loss-type " + self.loss_type) report_keys = [ { "l1_loss": l1_loss.mean().item() }, { "l2_loss": l2_loss.mean().item() }, { "bce_loss": bce_loss.mean().item() }, { "loss": loss.mean().item() }, ] else: l1_loss, l2_loss, bce_loss = self.criterion( after_outs, before_outs, logits, ys, labels, olens) if self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = l2_loss + bce_loss elif self.loss_type == "L1+L2": loss = l1_loss + l2_loss + bce_loss else: raise ValueError("unknown --loss-type " + self.loss_type) report_keys = [ { "l1_loss": l1_loss.item() }, { "l2_loss": l2_loss.item() }, { "bce_loss": bce_loss.item() }, { "loss": loss.item() }, ] # calculate guided attention loss if self.use_guided_attn_loss: # calculate for encoder if "encoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.encoder.encoders)))): att_ws += [ self.encoder.encoders[layer_idx].self_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_in, T_in) enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens) loss = loss + enc_attn_loss report_keys += [{"enc_attn_loss": enc_attn_loss.item()}] # calculate for decoder if "decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.decoder.decoders[layer_idx].self_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_out) dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in) loss = loss + dec_attn_loss report_keys += [{"dec_attn_loss": dec_attn_loss.item()}] # calculate for encoder-decoder if "encoder-decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.decoder.decoders[layer_idx].src_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_in) enc_dec_attn_loss = self.attn_criterion( att_ws, ilens, olens_in) loss = loss + enc_dec_attn_loss report_keys += [{ "enc_dec_attn_loss": enc_dec_attn_loss.item() }] else: att_ws = None # report extra information if self.use_scaled_pos_enc: report_keys += [ { "encoder_alpha": self.encoder.embed[-1].alpha.data.item() }, { "decoder_alpha": self.decoder.embed[-1].alpha.data.item() }, ] self.reporter.report(report_keys) if asrtts: return loss, after_outs, before_outs, logits, att_ws else: return loss def inference(self, x, inference_args, spemb=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). inference_args (Namespace): - threshold (float): Threshold in inference. - minlenratio (float): Minimum length ratio in inference. - maxlenratio (float): Maximum length ratio in inference. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T). """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio # forward encoder xs = x.unsqueeze(0) hs, _ = self.encoder(xs, None) # integrate speaker embedding if self.spk_embed_dim is not None: spembs = spemb.unsqueeze(0) hs = self._integrate_with_spk_embed(hs, spembs) # set limits of length maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) minlen = int(hs.size(1) * minlenratio / self.reduction_factor) # initialize idx = 0 ys = hs.new_zeros(1, 1, self.odim) outs, probs = [], [] # forward decoder step-by-step while True: # update index idx += 1 # calculate output and stop prob at idx-th step y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device) z = self.decoder.recognize(ys, y_masks, hs) # (B, adim) outs += [self.feat_out(z).view(self.reduction_factor, self.odim)] # [(r, odim), ...] probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...] # update next inputs ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)), dim=1) # (1, idx + 1, odim) # check whether to finish generation if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue outs = torch.cat(outs, dim=0).unsqueeze(0).transpose( 1, 2) # (L, odim) -> (1, L, odim) -> (1, odim, L) if self.postnet is not None: outs = outs + self.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) break # get attention weights att_ws = [] for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention) and "src" in name: att_ws += [m.attn] att_ws = torch.cat(att_ws, dim=0) return outs, probs, att_ws def calculate_all_attentions(self, xs, ilens, ys, olens, spembs=None, skip_output=False, keep_tensor=False, *args, **kwargs): """Calculate all of the attention weights. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). skip_output (bool, optional): Whether to skip calculate the final output. keep_tensor (bool, optional): Whether to keep original tensor. Returns: dict: Dict of attention weights and outputs. """ with torch.no_grad(): # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) if self.reduction_factor > 1: ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor] olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: ys_in, olens_in = ys, olens # add first zero frame and remove last frame for auto-regressive ys_in = self._add_first_frame_and_remove_last_frame(ys_in) # forward decoder y_masks = self._target_mask(olens_in) xy_masks = self._source_to_target_mask(ilens, olens_in) zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks) # calculate final outputs if not skip_output: before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet( before_outs.transpose(1, 2)).transpose(1, 2) # modifiy mod part of output lengths due to reduction factor > 1 if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) # store into dict att_ws_dict = dict() if keep_tensor: for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): att_ws_dict[name] = m.attn if not skip_output: att_ws_dict["before_postnet_fbank"] = before_outs att_ws_dict["after_postnet_fbank"] = after_outs else: for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): attn = m.attn.cpu().numpy() if "encoder" in name: attn = [ a[:, :l, :l] for a, l in zip(attn, ilens.tolist()) ] elif "decoder" in name: if "src" in name: attn = [ a[:, :ol, :il] for a, il, ol in zip( attn, ilens.tolist(), olens_in.tolist()) ] elif "self" in name: attn = [ a[:, :l, :l] for a, l in zip(attn, olens_in.tolist()) ] else: logging.warning("unknown attention module: " + name) else: logging.warning("unknown attention module: " + name) att_ws_dict[name] = attn if not skip_output: before_outs = before_outs.cpu().numpy() after_outs = after_outs.cpu().numpy() att_ws_dict["before_postnet_fbank"] = [ m[:l].T for m, l in zip(before_outs, olens.tolist()) ] att_ws_dict["after_postnet_fbank"] = [ m[:l].T for m, l in zip(after_outs, olens.tolist()) ] return att_ws_dict def _integrate_with_spk_embed(self, hs, spembs): """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, adim) """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens): """Make masks for self-attention. Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1) def _target_mask(self, olens): """Make masks for masked self-attention. Examples: >>> olens = [5, 3] >>> self._target_mask(olens) tensor([[[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]], [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) return y_masks.unsqueeze(-2) & s_masks & y_masks.unsqueeze(-1) def _source_to_target_mask(self, ilens, olens): """Make masks for encoder-decoder attention. Examples: >>> ilens = [4, 2] >>> olens = [5, 3] >>> self._source_to_target_mask(ilens) tensor([[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) & y_masks.unsqueeze(-1) @property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ["loss", "l1_loss", "l2_loss", "bce_loss"] if self.use_scaled_pos_enc: plot_keys += ["encoder_alpha", "decoder_alpha"] if self.use_guided_attn_loss: if "encoder" in self.modules_applied_guided_attn: plot_keys += ["enc_attn_loss"] if "decoder" in self.modules_applied_guided_attn: plot_keys += ["dec_attn_loss"] if "encoder-decoder" in self.modules_applied_guided_attn: plot_keys += ["enc_dec_attn_loss"] return plot_keys
class Transformer(TTSInterface, torch.nn.Module): """Text-to-Speech Transformer Reference: Neural Speech Synthesis with Transformer Network (https://arxiv.org/pdf/1809.08895.pdf) :param int idim: dimension of the inputs :param int odim: dimension of the outputs :param Namespace args: argments containing following attributes (int) embed_dim: dimension of character embedding (int) eprenet_conv_layers: number of encoder prenet convolution layers (int) eprenet_conv_chans: number of encoder prenet convolution channels (int) eprenet_conv_filts: filter size of encoder prenet convolution (int) dprenet_layers: number of decoder prenet layers (int) dprenet_units: number of decoder prenet hidden units (int) elayers: number of encoder layers (int) eunits: number of encoder hidden units (int) adim: number of attention transformation dimensions (int) aheads: number of heads for multi head attention (int) dlayers: number of decoder layers (int) dunits: number of decoder hidden units (int) postnet_layers: number of postnet layers (int) postnet_chans: number of postnet channels (int) postnet_filts: filter size of postnet (bool) use_scaled_pos_enc: whether to use trainable scaled positional encoding instead of the fixed scale one (bool) use_batch_norm: whether to use batch normalization in encoder prenet (bool) encoder_normalize_before: whether to perform layer normalization before encoder block (bool) decoder_normalize_before: whether to perform layer normalization before decoder block (bool) encoder_concat_after: whether to concatenate attention layer's input and output in encoder (bool) decoder_concat_after: whether to concatenate attention layer's input and output in decoder (int) reduction_factor: reduction factor (float) transformer_init: how to initialize transformer parameters (float) transformer_lr: initial value of learning rate (int) transformer_warmup_steps: optimizer warmup steps (float) transformer_enc_dropout_rate: dropout rate in encoder except for attention and positional encoding (float) transformer_enc_positional_dropout_rate: dropout rate after encoder positional encoding (float) transformer_enc_attn_dropout_rate: dropout rate in encoder self-attention module (float) transformer_dec_dropout_rate: dropout rate in decoder except for attention and positional encoding (float) transformer_dec_positional_dropout_rate: dropout rate after decoder positional encoding (float) transformer_dec_attn_dropout_rate: dropout rate in deocoder self-attention module (float) transformer_enc_dec_attn_dropout_rate: dropout rate in encoder-deocoder attention module (float) eprenet_dropout_rate: dropout rate in encoder prenet (float) dprenet_dropout_rate: dropout rate in decoder prenet (float) postnet_dropout_rate: dropout rate in postnet (bool) use_masking: whether to use masking in calculation of loss (float) bce_pos_weight: positive sample weight in bce calculation (only for use_masking=true) (str) loss_type: how to calculate loss (bool) use_guided_attn_loss: whether to use guided attention loss (int) num_heads_applied_guided_attn: number of heads in each layer to be applied guided attention loss (int) num_layers_applied_guided_attn: number of layers to be applied guided attention loss (list) modules_applied_guided_attn: list of module names to be applied guided attention loss """ @staticmethod def add_arguments(parser): group = parser.add_argument_group("transformer model setting") # network structure related group.add_argument( "--embed-dim", default=512, type=int, help="Dimension of character embedding in encoder prenet") group.add_argument("--eprenet-conv-layers", default=3, type=int, help="Number of encoder prenet convolution layers") group.add_argument( "--eprenet-conv-chans", default=256, type=int, help="Number of encoder prenet convolution channels") group.add_argument("--eprenet-conv-filts", default=5, type=int, help="Filter size of encoder prenet convolution") group.add_argument("--dprenet-layers", default=2, type=int, help="Number of decoder prenet layers") group.add_argument("--dprenet-units", default=256, type=int, help="Number of decoder prenet hidden units") group.add_argument("--elayers", default=3, type=int, help="Number of encoder layers") group.add_argument("--eunits", default=1536, type=int, help="Number of encoder hidden units") group.add_argument( "--adim", default=384, type=int, help="Number of attention transformation dimensions") group.add_argument("--aheads", default=4, type=int, help="Number of heads for multi head attention") group.add_argument("--dlayers", default=3, type=int, help="Number of decoder layers") group.add_argument("--dunits", default=1536, type=int, help="Number of decoder hidden units") group.add_argument("--postnet-layers", default=5, type=int, help="Number of postnet layers") group.add_argument("--postnet-chans", default=256, type=int, help="Number of postnet channels") group.add_argument("--postnet-filts", default=5, type=int, help="Filter size of postnet") group.add_argument( "--use-scaled-pos-enc", default=True, type=strtobool, help= "Use trainable scaled positional encoding instead of the fixed scale one." ) group.add_argument("--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization") group.add_argument( "--encoder-normalize-before", default=False, type=strtobool, help="Whether to apply layer norm before encoder block") group.add_argument( "--decoder-normalize-before", default=False, type=strtobool, help="Whether to apply layer norm before decoder block") group.add_argument( "--encoder-concat-after", default=False, type=strtobool, help= "Whether to concatenate attention layer's input and output in encoder" ) group.add_argument( "--decoder-concat-after", default=False, type=strtobool, help= "Whether to concatenate attention layer's input and output in decoder" ) parser.add_argument("--reduction-factor", default=1, type=int, help="Reduction factor") # training related group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help="How to initialize transformer parameters") group.add_argument( "--initial-encoder-alpha", type=float, default=1.0, help="Initial alpha value in encoder's ScaledPositionalEncoding") group.add_argument( "--initial-decoder-alpha", type=float, default=1.0, help="Initial alpha value in decoder's ScaledPositionalEncoding") group.add_argument("--transformer-lr", default=1.0, type=float, help="Initial value of learning rate") group.add_argument("--transformer-warmup-steps", default=4000, type=int, help="Optimizer warmup steps") group.add_argument( "--transformer-enc-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder except for attention") group.add_argument( "--transformer-enc-positional-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder positional encoding") group.add_argument( "--transformer-enc-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder self-attention") group.add_argument( "--transformer-dec-dropout-rate", default=0.1, type=float, help= "Dropout rate for transformer decoder except for attention and pos encoding" ) group.add_argument( "--transformer-dec-positional-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder positional encoding") group.add_argument( "--transformer-dec-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder self-attention") group.add_argument( "--transformer-enc-dec-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder-decoder attention") group.add_argument("--eprenet-dropout-rate", default=0.5, type=float, help="Dropout rate in encoder prenet") group.add_argument("--dprenet-dropout-rate", default=0.5, type=float, help="Dropout rate in decoder prenet") group.add_argument("--postnet-dropout-rate", default=0.5, type=float, help="Dropout rate in postnet") # loss related group.add_argument( "--use-masking", default=True, type=strtobool, help="Whether to use masking in calculation of loss") group.add_argument("--loss-type", default="L1", choices=["L1", "L2", "L1+L2"], help="How to calc loss") group.add_argument( "--bce-pos-weight", default=5.0, type=float, help= "Positive sample weight in BCE calculation (only for use-masking=True)" ) group.add_argument("--use-guided-attn-loss", default=False, type=strtobool, help="Whether to use guided attention loss") group.add_argument("--guided-attn-loss-sigma", default=0.4, type=float, help="Sigma in guided attention loss") group.add_argument( "--num-heads-applied-guided-attn", default=2, type=int, help= "Number of heads in each layer to be applied guided attention loss" "if set -1, all of the heads will be applied.") group.add_argument( "--num-layers-applied-guided-attn", default=2, type=int, help="Number of layers to be applied guided attention loss" "if set -1, all of the layers will be applied.") group.add_argument( "--modules-applied-guided-attn", type=str, nargs="+", default=["encoder-decoder"], help="Module name list to be applied guided attention loss") return parser @property def attention_plot_class(self): return TTSPlot def __init__(self, idim, odim, args): # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # store hyperparameters self.idim = idim self.odim = odim self.use_scaled_pos_enc = args.use_scaled_pos_enc self.reduction_factor = args.reduction_factor self.loss_type = args.loss_type self.use_guided_attn_loss = args.use_guided_attn_loss if self.use_guided_attn_loss: if args.num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = args.elayers else: self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn if args.num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = args.aheads else: self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn self.modules_applied_guided_attn = args.modules_applied_guided_attn # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define transformer encoder if args.eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet(idim=idim, embed_dim=args.embed_dim, elayers=0, econv_layers=args.eprenet_conv_layers, econv_chans=args.eprenet_conv_chans, econv_filts=args.eprenet_conv_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.eprenet_dropout_rate, padding_idx=padding_idx), torch.nn.Linear(args.eprenet_conv_chans, args.adim)) else: encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after) # define transformer decoder if args.dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet(idim=odim, n_layers=args.dprenet_layers, n_units=args.dprenet_units, dropout_rate=args.dprenet_dropout_rate), torch.nn.Linear(args.dprenet_units, args.adim)) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=-1, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, src_attention_dropout_rate=args. transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) # define postnet self.postnet = None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate) # define loss function self.criterion = TransformerLoss(args) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( args.guided_attn_loss_sigma) # initialize parameters self._reset_parameters(args) def _reset_parameters(self, args): if self.use_scaled_pos_enc: # alpha in scaled positional encoding init self.encoder.embed[-1].alpha.data = torch.tensor( args.initial_encoder_alpha) self.decoder.embed[-1].alpha.data = torch.tensor( args.initial_decoder_alpha) if args.transformer_init == "pytorch": return # weight init for p in self.parameters(): if p.dim() > 1: if args.transformer_init == "xavier_uniform": torch.nn.init.xavier_uniform_(p.data) elif args.transformer_init == "xavier_normal": torch.nn.init.xavier_normal_(p.data) elif args.transformer_init == "kaiming_uniform": torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") elif args.transformer_init == "kaiming_normal": torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") else: raise ValueError("Unknown initialization: " + args.transformer_init) # bias init for p in self.parameters(): if p.dim() == 1: p.data.zero_() # reset some modules with default init for m in self.modules(): if isinstance(m, (torch.nn.Embedding, LayerNorm)): m.reset_parameters() def _add_first_frame_and_remove_last_frame(self, ys): ys_in = torch.cat( [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1) return ys_in def forward(self, xs, ilens, ys, labels, olens, *args, **kwargs): """Calculate forward propagation :param torch.Tensor xs: batch of padded character ids (B, Tmax) :param torch.Tensor ilens: list of lengths of each input batch (B) :param torch.Tensor ys: batch of padded target features (B, Lmax, odim) :param torch.Tensor olens: batch of the lengths of each target (B) :return: loss value :rtype: torch.Tensor """ # remove unnecessary padded part (for multi-gpus) max_ilen = max(ilens) max_olen = max(olens) if max_ilen != xs.shape[1]: xs = xs[:, :max_ilen] if max_olen != ys.shape[1]: ys = ys[:, :max_olen] labels = labels[:, :max_olen] # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) if self.reduction_factor > 1: ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor] olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: ys_in, olens_in = ys, olens # add first zero frame and remove last frame for auto-regressive ys_in = self._add_first_frame_and_remove_last_frame(ys_in) # forward decoder y_masks = self._target_mask(olens_in) xy_masks = self._source_to_target_mask(ilens, olens_in) zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks) # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax//r, r) -> (B, Lmax//r * r) logits = self.prob_out(zs).view(zs.size(0), -1) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] labels = labels[:, :max_olen] labels[:, -1] = 1.0 # make sure at least one frame has 1 # caluculate loss values l1_loss, l2_loss, bce_loss = self.criterion(after_outs, before_outs, logits, ys, labels, olens) if self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = l2_loss + bce_loss elif self.loss_type == "L1+L2": loss = l1_loss + l2_loss + bce_loss else: raise ValueError("unknown --loss-type " + self.loss_type) report_keys = [ { "l1_loss": l1_loss.item() }, { "l2_loss": l2_loss.item() }, { "bce_loss": bce_loss.item() }, { "loss": loss.item() }, ] # calculate guided attention loss if self.use_guided_attn_loss: # calculate for encoder if "encoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.encoder.encoders)))): att_ws += [ self.encoder.encoders[layer_idx].self_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_in, T_in) enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens) loss = loss + enc_attn_loss report_keys += [{"enc_attn_loss": enc_attn_loss.item()}] # calculate for decoder if "decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.decoder.decoders[layer_idx].self_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_out) dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in) loss = loss + dec_attn_loss report_keys += [{"dec_attn_loss": dec_attn_loss.item()}] # calculate for encoder-decoder if "encoder-decoder" in self.modules_applied_guided_attn: att_ws = [] for idx, layer_idx in enumerate( reversed(range(len(self.decoder.decoders)))): att_ws += [ self.decoder.decoders[layer_idx].src_attn. attn[:, :self.num_heads_applied_guided_attn] ] if idx + 1 == self.num_layers_applied_guided_attn: break att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_in) enc_dec_attn_loss = self.attn_criterion( att_ws, ilens, olens_in) loss = loss + enc_dec_attn_loss report_keys += [{ "enc_dec_attn_loss": enc_dec_attn_loss.item() }] # report extra information if self.use_scaled_pos_enc: report_keys += [ { "encoder_alpha": self.encoder.embed[-1].alpha.data.item() }, { "decoder_alpha": self.decoder.embed[-1].alpha.data.item() }, ] self.reporter.report(report_keys) return loss def inference(self, x, inference_args, *args, **kwargs): """Generates the sequence of features from given a sequences of characters :param torch.Tensor x: the sequence of character ids (T) :param Namespace inference_args: argments containing following attributes (float) threshold: threshold in inference (float) minlenratio: minimum length ratio in inference (float) maxlenratio: maximum length ratio in inference :rtype: torch.Tensor :return: the sequence of stop probabilities (L) :rtype: torch.Tensor """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio # forward encoder xs = x.unsqueeze(0) hs, _ = self.encoder(xs, None) # set limits of length maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) minlen = int(hs.size(1) * minlenratio / self.reduction_factor) # initialize idx = 0 ys = hs.new_zeros(1, 1, self.odim) outs, probs = [], [] # forward decoder step-by-step while True: # update index idx += 1 # calculate output and stop prob at idx-th step y_masks = subsequent_mask(idx).unsqueeze(0) z = self.decoder.recognize(ys, y_masks, hs) # (B, adim) outs += [self.feat_out(z).view(self.reduction_factor, self.odim)] # [(r, odim), ...] probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...] # update next inputs ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)), dim=1) # (1, idx + 1, odim) # check whether to finish generation if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue outs = torch.cat(outs, dim=0).unsqueeze(0).transpose( 1, 2) # (L, odim) -> (1, L, odim) -> (1, odim, L) if self.postnet is not None: outs = outs + self.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) break return outs, probs def calculate_all_attentions(self, xs, ilens, ys, olens, *args, **kwargs): """Calculate attention weights of all of the layers :param torch.Tensor xs: batch of padded character ids (B, Tmax) :param torch.Tensor ilens: list of lengths of each input batch (B) :param torch.Tensor ys: batch of padded target features (B, Lmax, odim) :param torch.Tensor ilens: list of lengths of each output batch (B) :return: attention weights dict :rtype: dict """ with torch.no_grad(): # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) if self.reduction_factor > 1: ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor] olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: ys_in, olens_in = ys, olens # add first zero frame and remove last frame for auto-regressive ys_in = self._add_first_frame_and_remove_last_frame(ys_in) # forward decoder y_masks = self._target_mask(olens_in) xy_masks = self._source_to_target_mask(ilens, olens_in) zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks) # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet( before_outs.transpose(1, 2)).transpose(1, 2) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) att_ws_dict = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): attn = m.attn.cpu().numpy() if "encoder" in name: attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())] elif "decoder" in name: if "src" in name: attn = [ a[:, :ol, :il] for a, il, ol in zip( attn, ilens.tolist(), olens_in.tolist()) ] elif "self" in name: attn = [ a[:, :l, :l] for a, l in zip(attn, olens_in.tolist()) ] else: logging.warning("unknown attention module: " + name) else: logging.warning("unknown attention module: " + name) att_ws_dict[name] = attn att_ws_dict["before_postnet_fbank"] = [ m[:l].T for m, l in zip(before_outs.cpu().numpy(), olens.tolist()) ] att_ws_dict["after_postnet_fbank"] = [ m[:l].T for m, l in zip(after_outs.cpu().numpy(), olens.tolist()) ] return att_ws_dict def _source_mask(self, ilens): """Make mask for MultiHeadedAttention using padded sequences >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1) def _target_mask(self, olens): """Make mask for MaskedMultiHeadedAttention using padded sequences >>> olens = [5, 3] >>> self._target_mask(olens) tensor([[[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]], [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) return y_masks.unsqueeze(-2) & s_masks & y_masks.unsqueeze(-1) def _source_to_target_mask(self, ilens, olens): """Make source to target mask for MultiHeadedAttention using padded sequences >>> ilens = [4, 2] >>> olens = [5, 3] >>> self._source_to_target_mask(ilens) tensor([[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) & y_masks.unsqueeze(-1) @property def base_plot_keys(self): """base key names to plot during training. keys should match what `chainer.reporter` reports if you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. :rtype list[str] plot_keys: base keys to plot during training """ plot_keys = ["loss", "l1_loss", "l2_loss", "bce_loss"] if self.use_scaled_pos_enc: plot_keys += ["encoder_alpha", "decoder_alpha"] if self.use_guided_attn_loss: if "encoder" in self.modules_applied_guided_attn: plot_keys += ["enc_attn_loss"] if "decoder" in self.modules_applied_guided_attn: plot_keys += ["dec_attn_loss"] if "encoder-decoder" in self.modules_applied_guided_attn: plot_keys += ["enc_dec_attn_loss"] return plot_keys
class E2E(ASRInterface, torch.nn.Module): """E2E module. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (Namespace): argument Namespace containing options """ @staticmethod def add_arguments(parser): """Extend arguments for transducer models. Both Transformer and RNN modules are supported. General options encapsulate both modules options. """ group = parser.add_argument_group("transformer model setting") # Encoder - general group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') group.add_argument('--kernel-size', default=32, type=int) group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder - RNN group.add_argument('--eprojs', default=320, type=int, help='Number of encoder projection units') group.add_argument( '--subsample', default="1", type=str, help='Subsample input frames x_y_z means subsample every x frame \ at 1st layer, every y frame at 2nd layer etc.') # Attention - general group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder - general group.add_argument('--dtype', default='lstm', type=str, choices=['lstm', 'gru', 'transformer'], help='Type of decoder to use.') group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') group.add_argument('--dropout-rate-decoder', default=0.0, type=float, help='Dropout rate for the decoder') group.add_argument('--relative-v', default=False, type=strtobool) # Decoder - RNN group.add_argument('--dec-embed-dim', default=320, type=int, help='Number of decoder embeddings dimensions') group.add_argument('--dropout-rate-embed-decoder', default=0.0, type=float, help='Dropout rate for the decoder embeddings') # Transformer group.add_argument( "--input-layer", type=str, default="conv2d", choices=["conv2d", "vgg2l", "linear", "embed", "custom"], help='transformer encoder input layer type') group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=False, type=strtobool, help='normalize loss by length') group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument('--causal', type=strtobool, default=False) return parser def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0): """Construct an E2E object for transducer model. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (Namespace): argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder(idim=idim, d_model=args.adim, n_heads=args.aheads, d_ffn=args.eunits, layers=args.elayers, kernel_size=args.kernel_size, input_layer=args.input_layer, dropout_rate=args.dropout_rate, causal=args.causal) args.eprojs = args.adim if args.mtlalpha < 1.0: self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.ignore_id = ignore_id self.odim = odim self.adim = args.adim self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad, enc_mask=None): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder xs_pad = xs_pad[:, :max(ilens)] masks = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, masks) self.hs_pad = hs_pad # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] self.hs_pad = hs_pad cer_ctc = None batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask start_time = time.time() # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) return self.loss, loss_att_data, loss_ctc_data, self.acc def encode(self, x, streaming_mask=None): """Encode acoustic features. Args: x (ndarray): input acoustic feature (T, D) Returns: x (torch.Tensor): encoded features (T, attention_dim) """ self.eval() x = torch.as_tensor(x).unsqueeze(0).cuda() enc_output, _ = self.encoder(x, None, streaming_mask) return enc_output.squeeze(0) def greedy_recognize(self, x, recog_args, streaming_mask=None): h = self.encode(x, streaming_mask).unsqueeze(0) lpz = self.ctc.log_softmax(h) ys_hat = lpz.argmax(dim=-1)[0].cpu() ids = [] for i in range(len(ys_hat)): if ys_hat[i].item() != 0 and ys_hat[i].item() != ys_hat[i - 1].item(): ids.append(ys_hat[i].item()) rets = [{'score': 0.0, 'yseq': ids}] return rets def recognize(self, x, recog_args, char_list=None, rnnlm=None, streaming_mask=None): """ Recognize input features. Args: x (ndarray): input acoustic feature (T, D) recog_args (namespace): argument Namespace containing options char_list (list): list of characters rnnlm (torch.nn.Module): language model module Returns: y (list): n-best decoding results """ h = self.encode(x, streaming_mask) params = [h, recog_args] if recog_args.beam_size == 1: nbest_hyps = self.decoder.recognize(*params) else: params.append(rnnlm) start_time = time.time() nbest_hyps = self.decoder.recognize_beam(*params) #nbest_hyps, decoder_time = self.decoder.beam_search(h, recog_args, prefix=False) end_time = time.time() decode_time = end_time - start_time return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. Args: xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) Returns: ret (ndarray): attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). """ if self.etype == 'transformer' and self.dtype != 'transformer' and \ self.rnnt_mode == 'rnnt-att': raise NotImplementedError( "Transformer encoder with rnn attention decoder" "is not supported yet.") elif self.etype != 'transformer' and self.dtype != 'transformer': if self.rnnt_mode == 'rnnt': return [] else: with torch.no_grad(): hs_pad, hlens = xs_pad, ilens hpad, hlens, _ = self.encoder(hs_pad, hlens) ret = self.decoder.calculate_all_attentions( hpad, hlens, ys_pad) else: with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): ret[name] = m.attn.cpu().numpy() return ret