Beispiel #1
0
    def forward(self, **kwargs):
        """
        encoder_output= "encoder_out": x,
                        "encoded": encoded,
                        "encoder_padding_mask": padding_mask,  # B x T
                        "padding_mask": padding_mask,
        """
        encoder_output = self.encoder(tbc=False, **kwargs)
        alphas = CIFFcModelV2.get_alphas(encoder_output)
        if self.training:
            _alphas, num_output = self.resize(alphas, kwargs['target_lengths'])
            padding_mask = ~utils.sequence_mask(kwargs['target_lengths']).bool()
        else:
            _alphas, num_output = self.resize(alphas)
            padding_mask = ~utils.sequence_mask(torch.round(num_output).int()).bool()

        cif_outputs = self.cif(encoder_output['encoder_out'][:, :, :-1], _alphas)
        hidden = self.proj(cif_outputs)

        if self.training:
            gold_rate = self.set_gold_rate()
            input_ids = kwargs['bert_input'].long()
        else:
            input_ids = None
            gold_rate = 0.0

        bert_output, gold_embedding, pred_mask = self.forward_embeded(
            hidden, padding_mask, input_ids, gold_rate)

        logits = self.final_proj(bert_output)

        return {'logits': logits, 'len_logits': kwargs['target_lengths'],
                'alphas': alphas, 'num_output': num_output,
                'embedding': hidden, 'gold_embedding': gold_embedding, 'pred_mask': pred_mask,
                'gold_rate': gold_rate}
Beispiel #2
0
    def generate(self, models, sample, **unused):
        """Generate a batch of inferences."""
        model = models[0]

        encoder_output = model.encoder(tbc=False, **sample["net_input"])
        alphas = CIFFcModelV2.get_alphas(encoder_output)
        decode_length = torch.round(alphas.sum(-1)).int()
        _alphas, num_output = model.resize(alphas, decode_length, noise=0.0)

        padding_mask = ~utils.sequence_mask(decode_length).bool()
        cif_outputs = model.cif(encoder_output['encoder_out'][:, :, :-1], _alphas)
        hidden = model.proj(cif_outputs)
        logits_ac = model.to_vocab_ac(hidden)

        infer_threash = self.infer_threshold if self.infer_threshold else model.args.infer_threash
        for i in range(1):
            logits, gold_embedding, pred_mask, token_mask = model.bert_forward(
                hidden, logits_ac, padding_mask, None, 0.0,
                # threash=0.8)
                threash=infer_threash)
            logits = logits_ac + model.args.lambda_lm * logits
        probs = utils.softmax(logits.float(), dim=-1)

        res = []
        for distribution, length in zip(probs, decode_length):
            result = distribution.argmax(-1)
            score = 0.0
            res.append([{'tokens': result[:length],
                         "score": score}])

        return res
Beispiel #3
0
    def generate(self, models, sample, **unused):
        """Generate a batch of inferences.
        EncoderOut(
            encoder_out=encoder_out['encoder_out'],  # T x B x C
            encoder_embedding=None,
            encoder_padding_mask=encoder_out['encoder_padding_mask'],  # B x T
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )
        """
        encoder_output = models[0].get_encoder_output(sample['net_input'])
        encoder_out = {
            "encoder_out":
            encoder_output.encoder_out.transpose(0, 1),  # B x T x C
            "padding_mask": encoder_output.encoder_padding_mask
        }
        alphas, _ = models[0].assigner(encoder_out)
        # _alphas, num_output = self.resize(alphas, kwargs['target_lengths'], at_least_one=True)
        cif_outputs = models[0].cif(encoder_out, alphas)
        src_lengths = torch.round(alphas.sum(-1)).int()
        self.step_forward_fn = models[0].decode
        encoder_output = EncoderOut(
            encoder_out=cif_outputs.transpose(0, 1),  # T x B x C
            encoder_embedding=None,
            encoder_padding_mask=~utils.sequence_mask(
                src_lengths, dtype=torch.bool),  # B x T
            encoder_states=None,
            src_tokens=None,
            src_lengths=src_lengths,
        )

        return self.decode(encoder_output)
