Example #1
0
class TransformerDecoder(nn.Module):
    def __init__(self, params, cmlm=False):
        super(TransformerDecoder, self).__init__()

        self.vocab_size = params.vocab_size
        self.embed = nn.Embedding(self.vocab_size, params.dec_hidden_size)
        self.pe = PositionalEncoder(params.dec_hidden_size,
                                    params.dropout_dec_rate)
        self.dec_num_layers = params.dec_num_layers

        # TODO: rename to `decoders`
        self.transformers = nn.ModuleList()
        for _ in range(self.dec_num_layers):
            self.transformers += [
                TransformerDecoderLayer(
                    dec_num_attention_heads=params.dec_num_attention_heads,
                    dec_hidden_size=params.dec_hidden_size,
                    dec_intermediate_size=params.dec_intermediate_size,
                    dropout_dec_rate=params.dropout_dec_rate,
                    dropout_attn_rate=params.dropout_attn_rate,
                )
            ]

        self.mtl_ctc_weight = params.mtl_ctc_weight
        if self.mtl_ctc_weight > 0:
            self.ctc = CTCDecoder(params)
        if hasattr(params, "mtl_ctc_add_sos_eos"):
            self.mtl_ctc_add_sos_eos = params.mtl_ctc_add_sos_eos

        # normalize before
        # TODO: set `eps` to 1e-5 (default)
        self.norm = nn.LayerNorm(params.dec_hidden_size, eps=1e-12)
        self.output = nn.Linear(params.dec_hidden_size, self.vocab_size)

        self.cmlm = cmlm
        if self.cmlm:
            # TODO: label smoothing
            self.loss_fn = MaskedLMLoss(vocab_size=self.vocab_size)
        else:
            self.loss_fn = LabelSmoothingLoss(
                vocab_size=self.vocab_size,
                lsm_prob=params.lsm_prob,
                normalize_length=params.loss_normalize_length,
                normalize_batch=params.loss_normalize_batch,
            )

        self.kd_weight = params.kd_weight
        if self.kd_weight > 0:
            self.loss_fn = DistillLoss(
                vocab_size=self.vocab_size,
                soft_label_weight=params.kd_weight,
                lsm_prob=params.lsm_prob,
                normalize_length=params.loss_normalize_length,
                normalize_batch=params.loss_normalize_batch,
            )

        self.blank_id = params.blank_id
        self.eos_id = params.eos_id
        self.max_decode_ylen = params.max_decode_ylen

    def forward(
        self,
        eouts,
        elens,
        eouts_inter=None,
        ys=None,
        ylens=None,
        ys_in=None,
        ys_out=None,  # labels
        soft_labels=None,
        ps=None,
        plens=None,
    ):
        loss = 0
        loss_dict = {}

        # embedding + positional encoding
        ys_in = self.pe(self.embed(ys_in))
        emask = make_src_mask(elens)

        if self.cmlm:  # Conditional Masked LM
            ymask = make_src_mask(ylens)
        else:
            ymask = make_tgt_mask(ylens + 1)

        for layer_id in range(self.dec_num_layers):
            ys_in, ymask, eouts, emask = self.transformers[layer_id](ys_in,
                                                                     ymask,
                                                                     eouts,
                                                                     emask)
        ys_in = self.norm(ys_in)  # normalize before
        logits = self.output(ys_in)

        if ys_out is None:
            return logits

        if self.kd_weight > 0 and soft_labels is not None:
            # NOTE: ys_out (label) have length ylens+1
            loss_att_kd, loss_kd, loss_att = self.loss_fn(
                logits, ys_out, soft_labels, ylens + 1)

            loss += loss_att_kd
            loss_dict["loss_kd"] = loss_kd
            loss_dict["loss_att"] = loss_att
        else:
            if self.cmlm:
                loss_att = self.loss_fn(logits, labels=ys_out, ylens=None)
            else:
                # NOTE: ys_out (label) have length ylens+1
                loss_att = self.loss_fn(logits, ys_out, ylens + 1)

            loss += loss_att
            loss_dict["loss_att"] = loss_att

        if self.mtl_ctc_weight > 0:
            if self.mtl_ctc_add_sos_eos:
                ys, ylens = add_sos_eos(ys, ylens, eos_id=self.eos_id)

            # NOTE: KD is not applied to auxiliary CTC
            loss_ctc, _, _ = self.ctc(eouts=eouts,
                                      elens=elens,
                                      ys=ys,
                                      ylens=ylens,
                                      soft_labels=None)
            loss += self.mtl_ctc_weight * loss_ctc  # auxiliary loss
            loss_dict["loss_ctc"] = loss_ctc

        loss_dict["loss_total"] = loss

        return loss, loss_dict, logits

    def forward_one_step(self, ys_in, ylens_in, eouts):
        ys_in = self.pe(self.embed(ys_in))
        ymask = make_tgt_mask(ylens_in)

        for layer_id in range(self.dec_num_layers):
            ys_in, ymask, eouts, _ = self.transformers[layer_id](ys_in, ymask,
                                                                 eouts, None)

        ys_in = self.norm(ys_in[:, -1])  # normalize before
        logits = self.output(ys_in)
        return logits

    def decode(
        self,
        eouts,
        elens,
        eouts_inter=None,
        beam_width=1,
        len_weight=0,
        lm=None,
        lm_weight=0,
        decode_ctc_weight=0,
        decode_phone=False,
    ):
        """ Beam search decoding
        """
        bs = eouts.size(0)
        if decode_ctc_weight == 1:
            print("CTC is used")
            # greedy
            return self.ctc.decode(eouts, elens, beam_width=1)

        assert bs == 1

        # init
        beam = {
            "hyp": [self.eos_id],
            "score": 0.0,
            "score_ctc": 0.0,
            "ctc_state": None,
            "score_lm": 0.0,
            "lm_state": None,
        }
        if decode_ctc_weight > 0:
            ctc_logits = self.ctc(eouts, elens)
            ctc_log_probs = torch.log_softmax(ctc_logits, dim=-1)

            ctc_scorer = CTCPrefixScorer(
                tensor2np(ctc_log_probs.squeeze(0)),
                blank_id=self.blank_id,
                eos_id=self.eos_id,
            )
            beam["score_ctc"] = 0.0
            beam["ctc_state"] = ctc_scorer.initial_state()
            ctc_beam_width = min(ctc_log_probs.size(2),
                                 int(beam_width * CTC_BEAM_WIDTH_RATIO))
        beams = [beam]

        results = []

        for i in range(self.max_decode_ylen):
            new_beams = []

            for beam in beams:
                ys_in = torch.tensor([beam["hyp"]]).to(eouts.device)
                ylens_in = torch.tensor([i + 1]).to(eouts.device)

                scores_att = torch.log_softmax(self.forward_one_step(
                    ys_in, ylens_in, eouts),
                                               dim=-1)  # (1, vocab)
                scores = scores_att

                if lm_weight > 0:
                    scores_lm, _ = lm.predict(ys_in, ylens_in,
                                              states=None)  # (1, vocab)
                    scores += lm_weight * scores_lm[:, :self.vocab_size]

                if decode_ctc_weight > 0:
                    score_ctc_prev = beam["score_ctc"]
                    ctc_state_prev = beam["ctc_state"]
                    scores_topb, v_topb = torch.topk(scores,
                                                     k=ctc_beam_width,
                                                     dim=1)
                    scores_ctc, ctc_state = ctc_scorer(beam["hyp"], v_topb[0],
                                                       ctc_state_prev)
                    # re-calculate score
                    scores = (1 - decode_ctc_weight) * scores_att[:, v_topb[
                        0]] + decode_ctc_weight * np2tensor(scores_ctc -
                                                            score_ctc_prev)
                    if lm_weight > 0:
                        scores += lm_weight * scores_lm[:, v_topb[0]]
                    scores_topk, ids_topk = torch.topk(scores,
                                                       k=beam_width,
                                                       dim=1)
                    v_topk = v_topb[:, ids_topk[0]]
                else:
                    scores_topk, v_topk = torch.topk(scores,
                                                     k=beam_width,
                                                     dim=1)

                for j in range(beam_width):
                    new_beam = {}
                    new_beam["score"] = beam["score"] + float(scores_topk[0,
                                                                          j])
                    new_beam["hyp"] = beam["hyp"] + [int(v_topk[0, j])]
                    if decode_ctc_weight > 0:
                        new_beam["score_ctc"] = scores_ctc[ids_topk[0, j]]
                        new_beam["ctc_state"] = ctc_state[ids_topk[0, j]]
                    new_beams.append(new_beam)

            # update `beams`
            beams = sorted(new_beams, key=lambda x: x["score"],
                           reverse=True)[:beam_width]

            beams_extend = []
            for beam in beams:
                # ended beams
                if beam["hyp"][-1] == self.eos_id:
                    hyp_noeos = strip_eos(beam["hyp"], self.eos_id)
                    # only <eos> is not acceptable
                    if len(hyp_noeos) < 1:
                        continue

                    # add length penalty
                    score = beam["score"] + len_weight * len(beam["hyp"])

                    results.append({"hyp": hyp_noeos, "score": score})

                    if len(results) >= beam_width:
                        break
                else:
                    beams_extend.append(beam)

            if len(results) >= beam_width:
                break

            beams = beams_extend

        results = sorted(results, key=lambda x: x["score"], reverse=True)
        hyps = [result["hyp"] for result in results]
        scores = [result["score"] for result in results]
        logits = None
        aligns = None

        return hyps, scores, logits, aligns
