Esempio n. 1
0
class E2E(ASRInterface, torch.nn.Module):

    @staticmethod
    def add_arguments(parser):
        group = parser.add_argument_group("transformer model setting")
        group.add_argument("--transformer-init", type=str, default="pytorch",
                           choices=["pytorch", "xavier_uniform", "xavier_normal",
                                    "kaiming_uniform", "kaiming_normal"],
                           help='how to initialize transformer parameters')
        group.add_argument("--transformer-input-layer", type=str, default="conv2d",
                           choices=["conv2d", "linear", "embed"],
                           help='transformer input layer type')
        group.add_argument('--transformer-attn-dropout-rate', default=None, type=float,
                           help='dropout in transformer attention. use --dropout-rate if None is set')
        group.add_argument('--transformer-lr', default=10.0, type=float,
                           help='Initial value of learning rate')
        group.add_argument('--transformer-warmup-steps', default=25000, type=int,
                           help='optimizer warmup steps')
        group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool,
                           help='normalize loss by length')
        group.add_argument("--transformer-encoder-center-chunk-len", type=int, default=8,
                           help='transformer chunk size, only used when transformer-encoder-type is memory')
        group.add_argument("--transformer-encoder-left-chunk-len", type=int, default=0,
                           help='transformer left chunk size')
        group.add_argument("--transformer-encoder-hop-len", type=int, default=0,
                           help='transformer encoder hop len, default')
        group.add_argument("--transformer-encoder-right-chunk-len", type=int, default=0,
                           help='future data of the encoder')
        group.add_argument("--transformer-encoder-abs-embed", type=int, default=1,
                           help='whether the network use absolute embed')
        group.add_argument("--transformer-encoder-rel-embed", type=int, default=0,
                           help='whether the network us reality embed')
        group.add_argument("--transformer-encoder-use-memory", type=int, default=0,
                           help='whether the network us memory to store history')
        return parser

    @property
    def attention_plot_class(self):
        return PlotAttentionReport

    def __init__(self, idim, odim, args, ignore_id=-1):
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            center_len=args.transformer_encoder_center_chunk_len,
            left_len=args.transformer_encoder_left_chunk_len,
            hop_len=args.transformer_encoder_hop_len,
            right_len=args.transformer_encoder_right_chunk_len,
            abs_pos=args.transformer_encoder_abs_embed,
            rel_pos=args.transformer_encoder_rel_embed,
            use_mem=args.transformer_encoder_use_memory,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate
        )
        self.decoder = Decoder(
            odim=odim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate
        )
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = [1]
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight,
                                            args.transformer_length_normalized_loss)
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True)
        else:
            self.ctc = None

        if args.report_cer or args.report_wer or args.mtlalpha > 0.0:
            from espnet.nets.e2e_asr_common import ErrorCalculator
            self.error_calculator = ErrorCalculator(args.char_list,
                                                    args.sym_space, args.sym_blank,
                                                    args.report_cer, args.report_wer)
        else:
            self.error_calculator = None
        self.rnnlm = None

    def reset_parameters(self, args):
        if args.transformer_init == "pytorch":
            return
        # weight init
        for p in self.parameters():
            if p.dim() > 1:
                if args.transformer_init == "xavier_uniform":
                    torch.nn.init.xavier_uniform_(p.data)
                elif args.transformer_init == "xavier_normal":
                    torch.nn.init.xavier_normal_(p.data)
                elif args.transformer_init == "kaiming_uniform":
                    torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
                elif args.transformer_init == "kaiming_normal":
                    torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
                else:
                    raise ValueError("Unknown initialization: " + args.transformer_init)
        # bias init
        for p in self.parameters():
            if p.dim() == 1:
                p.data.zero_()
        # reset some modules with default init
        for m in self.modules():
            if isinstance(m, (torch.nn.Embedding, LayerNorm)):
                m.reset_parameters()

    def add_sos_eos(self, ys_pad):
        from espnet.nets.pytorch_backend.nets_utils import pad_list
        eos = ys_pad.new([self.eos])
        sos = ys_pad.new([self.sos])
        ys = [y[y != self.ignore_id] for y in ys_pad]  # parse padded ys
        ys_in = [torch.cat([sos, y], dim=0) for y in ys]
        ys_out = [torch.cat([y, eos], dim=0) for y in ys]
        return pad_list(ys_in, self.eos), pad_list(ys_out, self.ignore_id)

    def target_mask(self, ys_in_pad):
        ys_mask = ys_in_pad != self.ignore_id
        m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
        return ys_mask.unsqueeze(-2) & m

    def forward(self, xs_pad, ilens, ys_pad):
        '''E2E forward

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        '''
        # forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # forward decoder
        ys_in_pad, ys_out_pad = self.add_sos_eos(ys_pad)
        ys_mask = self.target_mask(ys_in_pad)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # compute loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad,
                               ignore_label=self.ignore_id)

        # TODO(karita) show predected text
        # TODO(karita) calculate these stats
        if self.mtlalpha == 0.0:
            loss_ctc = None
            cer_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad)

            ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss

    def recognize(self, feat, recog_args, char_list=None, rnnlm=None, use_jit=False):
        '''recognize feat

        :param ndnarray x: input acouctic feature (B, T, D) or (T, D)
        :param namespace recog_args: argment namespace contraining options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list

        TODO(karita): do not recompute previous attention for faster decoding
        '''
        self.eval()
        feat = torch.as_tensor(feat).unsqueeze(0)
        enc_output, _ = self.encoder(feat, None)
        if recog_args.ctc_weight > 0.0:
            lpz = self.ctc.log_softmax(enc_output)
            lpz = lpz.squeeze(0)
        else:
            lpz = None

        h = enc_output.squeeze(0)

        logging.info('input lengths: ' + str(h.size(0)))
        # search parms
        beam = recog_args.beam_size
        penalty = recog_args.penalty
        ctc_weight = recog_args.ctc_weight

        # preprare sos
        y = self.sos
        vy = h.new_zeros(1).long()

        if recog_args.maxlenratio == 0:
            maxlen = h.shape[0]
        else:
            # maxlen >= 1
            maxlen = max(1, int(recog_args.maxlenratio * h.size(0)))
        minlen = int(recog_args.minlenratio * h.size(0))
        logging.info('max output length: ' + str(maxlen))
        logging.info('min output length: ' + str(minlen))

        # initialize hypothesis
        if rnnlm:
            hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None}
        else:
            hyp = {'score': 0.0, 'yseq': [y]}
        if lpz is not None:
            import numpy

            from espnet.nets.ctc_prefix_score import CTCPrefixScore

            ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy)
            hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
            hyp['ctc_score_prev'] = 0.0
            if ctc_weight != 1.0:
                # pre-pruning based on attention scores
                from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO
                ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
            else:
                ctc_beam = lpz.shape[-1]
        hyps = [hyp]
        ended_hyps = []

        import six
        traced_decoder = None
        for i in six.moves.range(maxlen):
            logging.debug('position ' + str(i))

            hyps_best_kept = []
            for hyp in hyps:
                vy.unsqueeze(1)
                vy[0] = hyp['yseq'][i]

                # get nbest local scores and their ids
                ys_mask = subsequent_mask(i + 1).unsqueeze(0)
                ys = torch.tensor(hyp['yseq']).unsqueeze(0)
                # FIXME: jit does not match non-jit result
                if use_jit:
                    if traced_decoder is None:
                        traced_decoder = torch.jit.trace(self.decoder.recognize, (ys, ys_mask, enc_output))
                    local_att_scores = traced_decoder(ys, ys_mask, enc_output)
                else:
                    local_att_scores = self.decoder.recognize(ys, ys_mask, enc_output)

                if rnnlm:
                    rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy)
                    local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
                else:
                    local_scores = local_att_scores

                if lpz is not None:
                    local_best_scores, local_best_ids = torch.topk(
                        local_att_scores, ctc_beam, dim=1)
                    ctc_scores, ctc_states = ctc_prefix_score(
                        hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'])
                    local_scores = \
                        (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \
                        + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev'])
                    if rnnlm:
                        local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
                    local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
                    local_best_ids = local_best_ids[:, joint_best_ids[0]]
                else:
                    local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1)

                for j in six.moves.range(beam):
                    new_hyp = {}
                    new_hyp['score'] = hyp['score'] + float(local_best_scores[0, j])
                    new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
                    new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
                    new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j])
                    if rnnlm:
                        new_hyp['rnnlm_prev'] = rnnlm_state
                    if lpz is not None:
                        new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]]
                        new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]]
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(
                    hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam]

            # sort and get nbest
            hyps = hyps_best_kept
            logging.debug('number of pruned hypothes: ' + str(len(hyps)))
            if char_list is not None:
                logging.debug(
                    'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                logging.info('adding <eos> in the last postion in the loop')
                for hyp in hyps:
                    hyp['yseq'].append(self.eos)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][-1] == self.eos:
                    # only store the sequence that has more than minlen outputs
                    # also add penalty
                    if len(hyp['yseq']) > minlen:
                        hyp['score'] += (i + 1) * penalty
                        if rnnlm:  # Word LM needs to add final <eos> score
                            hyp['score'] += recog_args.lm_weight * rnnlm.final(
                                hyp['rnnlm_prev'])
                        ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            # end detection
            from espnet.nets.e2e_asr_common import end_detect
            if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
                logging.info('end detected at %d', i)
                break

            hyps = remained_hyps
            if len(hyps) > 0:
                logging.debug('remeined hypothes: ' + str(len(hyps)))
            else:
                logging.info('no hypothesis. Finish decoding.')
                break

            if char_list is not None:
                for hyp in hyps:
                    logging.debug(
                        'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))

            logging.debug('number of ended hypothes: ' + str(len(ended_hyps)))

        nbest_hyps = sorted(
            ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)]

        # check number of hypotheis
        if len(nbest_hyps) == 0:
            logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.')
            # should copy becasuse Namespace will be overwritten globally
            recog_args = Namespace(**vars(recog_args))
            recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
            return self.recognize(feat, recog_args, char_list, rnnlm)

        logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
        logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))
        return nbest_hyps

    def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
        '''E2E attention calculation

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
        :return: attention weights with the following shape,
            1) multi-head case => attention weights (B, H, Lmax, Tmax),
            2) other case => attention weights (B, Lmax, Tmax).
        :rtype: float ndarray
        '''
        with torch.no_grad():
            self.forward(xs_pad, ilens, ys_pad)
        ret = dict()
        for name, m in self.named_modules():
            if isinstance(m, MultiHeadedAttention):
                ret[name] = m.attn.cpu().numpy()
        return ret
