class E2E(ASRInterface, torch.nn.Module):
    @staticmethod
    def add_arguments(parser):
        group = parser.add_argument_group("transformer model setting")

        group.add_argument("--transformer-init",
                           type=str,
                           default="pytorch",
                           choices=[
                               "pytorch", "xavier_uniform", "xavier_normal",
                               "kaiming_uniform", "kaiming_normal"
                           ],
                           help='how to initialize transformer parameters')
        group.add_argument("--transformer-input-layer",
                           type=str,
                           default="conv2d",
                           choices=["conv2d", "linear", "embed"],
                           help='transformer input layer type')
        group.add_argument(
            '--transformer-attn-dropout-rate',
            default=None,
            type=float,
            help=
            'dropout in transformer attention. use --dropout-rate if None is set'
        )
        group.add_argument('--transformer-lr',
                           default=10.0,
                           type=float,
                           help='Initial value of learning rate')
        group.add_argument('--transformer-warmup-steps',
                           default=25000,
                           type=int,
                           help='optimizer warmup steps')
        group.add_argument('--transformer-length-normalized-loss',
                           default=True,
                           type=strtobool,
                           help='normalize loss by length')

        group.add_argument('--dropout-rate',
                           default=0.0,
                           type=float,
                           help='Dropout rate for the encoder')
        # Encoder
        group.add_argument(
            '--elayers',
            default=4,
            type=int,
            help=
            'Number of encoder layers (for shared recognition part in multi-speaker asr mode)'
        )
        group.add_argument('--eunits',
                           '-u',
                           default=300,
                           type=int,
                           help='Number of encoder hidden units')
        # Attention
        group.add_argument(
            '--adim',
            default=320,
            type=int,
            help='Number of attention transformation dimensions')
        group.add_argument('--aheads',
                           default=4,
                           type=int,
                           help='Number of heads for multi head attention')
        # Decoder
        group.add_argument('--dlayers',
                           default=1,
                           type=int,
                           help='Number of decoder layers')
        group.add_argument('--dunits',
                           default=320,
                           type=int,
                           help='Number of decoder hidden units')
        return parser

    @property
    def attention_plot_class(self):
        return PlotAttentionReport

    def __init__(self, idim, odim, args, ignore_id=-1):
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=args.transformer_input_layer,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate)
        self.decoder = Decoder(
            odim=odim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate)
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = [1]
        self.reporter = Reporter()

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(
            self.odim, self.ignore_id, args.lsm_weight,
            args.transformer_length_normalized_loss)
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        self.mtlalpha = args.mtlalpha
        if args.mtlalpha > 0.0:
            self.ctc = CTC(odim,
                           args.adim,
                           args.dropout_rate,
                           ctc_type=args.ctc_type,
                           reduce=True)
        else:
            self.ctc = None

        if args.report_cer or args.report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculator
            self.error_calculator = ErrorCalculator(args.char_list,
                                                    args.sym_space,
                                                    args.sym_blank,
                                                    args.report_cer,
                                                    args.report_wer)
        else:
            self.error_calculator = None
        self.rnnlm = None

    def unfreeze_layers(self, slu_tune_weights):
        if slu_tune_weights == '':
            return

        tuned_models = []

        for item in slu_tune_weights.split('+'):
            n = int(item[3:])

            if n > 0:
                if item[:3] == 'asr':
                    tuned_models += self.encoder.encoders[-n:]
                elif item[:3] == 'slu' and hasattr(self.slu, 'encoder'):
                    tuned_models += self.slu.encoder.layer[:n]
                elif item[:3] == 'emb':
                    tuned_models += [self.encoder.embed]
            elif n < 0:
                n = n * -1
                if item[:3] == 'asr':
                    tuned_models += self.encoder.encoders[:n]
                elif item[:3] == 'slu' and hasattr(self.slu, 'encoder'):
                    tuned_models += self.slu.encoder.layer[-n:]
                elif item[:3] == 'emb':
                    tuned_models += [self.encoder.embed]

        for tuned_model in tuned_models:
            for parameter in tuned_model.parameters():
                parameter.requires_grad = True

    def add_slu(self, slu_model, slu_loss, slu_tune_weights, slu_pooling):
        if slu_model == 'none':
            self.slu = torch.nn.Identity()
            self.slu_mapper = torch.nn.Linear(self.adim, 768)
        else:
            if 'nli' in slu_model:
                from sentence_transformers import SentenceTransformer
                self.slu = SentenceTransformer(slu_model)[0].bert
            else:
                from transformers import AutoModel
                self.slu = AutoModel.from_pretrained(slu_model)
            del (self.slu.embeddings)
            self.slu_mapper = torch.nn.Linear(self.adim,
                                              self.slu.config.hidden_size)

        del (self.decoder)

        for parameter in self.parameters():
            parameter.requires_grad = False

        self.unfreeze_layers(slu_tune_weights)

        self.slu_loss = getattr(torch.nn.functional, slu_loss)

        if slu_loss == 'cosine_embedding_loss':
            self.slu_loss_label = True
        else:
            self.slu_loss_label = False

        if slu_pooling != '':
            kernel_size = int(slu_pooling[3:])

            if slu_pooling[:3] == 'max':
                self.slu_pooler = torch.nn.MaxPool1d(kernel_size)
            elif slu_pooling[:3] == 'avg':
                self.slu_pooler = torch.nn.AvgPool1d(kernel_size)

    def reset_parameters(self, args):
        # initialize parameters
        initialize(self, args.transformer_init)

    def add_sos_eos(self, ys_pad):
        from espnet.nets.pytorch_backend.nets_utils import pad_list
        eos = ys_pad.new([self.eos])
        sos = ys_pad.new([self.sos])
        ys = [y[y != self.ignore_id] for y in ys_pad]  # parse padded ys
        ys_in = [torch.cat([sos, y], dim=0) for y in ys]
        ys_out = [torch.cat([y, eos], dim=0) for y in ys]
        return pad_list(ys_in, self.eos), pad_list(ys_out, self.ignore_id)

    def target_mask(self, ys_in_pad):
        ys_mask = ys_in_pad != self.ignore_id
        m = subsequent_mask(ys_mask.size(-1),
                            device=ys_mask.device).unsqueeze(0)
        return ys_mask.unsqueeze(-2) & m

    def forward(self, xs_pad, ilens, ys_pad):
        '''E2E forward

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        '''
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        if self.slu_loss is not None:
            if hasattr(self, 'slu_pooler'):
                hs_pad = self.slu_pooler(hs_pad.permute(0, 2,
                                                        1)).permute(0, 2, 1)

            hidden_state_mapped = self.slu_mapper(hs_pad)

            if hasattr(self.slu, 'encoder'):
                attention_mask = to_device(
                    self,
                    torch.ones(hidden_state_mapped.shape[0],
                               hidden_state_mapped.shape[1]))
                extended_attention_mask = attention_mask.unsqueeze(
                    1).unsqueeze(2)
                extended_attention_mask = (1.0 -
                                           extended_attention_mask) * -10000.0

                head_mask = [None] * self.slu.config.num_hidden_layers
                encoded_layers = self.slu.encoder(hidden_state_mapped,
                                                  extended_attention_mask,
                                                  head_mask)
                embeddings = torch.mean(encoded_layers[0], 1)
            else:
                embeddings = torch.mean(hidden_state_mapped, 1)

            if ys_pad is None:
                return embeddings
            else:
                if self.slu_loss_label:
                    self.loss = self.slu_loss(
                        embeddings, ys_pad.squeeze(1),
                        to_device(self, torch.ones(embeddings.size(0))))
                else:
                    self.loss = self.slu_loss(embeddings, ys_pad.squeeze(1))

                self.reporter.report(0.0, 0.0, 0.0, 100.0, 100.0, 100.0,
                                     float(self.loss))
                return self.loss

        # 2. forward decoder
        ys_in_pad, ys_out_pad = self.add_sos_eos(ys_pad)
        ys_mask = self.target_mask(ys_in_pad)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attenttion loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss

    def recognize(self,
                  feat,
                  recog_args,
                  char_list=None,
                  rnnlm=None,
                  use_jit=False):
        '''recognize feat

        :param ndnarray x: input acouctic feature (B, T, D) or (T, D)
        :param namespace recog_args: argment namespace contraining options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list

        TODO(karita): do not recompute previous attention for faster decoding
        '''
        self.eval()
        feat = torch.as_tensor(feat).unsqueeze(0)
        enc_output, _ = self.encoder(feat, None)
        if recog_args.ctc_weight > 0.0:
            lpz = self.ctc.log_softmax(enc_output)
            lpz = lpz.squeeze(0)
        else:
            lpz = None

        h = enc_output.squeeze(0)

        logging.info('input lengths: ' + str(h.size(0)))
        # search parms
        beam = recog_args.beam_size
        penalty = recog_args.penalty
        ctc_weight = recog_args.ctc_weight

        # preprare sos
        y = self.sos
        vy = h.new_zeros(1).long()

        if recog_args.maxlenratio == 0:
            maxlen = h.shape[0]
        else:
            # maxlen >= 1
            maxlen = max(1, int(recog_args.maxlenratio * h.size(0)))
        minlen = int(recog_args.minlenratio * h.size(0))
        logging.info('max output length: ' + str(maxlen))
        logging.info('min output length: ' + str(minlen))

        # initialize hypothesis
        if rnnlm:
            hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None}
        else:
            hyp = {'score': 0.0, 'yseq': [y]}
        if lpz is not None:
            import numpy

            from espnet.nets.ctc_prefix_score import CTCPrefixScore

            ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0,
                                              self.eos, numpy)
            hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
            hyp['ctc_score_prev'] = 0.0
            if ctc_weight != 1.0:
                # pre-pruning based on attention scores
                from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO
                ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
            else:
                ctc_beam = lpz.shape[-1]
        hyps = [hyp]
        ended_hyps = []

        import six
        traced_decoder = None
        for i in six.moves.range(maxlen):
            logging.debug('position ' + str(i))

            hyps_best_kept = []
            for hyp in hyps:
                vy.unsqueeze(1)
                vy[0] = hyp['yseq'][i]

                # get nbest local scores and their ids
                ys_mask = subsequent_mask(i + 1).unsqueeze(0)
                ys = torch.tensor(hyp['yseq']).unsqueeze(0)
                # FIXME: jit does not match non-jit result
                if use_jit:
                    if traced_decoder is None:
                        traced_decoder = torch.jit.trace(
                            self.decoder.recognize, (ys, ys_mask, enc_output))
                    local_att_scores = traced_decoder(ys, ys_mask, enc_output)
                else:
                    local_att_scores = self.decoder.recognize(
                        ys, ys_mask, enc_output)

                if rnnlm:
                    rnnlm_state, local_lm_scores = rnnlm.predict(
                        hyp['rnnlm_prev'], vy)
                    local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
                else:
                    local_scores = local_att_scores

                if lpz is not None:
                    local_best_scores, local_best_ids = torch.topk(
                        local_att_scores, ctc_beam, dim=1)
                    ctc_scores, ctc_states = ctc_prefix_score(
                        hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'])
                    local_scores = \
                        (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \
                        + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev'])
                    if rnnlm:
                        local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[
                            0]]
                    local_best_scores, joint_best_ids = torch.topk(
                        local_scores, beam, dim=1)
                    local_best_ids = local_best_ids[:, joint_best_ids[0]]
                else:
                    local_best_scores, local_best_ids = torch.topk(
                        local_scores, beam, dim=1)

                for j in six.moves.range(beam):
                    new_hyp = {}
                    new_hyp['score'] = hyp['score'] + float(
                        local_best_scores[0, j])
                    new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
                    new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
                    new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0,
                                                                           j])
                    if rnnlm:
                        new_hyp['rnnlm_prev'] = rnnlm_state
                    if lpz is not None:
                        new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[
                            0, j]]
                        new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[
                            0, j]]
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x['score'],
                                        reverse=True)[:beam]

            # sort and get nbest
            hyps = hyps_best_kept
            logging.debug('number of pruned hypothes: ' + str(len(hyps)))
            if char_list is not None:
                logging.debug(
                    'best hypo: ' +
                    ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                logging.info('adding <eos> in the last postion in the loop')
                for hyp in hyps:
                    hyp['yseq'].append(self.eos)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][-1] == self.eos:
                    # only store the sequence that has more than minlen outputs
                    # also add penalty
                    if len(hyp['yseq']) > minlen:
                        hyp['score'] += (i + 1) * penalty
                        if rnnlm:  # Word LM needs to add final <eos> score
                            hyp['score'] += recog_args.lm_weight * rnnlm.final(
                                hyp['rnnlm_prev'])
                        ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            # end detection
            from espnet.nets.e2e_asr_common import end_detect
            if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
                logging.info('end detected at %d', i)
                break

            hyps = remained_hyps
            if len(hyps) > 0:
                logging.debug('remeined hypothes: ' + str(len(hyps)))
            else:
                logging.info('no hypothesis. Finish decoding.')
                break

            if char_list is not None:
                for hyp in hyps:
                    logging.debug(
                        'hypo: ' +
                        ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))

            logging.debug('number of ended hypothes: ' + str(len(ended_hyps)))

        nbest_hyps = sorted(
            ended_hyps, key=lambda x: x['score'],
            reverse=True)[:min(len(ended_hyps), recog_args.nbest)]

        # check number of hypotheis
        if len(nbest_hyps) == 0:
            logging.warning(
                'there is no N-best results, perform recognition again with smaller minlenratio.'
            )
            # should copy becasuse Namespace will be overwritten globally
            recog_args = Namespace(**vars(recog_args))
            recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
            return self.recognize(feat, recog_args, char_list, rnnlm)

        logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
        logging.info('normalized log probability: ' +
                     str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))
        return nbest_hyps

    def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
        '''E2E attention calculation

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
        :return: attention weights with the following shape,
            1) multi-head case => attention weights (B, H, Lmax, Tmax),
            2) other case => attention weights (B, Lmax, Tmax).
        :rtype: float ndarray
        '''
        with torch.no_grad():
            self.forward(xs_pad, ilens, ys_pad)
        ret = dict()
        for name, m in self.named_modules():
            if isinstance(m, MultiHeadedAttention):
                ret[name] = m.attn.cpu().numpy()
        return ret