Beispiel #4
0
    def forward(self, **kwargs):
        """
        encoder_output= "encoder_out": x,
                        "encoded": encoded,
                        "encoder_padding_mask": padding_mask,  # B x T
                        "padding_mask": padding_mask,
        """
        encoder_output = self.encoder(tbc=False, **kwargs)
        hidden_encoded = encoder_output['encoder_out'][:, :, :-1]
        hidden_ctc = F.pad(hidden_encoded, [0, 1, 0, 0, 0, 0], value=0)
        logits_ctc = self.to_vocab_ctc(hidden_ctc)
        len_logits_ctc = (~encoder_output['padding_mask']).sum(-1).long()
        alphas = CIFFcModelV2.get_alphas(encoder_output)

        if self.training:
            gold_rate = self.set_gold_rate()
            decode_length = kwargs['target_lengths']
            gold_ids = kwargs['bert_input'].long()
            noise = 0.0
        else:
            gold_rate = 0.0
            decode_length = torch.round(alphas.sum(-1)).int()
            gold_ids = None
            noise = 0.0

        _alphas, num_output = self.resize(alphas, decode_length, noise=noise)
        padding_mask = ~utils.sequence_mask(decode_length).bool()
        cif_outputs = self.cif(hidden_encoded, _alphas)
        hidden_ac = self.proj(cif_outputs)
        logits_ac = self.to_vocab_ac(hidden_ac)

        ft = self.freeze_lm_finetune_updates <= self.num_updates
        with torch.no_grad() if not ft else contextlib.ExitStack():
            logits_lm, gold_embedding, pred_mask, token_mask = self.bert_forward(
                hidden_ac,
                logits_ac,
                padding_mask,
                gold_ids,
                gold_rate,
                threash=self.args.infer_threash)
        logits = self.args.lambda_am * logits_ac + self.args.lambda_lm * logits_lm
        logits *= (~padding_mask).unsqueeze(-1).float()

        return {
            'logits': logits,
            'len_logits': decode_length,
            'alphas': alphas,
            'num_output': num_output,
            'gold_rate': gold_rate,
            'logits_ctc': logits_ctc,
            'len_logits_ctc': len_logits_ctc,
            'pred_mask': pred_mask[:, 1:-1],
            'token_mask': token_mask[:, 1:-1]
        }
Beispiel #5
0
    def forward(self, **kwargs):
        encoder_output = self.w2v_encoder(tbc=False,**kwargs['net_input'])
        hidden_encoded = encoder_output['encoder_out']
        # ctc part
        logits_ctc = self.to_vocab_ctc(hidden_encoded)
        # cif part
        alphas = get_alphas(self.fc_alpha,encoder_output)

        if self.training:
            decode_length = kwargs['target_lengths']
            targets = kwargs['target']
            targets_embs = self.gpt2.transformer.wte(targets.long())
        else:
            decode_length = torch.round(alphas.sum(-1)).int()
            targets = None
            targets_embs = None

        _alphas, num_output = self.resize(alphas, decode_length)
        padding_mask = ~utils.sequence_mask(decode_length).bool()
        cif_outputs = self.cif(hidden_encoded, _alphas).type_as(hidden_encoded)
        logits_ac = self.to_vocab(cif_outputs)

        # gpt2 part
        with torch.no_grad():
            device = cif_outputs.device
            bos_indx = torch.tensor([self.bos_idx]).to(device)
            sos_embeds = self.gpt2.transformer.wte(bos_indx).expand(cif_outputs.size(0), 1, cif_outputs.size(2))
            token_mask = F.pad(~padding_mask, [1, 0, 0, 0], value=0)
            attention_mask = token_mask.int()
            gpt_inputs = torch.cat((sos_embeds, cif_outputs), 1)
            gpt_outputs = self.gpt2(inputs_embeds=gpt_inputs,attention_mask=attention_mask).logits[:,1:,:]

            logits = self.cfg.lambda_am * logits_ac + self.cfg.lambda_lm * gpt_outputs

        result = {
            "encoder_out": logits_ctc.transpose(0, 1),  # T x B x C
            "padding_mask":encoder_output['padding_mask'],
            "cif_out":logits_ac ,  # B x T x C
            "cif_embeds": cif_outputs,
            "targets_embs":targets_embs,
            "len_logits": decode_length,
            "alphas": alphas,
            "num_output": num_output,
            "gpt2_out":gpt_outputs,
            "attention_mask":attention_mask,
            "logits":logits


        }
        return result