class Transformer(TTSInterface, torch.nn.Module):
    """Text-to-Speech Transformer module.

    This is a module of text-to-speech Transformer described in `Neural Speech Synthesis with Transformer Network`_,
    which convert the sequence of characters or phonemes into the sequence of Mel-filterbanks.

    .. _`Neural Speech Synthesis with Transformer Network`:
        https://arxiv.org/pdf/1809.08895.pdf

    """
    @staticmethod
    def add_arguments(parser):
        """Add model-specific arguments to the parser."""
        group = parser.add_argument_group("transformer model setting")
        # network structure related
        group.add_argument(
            "--embed-dim",
            default=512,
            type=int,
            help="Dimension of character embedding in encoder prenet")
        group.add_argument("--eprenet-conv-layers",
                           default=3,
                           type=int,
                           help="Number of encoder prenet convolution layers")
        group.add_argument(
            "--eprenet-conv-chans",
            default=256,
            type=int,
            help="Number of encoder prenet convolution channels")
        group.add_argument("--eprenet-conv-filts",
                           default=5,
                           type=int,
                           help="Filter size of encoder prenet convolution")
        group.add_argument("--dprenet-layers",
                           default=2,
                           type=int,
                           help="Number of decoder prenet layers")
        group.add_argument("--dprenet-units",
                           default=256,
                           type=int,
                           help="Number of decoder prenet hidden units")
        group.add_argument("--elayers",
                           default=3,
                           type=int,
                           help="Number of encoder layers")
        group.add_argument("--eunits",
                           default=1536,
                           type=int,
                           help="Number of encoder hidden units")
        group.add_argument(
            "--adim",
            default=384,
            type=int,
            help="Number of attention transformation dimensions")
        group.add_argument("--aheads",
                           default=4,
                           type=int,
                           help="Number of heads for multi head attention")
        group.add_argument("--dlayers",
                           default=3,
                           type=int,
                           help="Number of decoder layers")
        group.add_argument("--dunits",
                           default=1536,
                           type=int,
                           help="Number of decoder hidden units")
        group.add_argument("--postnet-layers",
                           default=5,
                           type=int,
                           help="Number of postnet layers")
        group.add_argument("--postnet-chans",
                           default=256,
                           type=int,
                           help="Number of postnet channels")
        group.add_argument("--postnet-filts",
                           default=5,
                           type=int,
                           help="Filter size of postnet")
        group.add_argument(
            "--use-scaled-pos-enc",
            default=True,
            type=strtobool,
            help=
            "Use trainable scaled positional encoding instead of the fixed scale one."
        )
        group.add_argument("--use-batch-norm",
                           default=True,
                           type=strtobool,
                           help="Whether to use batch normalization")
        group.add_argument(
            "--encoder-normalize-before",
            default=False,
            type=strtobool,
            help="Whether to apply layer norm before encoder block")
        group.add_argument(
            "--decoder-normalize-before",
            default=False,
            type=strtobool,
            help="Whether to apply layer norm before decoder block")
        group.add_argument(
            "--encoder-concat-after",
            default=False,
            type=strtobool,
            help=
            "Whether to concatenate attention layer's input and output in encoder"
        )
        group.add_argument(
            "--decoder-concat-after",
            default=False,
            type=strtobool,
            help=
            "Whether to concatenate attention layer's input and output in decoder"
        )
        group.add_argument("--reduction-factor",
                           default=1,
                           type=int,
                           help="Reduction factor")
        group.add_argument("--spk-embed-dim",
                           default=None,
                           type=int,
                           help="Number of speaker embedding dimensions")
        group.add_argument("--spk-embed-integration-type",
                           type=str,
                           default="add",
                           choices=["add", "concat"],
                           help="How to integrate speaker embedding")
        # training related
        group.add_argument("--transformer-init",
                           type=str,
                           default="pytorch",
                           choices=[
                               "pytorch", "xavier_uniform", "xavier_normal",
                               "kaiming_uniform", "kaiming_normal"
                           ],
                           help="How to initialize transformer parameters")
        group.add_argument(
            "--initial-encoder-alpha",
            type=float,
            default=1.0,
            help="Initial alpha value in encoder's ScaledPositionalEncoding")
        group.add_argument(
            "--initial-decoder-alpha",
            type=float,
            default=1.0,
            help="Initial alpha value in decoder's ScaledPositionalEncoding")
        group.add_argument("--transformer-lr",
                           default=1.0,
                           type=float,
                           help="Initial value of learning rate")
        group.add_argument("--transformer-warmup-steps",
                           default=4000,
                           type=int,
                           help="Optimizer warmup steps")
        group.add_argument(
            "--transformer-enc-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for transformer encoder except for attention")
        group.add_argument(
            "--transformer-enc-positional-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for transformer encoder positional encoding")
        group.add_argument(
            "--transformer-enc-attn-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for transformer encoder self-attention")
        group.add_argument(
            "--transformer-dec-dropout-rate",
            default=0.1,
            type=float,
            help=
            "Dropout rate for transformer decoder except for attention and pos encoding"
        )
        group.add_argument(
            "--transformer-dec-positional-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for transformer decoder positional encoding")
        group.add_argument(
            "--transformer-dec-attn-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for transformer decoder self-attention")
        group.add_argument(
            "--transformer-enc-dec-attn-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for transformer encoder-decoder attention")
        group.add_argument("--eprenet-dropout-rate",
                           default=0.5,
                           type=float,
                           help="Dropout rate in encoder prenet")
        group.add_argument("--dprenet-dropout-rate",
                           default=0.5,
                           type=float,
                           help="Dropout rate in decoder prenet")
        group.add_argument("--postnet-dropout-rate",
                           default=0.5,
                           type=float,
                           help="Dropout rate in postnet")
        # loss related
        group.add_argument(
            "--use-masking",
            default=True,
            type=strtobool,
            help="Whether to use masking in calculation of loss")
        group.add_argument("--loss-type",
                           default="L1",
                           choices=["L1", "L2", "L1+L2"],
                           help="How to calc loss")
        group.add_argument(
            "--bce-pos-weight",
            default=5.0,
            type=float,
            help=
            "Positive sample weight in BCE calculation (only for use-masking=True)"
        )
        group.add_argument("--use-guided-attn-loss",
                           default=False,
                           type=strtobool,
                           help="Whether to use guided attention loss")
        group.add_argument("--guided-attn-loss-sigma",
                           default=0.4,
                           type=float,
                           help="Sigma in guided attention loss")
        group.add_argument("--guided-attn-loss-lambda",
                           default=1.0,
                           type=float,
                           help="Lambda in guided attention loss")
        group.add_argument(
            "--num-heads-applied-guided-attn",
            default=2,
            type=int,
            help=
            "Number of heads in each layer to be applied guided attention loss"
            "if set -1, all of the heads will be applied.")
        group.add_argument(
            "--num-layers-applied-guided-attn",
            default=2,
            type=int,
            help="Number of layers to be applied guided attention loss"
            "if set -1, all of the layers will be applied.")
        group.add_argument(
            "--modules-applied-guided-attn",
            type=str,
            nargs="+",
            default=["encoder-decoder"],
            help="Module name list to be applied guided attention loss")
        return parser

    @property
    def attention_plot_class(self):
        """Return plot class for attention weight plot."""
        return TTSPlot

    def __init__(self, idim, odim, args=None):
        """Initialize TTS-Transformer module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            args (Namespace, optional):
                - embed_dim (int): Dimension of character embedding.
                - eprenet_conv_layers (int): Number of encoder prenet convolution layers.
                - eprenet_conv_chans (int): Number of encoder prenet convolution channels.
                - eprenet_conv_filts (int): Filter size of encoder prenet convolution.
                - dprenet_layers (int): Number of decoder prenet layers.
                - dprenet_units (int): Number of decoder prenet hidden units.
                - elayers (int): Number of encoder layers.
                - eunits (int): Number of encoder hidden units.
                - adim (int): Number of attention transformation dimensions.
                - aheads (int): Number of heads for multi head attention.
                - dlayers (int): Number of decoder layers.
                - dunits (int): Number of decoder hidden units.
                - postnet_layers (int): Number of postnet layers.
                - postnet_chans (int): Number of postnet channels.
                - postnet_filts (int): Filter size of postnet.
                - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding.
                - use_batch_norm (bool): Whether to use batch normalization in encoder prenet.
                - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block.
                - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block.
                - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder.
                - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder.
                - reduction_factor (int): Reduction factor.
                - spk_embed_dim (int): Number of speaker embedding dimenstions.
                - spk_embed_integration_type: How to integrate speaker embedding.
                - transformer_init (float): How to initialize transformer parameters.
                - transformer_lr (float): Initial value of learning rate.
                - transformer_warmup_steps (int): Optimizer warmup steps.
                - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding.
                - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding.
                - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module.
                - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding.
                - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding.
                - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module.
                - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module.
                - eprenet_dropout_rate (float): Dropout rate in encoder prenet.
                - dprenet_dropout_rate (float): Dropout rate in decoder prenet.
                - postnet_dropout_rate (float): Dropout rate in postnet.
                - use_masking (bool): Whether to use masking in calculation of loss.
                - bce_pos_weight (float): Positive sample weight in bce calculation (only for use_masking=true).
                - loss_type (str): How to calculate loss.
                - use_guided_attn_loss (bool): Whether to use guided attention loss.
                - num_heads_applied_guided_attn (int): Number of heads in each layer to apply guided attention loss.
                - num_layers_applied_guided_attn (int): Number of layers to apply guided attention loss.
                - modules_applied_guided_attn (list): List of module names to apply guided attention loss.
                - guided-attn-loss-sigma (float) Sigma in guided attention loss.
                - guided-attn-loss-lambda (float): Lambda in guided attention loss.

        """
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.spk_embed_dim = args.spk_embed_dim
        if self.spk_embed_dim is not None:
            self.spk_embed_integration_type = args.spk_embed_integration_type
        self.use_scaled_pos_enc = args.use_scaled_pos_enc
        self.reduction_factor = args.reduction_factor
        self.loss_type = args.loss_type
        self.use_guided_attn_loss = args.use_guided_attn_loss
        if self.use_guided_attn_loss:
            if args.num_layers_applied_guided_attn == -1:
                self.num_layers_applied_guided_attn = args.elayers
            else:
                self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn
            if args.num_heads_applied_guided_attn == -1:
                self.num_heads_applied_guided_attn = args.aheads
            else:
                self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn
            self.modules_applied_guided_attn = args.modules_applied_guided_attn

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding

        # define transformer encoder
        if args.eprenet_conv_layers != 0:
            # encoder prenet
            encoder_input_layer = torch.nn.Sequential(
                EncoderPrenet(idim=idim,
                              embed_dim=args.embed_dim,
                              elayers=0,
                              econv_layers=args.eprenet_conv_layers,
                              econv_chans=args.eprenet_conv_chans,
                              econv_filts=args.eprenet_conv_filts,
                              use_batch_norm=args.use_batch_norm,
                              dropout_rate=args.eprenet_dropout_rate,
                              padding_idx=padding_idx),
                torch.nn.Linear(args.eprenet_conv_chans, args.adim))
        else:
            encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                     embedding_dim=args.adim,
                                                     padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=args.transformer_enc_dropout_rate,
            positional_dropout_rate=args.
            transformer_enc_positional_dropout_rate,
            attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
            pos_enc_class=pos_enc_class,
            normalize_before=args.encoder_normalize_before,
            concat_after=args.encoder_concat_after)

        # define projection layer
        if self.spk_embed_dim is not None:
            if self.spk_embed_integration_type == "add":
                self.projection = torch.nn.Linear(self.spk_embed_dim,
                                                  args.adim)
            else:
                self.projection = torch.nn.Linear(
                    args.adim + self.spk_embed_dim, args.adim)

        # define transformer decoder
        if args.dprenet_layers != 0:
            # decoder prenet
            decoder_input_layer = torch.nn.Sequential(
                DecoderPrenet(idim=odim,
                              n_layers=args.dprenet_layers,
                              n_units=args.dprenet_units,
                              dropout_rate=args.dprenet_dropout_rate),
                torch.nn.Linear(args.dprenet_units, args.adim))
        else:
            decoder_input_layer = "linear"
        self.decoder = Decoder(
            odim=-1,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.transformer_dec_dropout_rate,
            positional_dropout_rate=args.
            transformer_dec_positional_dropout_rate,
            self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
            src_attention_dropout_rate=args.
            transformer_enc_dec_attn_dropout_rate,
            input_layer=decoder_input_layer,
            use_output_layer=False,
            pos_enc_class=pos_enc_class,
            normalize_before=args.decoder_normalize_before,
            concat_after=args.decoder_concat_after)

        # define final projection
        self.feat_out = torch.nn.Linear(args.adim,
                                        odim * args.reduction_factor)
        self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor)

        # define postnet
        self.postnet = None if args.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=args.postnet_layers,
            n_chans=args.postnet_chans,
            n_filts=args.postnet_filts,
            use_batch_norm=args.use_batch_norm,
            dropout_rate=args.postnet_dropout_rate)

        # define loss function
        self.criterion = TransformerLoss(use_masking=args.use_masking,
                                         bce_pos_weight=args.bce_pos_weight)
        if self.use_guided_attn_loss:
            self.attn_criterion = GuidedMultiHeadAttentionLoss(
                sigma=args.guided_attn_loss_sigma,
                alpha=args.guided_attn_loss_lambda,
            )

        # initialize parameters
        self._reset_parameters(init_type=args.transformer_init,
                               init_enc_alpha=args.initial_encoder_alpha,
                               init_dec_alpha=args.initial_decoder_alpha)

    def _reset_parameters(self,
                          init_type,
                          init_enc_alpha=1.0,
                          init_dec_alpha=1.0):
        # initialize parameters
        initialize(self, init_type)

        # initialize alpha in scaled positional encoding
        if self.use_scaled_pos_enc:
            self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha)
            self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)

    def _add_first_frame_and_remove_last_frame(self, ys):
        ys_in = torch.cat(
            [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1)
        return ys_in

    def forward(self,
                xs,
                ilens,
                ys,
                labels,
                olens,
                spembs=None,
                asrtts=False,
                *args,
                **kwargs):
        """Calculate forward propagation.

        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).

        Returns:
            Tensor: Loss value.

        """
        # remove unnecessary padded part (for multi-gpus)
        max_ilen = max(ilens)
        max_olen = max(olens)
        if max_ilen != xs.shape[1]:
            xs = xs[:, :max_ilen]
        if max_olen != ys.shape[1]:
            ys = ys[:, :max_olen]
            labels = labels[:, :max_olen]

        # forward encoder
        x_masks = self._source_mask(ilens)
        hs, _ = self.encoder(xs, x_masks)

        # integrate speaker embedding
        if self.spk_embed_dim is not None:
            hs = self._integrate_with_spk_embed(hs, spembs)

        # thin out frames for reduction factor (B, Lmax, odim) ->  (B, Lmax//r, odim)
        if self.reduction_factor > 1:
            ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor]
            olens_in = olens.new(
                [olen // self.reduction_factor for olen in olens])
        else:
            ys_in, olens_in = ys, olens

        # add first zero frame and remove last frame for auto-regressive
        ys_in = self._add_first_frame_and_remove_last_frame(ys_in)

        # forward decoder
        y_masks = self._target_mask(olens_in)
        xy_masks = self._source_to_target_mask(ilens, olens_in)
        zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks)
        # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
        before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
        # (B, Lmax//r, r) -> (B, Lmax//r * r)
        logits = self.prob_out(zs).view(zs.size(0), -1)

        # postnet -> (B, Lmax//r * r, odim)
        if self.postnet is None:
            after_outs = before_outs
        else:
            after_outs = before_outs + self.postnet(before_outs.transpose(
                1, 2)).transpose(1, 2)
        import numpy as np
        name = np.random.choice(100, 1)
        import matplotlib.pyplot as plt
        fig = plt.figure()
        ax1 = plt.subplot(2, 1, 1)
        ax2 = plt.subplot(2, 1, 2)
        ##for n in range(int(after_outs.size(0))):
        mel_fbank = np.transpose(after_outs[0].detach().cpu().numpy())
        gt_mel_fbank = np.transpose(ys[0].detach().cpu().numpy())
        ax1.imshow(mel_fbank, aspect='auto', cmap=plt.cm.jet)
        ax2.imshow(gt_mel_fbank, aspect='auto', cmap=plt.cm.jet)
        #fig.savefig('images_asr2tts_transformer/asr2tts.%d.%d.png' % (0,name), orientation='landscape')

        # modifiy mod part of groundtruth
        if self.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_olen = max(olens)
            ys = ys[:, :max_olen]
            labels = labels[:, :max_olen]
            labels[:, -1] = 1.0  # make sure at least one frame has 1

        # caluculate loss values
        if asrtts:
            l1_loss, l2_loss, bce_loss = self.criterion(after_outs,
                                                        before_outs,
                                                        logits,
                                                        ys,
                                                        labels,
                                                        olens,
                                                        reduction=False)
            if self.loss_type == "L1":
                loss = l1_loss + bce_loss
            elif self.loss_type == "L2":
                loss = l2_loss + bce_loss
            elif self.loss_type == "L1+L2":
                loss = l1_loss + l2_loss + bce_loss
            else:
                raise ValueError("unknown --loss-type " + self.loss_type)
            report_keys = [
                {
                    "l1_loss": l1_loss.mean().item()
                },
                {
                    "l2_loss": l2_loss.mean().item()
                },
                {
                    "bce_loss": bce_loss.mean().item()
                },
                {
                    "loss": loss.mean().item()
                },
            ]
        else:
            l1_loss, l2_loss, bce_loss = self.criterion(
                after_outs, before_outs, logits, ys, labels, olens)
            if self.loss_type == "L1":
                loss = l1_loss + bce_loss
            elif self.loss_type == "L2":
                loss = l2_loss + bce_loss
            elif self.loss_type == "L1+L2":
                loss = l1_loss + l2_loss + bce_loss
            else:
                raise ValueError("unknown --loss-type " + self.loss_type)
            report_keys = [
                {
                    "l1_loss": l1_loss.item()
                },
                {
                    "l2_loss": l2_loss.item()
                },
                {
                    "bce_loss": bce_loss.item()
                },
                {
                    "loss": loss.item()
                },
            ]

        # calculate guided attention loss
        if self.use_guided_attn_loss:
            # calculate for encoder
            if "encoder" in self.modules_applied_guided_attn:
                att_ws = []
                for idx, layer_idx in enumerate(
                        reversed(range(len(self.encoder.encoders)))):
                    att_ws += [
                        self.encoder.encoders[layer_idx].self_attn.
                        attn[:, :self.num_heads_applied_guided_attn]
                    ]
                    if idx + 1 == self.num_layers_applied_guided_attn:
                        break
                att_ws = torch.cat(att_ws, dim=1)  # (B, H*L, T_in, T_in)
                enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens)
                loss = loss + enc_attn_loss
                report_keys += [{"enc_attn_loss": enc_attn_loss.item()}]
            # calculate for decoder
            if "decoder" in self.modules_applied_guided_attn:
                att_ws = []
                for idx, layer_idx in enumerate(
                        reversed(range(len(self.decoder.decoders)))):
                    att_ws += [
                        self.decoder.decoders[layer_idx].self_attn.
                        attn[:, :self.num_heads_applied_guided_attn]
                    ]
                    if idx + 1 == self.num_layers_applied_guided_attn:
                        break
                att_ws = torch.cat(att_ws, dim=1)  # (B, H*L, T_out, T_out)
                dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in)
                loss = loss + dec_attn_loss
                report_keys += [{"dec_attn_loss": dec_attn_loss.item()}]
            # calculate for encoder-decoder
            if "encoder-decoder" in self.modules_applied_guided_attn:
                att_ws = []
                for idx, layer_idx in enumerate(
                        reversed(range(len(self.decoder.decoders)))):
                    att_ws += [
                        self.decoder.decoders[layer_idx].src_attn.
                        attn[:, :self.num_heads_applied_guided_attn]
                    ]
                    if idx + 1 == self.num_layers_applied_guided_attn:
                        break
                att_ws = torch.cat(att_ws, dim=1)  # (B, H*L, T_out, T_in)
                enc_dec_attn_loss = self.attn_criterion(
                    att_ws, ilens, olens_in)
                loss = loss + enc_dec_attn_loss
                report_keys += [{
                    "enc_dec_attn_loss": enc_dec_attn_loss.item()
                }]
        else:
            att_ws = None
        # report extra information
        if self.use_scaled_pos_enc:
            report_keys += [
                {
                    "encoder_alpha": self.encoder.embed[-1].alpha.data.item()
                },
                {
                    "decoder_alpha": self.decoder.embed[-1].alpha.data.item()
                },
            ]
        self.reporter.report(report_keys)
        if asrtts:
            return loss, after_outs, before_outs, logits, att_ws
        else:
            return loss

    def inference(self, x, inference_args, spemb=None, *args, **kwargs):
        """Generate the sequence of features given the sequences of characters.

        Args:
            x (Tensor): Input sequence of characters (T,).
            inference_args (Namespace):
                - threshold (float): Threshold in inference.
                - minlenratio (float): Minimum length ratio in inference.
                - maxlenratio (float): Maximum length ratio in inference.
            spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).

        Returns:
            Tensor: Output sequence of features (L, odim).
            Tensor: Output sequence of stop probabilities (L,).
            Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T).

        """
        # get options
        threshold = inference_args.threshold
        minlenratio = inference_args.minlenratio
        maxlenratio = inference_args.maxlenratio

        # forward encoder
        xs = x.unsqueeze(0)
        hs, _ = self.encoder(xs, None)

        # integrate speaker embedding
        if self.spk_embed_dim is not None:
            spembs = spemb.unsqueeze(0)
            hs = self._integrate_with_spk_embed(hs, spembs)

        # set limits of length
        maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor)
        minlen = int(hs.size(1) * minlenratio / self.reduction_factor)

        # initialize
        idx = 0
        ys = hs.new_zeros(1, 1, self.odim)
        outs, probs = [], []

        # forward decoder step-by-step
        while True:
            # update index
            idx += 1

            # calculate output and stop prob at idx-th step
            y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device)
            z = self.decoder.recognize(ys, y_masks, hs)  # (B, adim)
            outs += [self.feat_out(z).view(self.reduction_factor,
                                           self.odim)]  # [(r, odim), ...]
            probs += [torch.sigmoid(self.prob_out(z))[0]]  # [(r), ...]

            # update next inputs
            ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)),
                           dim=1)  # (1, idx + 1, odim)

            # check whether to finish generation
            if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
                # check mininum length
                if idx < minlen:
                    continue
                outs = torch.cat(outs, dim=0).unsqueeze(0).transpose(
                    1, 2)  # (L, odim) -> (1, L, odim) -> (1, odim, L)
                if self.postnet is not None:
                    outs = outs + self.postnet(outs)  # (1, odim, L)
                outs = outs.transpose(2, 1).squeeze(0)  # (L, odim)
                probs = torch.cat(probs, dim=0)
                break

        # get attention weights
        att_ws = []
        for name, m in self.named_modules():
            if isinstance(m, MultiHeadedAttention) and "src" in name:
                att_ws += [m.attn]
        att_ws = torch.cat(att_ws, dim=0)

        return outs, probs, att_ws

    def calculate_all_attentions(self,
                                 xs,
                                 ilens,
                                 ys,
                                 olens,
                                 spembs=None,
                                 skip_output=False,
                                 keep_tensor=False,
                                 *args,
                                 **kwargs):
        """Calculate all of the attention weights.

        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
            skip_output (bool, optional): Whether to skip calculate the final output.
            keep_tensor (bool, optional): Whether to keep original tensor.

        Returns:
            dict: Dict of attention weights and outputs.

        """
        with torch.no_grad():
            # forward encoder
            x_masks = self._source_mask(ilens)
            hs, _ = self.encoder(xs, x_masks)

            # integrate speaker embedding
            if self.spk_embed_dim is not None:
                hs = self._integrate_with_spk_embed(hs, spembs)

            # thin out frames for reduction factor (B, Lmax, odim) ->  (B, Lmax//r, odim)
            if self.reduction_factor > 1:
                ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor]
                olens_in = olens.new(
                    [olen // self.reduction_factor for olen in olens])
            else:
                ys_in, olens_in = ys, olens

            # add first zero frame and remove last frame for auto-regressive
            ys_in = self._add_first_frame_and_remove_last_frame(ys_in)

            # forward decoder
            y_masks = self._target_mask(olens_in)
            xy_masks = self._source_to_target_mask(ilens, olens_in)
            zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks)

            # calculate final outputs
            if not skip_output:
                before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
                if self.postnet is None:
                    after_outs = before_outs
                else:
                    after_outs = before_outs + self.postnet(
                        before_outs.transpose(1, 2)).transpose(1, 2)

        # modifiy mod part of output lengths due to reduction factor > 1
        if self.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])

        # store into dict
        att_ws_dict = dict()
        if keep_tensor:
            for name, m in self.named_modules():
                if isinstance(m, MultiHeadedAttention):
                    att_ws_dict[name] = m.attn
            if not skip_output:
                att_ws_dict["before_postnet_fbank"] = before_outs
                att_ws_dict["after_postnet_fbank"] = after_outs
        else:
            for name, m in self.named_modules():
                if isinstance(m, MultiHeadedAttention):
                    attn = m.attn.cpu().numpy()
                    if "encoder" in name:
                        attn = [
                            a[:, :l, :l] for a, l in zip(attn, ilens.tolist())
                        ]
                    elif "decoder" in name:
                        if "src" in name:
                            attn = [
                                a[:, :ol, :il] for a, il, ol in zip(
                                    attn, ilens.tolist(), olens_in.tolist())
                            ]
                        elif "self" in name:
                            attn = [
                                a[:, :l, :l]
                                for a, l in zip(attn, olens_in.tolist())
                            ]
                        else:
                            logging.warning("unknown attention module: " +
                                            name)
                    else:
                        logging.warning("unknown attention module: " + name)
                    att_ws_dict[name] = attn
            if not skip_output:
                before_outs = before_outs.cpu().numpy()
                after_outs = after_outs.cpu().numpy()
                att_ws_dict["before_postnet_fbank"] = [
                    m[:l].T for m, l in zip(before_outs, olens.tolist())
                ]
                att_ws_dict["after_postnet_fbank"] = [
                    m[:l].T for m, l in zip(after_outs, olens.tolist())
                ]

        return att_ws_dict

    def _integrate_with_spk_embed(self, hs, spembs):
        """Integrate speaker embedding with hidden states.

        Args:
            hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
            spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim).

        Returns:
            Tensor: Batch of integrated hidden state sequences (B, Tmax, adim)

        """
        if self.spk_embed_integration_type == "add":
            # apply projection and then add to hidden states
            spembs = self.projection(F.normalize(spembs))
            hs = hs + spembs.unsqueeze(1)
        elif self.spk_embed_integration_type == "concat":
            # concat hidden states with spk embeds and then apply projection
            spembs = F.normalize(spembs).unsqueeze(1).expand(
                -1, hs.size(1), -1)
            hs = self.projection(torch.cat([hs, spembs], dim=-1))
        else:
            raise NotImplementedError("support only add or concat.")

        return hs

    def _source_mask(self, ilens):
        """Make masks for self-attention.

        Examples:
            >>> ilens = [5, 3]
            >>> self._source_mask(ilens)
            tensor([[[1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1]],
                    [[1, 1, 1, 0, 0],
                     [1, 1, 1, 0, 0],
                     [1, 1, 1, 0, 0],
                     [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0]]], dtype=torch.uint8)

        """
        x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
        return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)

    def _target_mask(self, olens):
        """Make masks for masked self-attention.

        Examples:
            >>> olens = [5, 3]
            >>> self._target_mask(olens)
            tensor([[[1, 0, 0, 0, 0],
                     [1, 1, 0, 0, 0],
                     [1, 1, 1, 0, 0],
                     [1, 1, 1, 1, 0],
                     [1, 1, 1, 1, 1]],
                    [[1, 0, 0, 0, 0],
                     [1, 1, 0, 0, 0],
                     [1, 1, 1, 0, 0],
                     [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0]]], dtype=torch.uint8)

        """
        y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
        s_masks = subsequent_mask(y_masks.size(-1),
                                  device=y_masks.device).unsqueeze(0)
        return y_masks.unsqueeze(-2) & s_masks & y_masks.unsqueeze(-1)

    def _source_to_target_mask(self, ilens, olens):
        """Make masks for encoder-decoder attention.

        Examples:
            >>> ilens = [4, 2]
            >>> olens = [5, 3]
            >>> self._source_to_target_mask(ilens)
            tensor([[[1, 1, 1, 1],
                     [1, 1, 1, 1],
                     [1, 1, 1, 1],
                     [1, 1, 1, 1],
                     [1, 1, 1, 1]],
                    [[1, 1, 0, 0],
                     [1, 1, 0, 0],
                     [1, 1, 0, 0],
                     [0, 0, 0, 0],
                     [0, 0, 0, 0]]], dtype=torch.uint8)

        """
        x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
        y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
        return x_masks.unsqueeze(-2) & y_masks.unsqueeze(-1)

    @property
    def base_plot_keys(self):
        """Return base key names to plot during training. keys should match what `chainer.reporter` reports.

        If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values.
        also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values.

        Returns:
            list: List of strings which are base keys to plot during training.

        """
        plot_keys = ["loss", "l1_loss", "l2_loss", "bce_loss"]
        if self.use_scaled_pos_enc:
            plot_keys += ["encoder_alpha", "decoder_alpha"]
        if self.use_guided_attn_loss:
            if "encoder" in self.modules_applied_guided_attn:
                plot_keys += ["enc_attn_loss"]
            if "decoder" in self.modules_applied_guided_attn:
                plot_keys += ["dec_attn_loss"]
            if "encoder-decoder" in self.modules_applied_guided_attn:
                plot_keys += ["enc_dec_attn_loss"]

        return plot_keys
