Example #1
0
class E2E(MTInterface, torch.nn.Module):
    """E2E module.

    :param int idim: dimension of inputs
    :param int odim: dimension of outputs
    :param Namespace args: argument Namespace containing options

    """
    @staticmethod
    def add_arguments(parser):
        """Add arguments."""
        group = parser.add_argument_group("transformer model setting")

        group.add_argument(
            "--transformer-init",
            type=str,
            default="xavier_uniform",
            choices=[
                "pytorch",
                "xavier_uniform",
                "xavier_normal",
                "kaiming_uniform",
                "kaiming_normal",
            ],
            help="how to initialize transformer parameters",
        )
        group.add_argument(
            "--transformer-attn-dropout-rate",
            default=None,
            type=float,
            help=
            "dropout in transformer attention. use --dropout-rate if None is set",
        )
        group.add_argument(
            "--transformer-lr",
            default=1.0,
            type=float,
            help="Initial value of learning rate",
        )
        group.add_argument(
            "--transformer-warmup-steps",
            default=4000,
            type=int,
            help="optimizer warmup steps",
        )
        group.add_argument(
            "--transformer-length-normalized-loss",
            default=False,
            type=strtobool,
            help="normalize loss by length",
        )
        group.add_argument(
            "--transformer-encoder-selfattn-layer-type",
            type=str,
            default="selfattn",
            choices=["selfattn", "lightconv", "dynamicconv"],
            help="transformer encoder self-attention layer type",
        )
        group.add_argument(
            "--transformer-decoder-selfattn-layer-type",
            type=str,
            default="selfattn",
            choices=["selfattn", "lightconv", "dynamicconv"],
            help="transformer decoder self-attention layer type",
        )
        # Lightweight/Dynamic convolution related parameters.
        # See https://arxiv.org/abs/1912.11793v2
        # and https://arxiv.org/abs/1901.10430 for detail of the method.
        # Configurations used in the first paper are in
        # egs/{csj, librispeech}/asr1/conf/tuning/ld_conv/
        group.add_argument(
            "--wshare",
            default=4,
            type=int,
            help="Number of parameter shargin for lightweight convolution",
        )
        group.add_argument(
            "--ldconv-encoder-kernel-length",
            default="21_23_25_27_29_31_33_35_37_39_41_43",
            type=str,
            help="kernel size for lightweight/dynamic convolution: "
            'Encoder side. For example, "21_23_25" means kernel length 21 for '
            "First layer, 23 for Second layer and so on.",
        )
        group.add_argument(
            "--ldconv-decoder-kernel-length",
            default="11_13_15_17_19_21",
            type=str,
            help="kernel size for lightweight/dynamic convolution: "
            'Decoder side. For example, "21_23_25" means kernel length 21 for '
            "First layer, 23 for Second layer and so on.",
        )
        group.add_argument(
            "--ldconv-usebias",
            type=strtobool,
            default=False,
            help="use bias term in lightweight/dynamic convolution",
        )
        group.add_argument(
            "--dropout-rate",
            default=0.1,
            type=float,
            help="Dropout rate for the encoder",
        )
        # Encoder
        group.add_argument(
            "--elayers",
            default=6,
            type=int,
            help="Number of encoder layers",
        )
        group.add_argument(
            "--eunits",
            "-u",
            default=2048,
            type=int,
            help="Number of encoder hidden units",
        )
        # Attention
        group.add_argument(
            "--adim",
            default=256,
            type=int,
            help="Number of attention transformation dimensions",
        )
        group.add_argument(
            "--aheads",
            default=4,
            type=int,
            help="Number of heads for multi head attention",
        )
        # Decoder
        group.add_argument("--dlayers",
                           default=6,
                           type=int,
                           help="Number of decoder layers")
        group.add_argument("--dunits",
                           default=2048,
                           type=int,
                           help="Number of decoder hidden units")
        return parser

    @property
    def attention_plot_class(self):
        """Return PlotAttentionReport."""
        return PlotAttentionReport

    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)

        # fill missing arguments for compatibility
        args = fill_missing_args(args, self.add_arguments)

        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            selfattention_layer_type=args.
            transformer_encoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_encoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer="embed",
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        self.decoder = Decoder(
            odim=odim,
            selfattention_layer_type=args.
            transformer_decoder_selfattn_layer_type,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            conv_wshare=args.wshare,
            conv_kernel_length=args.ldconv_decoder_kernel_length,
            conv_usebias=args.ldconv_usebias,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        self.pad = 0  # use <blank> for padding
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode="mt", arch="transformer")
        self.reporter = Reporter()

        # tie source and target emeddings
        if args.tie_src_tgt_embedding:
            if idim != odim:
                raise ValueError(
                    "When using tie_src_tgt_embedding, idim and odim must be equal."
                )
            self.encoder.embed[0].weight = self.decoder.embed[0].weight

        # tie emeddings and the classfier
        if args.tie_classifier:
            self.decoder.output_layer.weight = self.decoder.embed[0].weight

        self.criterion = LabelSmoothingLoss(
            self.odim,
            self.ignore_id,
            args.lsm_weight,
            args.transformer_length_normalized_loss,
        )
        self.normalize_length = args.transformer_length_normalized_loss  # for PPL
        self.reset_parameters(args)
        self.adim = args.adim
        self.error_calculator = ErrorCalculator(args.char_list, args.sym_space,
                                                args.sym_blank,
                                                args.report_bleu)
        self.rnnlm = None

        # multilingual MT related
        self.multilingual = args.multilingual

    def reset_parameters(self, args):
        """Initialize parameters."""
        initialize(self, args.transformer_init)
        torch.nn.init.normal_(self.encoder.embed[0].weight,
                              mean=0,
                              std=args.adim**-0.5)
        torch.nn.init.constant_(self.encoder.embed[0].weight[self.pad], 0)
        torch.nn.init.normal_(self.decoder.embed[0].weight,
                              mean=0,
                              std=args.adim**-0.5)
        torch.nn.init.constant_(self.decoder.embed[0].weight[self.pad], 0)

    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # 3. compute attention loss
        self.loss = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # 4. compute corpus-level bleu in a mini-batch
        if self.training:
            self.bleu = None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        loss_data = float(self.loss)
        if self.normalize_length:
            self.ppl = np.exp(loss_data)
        else:
            batch_size = ys_out_pad.size(0)
            ys_out_pad = ys_out_pad.view(-1)
            ignore = ys_out_pad == self.ignore_id  # (B*T,)
            total_n_tokens = len(ys_out_pad) - ignore.sum().item()
            self.ppl = np.exp(loss_data * batch_size / total_n_tokens)
        if not math.isnan(loss_data):
            self.reporter.report(loss_data, self.acc, self.ppl, self.bleu)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss

    def scorers(self):
        """Scorers."""
        return dict(decoder=self.decoder)

    def encode(self, xs):
        """Encode source sentences."""
        self.eval()
        xs = torch.as_tensor(xs).unsqueeze(0)
        enc_output, _ = self.encoder(xs, None)
        return enc_output.squeeze(0)

    def target_forcing(self, xs_pad, ys_pad=None, tgt_lang=None):
        """Prepend target language IDs to source sentences for multilingual MT.

        These tags are prepended in source/target sentences as pre-processing.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax)
        :return: source text without language IDs
        :rtype: torch.Tensor
        :return: target text without language IDs
        :rtype: torch.Tensor
        :return: target language IDs
        :rtype: torch.Tensor (B, 1)
        """
        if self.multilingual:
            xs_pad = xs_pad[:, 1:]  # remove source language IDs here
            if ys_pad is not None:
                # remove language ID in the beginning
                lang_ids = ys_pad[:, 0].unsqueeze(1)
                ys_pad = ys_pad[:, 1:]
            elif tgt_lang is not None:
                lang_ids = xs_pad.new_zeros(xs_pad.size(0), 1).fill_(tgt_lang)
            else:
                raise ValueError("Set ys_pad or tgt_lang.")

            # prepend target language ID to source sentences
            xs_pad = torch.cat([lang_ids, xs_pad], dim=1)
        return xs_pad, ys_pad

    def translate(self, x, trans_args, char_list=None):
        """Translate source text.

        :param list x: input source text feature (T,)
        :param Namespace trans_args: argment Namespace contraining options
        :param list char_list: list of characters
        :return: N-best decoding results
        :rtype: list
        """
        self.eval(
        )  # NOTE: this is important because self.encode() is not used
        assert isinstance(x, list)

        # make a utt list (1) to use the same interface for encoder
        if self.multilingual:
            x = to_device(
                self,
                torch.from_numpy(
                    np.fromiter(map(int, x[0][1:]), dtype=np.int64)))
        else:
            x = to_device(
                self,
                torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64)))

        logging.info("input lengths: " + str(x.size(0)))
        xs_pad = x.unsqueeze(0)
        tgt_lang = None
        if trans_args.tgt_lang:
            tgt_lang = char_list.index(trans_args.tgt_lang)
        xs_pad, _ = self.target_forcing(xs_pad, tgt_lang=tgt_lang)
        h, _ = self.encoder(xs_pad, None)
        logging.info("encoder output lengths: " + str(h.size(1)))

        # search parms
        beam = trans_args.beam_size
        penalty = trans_args.penalty

        if trans_args.maxlenratio == 0:
            maxlen = h.size(1)
        else:
            # maxlen >= 1
            maxlen = max(1, int(trans_args.maxlenratio * h.size(1)))
        minlen = int(trans_args.minlenratio * h.size(1))
        logging.info("max output length: " + str(maxlen))
        logging.info("min output length: " + str(minlen))

        # initialize hypothesis
        hyp = {"score": 0.0, "yseq": [self.sos]}
        hyps = [hyp]
        ended_hyps = []

        for i in range(maxlen):
            logging.debug("position " + str(i))

            # batchfy
            ys = h.new_zeros((len(hyps), i + 1), dtype=torch.int64)
            for j, hyp in enumerate(hyps):
                ys[j, :] = torch.tensor(hyp["yseq"])
            ys_mask = subsequent_mask(i + 1).unsqueeze(0).to(h.device)

            local_scores = self.decoder.forward_one_step(
                ys, ys_mask, h.repeat([len(hyps), 1, 1]))[0]

            hyps_best_kept = []
            for j, hyp in enumerate(hyps):
                local_best_scores, local_best_ids = torch.topk(
                    local_scores[j:j + 1], beam, dim=1)

                for j in range(beam):
                    new_hyp = {}
                    new_hyp["score"] = hyp["score"] + float(
                        local_best_scores[0, j])
                    new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
                    new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"]
                    new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0,
                                                                           j])
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x["score"],
                                        reverse=True)[:beam]

            # sort and get nbest
            hyps = hyps_best_kept
            logging.debug("number of pruned hypothes: " + str(len(hyps)))
            if char_list is not None:
                logging.debug(
                    "best hypo: " +
                    "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]))

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                logging.info("adding <eos> in the last postion in the loop")
                for hyp in hyps:
                    hyp["yseq"].append(self.eos)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp["yseq"][-1] == self.eos:
                    # only store the sequence that has more than minlen outputs
                    # also add penalty
                    if len(hyp["yseq"]) > minlen:
                        hyp["score"] += (i + 1) * penalty
                        ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            # end detection
            if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0:
                logging.info("end detected at %d", i)
                break

            hyps = remained_hyps
            if len(hyps) > 0:
                logging.debug("remeined hypothes: " + str(len(hyps)))
            else:
                logging.info("no hypothesis. Finish decoding.")
                break

            if char_list is not None:
                for hyp in hyps:
                    logging.debug(
                        "hypo: " +
                        "".join([char_list[int(x)] for x in hyp["yseq"][1:]]))

            logging.debug("number of ended hypothes: " + str(len(ended_hyps)))

        nbest_hyps = sorted(
            ended_hyps, key=lambda x: x["score"],
            reverse=True)[:min(len(ended_hyps), trans_args.nbest)]

        # check number of hypotheis
        if len(nbest_hyps) == 0:
            logging.warning("there is no N-best results, perform translation "
                            "again with smaller minlenratio.")
            # should copy becasuse Namespace will be overwritten globally
            trans_args = Namespace(**vars(trans_args))
            trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1)
            return self.translate(x, trans_args, char_list)

        logging.info("total log probability: " + str(nbest_hyps[0]["score"]))
        logging.info("normalized log probability: " +
                     str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])))
        return nbest_hyps

    def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
        """E2E attention calculation.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
        :return: attention weights (B, H, Lmax, Tmax)
        :rtype: float ndarray
        """
        self.eval()
        with torch.no_grad():
            self.forward(xs_pad, ilens, ys_pad)
        ret = dict()
        for name, m in self.named_modules():
            if isinstance(m, MultiHeadedAttention) and m.attn is not None:
                ret[name] = m.attn.cpu().numpy()
        self.train()
        return ret
