def forward(self, padded_input, encoder_padded_outputs,
                encoder_input_lengths):
        """
        args:
            padded_input: B x T
            encoder_padded_outputs: B x T x H
            encoder_input_lengths: B
        returns:
            pred: B x T x vocab
            gold: B x T
        """
        decoder_self_attn_list, decoder_encoder_attn_list = [], []
        seq_in_pad, seq_out_pad = self.preprocess(padded_input)

        # Prepare masks
        non_pad_mask = get_non_pad_mask(seq_in_pad, pad_idx=constant.EOS_TOKEN)
        self_attn_mask_subseq = get_subsequent_mask(seq_in_pad)
        self_attn_mask_keypad = get_attn_key_pad_mask(
            seq_k=seq_in_pad, seq_q=seq_in_pad, pad_idx=constant.EOS_TOKEN)
        self_attn_mask = (self_attn_mask_keypad + self_attn_mask_subseq).gt(0)

        output_length = seq_in_pad.size(1)
        dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs,
                                              encoder_input_lengths,
                                              output_length)

        decoder_output = self.dropout(
            self.trg_embedding(seq_in_pad) * self.x_logit_scale +
            self.positional_encoding(seq_in_pad))

        for layer in self.layers:
            decoder_output, decoder_self_attn, decoder_enc_attn = layer(
                decoder_output,
                encoder_padded_outputs,
                non_pad_mask=non_pad_mask,
                self_attn_mask=self_attn_mask,
                dec_enc_attn_mask=dec_enc_attn_mask)

            decoder_self_attn_list += [decoder_self_attn]
            decoder_encoder_attn_list += [decoder_enc_attn]

        seq_logit = self.output_linear(decoder_output)
        pred, gold = seq_logit, seq_out_pad

        return pred, gold, decoder_self_attn_list, decoder_encoder_attn_list
    def beam_search(self,
                    encoder_padded_outputs,
                    beam_width=2,
                    nbest=5,
                    lm_rescoring=False,
                    lm=None,
                    lm_weight=0.1,
                    c_weight=1,
                    prob_weight=1.0):
        """
        Beam search, decode nbest utterances
        args:
            encoder_padded_outputs: B x T x H
            beam_size: int
            nbest: int
        output:
            batch_ids_nbest_hyps: list of nbest in ids (size B)
            batch_strs_nbest_hyps: list of nbest in strings (size B)
        """
        batch_size = encoder_padded_outputs.size(0)
        max_len = encoder_padded_outputs.size(1)

        batch_ids_nbest_hyps = []
        batch_strs_nbest_hyps = []

        for x in range(batch_size):
            encoder_output = encoder_padded_outputs[x].unsqueeze(
                0)  # 1 x T x H

            # add SOS_TOKEN
            ys = torch.ones(1, 1).fill_(
                constant.SOS_TOKEN).type_as(encoder_output).long()

            hyp = {'score': 0.0, 'yseq': ys}
            hyps = [hyp]
            ended_hyps = []

            for i in range(300):
                # for i in range(self.trg_max_length):
                hyps_best_kept = []
                for hyp in hyps:
                    ys = hyp['yseq']  # 1 x i

                    # Prepare masks
                    non_pad_mask = torch.ones_like(ys).float().unsqueeze(
                        -1)  # 1xix1
                    self_attn_mask = get_subsequent_mask(ys)

                    decoder_output = self.dropout(
                        self.trg_embedding(ys) * self.x_logit_scale +
                        self.positional_encoding(ys))

                    for layer in self.layers:
                        # print(decoder_output.size(), encoder_output.size())
                        decoder_output, _, _ = layer(
                            decoder_output,
                            encoder_output,
                            non_pad_mask=non_pad_mask,
                            self_attn_mask=self_attn_mask,
                            dec_enc_attn_mask=None)

                    seq_logit = self.output_linear(decoder_output[:, -1])
                    local_scores = F.log_softmax(seq_logit, dim=1)
                    local_best_scores, local_best_ids = torch.topk(
                        local_scores, beam_width, dim=1)

                    # calculate beam scores
                    for j in range(beam_width):
                        new_hyp = {}
                        new_hyp["score"] = hyp["score"] + local_best_scores[0,
                                                                            j]

                        new_hyp["yseq"] = torch.ones(
                            1,
                            (1 + ys.size(1))).type_as(encoder_output).long()
                        new_hyp["yseq"][:, :ys.size(1)] = hyp["yseq"].cpu()
                        new_hyp["yseq"][:, ys.size(1)] = int(
                            local_best_ids[0, j])  # adding new word

                        hyps_best_kept.append(new_hyp)

                    hyps_best_kept = sorted(hyps_best_kept,
                                            key=lambda x: x["score"],
                                            reverse=True)[:beam_width]

                hyps = hyps_best_kept

                # add EOS_TOKEN
                if i == max_len - 1:
                    for hyp in hyps:
                        hyp["yseq"] = torch.cat([
                            hyp["yseq"],
                            torch.ones(1, 1).fill_(constant.EOS_TOKEN).type_as(
                                encoder_output).long()
                        ],
                                                dim=1)

                # add hypothesis that have EOS_TOKEN to ended_hyps list
                unended_hyps = []
                for hyp in hyps:
                    if hyp["yseq"][0, -1] == constant.EOS_TOKEN:
                        if lm_rescoring:
                            # seq_str = "".join(self.id2label[char.item()] for char in hyp["yseq"][0]).replace(constant.PAD_CHAR,"").replace(constant.SOS_CHAR,"").replace(constant.EOS_CHAR,"")
                            # seq_str = seq_str.replace("  ", " ")
                            # num_words = len(seq_str.split())

                            hyp["lm_score"], hyp[
                                "num_words"], oov_token = calculate_lm_score(
                                    hyp["yseq"], lm, self.id2label)
                            num_words = hyp["num_words"]
                            hyp["lm_score"] -= oov_token * 2
                            hyp["final_score"] = hyp["score"] + lm_weight * hyp[
                                "lm_score"] + math.sqrt(num_words) * c_weight
                        else:
                            seq_str = "".join(
                                self.id2label[char.item()]
                                for char in hyp["yseq"][0]).replace(
                                    constant.PAD_CHAR,
                                    "").replace(constant.SOS_CHAR, "").replace(
                                        constant.EOS_CHAR, "")
                            seq_str = seq_str.replace("  ", " ")
                            num_words = len(seq_str.split())
                            hyp["final_score"] = hyp["score"] + math.sqrt(
                                num_words) * c_weight

                        ended_hyps.append(hyp)

                    else:
                        unended_hyps.append(hyp)
                hyps = unended_hyps

                if len(hyps) == 0:
                    # decoding process is finished
                    break

            num_nbest = min(len(ended_hyps), nbest)
            nbest_hyps = sorted(ended_hyps,
                                key=lambda x: x["final_score"],
                                reverse=True)[:num_nbest]

            a_nbest_hyps = sorted(ended_hyps,
                                  key=lambda x: x["final_score"],
                                  reverse=True)[:beam_width]

            if lm_rescoring:
                for hyp in a_nbest_hyps:
                    seq_str = "".join(self.id2label[char.item()]
                                      for char in hyp["yseq"][0]).replace(
                                          constant.PAD_CHAR,
                                          "").replace(constant.SOS_CHAR,
                                                      "").replace(
                                                          constant.EOS_CHAR,
                                                          "")
                    seq_str = seq_str.replace("  ", " ")
                    num_words = len(seq_str.split())
                    # print("{}  || final:{} e2e:{} lm:{} num words:{}".format(seq_str, hyp["final_score"], hyp["score"], hyp["lm_score"], hyp["num_words"]))

            for hyp in nbest_hyps:
                hyp["yseq"] = hyp["yseq"][0].cpu().numpy().tolist()
                hyp_strs = self.post_process_hyp(hyp)

                batch_ids_nbest_hyps.append(hyp["yseq"])
                batch_strs_nbest_hyps.append(hyp_strs)
                # print(hyp["yseq"], hyp_strs)
        return batch_ids_nbest_hyps, batch_strs_nbest_hyps
    def greedy_search(self,
                      encoder_padded_outputs,
                      beam_width=2,
                      lm_rescoring=False,
                      lm=None,
                      lm_weight=0.1,
                      c_weight=1):
        """
        Greedy search, decode 1-best utterance
        args:
            encoder_padded_outputs: B x T x H
        output:
            batch_ids_nbest_hyps: list of nbest in ids (size B)
            batch_strs_nbest_hyps: list of nbest in strings (size B)
        """
        max_seq_len = self.trg_max_length

        ys = torch.ones(encoder_padded_outputs.size(0),
                        1).fill_(constant.SOS_TOKEN).long()  # batch_size x 1
        if constant.args.cuda:
            ys = ys.cuda()

        decoded_words = []
        for t in range(300):
            # for t in range(max_seq_len):
            # print(t)
            # Prepare masks
            non_pad_mask = torch.ones_like(ys).float().unsqueeze(
                -1)  # batch_size x t x 1
            self_attn_mask = get_subsequent_mask(ys)  # batch_size x t x t

            decoder_output = self.dropout(
                self.trg_embedding(ys) * self.x_logit_scale +
                self.positional_encoding(ys))

            for layer in self.layers:
                decoder_output, _, _ = layer(decoder_output,
                                             encoder_padded_outputs,
                                             non_pad_mask=non_pad_mask,
                                             self_attn_mask=self_attn_mask,
                                             dec_enc_attn_mask=None)

            prob = self.output_linear(
                decoder_output)  # batch_size x t x label_size
            # _, next_word = torch.max(prob[:, -1], dim=1)
            # decoded_words.append([constant.EOS_CHAR if ni.item() == constant.EOS_TOKEN else self.id2label[ni.item()] for ni in next_word.view(-1)])
            # next_word = next_word.unsqueeze(-1)

            # local_best_scores, local_best_ids = torch.topk(local_scores, beam_width, dim=1)

            if lm_rescoring:
                local_scores = F.log_softmax(prob, dim=1)
                local_best_scores, local_best_ids = torch.topk(local_scores,
                                                               beam_width,
                                                               dim=1)

                best_score = -1
                best_word = None

                # calculate beam scores
                for j in range(beam_width):
                    cur_seq = " ".join(word for word in decoded_words)
                    lm_score, num_words, oov_token = calculate_lm_score(
                        cur_seq, lm, self.id2label)
                    score = local_best_scores[0, j] + lm_score
                    if best_score < score:
                        best_score = score
                        best_word = local_best_ids[0, j]
                        next_word = best_word.unsqueeze(-1)
                decoded_words.append(self.id2label[int(best_word)])
            else:
                _, next_word = torch.max(prob[:, -1], dim=1)
                decoded_words.append([
                    constant.EOS_CHAR if ni.item() == constant.EOS_TOKEN else
                    self.id2label[ni.item()] for ni in next_word.view(-1)
                ])
                next_word = next_word.unsqueeze(-1)

            if constant.args.cuda:
                ys = torch.cat([ys, next_word.cuda()], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word], dim=1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == constant.EOS_CHAR:
                    break
                else:
                    st += e
            sent.append(st)
        return sent