class Transformer(TTSInterface, torch.nn.Module):
    """Text-to-Speech Transformer

    Reference:
        Neural Speech Synthesis with Transformer Network
        (https://arxiv.org/pdf/1809.08895.pdf)

    :param int idim: dimension of the inputs
    :param int odim: dimension of the outputs
    :param Namespace args: argments containing following attributes
        (int) embed_dim: dimension of character embedding
        (int) eprenet_conv_layers: number of encoder prenet convolution layers
        (int) eprenet_conv_chans: number of encoder prenet convolution channels
        (int) eprenet_conv_filts: filter size of encoder prenet convolution
        (int) dprenet_layers: number of decoder prenet layers
        (int) dprenet_units: number of decoder prenet hidden units
        (int) elayers: number of encoder layers
        (int) eunits: number of encoder hidden units
        (int) adim: number of attention transformation dimensions
        (int) aheads: number of heads for multi head attention
        (int) dlayers: number of decoder layers
        (int) dunits: number of decoder hidden units
        (int) postnet_layers: number of postnet layers
        (int) postnet_chans: number of postnet channels
        (int) postnet_filts: filter size of postnet
        (bool) use_scaled_pos_enc: whether to use trainable scaled positional encoding instead of the fixed scale one
        (bool) use_batch_norm: whether to use batch normalization in encoder prenet
        (bool) encoder_normalize_before: whether to perform layer normalization before encoder block
        (bool) decoder_normalize_before: whether to perform layer normalization before decoder block
        (bool) encoder_concat_after: whether to concatenate attention layer's input and output in encoder
        (bool) decoder_concat_after: whether to concatenate attention layer's input and output in decoder
        (int) reduction_factor: reduction factor
        (float) transformer_init: how to initialize transformer parameters
        (float) transformer_lr: initial value of learning rate
        (int) transformer_warmup_steps: optimizer warmup steps
        (float) transformer_enc_dropout_rate: dropout rate in encoder except for attention and positional encoding
        (float) transformer_enc_positional_dropout_rate: dropout rate after encoder positional encoding
        (float) transformer_enc_attn_dropout_rate: dropout rate in encoder self-attention module
        (float) transformer_dec_dropout_rate: dropout rate in decoder except for attention and positional encoding
        (float) transformer_dec_positional_dropout_rate:  dropout rate after decoder positional encoding
        (float) transformer_dec_attn_dropout_rate: dropout rate in deocoder self-attention module
        (float) transformer_enc_dec_attn_dropout_rate: dropout rate in encoder-deocoder attention module
        (float) eprenet_dropout_rate: dropout rate in encoder prenet
        (float) dprenet_dropout_rate: dropout rate in decoder prenet
        (float) postnet_dropout_rate: dropout rate in postnet
        (bool) use_masking: whether to use masking in calculation of loss
        (float) bce_pos_weight: positive sample weight in bce calculation (only for use_masking=true)
        (str) loss_type: how to calculate loss
        (bool) use_guided_attn_loss: whether to use guided attention loss
        (int) num_heads_applied_guided_attn: number of heads in each layer to be applied guided attention loss
        (int) num_layers_applied_guided_attn: number of layers to be applied guided attention loss
        (list) modules_applied_guided_attn: list of module names to be applied guided attention loss
    """
    @staticmethod
    def add_arguments(parser):
        group = parser.add_argument_group("transformer model setting")
        # network structure related
        group.add_argument(
            "--embed-dim",
            default=512,
            type=int,
            help="Dimension of character embedding in encoder prenet")
        group.add_argument("--eprenet-conv-layers",
                           default=3,
                           type=int,
                           help="Number of encoder prenet convolution layers")
        group.add_argument(
            "--eprenet-conv-chans",
            default=256,
            type=int,
            help="Number of encoder prenet convolution channels")
        group.add_argument("--eprenet-conv-filts",
                           default=5,
                           type=int,
                           help="Filter size of encoder prenet convolution")
        group.add_argument("--dprenet-layers",
                           default=2,
                           type=int,
                           help="Number of decoder prenet layers")
        group.add_argument("--dprenet-units",
                           default=256,
                           type=int,
                           help="Number of decoder prenet hidden units")
        group.add_argument("--elayers",
                           default=3,
                           type=int,
                           help="Number of encoder layers")
        group.add_argument("--eunits",
                           default=1536,
                           type=int,
                           help="Number of encoder hidden units")
        group.add_argument(
            "--adim",
            default=384,
            type=int,
            help="Number of attention transformation dimensions")
        group.add_argument("--aheads",
                           default=4,
                           type=int,
                           help="Number of heads for multi head attention")
        group.add_argument("--dlayers",
                           default=3,
                           type=int,
                           help="Number of decoder layers")
        group.add_argument("--dunits",
                           default=1536,
                           type=int,
                           help="Number of decoder hidden units")
        group.add_argument("--postnet-layers",
                           default=5,
                           type=int,
                           help="Number of postnet layers")
        group.add_argument("--postnet-chans",
                           default=256,
                           type=int,
                           help="Number of postnet channels")
        group.add_argument("--postnet-filts",
                           default=5,
                           type=int,
                           help="Filter size of postnet")
        group.add_argument(
            "--use-scaled-pos-enc",
            default=True,
            type=strtobool,
            help=
            "Use trainable scaled positional encoding instead of the fixed scale one."
        )
        group.add_argument("--use-batch-norm",
                           default=True,
                           type=strtobool,
                           help="Whether to use batch normalization")
        group.add_argument(
            "--encoder-normalize-before",
            default=False,
            type=strtobool,
            help="Whether to apply layer norm before encoder block")
        group.add_argument(
            "--decoder-normalize-before",
            default=False,
            type=strtobool,
            help="Whether to apply layer norm before decoder block")
        group.add_argument(
            "--encoder-concat-after",
            default=False,
            type=strtobool,
            help=
            "Whether to concatenate attention layer's input and output in encoder"
        )
        group.add_argument(
            "--decoder-concat-after",
            default=False,
            type=strtobool,
            help=
            "Whether to concatenate attention layer's input and output in decoder"
        )
        parser.add_argument("--reduction-factor",
                            default=1,
                            type=int,
                            help="Reduction factor")
        # training related
        group.add_argument("--transformer-init",
                           type=str,
                           default="pytorch",
                           choices=[
                               "pytorch", "xavier_uniform", "xavier_normal",
                               "kaiming_uniform", "kaiming_normal"
                           ],
                           help="How to initialize transformer parameters")
        group.add_argument(
            "--initial-encoder-alpha",
            type=float,
            default=1.0,
            help="Initial alpha value in encoder's ScaledPositionalEncoding")
        group.add_argument(
            "--initial-decoder-alpha",
            type=float,
            default=1.0,
            help="Initial alpha value in decoder's ScaledPositionalEncoding")
        group.add_argument("--transformer-lr",
                           default=1.0,
                           type=float,
                           help="Initial value of learning rate")
        group.add_argument("--transformer-warmup-steps",
                           default=4000,
                           type=int,
                           help="Optimizer warmup steps")
        group.add_argument(
            "--transformer-enc-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for transformer encoder except for attention")
        group.add_argument(
            "--transformer-enc-positional-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for transformer encoder positional encoding")
        group.add_argument(
            "--transformer-enc-attn-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for transformer encoder self-attention")
        group.add_argument(
            "--transformer-dec-dropout-rate",
            default=0.1,
            type=float,
            help=
            "Dropout rate for transformer decoder except for attention and pos encoding"
        )
        group.add_argument(
            "--transformer-dec-positional-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for transformer decoder positional encoding")
        group.add_argument(
            "--transformer-dec-attn-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for transformer decoder self-attention")
        group.add_argument(
            "--transformer-enc-dec-attn-dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for transformer encoder-decoder attention")
        group.add_argument("--eprenet-dropout-rate",
                           default=0.5,
                           type=float,
                           help="Dropout rate in encoder prenet")
        group.add_argument("--dprenet-dropout-rate",
                           default=0.5,
                           type=float,
                           help="Dropout rate in decoder prenet")
        group.add_argument("--postnet-dropout-rate",
                           default=0.5,
                           type=float,
                           help="Dropout rate in postnet")
        # loss related
        group.add_argument(
            "--use-masking",
            default=True,
            type=strtobool,
            help="Whether to use masking in calculation of loss")
        group.add_argument("--loss-type",
                           default="L1",
                           choices=["L1", "L2", "L1+L2"],
                           help="How to calc loss")
        group.add_argument(
            "--bce-pos-weight",
            default=5.0,
            type=float,
            help=
            "Positive sample weight in BCE calculation (only for use-masking=True)"
        )
        group.add_argument("--use-guided-attn-loss",
                           default=False,
                           type=strtobool,
                           help="Whether to use guided attention loss")
        group.add_argument("--guided-attn-loss-sigma",
                           default=0.4,
                           type=float,
                           help="Sigma in guided attention loss")
        group.add_argument(
            "--num-heads-applied-guided-attn",
            default=2,
            type=int,
            help=
            "Number of heads in each layer to be applied guided attention loss"
            "if set -1, all of the heads will be applied.")
        group.add_argument(
            "--num-layers-applied-guided-attn",
            default=2,
            type=int,
            help="Number of layers to be applied guided attention loss"
            "if set -1, all of the layers will be applied.")
        group.add_argument(
            "--modules-applied-guided-attn",
            type=str,
            nargs="+",
            default=["encoder-decoder"],
            help="Module name list to be applied guided attention loss")
        return parser

    @property
    def attention_plot_class(self):
        return TTSPlot

    def __init__(self, idim, odim, args):
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.use_scaled_pos_enc = args.use_scaled_pos_enc
        self.reduction_factor = args.reduction_factor
        self.loss_type = args.loss_type
        self.use_guided_attn_loss = args.use_guided_attn_loss
        if self.use_guided_attn_loss:
            if args.num_layers_applied_guided_attn == -1:
                self.num_layers_applied_guided_attn = args.elayers
            else:
                self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn
            if args.num_heads_applied_guided_attn == -1:
                self.num_heads_applied_guided_attn = args.aheads
            else:
                self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn
            self.modules_applied_guided_attn = args.modules_applied_guided_attn

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding

        # define transformer encoder
        if args.eprenet_conv_layers != 0:
            # encoder prenet
            encoder_input_layer = torch.nn.Sequential(
                EncoderPrenet(idim=idim,
                              embed_dim=args.embed_dim,
                              elayers=0,
                              econv_layers=args.eprenet_conv_layers,
                              econv_chans=args.eprenet_conv_chans,
                              econv_filts=args.eprenet_conv_filts,
                              use_batch_norm=args.use_batch_norm,
                              dropout_rate=args.eprenet_dropout_rate,
                              padding_idx=padding_idx),
                torch.nn.Linear(args.eprenet_conv_chans, args.adim))
        else:
            encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                     embedding_dim=args.adim,
                                                     padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=args.transformer_enc_dropout_rate,
            positional_dropout_rate=args.
            transformer_enc_positional_dropout_rate,
            attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
            pos_enc_class=pos_enc_class,
            normalize_before=args.encoder_normalize_before,
            concat_after=args.encoder_concat_after)

        # define transformer decoder
        if args.dprenet_layers != 0:
            # decoder prenet
            decoder_input_layer = torch.nn.Sequential(
                DecoderPrenet(idim=odim,
                              n_layers=args.dprenet_layers,
                              n_units=args.dprenet_units,
                              dropout_rate=args.dprenet_dropout_rate),
                torch.nn.Linear(args.dprenet_units, args.adim))
        else:
            decoder_input_layer = "linear"
        self.decoder = Decoder(
            odim=-1,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.transformer_dec_dropout_rate,
            positional_dropout_rate=args.
            transformer_dec_positional_dropout_rate,
            self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
            src_attention_dropout_rate=args.
            transformer_enc_dec_attn_dropout_rate,
            input_layer=decoder_input_layer,
            use_output_layer=False,
            pos_enc_class=pos_enc_class,
            normalize_before=args.decoder_normalize_before,
            concat_after=args.decoder_concat_after)

        # define final projection
        self.feat_out = torch.nn.Linear(args.adim,
                                        odim * args.reduction_factor)
        self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor)

        # define postnet
        self.postnet = None if args.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=args.postnet_layers,
            n_chans=args.postnet_chans,
            n_filts=args.postnet_filts,
            use_batch_norm=args.use_batch_norm,
            dropout_rate=args.postnet_dropout_rate)

        # define loss function
        self.criterion = TransformerLoss(args)
        if self.use_guided_attn_loss:
            self.attn_criterion = GuidedMultiHeadAttentionLoss(
                args.guided_attn_loss_sigma)

        # initialize parameters
        self._reset_parameters(args)

    def _reset_parameters(self, args):
        if self.use_scaled_pos_enc:
            # alpha in scaled positional encoding init
            self.encoder.embed[-1].alpha.data = torch.tensor(
                args.initial_encoder_alpha)
            self.decoder.embed[-1].alpha.data = torch.tensor(
                args.initial_decoder_alpha)

        if args.transformer_init == "pytorch":
            return
        # weight init
        for p in self.parameters():
            if p.dim() > 1:
                if args.transformer_init == "xavier_uniform":
                    torch.nn.init.xavier_uniform_(p.data)
                elif args.transformer_init == "xavier_normal":
                    torch.nn.init.xavier_normal_(p.data)
                elif args.transformer_init == "kaiming_uniform":
                    torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
                elif args.transformer_init == "kaiming_normal":
                    torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
                else:
                    raise ValueError("Unknown initialization: " +
                                     args.transformer_init)
        # bias init
        for p in self.parameters():
            if p.dim() == 1:
                p.data.zero_()
        # reset some modules with default init
        for m in self.modules():
            if isinstance(m, (torch.nn.Embedding, LayerNorm)):
                m.reset_parameters()

    def _add_first_frame_and_remove_last_frame(self, ys):
        ys_in = torch.cat(
            [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1)
        return ys_in

    def forward(self, xs, ilens, ys, labels, olens, *args, **kwargs):
        """Calculate forward propagation

        :param torch.Tensor xs: batch of padded character ids (B, Tmax)
        :param torch.Tensor ilens: list of lengths of each input batch (B)
        :param torch.Tensor ys: batch of padded target features (B, Lmax, odim)
        :param torch.Tensor olens: batch of the lengths of each target (B)
        :return: loss value
        :rtype: torch.Tensor
        """
        # remove unnecessary padded part (for multi-gpus)
        max_ilen = max(ilens)
        max_olen = max(olens)
        if max_ilen != xs.shape[1]:
            xs = xs[:, :max_ilen]
        if max_olen != ys.shape[1]:
            ys = ys[:, :max_olen]
            labels = labels[:, :max_olen]

        # forward encoder
        x_masks = self._source_mask(ilens)
        hs, _ = self.encoder(xs, x_masks)

        # thin out frames for reduction factor (B, Lmax, odim) ->  (B, Lmax//r, odim)
        if self.reduction_factor > 1:
            ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor]
            olens_in = olens.new(
                [olen // self.reduction_factor for olen in olens])
        else:
            ys_in, olens_in = ys, olens

        # add first zero frame and remove last frame for auto-regressive
        ys_in = self._add_first_frame_and_remove_last_frame(ys_in)

        # forward decoder
        y_masks = self._target_mask(olens_in)
        xy_masks = self._source_to_target_mask(ilens, olens_in)
        zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks)
        # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
        before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
        # (B, Lmax//r, r) -> (B, Lmax//r * r)
        logits = self.prob_out(zs).view(zs.size(0), -1)

        # postnet -> (B, Lmax//r * r, odim)
        if self.postnet is None:
            after_outs = before_outs
        else:
            after_outs = before_outs + self.postnet(before_outs.transpose(
                1, 2)).transpose(1, 2)

        # modifiy mod part of groundtruth
        if self.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_olen = max(olens)
            ys = ys[:, :max_olen]
            labels = labels[:, :max_olen]
            labels[:, -1] = 1.0  # make sure at least one frame has 1

        # caluculate loss values
        l1_loss, l2_loss, bce_loss = self.criterion(after_outs, before_outs,
                                                    logits, ys, labels, olens)
        if self.loss_type == "L1":
            loss = l1_loss + bce_loss
        elif self.loss_type == "L2":
            loss = l2_loss + bce_loss
        elif self.loss_type == "L1+L2":
            loss = l1_loss + l2_loss + bce_loss
        else:
            raise ValueError("unknown --loss-type " + self.loss_type)
        report_keys = [
            {
                "l1_loss": l1_loss.item()
            },
            {
                "l2_loss": l2_loss.item()
            },
            {
                "bce_loss": bce_loss.item()
            },
            {
                "loss": loss.item()
            },
        ]

        # calculate guided attention loss
        if self.use_guided_attn_loss:
            # calculate for encoder
            if "encoder" in self.modules_applied_guided_attn:
                att_ws = []
                for idx, layer_idx in enumerate(
                        reversed(range(len(self.encoder.encoders)))):
                    att_ws += [
                        self.encoder.encoders[layer_idx].self_attn.
                        attn[:, :self.num_heads_applied_guided_attn]
                    ]
                    if idx + 1 == self.num_layers_applied_guided_attn:
                        break
                att_ws = torch.cat(att_ws, dim=1)  # (B, H*L, T_in, T_in)
                enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens)
                loss = loss + enc_attn_loss
                report_keys += [{"enc_attn_loss": enc_attn_loss.item()}]
            # calculate for decoder
            if "decoder" in self.modules_applied_guided_attn:
                att_ws = []
                for idx, layer_idx in enumerate(
                        reversed(range(len(self.decoder.decoders)))):
                    att_ws += [
                        self.decoder.decoders[layer_idx].self_attn.
                        attn[:, :self.num_heads_applied_guided_attn]
                    ]
                    if idx + 1 == self.num_layers_applied_guided_attn:
                        break
                att_ws = torch.cat(att_ws, dim=1)  # (B, H*L, T_out, T_out)
                dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in)
                loss = loss + dec_attn_loss
                report_keys += [{"dec_attn_loss": dec_attn_loss.item()}]
            # calculate for encoder-decoder
            if "encoder-decoder" in self.modules_applied_guided_attn:
                att_ws = []
                for idx, layer_idx in enumerate(
                        reversed(range(len(self.decoder.decoders)))):
                    att_ws += [
                        self.decoder.decoders[layer_idx].src_attn.
                        attn[:, :self.num_heads_applied_guided_attn]
                    ]
                    if idx + 1 == self.num_layers_applied_guided_attn:
                        break
                att_ws = torch.cat(att_ws, dim=1)  # (B, H*L, T_out, T_in)
                enc_dec_attn_loss = self.attn_criterion(
                    att_ws, ilens, olens_in)
                loss = loss + enc_dec_attn_loss
                report_keys += [{
                    "enc_dec_attn_loss": enc_dec_attn_loss.item()
                }]

        # report extra information
        if self.use_scaled_pos_enc:
            report_keys += [
                {
                    "encoder_alpha": self.encoder.embed[-1].alpha.data.item()
                },
                {
                    "decoder_alpha": self.decoder.embed[-1].alpha.data.item()
                },
            ]
        self.reporter.report(report_keys)

        return loss

    def inference(self, x, inference_args, *args, **kwargs):
        """Generates the sequence of features from given a sequences of characters

        :param torch.Tensor x: the sequence of character ids (T)
        :param Namespace inference_args: argments containing following attributes
            (float) threshold: threshold in inference
            (float) minlenratio: minimum length ratio in inference
            (float) maxlenratio: maximum length ratio in inference
        :rtype: torch.Tensor
        :return: the sequence of stop probabilities (L)
        :rtype: torch.Tensor
        """
        # get options
        threshold = inference_args.threshold
        minlenratio = inference_args.minlenratio
        maxlenratio = inference_args.maxlenratio

        # forward encoder
        xs = x.unsqueeze(0)
        hs, _ = self.encoder(xs, None)

        # set limits of length
        maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor)
        minlen = int(hs.size(1) * minlenratio / self.reduction_factor)

        # initialize
        idx = 0
        ys = hs.new_zeros(1, 1, self.odim)
        outs, probs = [], []

        # forward decoder step-by-step
        while True:
            # update index
            idx += 1

            # calculate output and stop prob at idx-th step
            y_masks = subsequent_mask(idx).unsqueeze(0)
            z = self.decoder.recognize(ys, y_masks, hs)  # (B, adim)
            outs += [self.feat_out(z).view(self.reduction_factor,
                                           self.odim)]  # [(r, odim), ...]
            probs += [torch.sigmoid(self.prob_out(z))[0]]  # [(r), ...]

            # update next inputs
            ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)),
                           dim=1)  # (1, idx + 1, odim)

            # check whether to finish generation
            if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
                # check mininum length
                if idx < minlen:
                    continue
                outs = torch.cat(outs, dim=0).unsqueeze(0).transpose(
                    1, 2)  # (L, odim) -> (1, L, odim) -> (1, odim, L)
                if self.postnet is not None:
                    outs = outs + self.postnet(outs)  # (1, odim, L)
                outs = outs.transpose(2, 1).squeeze(0)  # (L, odim)
                probs = torch.cat(probs, dim=0)
                break

        return outs, probs

    def calculate_all_attentions(self, xs, ilens, ys, olens, *args, **kwargs):
        """Calculate attention weights of all of the layers

        :param torch.Tensor xs: batch of padded character ids (B, Tmax)
        :param torch.Tensor ilens: list of lengths of each input batch (B)
        :param torch.Tensor ys: batch of padded target features (B, Lmax, odim)
        :param torch.Tensor ilens: list of lengths of each output batch (B)
        :return: attention weights dict
        :rtype: dict
        """
        with torch.no_grad():
            # forward encoder
            x_masks = self._source_mask(ilens)
            hs, _ = self.encoder(xs, x_masks)

            # thin out frames for reduction factor (B, Lmax, odim) ->  (B, Lmax//r, odim)
            if self.reduction_factor > 1:
                ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor]
                olens_in = olens.new(
                    [olen // self.reduction_factor for olen in olens])
            else:
                ys_in, olens_in = ys, olens

            # add first zero frame and remove last frame for auto-regressive
            ys_in = self._add_first_frame_and_remove_last_frame(ys_in)

            # forward decoder
            y_masks = self._target_mask(olens_in)
            xy_masks = self._source_to_target_mask(ilens, olens_in)
            zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks)
            # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
            before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
            # postnet -> (B, Lmax//r * r, odim)
            if self.postnet is None:
                after_outs = before_outs
            else:
                after_outs = before_outs + self.postnet(
                    before_outs.transpose(1, 2)).transpose(1, 2)

        # modifiy mod part of groundtruth
        if self.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])

        att_ws_dict = dict()
        for name, m in self.named_modules():
            if isinstance(m, MultiHeadedAttention):
                attn = m.attn.cpu().numpy()
                if "encoder" in name:
                    attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())]
                elif "decoder" in name:
                    if "src" in name:
                        attn = [
                            a[:, :ol, :il] for a, il, ol in zip(
                                attn, ilens.tolist(), olens_in.tolist())
                        ]
                    elif "self" in name:
                        attn = [
                            a[:, :l, :l]
                            for a, l in zip(attn, olens_in.tolist())
                        ]
                    else:
                        logging.warning("unknown attention module: " + name)
                else:
                    logging.warning("unknown attention module: " + name)
                att_ws_dict[name] = attn

        att_ws_dict["before_postnet_fbank"] = [
            m[:l].T for m, l in zip(before_outs.cpu().numpy(), olens.tolist())
        ]
        att_ws_dict["after_postnet_fbank"] = [
            m[:l].T for m, l in zip(after_outs.cpu().numpy(), olens.tolist())
        ]
        return att_ws_dict

    def _source_mask(self, ilens):
        """Make mask for MultiHeadedAttention using padded sequences

        >>> ilens = [5, 3]
        >>> self._source_mask(ilens)
        tensor([[[1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1],
                 [1, 1, 1, 1, 1]],

                [[1, 1, 1, 0, 0],
                 [1, 1, 1, 0, 0],
                 [1, 1, 1, 0, 0],
                 [0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0]]], dtype=torch.uint8)
        """
        x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
        return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)

    def _target_mask(self, olens):
        """Make mask for MaskedMultiHeadedAttention using padded sequences

        >>> olens = [5, 3]
        >>> self._target_mask(olens)
        tensor([[[1, 0, 0, 0, 0],
                 [1, 1, 0, 0, 0],
                 [1, 1, 1, 0, 0],
                 [1, 1, 1, 1, 0],
                 [1, 1, 1, 1, 1]],

                [[1, 0, 0, 0, 0],
                 [1, 1, 0, 0, 0],
                 [1, 1, 1, 0, 0],
                 [0, 0, 0, 0, 0],
                 [0, 0, 0, 0, 0]]], dtype=torch.uint8)
        """
        y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
        s_masks = subsequent_mask(y_masks.size(-1),
                                  device=y_masks.device).unsqueeze(0)
        return y_masks.unsqueeze(-2) & s_masks & y_masks.unsqueeze(-1)

    def _source_to_target_mask(self, ilens, olens):
        """Make source to target mask for MultiHeadedAttention using padded sequences

        >>> ilens = [4, 2]
        >>> olens = [5, 3]
        >>> self._source_to_target_mask(ilens)
        tensor([[[1, 1, 1, 1],
                 [1, 1, 1, 1],
                 [1, 1, 1, 1],
                 [1, 1, 1, 1],
                 [1, 1, 1, 1]],

                [[1, 1, 0, 0],
                 [1, 1, 0, 0],
                 [1, 1, 0, 0],
                 [0, 0, 0, 0],
                 [0, 0, 0, 0]]], dtype=torch.uint8)
        """
        x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
        y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
        return x_masks.unsqueeze(-2) & y_masks.unsqueeze(-1)

    @property
    def base_plot_keys(self):
        """base key names to plot during training. keys should match what `chainer.reporter` reports

        if you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values.
        also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values.

        :rtype list[str] plot_keys: base keys to plot during training
        """
        plot_keys = ["loss", "l1_loss", "l2_loss", "bce_loss"]
        if self.use_scaled_pos_enc:
            plot_keys += ["encoder_alpha", "decoder_alpha"]
        if self.use_guided_attn_loss:
            if "encoder" in self.modules_applied_guided_attn:
                plot_keys += ["enc_attn_loss"]
            if "decoder" in self.modules_applied_guided_attn:
                plot_keys += ["dec_attn_loss"]
            if "encoder-decoder" in self.modules_applied_guided_attn:
                plot_keys += ["enc_dec_attn_loss"]

        return plot_keys