class E2E(MTInterface, torch.nn.Module):
    """E2E module.

    :param int idim: dimension of inputs
    :param int odim: dimension of outputs
    :param Namespace args: argument Namespace containing options

    """
    @staticmethod
    def add_arguments(parser):
        """Add arguments."""
        group = parser.add_argument_group("transformer model setting")

        group.add_argument("--transformer-init",
                           type=str,
                           default="xavier_uniform",
                           choices=[
                               "pytorch", "xavier_uniform", "xavier_normal",
                               "kaiming_uniform", "kaiming_normal"
                           ],
                           help='how to initialize transformer parameters')
        group.add_argument(
            '--transformer-attn-dropout-rate',
            default=None,
            type=float,
            help=
            'dropout in transformer attention. use --dropout-rate if None is set'
        )
        group.add_argument('--transformer-lr',
                           default=1.0,
                           type=float,
                           help='Initial value of learning rate')
        group.add_argument('--transformer-warmup-steps',
                           default=4000,
                           type=int,
                           help='optimizer warmup steps')
        group.add_argument('--transformer-length-normalized-loss',
                           default=False,
                           type=strtobool,
                           help='normalize loss by length')

        group.add_argument('--dropout-rate',
                           default=0.1,
                           type=float,
                           help='Dropout rate for the encoder')
        # Encoder
        group.add_argument(
            '--elayers',
            default=6,
            type=int,
            help=
            'Number of encoder layers (for shared recognition part in multi-speaker asr mode)'
        )
        group.add_argument('--eunits',
                           '-u',
                           default=2048,
                           type=int,
                           help='Number of encoder hidden units')
        # Attention
        group.add_argument(
            '--adim',
            default=256,
            type=int,
            help='Number of attention transformation dimensions')
        group.add_argument('--aheads',
                           default=4,
                           type=int,
                           help='Number of heads for multi head attention')
        # Decoder
        group.add_argument('--dlayers',
                           default=6,
                           type=int,
                           help='Number of decoder layers')
        group.add_argument('--dunits',
                           default=2048,
                           type=int,
                           help='Number of decoder hidden units')
        return parser

    @property
    def attention_plot_class(self):
        """Return PlotAttentionReport."""
        return PlotAttentionReport

    def __init__(self, idim, odim, args, ignore_id=-1):
        """Construct an E2E object.

        :param int idim: dimension of inputs
        :param int odim: dimension of outputs
        :param Namespace args: argument Namespace containing options
        """
        torch.nn.Module.__init__(self)
        if args.transformer_attn_dropout_rate is None:
            args.transformer_attn_dropout_rate = args.dropout_rate
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer='embed',
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        self.decoder = Decoder(
            odim=odim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.dropout_rate,
            positional_dropout_rate=args.dropout_rate,
            self_attention_dropout_rate=args.transformer_attn_dropout_rate,
            src_attention_dropout_rate=args.transformer_attn_dropout_rate,
        )
        self.pad = 0
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id
        self.subsample = get_subsample(args, mode='mt', arch='transformer')
        self.reporter = Reporter()

        # tie source and target emeddings
        if args.tie_src_tgt_embedding:
            if idim != odim:
                raise ValueError(
                    'When using tie_src_tgt_embedding, idim and odim must be equal.'
                )
            self.encoder.embed[0].weight = self.decoder.embed[0].weight

        # tie emeddings and the classfier
        if args.tie_classifier:
            self.decoder.output_layer.weight = self.decoder.embed[0].weight

        # self.lsm_weight = a
        self.criterion = LabelSmoothingLoss(
            self.odim, self.ignore_id, args.lsm_weight,
            args.transformer_length_normalized_loss)
        self.normalize_length = args.transformer_length_normalized_loss  # for PPL
        # self.verbose = args.verbose
        self.reset_parameters(args)
        self.adim = args.adim
        if args.report_bleu:
            from espnet.nets.e2e_mt_common import ErrorCalculator
            self.error_calculator = ErrorCalculator(args.char_list,
                                                    args.sym_space,
                                                    args.report_bleu)
        else:
            self.error_calculator = None
        self.rnnlm = None

        # multilingual NMT related
        self.multilingual = args.multilingual

    def reset_parameters(self, args):
        """Initialize parameters."""
        # initialize parameters
        initialize(self, args.transformer_init)
        torch.nn.init.normal_(self.encoder.embed[0].weight,
                              mean=0,
                              std=args.adim**-0.5)
        torch.nn.init.constant_(self.encoder.embed[0].weight[self.pad], 0)
        torch.nn.init.normal_(self.decoder.embed[0].weight,
                              mean=0,
                              std=args.adim**-0.5)
        torch.nn.init.constant_(self.decoder.embed[0].weight[self.pad], 0)

    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats

        # 5. compute bleu
        if self.training or self.error_calculator is None:
            bleu = 0.0
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_mt
        self.loss = loss

        loss_data = float(self.loss)
        if self.normalize_length:
            self.ppl = np.exp(loss_data)
        else:
            ys_out_pad = ys_out_pad.view(-1)
            ignore = ys_out_pad == self.ignore_id  # (B,)
            total = len(ys_out_pad) - ignore.sum().item()
            self.ppl = np.exp(loss_data * ys_out_pad.size(0) / total)
        if not math.isnan(loss_data):
            self.reporter.report(loss_data, self.acc, self.ppl, bleu)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss

    def scorers(self):
        """Scorers."""
        return dict(decoder=self.decoder)

    def encode(self, xs):
        """Encode source sentences."""
        self.eval()
        xs = torch.as_tensor(xs).unsqueeze(0)
        enc_output, _ = self.encoder(xs, None)
        return enc_output.squeeze(0)

    def target_forcing(self, xs_pad, ys_pad=None, tgt_lang=None):
        """Prepend target language IDs to source sentences for multilingual NMT.

        These tags are prepended in source/target sentences as pre-processing.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax)
        :return: source text without language IDs
        :rtype: torch.Tensor
        :return: target text without language IDs
        :rtype: torch.Tensor
        :return: target language IDs
        :rtype: torch.Tensor (B, 1)
        """
        if self.multilingual:
            xs_pad = xs_pad[:, 1:]  # remove source language IDs here
            if ys_pad is not None:
                # remove language ID in the beginning
                lang_ids = ys_pad[:, 0].unsqueeze(1)
                ys_pad = ys_pad[:, 1:]
            elif tgt_lang is not None:
                lang_ids = xs_pad.new_zeros(xs_pad.size(0), 1).fill_(tgt_lang)
            else:
                raise ValueError("Set ys_pad or tgt_lang.")

            # prepend target language ID to source sentences
            xs_pad = torch.cat([lang_ids, xs_pad], dim=1)
        return xs_pad, ys_pad

    def translate(self,
                  x,
                  trans_args,
                  char_list=None,
                  rnnlm=None,
                  use_jit=False):
        """Translate source text.

        :param list x: input source text feature (T,)
        :param Namespace trans_args: argment Namespace contraining options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        self.eval(
        )  # NOTE: this is important because self.encode() is not used
        assert isinstance(x, list)

        # make a utt list (1) to use the same interface for encoder
        if self.multilingual:
            x = to_device(
                self,
                torch.from_numpy(
                    np.fromiter(map(int, x[0][1:]), dtype=np.int64)))
        else:
            x = to_device(
                self,
                torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64)))

        xs_pad = x.unsqueeze(0)
        tgt_lang = None
        if trans_args.tgt_lang:
            tgt_lang = char_list.index(trans_args.tgt_lang)
        xs_pad, _ = self.target_forcing(xs_pad, tgt_lang=tgt_lang)
        enc_output, _ = self.encoder(xs_pad, None)
        h = enc_output.squeeze(0)

        logging.info('input lengths: ' + str(h.size(0)))
        # search parms
        beam = trans_args.beam_size
        penalty = trans_args.penalty

        # preprare sos
        y = self.sos
        vy = h.new_zeros(1).long()

        if trans_args.maxlenratio == 0:
            maxlen = h.shape[0]
        else:
            # maxlen >= 1
            maxlen = max(1, int(trans_args.maxlenratio * h.size(0)))
        minlen = int(trans_args.minlenratio * h.size(0))
        logging.info('max output length: ' + str(maxlen))
        logging.info('min output length: ' + str(minlen))

        # initialize hypothesis
        if rnnlm:
            hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None}
        else:
            hyp = {'score': 0.0, 'yseq': [y]}
        hyps = [hyp]
        ended_hyps = []

        import six
        traced_decoder = None
        for i in six.moves.range(maxlen):
            logging.debug('position ' + str(i))

            hyps_best_kept = []
            for hyp in hyps:
                vy[0] = hyp['yseq'][i]

                # get nbest local scores and their ids
                ys_mask = subsequent_mask(i + 1).unsqueeze(0)
                ys = torch.tensor(hyp['yseq']).unsqueeze(0)
                # FIXME: jit does not match non-jit result
                if use_jit:
                    if traced_decoder is None:
                        traced_decoder = torch.jit.trace(
                            self.decoder.forward_one_step,
                            (ys, ys_mask, enc_output))
                    local_att_scores = traced_decoder(ys, ys_mask,
                                                      enc_output)[0]
                else:
                    local_att_scores = self.decoder.forward_one_step(
                        ys, ys_mask, enc_output)[0]

                if rnnlm:
                    rnnlm_state, local_lm_scores = rnnlm.predict(
                        hyp['rnnlm_prev'], vy)
                    local_scores = local_att_scores + trans_args.lm_weight * local_lm_scores
                else:
                    local_scores = local_att_scores

                local_best_scores, local_best_ids = torch.topk(local_scores,
                                                               beam,
                                                               dim=1)

                for j in six.moves.range(beam):
                    new_hyp = {}
                    new_hyp['score'] = hyp['score'] + float(
                        local_best_scores[0, j])
                    new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
                    new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
                    new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0,
                                                                           j])
                    if rnnlm:
                        new_hyp['rnnlm_prev'] = rnnlm_state
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x['score'],
                                        reverse=True)[:beam]

            # sort and get nbest
            hyps = hyps_best_kept
            logging.debug('number of pruned hypothes: ' + str(len(hyps)))
            if char_list is not None:
                logging.debug(
                    'best hypo: ' +
                    ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                logging.info('adding <eos> in the last postion in the loop')
                for hyp in hyps:
                    hyp['yseq'].append(self.eos)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][-1] == self.eos:
                    # only store the sequence that has more than minlen outputs
                    # also add penalty
                    if len(hyp['yseq']) > minlen:
                        hyp['score'] += (i + 1) * penalty
                        if rnnlm:  # Word LM needs to add final <eos> score
                            hyp['score'] += trans_args.lm_weight * rnnlm.final(
                                hyp['rnnlm_prev'])
                        ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            # end detection
            from espnet.nets.e2e_asr_common import end_detect
            if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0:
                logging.info('end detected at %d', i)
                break

            hyps = remained_hyps
            if len(hyps) > 0:
                logging.debug('remeined hypothes: ' + str(len(hyps)))
            else:
                logging.info('no hypothesis. Finish decoding.')
                break

            if char_list is not None:
                for hyp in hyps:
                    logging.debug(
                        'hypo: ' +
                        ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))

            logging.debug('number of ended hypothes: ' + str(len(ended_hyps)))

        nbest_hyps = sorted(
            ended_hyps, key=lambda x: x['score'],
            reverse=True)[:min(len(ended_hyps), trans_args.nbest)]

        # check number of hypotheis
        if len(nbest_hyps) == 0:
            logging.warning(
                'there is no N-best results, perform recognition again with smaller minlenratio.'
            )
            # should copy becasuse Namespace will be overwritten globally
            trans_args = Namespace(**vars(trans_args))
            trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1)
            return self.translate(x, trans_args, char_list, rnnlm)

        logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
        logging.info('normalized log probability: ' +
                     str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))
        return nbest_hyps

    def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
        """E2E attention calculation.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
        :return: attention weights with the following shape,
            1) multi-head case => attention weights (B, H, Lmax, Tmax),
            2) other case => attention weights (B, Lmax, Tmax).
        :rtype: float ndarray
        """
        with torch.no_grad():
            self.forward(xs_pad, ilens, ys_pad)
        ret = dict()
        for name, m in self.named_modules():
            if isinstance(m, MultiHeadedAttention):
                ret[name] = m.attn.cpu().numpy()
        return ret