Beispiel #6
0
    def forward(self, **kwargs):
        """
        encoder_output= "encoder_out": x,
                        "encoded": encoded,
                        "encoder_padding_mask": padding_mask,  # B x T
                        "padding_mask": padding_mask,
        """
        encoder_output = self.encoder(tbc=False, **kwargs)
        alphas = CIFFcModelV2.get_alphas(encoder_output)
        input_ids = kwargs['bert_input'].long()
        if self.training:
            _alphas, num_output = self.resize(alphas, kwargs['target_lengths'])
            padding_mask = ~utils.sequence_mask(kwargs['target_lengths']).bool()
            gold_rate = self.set_gold_rate()
        else:
            decode_length = kwargs['decode_length']
            # _alphas, num_output = self.resize(alphas)
            # padding_mask = ~utils.sequence_mask(torch.round(num_output).int()).bool()
            _alphas, num_output = self.resize(alphas, decode_length)
            padding_mask = ~utils.sequence_mask(decode_length).bool()
            gold_rate = 0.0

        cif_outputs = self.cif(encoder_output['encoder_out'][:, :, :-1], _alphas)
        hidden = self.proj(cif_outputs)
        logits_ac = self.to_vocab_ac(hidden)

        logits, gold_embedding, pred_mask, token_mask = self.bert_forward(
            hidden, logits_ac, padding_mask, input_ids, gold_rate, threash=self.args.infer_threash)
        # logits = GradMultiply.apply(logits, 0.1)
        logits = logits_ac + 0.1 * logits

        return {'logits': logits, 'len_logits': kwargs['target_lengths'],
                'alphas': alphas, 'num_output': num_output,
                'embedding': hidden, 'gold_embedding': gold_embedding,
                'pred_mask': pred_mask, 'token_mask': token_mask,
                'gold_rate': gold_rate}
    def forward(self, **kwargs):
        """
        encoder_output= "encoder_out": x,
                        "encoded": encoded,
                        "encoder_padding_mask": padding_mask,  # B x T
                        "padding_mask": padding_mask,
        """
        encoder_output = self.encoder(tbc=False, **kwargs)
        hidden_encoded = encoder_output['encoder_out'][:, :, :-1]
        hidden_ctc = F.pad(hidden_encoded, [0, 1, 0, 0, 0, 0], value=0)
        logits_ctc = self.to_vocab_ctc(hidden_ctc)
        len_logits_ctc = (~encoder_output['padding_mask']).sum(-1).long()
        alphas = get_alphas(encoder_output)
        decode_length = kwargs[
            'target_lengths'] if self.training else torch.round(
                alphas.sum(-1)).int()
        padding_mask = ~utils.sequence_mask(decode_length).bool()
        _alphas, num_output = self.resize(alphas, decode_length)

        # if not self.training:
        #     import pdb; pdb.set_trace()
        encoder_out = EncoderOut(
            encoder_out=encoder_output['encoder_out'].transpose(
                0, 1),  # T x B x C
            encoder_embedding=None,
            encoder_padding_mask=encoder_output[
                'encoder_padding_mask'],  # B x T
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )
        prev_output_tokens = torch.ones_like(
            padding_mask) * self.tgt_dict.bos()
        decoder_out = self.decoder(encoder_out=encoder_out,
                                   prev_output_tokens=prev_output_tokens)
        logits = decoder_out["logits"]
        logits *= (~padding_mask).unsqueeze(-1).float()

        return {
            'logits': logits,
            'len_logits': decode_length,
            'alphas': alphas,
            'num_output': num_output,
            'logits_ctc': logits_ctc,
            'len_logits_ctc': len_logits_ctc
        }