Example #2
0
class LASDecoder(nn.Module):
    def __init__(self, params, phase="train"):
        super().__init__()

        self.enc_hidden_size = params.enc_hidden_size
        self.dec_hidden_size = params.dec_hidden_size
        self.dec_num_layers = params.dec_num_layers
        self.mtl_ctc_weight = params.mtl_ctc_weight
        if self.mtl_ctc_weight > 0:
            self.ctc = CTCDecoder(params)

        self.embed = nn.Embedding(params.vocab_size, params.embedding_size)
        self.dropout_emb = nn.Dropout(p=params.dropout_dec_rate)

        # Recurrency
        self.rnns = nn.ModuleList()
        input_size = params.embedding_size + params.enc_hidden_size
        for _ in range(self.dec_num_layers):
            self.rnns += [nn.LSTMCell(
                input_size,
                params.dec_hidden_size,
            )]
            input_size = params.dec_hidden_size

        # Score
        self.score = AttentionLoc(key_dim=params.enc_hidden_size,
                                  query_dim=params.dec_hidden_size,
                                  attn_dim=params.attn_dim)

        # Generate
        self.intermed = nn.Linear(
            params.enc_hidden_size + params.dec_hidden_size,
            params.dec_intermediate_size)
        self.output = nn.Linear(params.dec_intermediate_size,
                                params.vocab_size)

        self.dropout = nn.Dropout(p=params.dropout_dec_rate)

        self.loss_fn = LabelSmoothingLoss(
            vocab_size=params.vocab_size,
            lsm_prob=params.lsm_prob,
            normalize_length=params.loss_normalize_length,
            normalize_batch=params.loss_normalize_batch,
        )

        self.kd_weight = params.kd_weight
        if self.kd_weight > 0:
            self.loss_fn = DistillLoss(
                vocab_size=params.vocab_size,
                soft_label_weight=self.kd_weight,
                lsm_prob=params.lsm_prob,
                normalize_length=params.loss_normalize_length,
                normalize_batch=params.loss_normalize_batch,
            )

        self.eos_id = params.eos_id
        self.max_decode_ylen = params.max_decode_ylen

    def forward(
        self,
        eouts,
        elens,
        eouts_inter=None,
        ys=None,
        ylens=None,
        ys_in=None,
        ys_out=None,  # labels
        soft_labels=None,
        ps=None,
        plens=None,
    ):
        loss = 0
        loss_dict = {}

        bs = eouts.size(0)

        ys_emb = self.dropout_emb(self.embed(ys_in))
        dstate = None
        # context vector
        ctx = eouts.new_zeros(bs, 1, self.enc_hidden_size)
        attn_weight = None
        attn_mask = make_nopad_mask(elens).unsqueeze(2)
        logits = []

        for i in range(ys_in.size(1)):
            y_emb = ys_emb[:, i:i + 1]  # (bs, 1, embedding_size)
            logit, ctx, dstate, attn_weight = self.forward_one_step(
                y_emb, ctx, eouts, dstate, attn_weight, attn_mask)
            logits.append(logit)  # (bs, 1, dec_intermediate_size)

        logits = self.output(torch.cat(logits, dim=1))  # (bs, ylen, vocab)

        if self.kd_weight > 0 and soft_labels is not None:
            # NOTE: ys_out (label) have length ylens+1
            loss_att_kd, loss_kd, loss_att = self.loss_fn(
                logits, ys_out, soft_labels, ylens + 1)
            loss += loss_att_kd
            loss_dict["loss_kd"] = loss_kd
            loss_dict["loss_att"] = loss_att
        else:
            loss_att = self.loss_fn(logits, ys_out, ylens + 1)
            loss += loss_att
            loss_dict["loss_att"] = loss_att

        if self.mtl_ctc_weight > 0:
            # NOTE: KD is not applied to auxiliary CTC
            loss_ctc, _, _ = self.ctc(eouts=eouts,
                                      elens=elens,
                                      ys=ys,
                                      ylens=ylens,
                                      soft_labels=None)
            loss += self.mtl_ctc_weight * loss_ctc  # auxiliary loss
            loss_dict["loss_ctc"] = loss_ctc

        loss_dict["loss_total"] = loss

        return loss, loss_dict, logits

    def forward_one_step(self,
                         y_emb,
                         ctx,
                         eouts,
                         dstate,
                         attn_weight,
                         attn_mask=None):
        # Recurrency -> Score -> Generate
        dstate, douts_1, douts_top = self.recurrency(
            torch.cat([y_emb, ctx], dim=-1), dstate)
        ctx, attn_weight = self.score(eouts, eouts, douts_1, attn_weight,
                                      attn_mask)
        logit = self.generate(ctx, douts_top)
        return logit, ctx, dstate, attn_weight

    def recurrency(self, dins, dstate=None):
        bs = dins.size(0)
        douts = dins.squeeze(1)

        if dstate is None:
            dstate = {}
            dstate["hs"] = torch.zeros(self.dec_num_layers,
                                       bs,
                                       self.dec_hidden_size,
                                       device=dins.device)
            dstate["cs"] = torch.zeros(self.dec_num_layers,
                                       bs,
                                       self.dec_hidden_size,
                                       device=dins.device)

        new_hs, new_cs = [], []
        for layer_id in range(self.dec_num_layers):
            h, c = self.rnns[layer_id](
                douts, (dstate["hs"][layer_id], dstate["cs"][layer_id]))
            new_hs.append(h)
            new_cs.append(c)
            douts = self.dropout(h)
            if layer_id == 0:
                douts_1 = douts.unsqueeze(1)

        new_dstate = {}
        new_dstate["hs"] = torch.stack(new_hs, dim=0)
        new_dstate["cs"] = torch.stack(new_cs, dim=0)

        douts_top = douts.unsqueeze(1)

        return new_dstate, douts_1, douts_top

    def generate(self, ctx, douts):
        out = self.intermed(torch.cat([ctx, douts], dim=-1))

        return torch.tanh(out)

    def decode(
        self,
        eouts,
        elens,
        eouts_inter=None,
        beam_width=1,
        len_weight=0,
        lm=None,
        lm_weight=0,
        decode_ctc_weight=0,
        decode_phone=False,
    ):
        """ Beam search decoding
        """
        bs = eouts.size(0)
        if decode_ctc_weight == 1:
            print("CTC is used")
            # greedy
            return self.ctc.decode(eouts, elens, beam_width=1)

        assert bs == 1

        # init
        beam = {
            "hyp": [self.eos_id],
            "dstate": None,
            "score": 0.0,
            "las_ctx": eouts.new_zeros(bs, 1, self.enc_hidden_size),
            "las_dstate": None,
            "las_attn_weight": None,
            "score_ctc": 0.0,
            "ctc_state": None,
            "score_lm": 0.0,
            "lm_state": None,
        }
        if decode_ctc_weight > 0:
            pass
        beams = [beam]

        results = []

        for i in range(self.max_decode_ylen):
            new_beams = []

            for beam in beams:
                y_in = torch.tensor([[beam["hyp"][-1]]]).to(eouts.device)
                y_emb = self.dropout_emb(self.embed(y_in))
                ctx = beam["las_ctx"]
                dstate = beam["las_dstate"]
                attn_weight = beam["las_attn_weight"]

                logit, ctx, dstate, attn_weight = self.forward_one_step(
                    y_emb, ctx, eouts, dstate, attn_weight)
                logit = self.output(logit)

                scores_att = torch.log_softmax(logit.squeeze(0),
                                               dim=-1)  # (1, vocab)
                scores = scores_att

                if lm_weight > 0:
                    pass
                if decode_ctc_weight > 0:
                    pass
                else:
                    scores_topk, v_topk = torch.topk(scores,
                                                     k=beam_width,
                                                     dim=1)

                for j in range(beam_width):
                    new_beam = {}
                    new_beam["score"] = beam["score"] + float(scores_topk[0,
                                                                          j])
                    new_beam["hyp"] = beam["hyp"] + [int(v_topk[0, j])]
                    #
                    new_beam["las_ctx"] = ctx
                    new_beam["las_dstate"] = dstate
                    new_beam["las_attn_weight"] = attn_weight
                    if decode_ctc_weight > 0:
                        pass
                    new_beams.append(new_beam)

            # update `beams`
            beams = sorted(new_beams, key=lambda x: x["score"],
                           reverse=True)[:beam_width]

            beams_extend = []
            for beam in beams:
                # ended beams
                if beam["hyp"][-1] == self.eos_id:
                    hyp_noeos = strip_eos(beam["hyp"], self.eos_id)
                    # only <eos> is not acceptable
                    if len(hyp_noeos) < 1:
                        continue

                    # add length penalty
                    score = beam["score"] + len_weight * len(beam["hyp"])

                    results.append({"hyp": hyp_noeos, "score": score})

                    if len(results) >= beam_width:
                        break
                else:
                    beams_extend.append(beam)

            if len(results) >= beam_width:
                break

            beams = beams_extend

        results = sorted(results, key=lambda x: x["score"], reverse=True)
        hyps = [result["hyp"] for result in results]
        scores = [result["score"] for result in results]
        logits = None
        aligns = None

        return hyps, scores, logits, aligns