Esempio n. 4
0
class E2E(ASRInterface, torch.nn.Module):
    """E2E module.

    Args:
        idim (int): dimension of inputs
        odim (int): dimension of outputs
        args (Namespace): argument Namespace containing options

    """
    @staticmethod
    def add_arguments(parser):
        """Extend arguments for transducer models.

        Both Transformer and RNN modules are supported.
        General options encapsulate both modules options.

        """
        group = parser.add_argument_group("transformer model setting")

        # Encoder - general
        group.add_argument(
            '--elayers',
            default=4,
            type=int,
            help=
            'Number of encoder layers (for shared recognition part in multi-speaker asr mode)'
        )
        group.add_argument('--eunits',
                           '-u',
                           default=300,
                           type=int,
                           help='Number of encoder hidden units')
        group.add_argument('--kernel-size', default=32, type=int)
        group.add_argument('--dropout-rate',
                           default=0.0,
                           type=float,
                           help='Dropout rate for the encoder')
        # Encoder - RNN
        group.add_argument('--eprojs',
                           default=320,
                           type=int,
                           help='Number of encoder projection units')
        group.add_argument(
            '--subsample',
            default="1",
            type=str,
            help='Subsample input frames x_y_z means subsample every x frame \
                           at 1st layer, every y frame at 2nd layer etc.')
        # Attention - general
        group.add_argument(
            '--transformer-attn-dropout-rate',
            default=None,
            type=float,
            help=
            'dropout in transformer attention. use --dropout-rate if None is set'
        )
        group.add_argument(
            '--adim',
            default=320,
            type=int,
            help='Number of attention transformation dimensions')
        group.add_argument('--aheads',
                           default=4,
                           type=int,
                           help='Number of heads for multi head attention')

        # Decoder - general
        group.add_argument('--dtype',
                           default='lstm',
                           type=str,
                           choices=['lstm', 'gru', 'transformer'],
                           help='Type of decoder to use.')
        group.add_argument('--dlayers',
                           default=1,
                           type=int,
                           help='Number of decoder layers')
        group.add_argument('--dunits',
                           default=320,
                           type=int,
                           help='Number of decoder hidden units')
        group.add_argument('--dropout-rate-decoder',
                           default=0.0,
                           type=float,
                           help='Dropout rate for the decoder')
        group.add_argument('--relative-v', default=False, type=strtobool)
        # Decoder - RNN
        group.add_argument('--dec-embed-dim',
                           default=320,
                           type=int,
                           help='Number of decoder embeddings dimensions')
        group.add_argument('--dropout-rate-embed-decoder',
                           default=0.0,
                           type=float,
                           help='Dropout rate for the decoder embeddings')
        # Transformer
        group.add_argument(
            "--input-layer",
            type=str,
            default="conv2d",
            choices=["conv2d", "vgg2l", "linear", "embed", "custom"],
            help='transformer encoder input layer type')
        group.add_argument('--transformer-lr',
                           default=10.0,
                           type=float,
                           help='Initial value of learning rate')
        group.add_argument('--transformer-warmup-steps',
                           default=25000,
                           type=int,
                           help='optimizer warmup steps')
        group.add_argument('--transformer-length-normalized-loss',
                           default=False,
                           type=strtobool,
                           help='normalize loss by length')
        group.add_argument("--transformer-init",
                           type=str,
                           default="pytorch",
                           choices=[
                               "pytorch", "xavier_uniform", "xavier_normal",
                               "kaiming_uniform", "kaiming_normal"
                           ],
                           help='how to initialize transformer parameters')
        group.add_argument('--causal', type=strtobool, default=False)

        return parser

    def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0):
        """Construct an E2E object for transducer model.

        Args:
            idim (int): dimension of inputs
            odim (int): dimension of outputs
            args (Namespace): argument Namespace containing options

        """
        torch.nn.Module.__init__(self)

        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(idim=idim,
                               d_model=args.adim,
                               n_heads=args.aheads,
                               d_ffn=args.eunits,
                               layers=args.elayers,
                               kernel_size=args.kernel_size,
                               input_layer=args.input_layer,
                               dropout_rate=args.dropout_rate,
                               causal=args.causal)

        args.eprojs = args.adim
        if args.mtlalpha < 1.0:
            self.decoder = Decoder(
                odim=odim,
                attention_dim=args.adim,
                attention_heads=args.aheads,
                linear_units=args.dunits,
                num_blocks=args.dlayers,
                dropout_rate=args.dropout_rate,
                positional_dropout_rate=args.dropout_rate,
                self_attention_dropout_rate=args.transformer_attn_dropout_rate,
                src_attention_dropout_rate=args.transformer_attn_dropout_rate)

        self.sos = odim - 1
        self.eos = odim - 1
        self.ignore_id = ignore_id

        self.odim = odim
        self.adim = args.adim

        self.criterion = LabelSmoothingLoss(
            self.odim, self.ignore_id, args.lsm_weight,
            args.transformer_length_normalized_loss)
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim,
                           args.adim,
                           args.dropout_rate,
                           ctc_type=args.ctc_type,
                           reduce=True)
        else:
            self.ctc = None

    def reset_parameters(self, args):
        """Initialize parameters."""
        # initialize parameters
        initialize(self, args.transformer_init)

    def forward(self, xs_pad, ilens, ys_pad, enc_mask=None):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """
        # 1. encoder
        xs_pad = xs_pad[:, :max(ilens)]
        masks = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, masks)

        self.hs_pad = hs_pad

        # CTC forward
        ys = [y[y != self.ignore_id] for y in ys_pad]
        y_len = max([len(y) for y in ys])
        ys_pad = ys_pad[:, :y_len]
        self.hs_pad = hs_pad
        cer_ctc = None
        batch_size = xs_pad.size(0)
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)

        # trigger mask
        start_time = time.time()
        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        return self.loss, loss_att_data, loss_ctc_data, self.acc

    def encode(self, x, streaming_mask=None):
        """Encode acoustic features.

        Args:
            x (ndarray): input acoustic feature (T, D)

        Returns:
            x (torch.Tensor): encoded features (T, attention_dim)
        """
        self.eval()

        x = torch.as_tensor(x).unsqueeze(0).cuda()
        enc_output, _ = self.encoder(x, None, streaming_mask)

        return enc_output.squeeze(0)

    def greedy_recognize(self, x, recog_args, streaming_mask=None):
        h = self.encode(x, streaming_mask).unsqueeze(0)
        lpz = self.ctc.log_softmax(h)
        ys_hat = lpz.argmax(dim=-1)[0].cpu()
        ids = []
        for i in range(len(ys_hat)):
            if ys_hat[i].item() != 0 and ys_hat[i].item() != ys_hat[i -
                                                                    1].item():
                ids.append(ys_hat[i].item())
        rets = [{'score': 0.0, 'yseq': ids}]
        return rets

    def recognize(self,
                  x,
                  recog_args,
                  char_list=None,
                  rnnlm=None,
                  streaming_mask=None):
        """
        Recognize input features.

        Args:
            x (ndarray): input acoustic feature (T, D)
            recog_args (namespace): argument Namespace containing options
            char_list (list): list of characters
            rnnlm (torch.nn.Module): language model module

        Returns:
            y (list): n-best decoding results

        """
        h = self.encode(x, streaming_mask)
        params = [h, recog_args]

        if recog_args.beam_size == 1:
            nbest_hyps = self.decoder.recognize(*params)
        else:
            params.append(rnnlm)
            start_time = time.time()
            nbest_hyps = self.decoder.recognize_beam(*params)
            #nbest_hyps, decoder_time = self.decoder.beam_search(h, recog_args, prefix=False)
            end_time = time.time()
            decode_time = end_time - start_time

        return nbest_hyps

    def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
        """E2E attention calculation.

        Args:
            xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax)

        Returns:
            ret (ndarray): attention weights with the following shape,
            1) multi-head case => attention weights (B, H, Lmax, Tmax),
            2) other case => attention weights (B, Lmax, Tmax).

        """
        if self.etype == 'transformer' and self.dtype != 'transformer' and \
           self.rnnt_mode == 'rnnt-att':
            raise NotImplementedError(
                "Transformer encoder with rnn attention decoder"
                "is not supported yet.")
        elif self.etype != 'transformer' and self.dtype != 'transformer':
            if self.rnnt_mode == 'rnnt':
                return []
            else:
                with torch.no_grad():
                    hs_pad, hlens = xs_pad, ilens
                    hpad, hlens, _ = self.encoder(hs_pad, hlens)

                    ret = self.decoder.calculate_all_attentions(
                        hpad, hlens, ys_pad)
        else:
            with torch.no_grad():
                self.forward(xs_pad, ilens, ys_pad)

                ret = dict()
                for name, m in self.named_modules():
                    if isinstance(m, MultiHeadedAttention):
                        ret[name] = m.attn.cpu().numpy()

        return ret