def forward(self, **kwargs): """ encoder_output= "encoder_out": x, "encoded": encoded, "encoder_padding_mask": padding_mask, # B x T "padding_mask": padding_mask, """ encoder_output = self.encoder(tbc=False, **kwargs) alphas = CIFFcModelV2.get_alphas(encoder_output) if self.training: _alphas, num_output = self.resize(alphas, kwargs['target_lengths']) padding_mask = ~utils.sequence_mask(kwargs['target_lengths']).bool() else: _alphas, num_output = self.resize(alphas) padding_mask = ~utils.sequence_mask(torch.round(num_output).int()).bool() cif_outputs = self.cif(encoder_output['encoder_out'][:, :, :-1], _alphas) hidden = self.proj(cif_outputs) if self.training: gold_rate = self.set_gold_rate() input_ids = kwargs['bert_input'].long() else: input_ids = None gold_rate = 0.0 bert_output, gold_embedding, pred_mask = self.forward_embeded( hidden, padding_mask, input_ids, gold_rate) logits = self.final_proj(bert_output) return {'logits': logits, 'len_logits': kwargs['target_lengths'], 'alphas': alphas, 'num_output': num_output, 'embedding': hidden, 'gold_embedding': gold_embedding, 'pred_mask': pred_mask, 'gold_rate': gold_rate}
def generate(self, models, sample, **unused): """Generate a batch of inferences.""" model = models[0] encoder_output = model.encoder(tbc=False, **sample["net_input"]) alphas = CIFFcModelV2.get_alphas(encoder_output) decode_length = torch.round(alphas.sum(-1)).int() _alphas, num_output = model.resize(alphas, decode_length, noise=0.0) padding_mask = ~utils.sequence_mask(decode_length).bool() cif_outputs = model.cif(encoder_output['encoder_out'][:, :, :-1], _alphas) hidden = model.proj(cif_outputs) logits_ac = model.to_vocab_ac(hidden) infer_threash = self.infer_threshold if self.infer_threshold else model.args.infer_threash for i in range(1): logits, gold_embedding, pred_mask, token_mask = model.bert_forward( hidden, logits_ac, padding_mask, None, 0.0, # threash=0.8) threash=infer_threash) logits = logits_ac + model.args.lambda_lm * logits probs = utils.softmax(logits.float(), dim=-1) res = [] for distribution, length in zip(probs, decode_length): result = distribution.argmax(-1) score = 0.0 res.append([{'tokens': result[:length], "score": score}]) return res
def generate(self, models, sample, **unused): """Generate a batch of inferences. EncoderOut( encoder_out=encoder_out['encoder_out'], # T x B x C encoder_embedding=None, encoder_padding_mask=encoder_out['encoder_padding_mask'], # B x T encoder_states=None, src_tokens=None, src_lengths=None, ) """ encoder_output = models[0].get_encoder_output(sample['net_input']) encoder_out = { "encoder_out": encoder_output.encoder_out.transpose(0, 1), # B x T x C "padding_mask": encoder_output.encoder_padding_mask } alphas, _ = models[0].assigner(encoder_out) # _alphas, num_output = self.resize(alphas, kwargs['target_lengths'], at_least_one=True) cif_outputs = models[0].cif(encoder_out, alphas) src_lengths = torch.round(alphas.sum(-1)).int() self.step_forward_fn = models[0].decode encoder_output = EncoderOut( encoder_out=cif_outputs.transpose(0, 1), # T x B x C encoder_embedding=None, encoder_padding_mask=~utils.sequence_mask( src_lengths, dtype=torch.bool), # B x T encoder_states=None, src_tokens=None, src_lengths=src_lengths, ) return self.decode(encoder_output)
def forward(self, **kwargs): """ encoder_output= "encoder_out": x, "encoded": encoded, "encoder_padding_mask": padding_mask, # B x T "padding_mask": padding_mask, """ encoder_output = self.encoder(tbc=False, **kwargs) hidden_encoded = encoder_output['encoder_out'][:, :, :-1] hidden_ctc = F.pad(hidden_encoded, [0, 1, 0, 0, 0, 0], value=0) logits_ctc = self.to_vocab_ctc(hidden_ctc) len_logits_ctc = (~encoder_output['padding_mask']).sum(-1).long() alphas = CIFFcModelV2.get_alphas(encoder_output) if self.training: gold_rate = self.set_gold_rate() decode_length = kwargs['target_lengths'] gold_ids = kwargs['bert_input'].long() noise = 0.0 else: gold_rate = 0.0 decode_length = torch.round(alphas.sum(-1)).int() gold_ids = None noise = 0.0 _alphas, num_output = self.resize(alphas, decode_length, noise=noise) padding_mask = ~utils.sequence_mask(decode_length).bool() cif_outputs = self.cif(hidden_encoded, _alphas) hidden_ac = self.proj(cif_outputs) logits_ac = self.to_vocab_ac(hidden_ac) ft = self.freeze_lm_finetune_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): logits_lm, gold_embedding, pred_mask, token_mask = self.bert_forward( hidden_ac, logits_ac, padding_mask, gold_ids, gold_rate, threash=self.args.infer_threash) logits = self.args.lambda_am * logits_ac + self.args.lambda_lm * logits_lm logits *= (~padding_mask).unsqueeze(-1).float() return { 'logits': logits, 'len_logits': decode_length, 'alphas': alphas, 'num_output': num_output, 'gold_rate': gold_rate, 'logits_ctc': logits_ctc, 'len_logits_ctc': len_logits_ctc, 'pred_mask': pred_mask[:, 1:-1], 'token_mask': token_mask[:, 1:-1] }
def forward(self, **kwargs): encoder_output = self.w2v_encoder(tbc=False,**kwargs['net_input']) hidden_encoded = encoder_output['encoder_out'] # ctc part logits_ctc = self.to_vocab_ctc(hidden_encoded) # cif part alphas = get_alphas(self.fc_alpha,encoder_output) if self.training: decode_length = kwargs['target_lengths'] targets = kwargs['target'] targets_embs = self.gpt2.transformer.wte(targets.long()) else: decode_length = torch.round(alphas.sum(-1)).int() targets = None targets_embs = None _alphas, num_output = self.resize(alphas, decode_length) padding_mask = ~utils.sequence_mask(decode_length).bool() cif_outputs = self.cif(hidden_encoded, _alphas).type_as(hidden_encoded) logits_ac = self.to_vocab(cif_outputs) # gpt2 part with torch.no_grad(): device = cif_outputs.device bos_indx = torch.tensor([self.bos_idx]).to(device) sos_embeds = self.gpt2.transformer.wte(bos_indx).expand(cif_outputs.size(0), 1, cif_outputs.size(2)) token_mask = F.pad(~padding_mask, [1, 0, 0, 0], value=0) attention_mask = token_mask.int() gpt_inputs = torch.cat((sos_embeds, cif_outputs), 1) gpt_outputs = self.gpt2(inputs_embeds=gpt_inputs,attention_mask=attention_mask).logits[:,1:,:] logits = self.cfg.lambda_am * logits_ac + self.cfg.lambda_lm * gpt_outputs result = { "encoder_out": logits_ctc.transpose(0, 1), # T x B x C "padding_mask":encoder_output['padding_mask'], "cif_out":logits_ac , # B x T x C "cif_embeds": cif_outputs, "targets_embs":targets_embs, "len_logits": decode_length, "alphas": alphas, "num_output": num_output, "gpt2_out":gpt_outputs, "attention_mask":attention_mask, "logits":logits } return result
def forward(self, **kwargs): """ encoder_output= "encoder_out": x, "encoded": encoded, "encoder_padding_mask": padding_mask, # B x T "padding_mask": padding_mask, """ encoder_output = self.encoder(tbc=False, **kwargs) alphas = CIFFcModelV2.get_alphas(encoder_output) input_ids = kwargs['bert_input'].long() if self.training: _alphas, num_output = self.resize(alphas, kwargs['target_lengths']) padding_mask = ~utils.sequence_mask(kwargs['target_lengths']).bool() gold_rate = self.set_gold_rate() else: decode_length = kwargs['decode_length'] # _alphas, num_output = self.resize(alphas) # padding_mask = ~utils.sequence_mask(torch.round(num_output).int()).bool() _alphas, num_output = self.resize(alphas, decode_length) padding_mask = ~utils.sequence_mask(decode_length).bool() gold_rate = 0.0 cif_outputs = self.cif(encoder_output['encoder_out'][:, :, :-1], _alphas) hidden = self.proj(cif_outputs) logits_ac = self.to_vocab_ac(hidden) logits, gold_embedding, pred_mask, token_mask = self.bert_forward( hidden, logits_ac, padding_mask, input_ids, gold_rate, threash=self.args.infer_threash) # logits = GradMultiply.apply(logits, 0.1) logits = logits_ac + 0.1 * logits return {'logits': logits, 'len_logits': kwargs['target_lengths'], 'alphas': alphas, 'num_output': num_output, 'embedding': hidden, 'gold_embedding': gold_embedding, 'pred_mask': pred_mask, 'token_mask': token_mask, 'gold_rate': gold_rate}
def forward(self, **kwargs): """ encoder_output= "encoder_out": x, "encoded": encoded, "encoder_padding_mask": padding_mask, # B x T "padding_mask": padding_mask, """ encoder_output = self.encoder(tbc=False, **kwargs) hidden_encoded = encoder_output['encoder_out'][:, :, :-1] hidden_ctc = F.pad(hidden_encoded, [0, 1, 0, 0, 0, 0], value=0) logits_ctc = self.to_vocab_ctc(hidden_ctc) len_logits_ctc = (~encoder_output['padding_mask']).sum(-1).long() alphas = get_alphas(encoder_output) decode_length = kwargs[ 'target_lengths'] if self.training else torch.round( alphas.sum(-1)).int() padding_mask = ~utils.sequence_mask(decode_length).bool() _alphas, num_output = self.resize(alphas, decode_length) # if not self.training: # import pdb; pdb.set_trace() encoder_out = EncoderOut( encoder_out=encoder_output['encoder_out'].transpose( 0, 1), # T x B x C encoder_embedding=None, encoder_padding_mask=encoder_output[ 'encoder_padding_mask'], # B x T encoder_states=None, src_tokens=None, src_lengths=None, ) prev_output_tokens = torch.ones_like( padding_mask) * self.tgt_dict.bos() decoder_out = self.decoder(encoder_out=encoder_out, prev_output_tokens=prev_output_tokens) logits = decoder_out["logits"] logits *= (~padding_mask).unsqueeze(-1).float() return { 'logits': logits, 'len_logits': decode_length, 'alphas': alphas, 'num_output': num_output, 'logits_ctc': logits_ctc, 'len_logits_ctc': len_logits_ctc }
def decode(self, encoder_shrunk_out): encoded_logits = encoder_shrunk_out["encoded_shrunk"] padding_mask = utils.sequence_mask(encoder_shrunk_out["len_encoded_shrunk"], dtype=torch.bool, reverse=True) # prob = torch.softmax(encoded_logits[:, :, :-1], -1) ft = self.freeze_lm_finetune_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): # embedded = torch.mm(prob.view(-1, prob.size(-1)), # self.lm.encoder.encoder.sentence_encoder.embed_tokens.weight[:-1, :] # ).view(prob.size(0), prob.size(1), -1) # embedded = self.proj(encoded_logits) # logits = self.lm.forward_embeded(embedded, padding_mask) logits = self.proj(encoded_logits) logits.batch_first = True return logits
def forward(self, **kwargs): """ encoder_output= "encoder_out": x, "encoded": encoded, "encoder_padding_mask": padding_mask, # B x T "padding_mask": padding_mask, """ encoder_output = self.encoder(tbc=False, **kwargs) hidden_encoded = encoder_output['encoder_out'][:, :, :-1] hidden_ctc = F.pad(hidden_encoded, [0, 1, 0, 0, 0, 0], value=0) logits_ctc = self.to_vocab_ctc(hidden_ctc) len_logits_ctc = (~encoder_output['padding_mask']).sum(-1).long() alphas = CIFFcModelV2.get_alphas(encoder_output) if self.training: decode_length = kwargs['target_lengths'] else: decode_length = torch.round(alphas.sum(-1)).int() decode_length = torch.max(decode_length, torch.ones_like(decode_length)) padding_mask = ~utils.sequence_mask(decode_length).bool() _alphas, num_output = self.resize(alphas, decode_length) cif_outputs = self.cif(hidden_encoded, _alphas) hidden_ac = self.proj(cif_outputs) logits = self.to_vocab_ac(hidden_ac) logits *= (~padding_mask).unsqueeze(-1).float() gold_rate = 0.0 return { 'logits': logits, 'len_logits': decode_length, 'alphas': alphas, 'num_output': num_output, 'gold_rate': gold_rate, 'logits_ctc': logits_ctc, 'len_logits_ctc': len_logits_ctc } return logits
def forward(self, src_tokens, src_lengths): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` """ x = self.dropout(self.pe(src_tokens)) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = (1 - utils.sequence_mask(src_lengths)).bool() # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) x = self.layer_norm(x) return EncoderOut( encoder_out=x, # T x B x C encoder_embedding=None, encoder_padding_mask=encoder_padding_mask, # B x T encoder_states=None, src_tokens=None, src_lengths=None, )
def forward(self, **kwargs): """ encoder_output= "encoder_out": x, "encoded": encoded, "encoder_padding_mask": padding_mask, # B x T "padding_mask": padding_mask, """ encoder_output = self.encoder(tbc=False, **kwargs) hidden_encoded = encoder_output['encoder_out'][:, :, :-1] hidden_ctc = F.pad(hidden_encoded, [0, 1, 0, 0, 0, 0], value=0) logits_ctc = self.to_vocab_ctc(hidden_ctc) len_logits_ctc = (~encoder_output['padding_mask']).sum(-1).long() alphas = get_alphas(encoder_output) decode_length = kwargs[ 'target_lengths'] if self.training else torch.round( alphas.sum(-1)).int() _, num_output = self.resize(alphas, decode_length) padding_mask = ~utils.sequence_mask(decode_length).bool() token_mask = ~padding_mask mask_ids = torch.ones_like(padding_mask) * self.tgt_dict.bos() # emcoded = self.proj(hidden_encoded) encoder_out = EncoderOut( encoder_out=hidden_ctc.transpose(0, 1), # T x B x C encoder_embedding=None, encoder_padding_mask=encoder_output[ 'encoder_padding_mask'], # B x T encoder_states=None, src_tokens=None, src_lengths=None, ) if self.training: gold_ids = kwargs['target'].long() rand = torch.rand(gold_ids.size(), device=gold_ids.device) * token_mask list_pred_mask = [] for i, l in enumerate(decode_length): k = random.randint(1, l) list_pred_mask.append( rand[i] >= torch.topk(rand[i], k).values.min()) pred_mask = torch.stack(list_pred_mask, 0) * token_mask gold_mask = ~pred_mask * token_mask gold_rate = gold_mask.sum() * 1.0 / token_mask.sum() decoder_input_ids = torch.where(pred_mask, mask_ids, gold_ids) logits = self.decoder(encoder_out=encoder_out, prev_output_tokens=decoder_input_ids) # import pdb; pdb.set_trace() else: pred_mask = gold_rate = 0.0 decoder_input_ids = mask_ids for _ in range(10): logits = self.decoder(encoder_out=encoder_out, prev_output_tokens=decoder_input_ids) probs, pred_ids = utils.softmax(logits, dim=-1).max(-1) gold_mask = probs > 0.9 decoder_input_ids = torch.where(gold_mask, pred_ids, mask_ids) * token_mask logits *= token_mask.unsqueeze(-1).float() return { 'logits': logits, 'len_logits': decode_length, 'gold_rate': gold_rate, 'alphas': alphas, 'num_output': num_output, 'pred_mask': pred_mask, 'logits_ctc': logits_ctc, 'len_logits_ctc': len_logits_ctc }
def forward(self, **kwargs): """ encoder_output= "encoder_out": x, "encoded": encoded, "encoder_padding_mask": padding_mask, # B x T "padding_mask": padding_mask, """ encoder_output = self.encoder(tbc=False, **kwargs) hidden_encoded = encoder_output['encoder_out'][:, :, :-1] hidden_ctc = F.pad(hidden_encoded, [0, 1, 0, 0, 0, 0], value=0) logits_ctc = self.to_vocab_ctc(hidden_ctc) len_logits_ctc = (~encoder_output['padding_mask']).sum(-1).long() alphas = CIFFcModelV2.get_alphas(encoder_output) if self.training: decode_length = kwargs['target_lengths'] else: decode_length = torch.round(alphas.sum(-1)).int() decode_length = torch.max(decode_length, torch.ones_like(decode_length)) _alphas, num_output = self.resize(alphas, decode_length) cif_outputs = self.cif(hidden_encoded, _alphas) hidden_ac = self.proj(cif_outputs) logits_ac = self.to_vocab_ac(hidden_ac) # other inputs B, T = hidden_ac.size(0), hidden_ac.size(1) padding_mask = ~utils.sequence_mask(decode_length).bool() # [batch_size, num_heads, seq_length, seq_length] # zeros = torch.zeros([B, 1, T, T]).cuda() ones = torch.ones([B, 1, T, T]).cuda() diag = torch.diag(torch.ones([T])).cuda()[None, None, :, :] tril = torch.tril(torch.ones([T, T])).cuda()[None, None, :, :] rm_padding_mask = (~padding_mask)[:, None, None, :] * \ (~padding_mask)[:, None, None, :].permute(0, 1, 3, 2) # mask_acQac = ones * rm_padding_mask # mask_lmQac = diag * rm_padding_mask # mask_lmQac = zeros mask_lmQac = ones * rm_padding_mask mask_lmQlm = tril * rm_padding_mask mask_lm = torch.cat([mask_lmQac, mask_lmQlm], dim=-1) # mask_ac = torch.ones_like(mask_lm) attention_mask = torch.cat([mask_lm, mask_lm], dim=-2) if self.training: input_ids = kwargs['prev_output_tokens'] gold_rate = self.set_gold_rate() input_ids = self.schedule_samlping(gold_rate, input_ids, logits_ac, padding_mask) text_embs = self.gpt2.transformer.wte(input_ids) ft = self.freeze_lm_finetune_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): outputs = self.gpt2( inputs_embeds=text_embs, external_embeds=hidden_ac, # attention_mask=attention_mask, ) logits_lm = outputs[0] logits = self.args.lambda_am * logits_ac + self.args.lambda_lm * logits_lm else: gold_rate = 0.0 list_logits = [] device, dtype = kwargs['prev_output_tokens'].device, kwargs[ 'prev_output_tokens'].dtype decoded = torch.ones([B, 1], device=device, dtype=dtype) * self.tgt_dict.bos() text_embs = self.gpt2.transformer.wte(decoded) for i in range(T): outputs = self.gpt2( inputs_embeds=text_embs, external_embeds=hidden_ac, # attention_mask=attention_mask[:, :, :T+i+1, :T+i+1] ) logits_lm = outputs[0][..., -1, :] logits_i = self.args.lambda_am * logits_ac[ ..., i, :] + self.args.lambda_lm * logits_lm list_logits.append(logits_i.unsqueeze(1)) preds = torch.argmax(logits_i, -1)[:, None] cur_embs = self.gpt2.transformer.wte(preds) text_embs = torch.cat([text_embs, cur_embs], dim=1) logits = torch.cat(list_logits, 1) logits *= (~padding_mask).unsqueeze(-1).float() return { 'logits': logits, 'len_logits': decode_length, 'alphas': alphas, 'num_output': num_output, 'gold_rate': gold_rate, 'logits_ctc': logits_ctc, 'len_logits_ctc': len_logits_ctc } return logits
def forward(self, **kwargs): """ encoder_output= "encoder_out": x, "encoded": encoded, "encoder_padding_mask": padding_mask, # B x T "padding_mask": padding_mask, """ encoder_output = self.encoder(tbc=False, **kwargs) hidden_encoded = encoder_output['encoder_out'][:, :, :-1] hidden_ctc = F.pad(hidden_encoded, [0, 1, 0, 0, 0, 0], value=0) logits_ctc = self.to_vocab_ctc(hidden_ctc) len_logits_ctc = (~encoder_output['padding_mask']).sum(-1).long() alphas = CIFFcModelV2.get_alphas(encoder_output) if self.training: decode_length = kwargs['target_lengths'] else: decode_length = torch.round(alphas.sum(-1)).int() decode_length = torch.max(decode_length, torch.ones_like(decode_length)) _alphas, num_output = self.resize(alphas, decode_length) cif_outputs = self.cif(hidden_encoded, _alphas) hidden_ac = self.proj(cif_outputs) logits_ac = self.to_vocab_ac(hidden_ac) # other inputs B, T = hidden_ac.size(0), hidden_ac.size(1) padding_mask = ~utils.sequence_mask(decode_length).bool() position_ac = torch.arange(T).repeat(B, 1).long().cuda() type_ac = torch.ones((B, T)).long().cuda() * 103 # [batch_size, num_heads, seq_length, seq_length] zeros = torch.zeros([B, 1, T, T]).cuda() ones = torch.ones([T, T]).cuda()[None, None, :, :] diag = torch.diag(torch.ones([T])).cuda()[None, None, :, :] tril = torch.tril(torch.ones([T, T])).cuda()[None, None, :, :] rm_padding_mask = (~padding_mask)[:, None, None, :] * \ (~padding_mask)[:, None, None, :].permute(0, 1, 3, 2) # mask_acQac = ones * rm_padding_mask mask_acQac = diag * rm_padding_mask mask_lmQac = diag * rm_padding_mask mask_lmQlm = tril * rm_padding_mask mask_ac = torch.cat([mask_acQac, zeros], dim=-1) mask_lm = torch.cat([mask_lmQac, mask_lmQlm], dim=-1) attention_mask = torch.cat([mask_ac, mask_lm], dim=-2) gold_rate = 0.0 if self.training: self.gpt2.eval() input_ids = kwargs['prev_output_tokens'] # input_ids = self.tokenizer.encode("The Manhattan bridge is a major") # input_ids = torch.tensor([[self.tokenizer.bos_token_id] + input_ids + [100] * 3]).cuda() text_embs = self.gpt2.transformer.wte(input_ids) ft = self.freeze_lm_finetune_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): input_embs = text_embs # input_embs = torch.cat([hidden_ac, text_embs], dim=1) # token_type = torch.zeros_like(input_ids) # type_ids = torch.cat([type_ac, token_type], dim=1) # position_ids = torch.cat([position_ac+self.args.position_bias, position_ac], dim=1) outputs = self.gpt2( inputs_embeds=input_embs, # token_type_ids=type_ids if not self.args.no_type_id else None, # position_ids=position_ids, # attention_mask=attention_mask, ) logits = outputs[0] # print(torch.argmax(logits, -1)) print(torch.argmax(logits, -1)[:, -T:]) import pdb pdb.set_trace() else: list_logits = [] token_type = torch.zeros_like(type_ac) decoded = torch.ones([B, 1], device=type_ac.device, dtype=type_ac.dtype) * self.tgt_dict.bos() text_embs = self.gpt2.transformer.wte(decoded) input_embs = text_embs # input_embs = torch.cat([hidden_ac, text_embs], dim=1) # type_ids = torch.cat([type_ac, token_type], dim=1) # position_ids = torch.cat([position_ac+self.args.position_bias, position_ac], dim=1) for i in range(T): outputs = self.gpt2( inputs_embeds=input_embs, # token_type_ids=type_ids[:, :T+i+1] if not self.args.no_type_id else None, # position_ids=position_ids[:, :T+i+1], # attention_mask=attention_mask[:, :, :T+i+1, :T+i+1] ) logits_lm = outputs[0][..., -1, :] # logits_i = self.args.lambda_am * logits_ac[..., i, :] + self.args.lambda_lm * logits_lm logits_i = logits_lm list_logits.append(logits_i.unsqueeze(1)) preds = torch.argmax(logits_i, -1)[:, None] cur_embs = self.gpt2.transformer.wte(preds) input_embs = torch.cat([input_embs, cur_embs], dim=1) logits = torch.cat(list_logits, 1) logits *= (~padding_mask).unsqueeze(-1).float() return { 'logits': logits, 'len_logits': decode_length, 'alphas': alphas, 'num_output': num_output, 'gold_rate': gold_rate, 'logits_ctc': logits_ctc, 'len_logits_ctc': len_logits_ctc } return logits