Example #3
0
class RNNTDecoder(nn.Module):
    def __init__(self, params, phase="train"):
        super(RNNTDecoder, self).__init__()

        self.dec_num_layers = params.dec_num_layers
        self.dec_hidden_size = params.dec_hidden_size
        self.eos_id = params.eos_id
        self.blank_id = params.blank_id
        self.max_seq_len = 256
        self.mtl_ctc_weight = params.mtl_ctc_weight
        self.kd_weight = params.kd_weight

        # Prediction network (decoder)
        # TODO: -> class
        self.embed = nn.Embedding(params.vocab_size, params.embedding_size)
        self.dropout_emb = nn.Dropout(p=params.dropout_emb_rate)
        self.dropout = nn.Dropout(p=params.dropout_dec_rate)

        self.rnns = nn.ModuleList()
        input_size = params.embedding_size
        for _ in range(self.dec_num_layers):
            self.rnns += [
                nn.LSTM(
                    input_size=input_size,
                    hidden_size=params.dec_hidden_size,
                    num_layers=1,
                    batch_first=True,
                )
            ]
            input_size = params.dec_hidden_size

        # Joint network
        # TODO: -> class
        self.w_enc = nn.Linear(params.enc_hidden_size,
                               params.joint_hidden_size)
        self.w_dec = nn.Linear(params.dec_hidden_size,
                               params.joint_hidden_size)
        self.output = nn.Linear(params.joint_hidden_size, params.vocab_size)

        if self.mtl_ctc_weight > 0:
            self.ctc = CTCDecoder(params)

        if phase == "train":
            logging.info(f"warp_rnnt version: {warp_rnnt.__version__}")

        if self.kd_weight > 0 and phase == "train":
            self.kd_type = params.kd_type
            self.reduce_main_loss_kd = params.reduce_main_loss_kd
            if self.kd_type == "word":
                self.transducer_kd_loss = RNNTWordDistillLoss()
            elif self.kd_type == "align":
                self.transducer_kd_loss = RNNTAlignDistillLoss()

                # cuda init only if forced aligner is used
                from asr.modeling.decoders.rnnt_aligner import \
                    RNNTForcedAligner

                self.forced_aligner = RNNTForcedAligner(blank_id=self.blank_id)

    def forward(
        self,
        eouts,
        elens,
        eouts_inter=None,
        ys=None,
        ylens=None,
        ys_in=None,
        ys_out=None,
        soft_labels=None,
        ps=None,
        plens=None,
    ):
        loss = 0
        loss_dict = {}

        # Prediction network
        douts, _ = self.recurrency(ys_in, dstate=None)

        # Joint network
        logits = self.joint(eouts, douts)  # (B, T, L + 1, vocab)
        log_probs = torch.log_softmax(logits, dim=-1)
        assert log_probs.size(2) == ys.size(1) + 1

        # NOTE: rnnt_loss only accepts ys, elens, ylens with torch.int
        loss_rnnt = warp_rnnt.rnnt_loss(
            log_probs,
            ys.int(),
            elens.int(),
            ylens.int(),
            average_frames=False,
            reduction="mean",
            blank=self.blank_id,
            gather=False,
        )
        loss += loss_rnnt  # main loss
        loss_dict["loss_rnnt"] = loss_rnnt

        if self.mtl_ctc_weight > 0:
            # NOTE: KD is not applied to auxiliary CTC
            loss_ctc, _, _ = self.ctc(eouts=eouts,
                                      elens=elens,
                                      ys=ys,
                                      ylens=ylens,
                                      soft_labels=None)
            loss += self.mtl_ctc_weight * loss_ctc  # auxiliary loss
            loss_dict["loss_ctc"] = loss_ctc

        if self.kd_weight > 0 and soft_labels is not None:
            if self.kd_type == "word":
                loss_kd = self.transducer_kd_loss(logits, soft_labels, elens,
                                                  ylens)
            elif self.kd_type == "align":
                aligns = self.forced_aligner(log_probs, elens, ys, ylens)
                loss_kd = self.transducer_kd_loss(logits, ys, soft_labels,
                                                  aligns, elens, ylens)

            loss_dict["loss_kd"] = loss_kd

            if self.reduce_main_loss_kd:
                loss = (1 - self.kd_weight) * loss + self.kd_weight * loss_kd
            else:
                loss += self.kd_weight * loss_kd

        loss_dict["loss_total"] = loss

        return loss, loss_dict, logits

    def joint(self, eouts, douts):
        """ Joint network
        """
        eouts = eouts.unsqueeze(2)  # (B, T, 1, enc_hidden_size)
        douts = douts.unsqueeze(1)  # (B, 1, L, dec_hidden_size)

        out = torch.tanh(self.w_enc(eouts) + self.w_dec(douts))
        out = self.output(out)  # (B, T, L, vocab)

        return out

    def recurrency(self, ys_in, dstate):
        """ Prediction network
        """
        ys_emb = self.dropout_emb(self.embed(ys_in))
        bs = ys_emb.size(0)

        if dstate is None:
            dstate = {}
            dstate["hs"] = torch.zeros(self.dec_num_layers,
                                       bs,
                                       self.dec_hidden_size,
                                       device=ys_in.device)
            dstate["cs"] = torch.zeros(self.dec_num_layers,
                                       bs,
                                       self.dec_hidden_size,
                                       device=ys_in.device)

        new_hs, new_cs = [], []
        for layer_id in range(self.dec_num_layers):
            self.rnns[layer_id].flatten_parameters()

            ys_emb, (h, c) = self.rnns[layer_id](
                ys_emb,
                hx=(
                    dstate["hs"][layer_id:layer_id +
                                 1],  # (1, B, dec_hidden_size)
                    dstate["cs"][layer_id:layer_id + 1],
                ),
            )
            new_hs.append(h)
            new_cs.append(c)
            ys_emb = self.dropout(ys_emb)

        new_dstate = {}
        new_dstate["hs"] = torch.cat(new_hs, dim=0)
        new_dstate["cs"] = torch.cat(new_cs, dim=0)

        return ys_emb, new_dstate

    def _greedy(self, eouts, elens, decode_ctc_weight=0):
        """ Greedy decoding
        """
        if decode_ctc_weight == 1:
            # greedy
            return self.ctc.decode(eouts, elens, beam_width=1)

        bs = eouts.size(0)

        hyps = []
        scores = []
        logits = None  # TODO
        aligns = []

        for b in range(bs):
            hyp = []
            align = []

            ys = eouts.new_zeros((1, 1),
                                 dtype=torch.long).fill_(self.eos_id)  # <sos>
            dout, dstate = self.recurrency(ys, None)

            T = elens[b]

            t = 0
            while t < T:
                out = self.joint(eouts[b:b + 1, t:t + 1],
                                 dout)  # (B, 1, 1, vocab_size)
                new_ys = out.squeeze(2).argmax(-1)
                token_id = new_ys[0].item()

                align.append(token_id)

                if token_id == self.blank_id:
                    t += 1
                else:
                    hyp.append(token_id)
                    dout, dstate = self.recurrency(new_ys, dstate)
                if len(hyp) > self.max_seq_len:
                    break

            hyps.append(hyp)
            # TODO
            scores.append(None)
            aligns.append(align)

        return hyps, scores, logits, aligns

    def _beam_search(self,
                     eouts,
                     elens,
                     beam_width=1,
                     len_weight=0,
                     lm=None,
                     lm_weight=0):
        """ Beam search decoding

        Reference:
            ALIGNMENT-LENGTH SYNCHRONOUS DECODING FOR RNN TRANSDUCER
            https://ieeexplore.ieee.org/document/9053040
        """
        bs = eouts.size(0)
        assert bs == 1
        NUM_EXPANDS = 3

        # init
        beam = {
            "hyp": [self.eos_id],  # <sos>
            "score": 0.0,
            "score_asr": 0.0,
            "dstate": {
                "hs":
                torch.zeros(self.dec_num_layers,
                            bs,
                            self.dec_hidden_size,
                            device=eouts.device),
                "cs":
                torch.zeros(self.dec_num_layers,
                            bs,
                            self.dec_hidden_size,
                            device=eouts.device)
            }
        }
        beams = [beam]

        # time synchronous decoding
        for t in range(eouts.size(1)):
            new_beams = []  # A
            beams_v = beams[:]  # C <- B

            for v in range(NUM_EXPANDS):
                new_beams_v = []  # D

                # prediction network
                ys = torch.zeros((len(beams_v), 1),
                                 dtype=torch.int64,
                                 device=eouts.device)
                for i, beam in enumerate(beams_v):
                    ys[i] = beam["hyp"][-1]
                dstates_prev = {
                    "hs":
                    torch.cat([beam["dstate"]["hs"] for beam in beams_v],
                              dim=1),
                    "cs":
                    torch.cat([beam["dstate"]["cs"] for beam in beams_v],
                              dim=1)
                }
                douts, dstates = self.recurrency(ys, dstates_prev)

                # for i, beam in enumerate(beams_v):
                #     beams_v[i]["dstate"] = {"hs": dstates["hs"][:, i:i + 1],
                #                             "cs": dstates["cs"][:, i:i + 1]}

                # joint network
                logits = self.joint(eouts[:, t:t + 1], douts)
                scores_asr = torch.log_softmax(logits.squeeze(2).squeeze(1),
                                               dim=-1)

                # blank expansion
                for i, beam in enumerate(beams_v):
                    blank_score = scores_asr[i, self.blank_id].item()
                    new_beams.append(beam.copy())
                    new_beams[-1]["score"] += blank_score
                    new_beams[-1]["score_asr"] += blank_score
                    # NOTE: do not update `dstate`

                for i, beam in enumerate(beams_v):
                    beams_v[i]["dstate"] = {
                        "hs": dstates["hs"][:, i:i + 1],
                        "cs": dstates["cs"][:, i:i + 1]
                    }

                # non-blank expansion
                if v < NUM_EXPANDS - 1:
                    for i, beam in enumerate(beams_v):
                        scores_topk, v_topk = torch.topk(scores_asr[i, 1:],
                                                         k=beam_width,
                                                         dim=-1,
                                                         largest=True,
                                                         sorted=True)
                        v_topk += 1

                        for k in range(beam_width):
                            v_index = v_topk[k].item()
                            new_beams_v.append({
                                "hyp":
                                beam["hyp"] + [v_index],
                                "score":
                                beam["score"] + scores_topk[k].item(),
                                "score_asr":
                                beam["score_asr"] + scores_topk[k].item(),
                                "dout":
                                None,
                                "dstate":
                                beam["dstate"]
                            })

                # Local pruning at each expansion
                new_beams_v = sorted(new_beams_v,
                                     key=lambda x: x["score"],
                                     reverse=True)
                new_beams_v = self._merge_rnnt_paths(new_beams_v)
                beams_v = new_beams_v[:beam_width]  # C <- D

            # Local pruning at t-th index
            new_beams = sorted(new_beams,
                               key=lambda x: x["score"],
                               reverse=True)
            new_beams = self._merge_rnnt_paths(new_beams)
            beams = new_beams[:beam_width]  # B <- A

        hyps = [beam["hyp"] for beam in beams]

        return hyps

    def decode(
        self,
        eouts,
        elens,
        eouts_inter=None,
        beam_width=1,
        len_weight=0,
        lm=None,
        lm_weight=0,
        decode_ctc_weight=0,
        decode_phone=False,
    ):
        if beam_width <= 1:
            hyps, scores, logits, aligns = self._greedy(
                eouts, elens, decode_ctc_weight)
        else:
            hyps = self._beam_search(eouts, elens, beam_width, len_weight, lm,
                                     lm_weight)
        scores, logits, aligns = None, None, None
        return hyps, scores, logits, aligns

    @staticmethod
    def _merge_rnnt_paths(beams):
        merged_beams = {}

        for beam in beams:
            hyp = ints2str(beam["hyp"])
            if hyp in merged_beams.keys():
                merged_beams[hyp]["score"] = np.logaddexp(
                    merged_beams[hyp]["score"], beam["score"])
            else:
                merged_beams[hyp] = beam

        return list(merged_beams.values())