Beispiel #8
0
    def decode(self, encoder_shrunk_out):
        encoded_logits = encoder_shrunk_out["encoded_shrunk"]
        padding_mask = utils.sequence_mask(encoder_shrunk_out["len_encoded_shrunk"],
                                           dtype=torch.bool, reverse=True)
        # prob = torch.softmax(encoded_logits[:, :, :-1], -1)
        ft = self.freeze_lm_finetune_updates <= self.num_updates
        with torch.no_grad() if not ft else contextlib.ExitStack():
            # embedded = torch.mm(prob.view(-1, prob.size(-1)),
            #                     self.lm.encoder.encoder.sentence_encoder.embed_tokens.weight[:-1, :]
            #                     ).view(prob.size(0), prob.size(1), -1)
            # embedded = self.proj(encoded_logits)
            # logits = self.lm.forward_embeded(embedded, padding_mask)
            logits = self.proj(encoded_logits)

        logits.batch_first = True

        return logits
Beispiel #9
0
    def forward(self, **kwargs):
        """
        encoder_output= "encoder_out": x,
                        "encoded": encoded,
                        "encoder_padding_mask": padding_mask,  # B x T
                        "padding_mask": padding_mask,
        """
        encoder_output = self.encoder(tbc=False, **kwargs)
        hidden_encoded = encoder_output['encoder_out'][:, :, :-1]
        hidden_ctc = F.pad(hidden_encoded, [0, 1, 0, 0, 0, 0], value=0)
        logits_ctc = self.to_vocab_ctc(hidden_ctc)
        len_logits_ctc = (~encoder_output['padding_mask']).sum(-1).long()
        alphas = CIFFcModelV2.get_alphas(encoder_output)

        if self.training:
            decode_length = kwargs['target_lengths']
        else:
            decode_length = torch.round(alphas.sum(-1)).int()
            decode_length = torch.max(decode_length,
                                      torch.ones_like(decode_length))
        padding_mask = ~utils.sequence_mask(decode_length).bool()
        _alphas, num_output = self.resize(alphas, decode_length)
        cif_outputs = self.cif(hidden_encoded, _alphas)
        hidden_ac = self.proj(cif_outputs)
        logits = self.to_vocab_ac(hidden_ac)
        logits *= (~padding_mask).unsqueeze(-1).float()
        gold_rate = 0.0

        return {
            'logits': logits,
            'len_logits': decode_length,
            'alphas': alphas,
            'num_output': num_output,
            'gold_rate': gold_rate,
            'logits_ctc': logits_ctc,
            'len_logits_ctc': len_logits_ctc
        }

        return logits
Beispiel #10
0
    def forward(self, src_tokens, src_lengths):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
        """
        x = self.dropout(self.pe(src_tokens))
        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = (1 - utils.sequence_mask(src_lengths)).bool()

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)
        x = self.layer_norm(x)

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_embedding=None,
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )
    def forward(self, **kwargs):
        """
        encoder_output= "encoder_out": x,
                        "encoded": encoded,
                        "encoder_padding_mask": padding_mask,  # B x T
                        "padding_mask": padding_mask,
        """
        encoder_output = self.encoder(tbc=False, **kwargs)
        hidden_encoded = encoder_output['encoder_out'][:, :, :-1]
        hidden_ctc = F.pad(hidden_encoded, [0, 1, 0, 0, 0, 0], value=0)
        logits_ctc = self.to_vocab_ctc(hidden_ctc)
        len_logits_ctc = (~encoder_output['padding_mask']).sum(-1).long()
        alphas = get_alphas(encoder_output)
        decode_length = kwargs[
            'target_lengths'] if self.training else torch.round(
                alphas.sum(-1)).int()
        _, num_output = self.resize(alphas, decode_length)

        padding_mask = ~utils.sequence_mask(decode_length).bool()
        token_mask = ~padding_mask
        mask_ids = torch.ones_like(padding_mask) * self.tgt_dict.bos()
        # emcoded = self.proj(hidden_encoded)
        encoder_out = EncoderOut(
            encoder_out=hidden_ctc.transpose(0, 1),  # T x B x C
            encoder_embedding=None,
            encoder_padding_mask=encoder_output[
                'encoder_padding_mask'],  # B x T
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )

        if self.training:
            gold_ids = kwargs['target'].long()
            rand = torch.rand(gold_ids.size(),
                              device=gold_ids.device) * token_mask
            list_pred_mask = []
            for i, l in enumerate(decode_length):
                k = random.randint(1, l)
                list_pred_mask.append(
                    rand[i] >= torch.topk(rand[i], k).values.min())
            pred_mask = torch.stack(list_pred_mask, 0) * token_mask
            gold_mask = ~pred_mask * token_mask
            gold_rate = gold_mask.sum() * 1.0 / token_mask.sum()
            decoder_input_ids = torch.where(pred_mask, mask_ids, gold_ids)
            logits = self.decoder(encoder_out=encoder_out,
                                  prev_output_tokens=decoder_input_ids)
            # import pdb; pdb.set_trace()
        else:
            pred_mask = gold_rate = 0.0
            decoder_input_ids = mask_ids
            for _ in range(10):
                logits = self.decoder(encoder_out=encoder_out,
                                      prev_output_tokens=decoder_input_ids)
                probs, pred_ids = utils.softmax(logits, dim=-1).max(-1)
                gold_mask = probs > 0.9
                decoder_input_ids = torch.where(gold_mask, pred_ids,
                                                mask_ids) * token_mask

        logits *= token_mask.unsqueeze(-1).float()

        return {
            'logits': logits,
            'len_logits': decode_length,
            'gold_rate': gold_rate,
            'alphas': alphas,
            'num_output': num_output,
            'pred_mask': pred_mask,
            'logits_ctc': logits_ctc,
            'len_logits_ctc': len_logits_ctc
        }
Beispiel #12
0
    def forward(self, **kwargs):
        """
        encoder_output= "encoder_out": x,
                        "encoded": encoded,
                        "encoder_padding_mask": padding_mask,  # B x T
                        "padding_mask": padding_mask,
        """
        encoder_output = self.encoder(tbc=False, **kwargs)
        hidden_encoded = encoder_output['encoder_out'][:, :, :-1]
        hidden_ctc = F.pad(hidden_encoded, [0, 1, 0, 0, 0, 0], value=0)
        logits_ctc = self.to_vocab_ctc(hidden_ctc)
        len_logits_ctc = (~encoder_output['padding_mask']).sum(-1).long()
        alphas = CIFFcModelV2.get_alphas(encoder_output)

        if self.training:
            decode_length = kwargs['target_lengths']
        else:
            decode_length = torch.round(alphas.sum(-1)).int()
            decode_length = torch.max(decode_length,
                                      torch.ones_like(decode_length))
        _alphas, num_output = self.resize(alphas, decode_length)
        cif_outputs = self.cif(hidden_encoded, _alphas)
        hidden_ac = self.proj(cif_outputs)
        logits_ac = self.to_vocab_ac(hidden_ac)

        # other inputs
        B, T = hidden_ac.size(0), hidden_ac.size(1)
        padding_mask = ~utils.sequence_mask(decode_length).bool()
        # [batch_size, num_heads, seq_length, seq_length]
        # zeros = torch.zeros([B, 1, T, T]).cuda()
        ones = torch.ones([B, 1, T, T]).cuda()
        diag = torch.diag(torch.ones([T])).cuda()[None, None, :, :]
        tril = torch.tril(torch.ones([T, T])).cuda()[None, None, :, :]
        rm_padding_mask = (~padding_mask)[:, None, None, :] * \
                          (~padding_mask)[:, None, None, :].permute(0, 1, 3, 2)
        # mask_acQac = ones * rm_padding_mask
        # mask_lmQac = diag * rm_padding_mask
        # mask_lmQac = zeros
        mask_lmQac = ones * rm_padding_mask
        mask_lmQlm = tril * rm_padding_mask
        mask_lm = torch.cat([mask_lmQac, mask_lmQlm], dim=-1)
        # mask_ac = torch.ones_like(mask_lm)
        attention_mask = torch.cat([mask_lm, mask_lm], dim=-2)

        if self.training:
            input_ids = kwargs['prev_output_tokens']
            gold_rate = self.set_gold_rate()
            input_ids = self.schedule_samlping(gold_rate, input_ids, logits_ac,
                                               padding_mask)
            text_embs = self.gpt2.transformer.wte(input_ids)
            ft = self.freeze_lm_finetune_updates <= self.num_updates
            with torch.no_grad() if not ft else contextlib.ExitStack():
                outputs = self.gpt2(
                    inputs_embeds=text_embs,
                    external_embeds=hidden_ac,
                    # attention_mask=attention_mask,
                )
            logits_lm = outputs[0]
            logits = self.args.lambda_am * logits_ac + self.args.lambda_lm * logits_lm
        else:
            gold_rate = 0.0
            list_logits = []
            device, dtype = kwargs['prev_output_tokens'].device, kwargs[
                'prev_output_tokens'].dtype
            decoded = torch.ones([B, 1], device=device,
                                 dtype=dtype) * self.tgt_dict.bos()
            text_embs = self.gpt2.transformer.wte(decoded)
            for i in range(T):
                outputs = self.gpt2(
                    inputs_embeds=text_embs,
                    external_embeds=hidden_ac,
                    # attention_mask=attention_mask[:, :, :T+i+1, :T+i+1]
                )
                logits_lm = outputs[0][..., -1, :]
                logits_i = self.args.lambda_am * logits_ac[
                    ..., i, :] + self.args.lambda_lm * logits_lm
                list_logits.append(logits_i.unsqueeze(1))
                preds = torch.argmax(logits_i, -1)[:, None]
                cur_embs = self.gpt2.transformer.wte(preds)
                text_embs = torch.cat([text_embs, cur_embs], dim=1)
            logits = torch.cat(list_logits, 1)
        logits *= (~padding_mask).unsqueeze(-1).float()

        return {
            'logits': logits,
            'len_logits': decode_length,
            'alphas': alphas,
            'num_output': num_output,
            'gold_rate': gold_rate,
            'logits_ctc': logits_ctc,
            'len_logits_ctc': len_logits_ctc
        }

        return logits
Beispiel #13
0
    def forward(self, **kwargs):
        """
        encoder_output= "encoder_out": x,
                        "encoded": encoded,
                        "encoder_padding_mask": padding_mask,  # B x T
                        "padding_mask": padding_mask,
        """
        encoder_output = self.encoder(tbc=False, **kwargs)
        hidden_encoded = encoder_output['encoder_out'][:, :, :-1]
        hidden_ctc = F.pad(hidden_encoded, [0, 1, 0, 0, 0, 0], value=0)
        logits_ctc = self.to_vocab_ctc(hidden_ctc)
        len_logits_ctc = (~encoder_output['padding_mask']).sum(-1).long()
        alphas = CIFFcModelV2.get_alphas(encoder_output)

        if self.training:
            decode_length = kwargs['target_lengths']
        else:
            decode_length = torch.round(alphas.sum(-1)).int()
            decode_length = torch.max(decode_length,
                                      torch.ones_like(decode_length))
        _alphas, num_output = self.resize(alphas, decode_length)
        cif_outputs = self.cif(hidden_encoded, _alphas)
        hidden_ac = self.proj(cif_outputs)
        logits_ac = self.to_vocab_ac(hidden_ac)

        # other inputs
        B, T = hidden_ac.size(0), hidden_ac.size(1)
        padding_mask = ~utils.sequence_mask(decode_length).bool()
        position_ac = torch.arange(T).repeat(B, 1).long().cuda()
        type_ac = torch.ones((B, T)).long().cuda() * 103
        # [batch_size, num_heads, seq_length, seq_length]
        zeros = torch.zeros([B, 1, T, T]).cuda()
        ones = torch.ones([T, T]).cuda()[None, None, :, :]
        diag = torch.diag(torch.ones([T])).cuda()[None, None, :, :]
        tril = torch.tril(torch.ones([T, T])).cuda()[None, None, :, :]
        rm_padding_mask = (~padding_mask)[:, None, None, :] * \
                          (~padding_mask)[:, None, None, :].permute(0, 1, 3, 2)
        # mask_acQac = ones * rm_padding_mask
        mask_acQac = diag * rm_padding_mask
        mask_lmQac = diag * rm_padding_mask
        mask_lmQlm = tril * rm_padding_mask
        mask_ac = torch.cat([mask_acQac, zeros], dim=-1)
        mask_lm = torch.cat([mask_lmQac, mask_lmQlm], dim=-1)
        attention_mask = torch.cat([mask_ac, mask_lm], dim=-2)
        gold_rate = 0.0

        if self.training:
            self.gpt2.eval()
            input_ids = kwargs['prev_output_tokens']
            # input_ids = self.tokenizer.encode("The Manhattan bridge is a major")
            # input_ids = torch.tensor([[self.tokenizer.bos_token_id] + input_ids + [100] * 3]).cuda()
            text_embs = self.gpt2.transformer.wte(input_ids)
            ft = self.freeze_lm_finetune_updates <= self.num_updates
            with torch.no_grad() if not ft else contextlib.ExitStack():
                input_embs = text_embs
                # input_embs = torch.cat([hidden_ac, text_embs], dim=1)
                # token_type = torch.zeros_like(input_ids)
                # type_ids = torch.cat([type_ac, token_type], dim=1)
                # position_ids = torch.cat([position_ac+self.args.position_bias, position_ac], dim=1)
                outputs = self.gpt2(
                    inputs_embeds=input_embs,
                    # token_type_ids=type_ids if not self.args.no_type_id else None,
                    # position_ids=position_ids,
                    # attention_mask=attention_mask,
                )
            logits = outputs[0]

            # print(torch.argmax(logits, -1))
            print(torch.argmax(logits, -1)[:, -T:])
            import pdb
            pdb.set_trace()
        else:
            list_logits = []
            token_type = torch.zeros_like(type_ac)
            decoded = torch.ones([B, 1],
                                 device=type_ac.device,
                                 dtype=type_ac.dtype) * self.tgt_dict.bos()
            text_embs = self.gpt2.transformer.wte(decoded)
            input_embs = text_embs
            # input_embs = torch.cat([hidden_ac, text_embs], dim=1)
            # type_ids = torch.cat([type_ac, token_type], dim=1)
            # position_ids = torch.cat([position_ac+self.args.position_bias, position_ac], dim=1)
            for i in range(T):
                outputs = self.gpt2(
                    inputs_embeds=input_embs,
                    # token_type_ids=type_ids[:, :T+i+1] if not self.args.no_type_id else None,
                    # position_ids=position_ids[:, :T+i+1],
                    # attention_mask=attention_mask[:, :, :T+i+1, :T+i+1]
                )
                logits_lm = outputs[0][..., -1, :]
                # logits_i = self.args.lambda_am * logits_ac[..., i, :] + self.args.lambda_lm * logits_lm
                logits_i = logits_lm
                list_logits.append(logits_i.unsqueeze(1))
                preds = torch.argmax(logits_i, -1)[:, None]
                cur_embs = self.gpt2.transformer.wte(preds)
                input_embs = torch.cat([input_embs, cur_embs], dim=1)
            logits = torch.cat(list_logits, 1)
        logits *= (~padding_mask).unsqueeze(-1).float()

        return {
            'logits': logits,
            'len_logits': decode_length,
            'alphas': alphas,
            'num_output': num_output,
            'gold_rate': gold_rate,
            'logits_ctc': logits_ctc,
            'len_logits_ctc': len_logits_ctc
        }

        return logits