def forward(self, ys, ylens=None, labels=None, ps=None, plens=None): if ylens is None: attention_mask = None else: attention_mask = make_nopad_mask(ylens).float().to(ys.device) ys = ys[:, :max(ylens)] # DataParallel gloss, glogits = self.gmodel(ys, attention_mask=attention_mask, labels=labels) generated_ids = ys.clone() masked_indices = labels.long() != -100 original_ids = ys.clone() original_ids[masked_indices] = labels[masked_indices] sample_ids = sample_temp(glogits) # sampling generated_ids[masked_indices] = sample_ids[masked_indices] labels_replaced = (generated_ids.long() != original_ids.long()).long() dloss, dlogits = self.dmodel(generated_ids, attention_mask=attention_mask, labels=labels_replaced) loss = gloss + self.electra_disc_weight * dloss loss_dict = {} loss_dict["loss_gen"] = gloss loss_dict["loss_disc"] = dloss loss_dict["num_replaced"] = labels_replaced.sum().long() / ys.size(0) loss_dict["num_masked"] = masked_indices.sum().long() / ys.size(0) return loss, loss_dict
def forward( self, eouts, elens, eouts_inter=None, ys=None, ylens=None, ys_in=None, ys_out=None, # labels soft_labels=None, ps=None, plens=None, ): loss = 0 loss_dict = {} bs = eouts.size(0) ys_emb = self.dropout_emb(self.embed(ys_in)) dstate = None # context vector ctx = eouts.new_zeros(bs, 1, self.enc_hidden_size) attn_weight = None attn_mask = make_nopad_mask(elens).unsqueeze(2) logits = [] for i in range(ys_in.size(1)): y_emb = ys_emb[:, i:i + 1] # (bs, 1, embedding_size) logit, ctx, dstate, attn_weight = self.forward_one_step( y_emb, ctx, eouts, dstate, attn_weight, attn_mask) logits.append(logit) # (bs, 1, dec_intermediate_size) logits = self.output(torch.cat(logits, dim=1)) # (bs, ylen, vocab) if self.kd_weight > 0 and soft_labels is not None: # NOTE: ys_out (label) have length ylens+1 loss_att_kd, loss_kd, loss_att = self.loss_fn( logits, ys_out, soft_labels, ylens + 1) loss += loss_att_kd loss_dict["loss_kd"] = loss_kd loss_dict["loss_att"] = loss_att else: loss_att = self.loss_fn(logits, ys_out, ylens + 1) loss += loss_att loss_dict["loss_att"] = loss_att if self.mtl_ctc_weight > 0: # NOTE: KD is not applied to auxiliary CTC loss_ctc, _, _ = self.ctc(eouts=eouts, elens=elens, ys=ys, ylens=ylens, soft_labels=None) loss += self.mtl_ctc_weight * loss_ctc # auxiliary loss loss_dict["loss_ctc"] = loss_ctc loss_dict["loss_total"] = loss return loss, loss_dict, logits
def forward_disc(self, ys, ylens=None, error_labels=None): if ylens is None: attention_mask = None else: attention_mask = make_nopad_mask(ylens).float().to(ys.device) ys = ys[:, :max(ylens)] # DataParallel loss, _ = self.dmodel(ys, attention_mask=attention_mask, labels=error_labels) loss_dict = {"loss_total": loss} return loss, loss_dict
def predict(self, ys, ylens, states=None): """ predict next token for Shallow Fusion """ attention_mask = make_nopad_mask(ylens).float().to(ys.device) with torch.no_grad(): (logits,) = self.transformer(ys, attention_mask, causal=True) log_probs = torch.log_softmax(logits, dim=-1) log_probs_next = [] bs = len(ys) for b in range(bs): log_probs_next.append(tensor2np(log_probs[b, ylens[b] - 1])) return torch.tensor(log_probs_next).to(ys.device), states
def score(self, ys, ylens, batch_size=None): """ score token sequence for Rescoring """ attention_mask = make_nopad_mask(ylens).float().to(ys.device) logits, = self.dmodel(ys, attention_mask=attention_mask) probs = torch.sigmoid(logits) if ys.size(0) == 1: return [torch.sum(probs, dim=-1).item()] score_lms = [] bs = len(ys) for b in range(bs): score_lm = (-1) * torch.sum(probs[b, :ylens[b]], dim=-1).item() score_lms.append(score_lm) return score_lms
def forward(self, ys, ylens=None, labels=None, ps=None, plens=None): if ylens is None: attention_mask = None else: attention_mask = make_nopad_mask(ylens).float().to(ys.device) # DataParallel ys = ys[:, :max(ylens)] if labels is None: (logits, ) = self.bert(ys, attention_mask=attention_mask) return logits if ylens is not None: labels = labels[:, :max(ylens)] loss, logits = self.bert(ys, attention_mask=attention_mask, labels=labels) loss_dict = {"loss_total": loss} return loss, loss_dict
def score(self, ys, ylens, batch_size=None): """ score token sequence for Rescoring """ attention_mask = make_nopad_mask(ylens).float().to(ys.device) with torch.no_grad(): (logits,) = self.transformer(ys, attention_mask, causal=True) log_probs = torch.log_softmax(logits, dim=-1) score_lms = [] bs = len(ys) for b in range(bs): score_lm = 0 for i in range(0, ylens[b] - 1): v = ys[b, i + 1].item() # predict next score_lm += log_probs[b, i, v].item() score_lms.append(score_lm) return score_lms
def forward(self, ys, ylens=None, labels=None, ps=None, plens=None): if ylens is None: attention_mask = None else: attention_mask = make_nopad_mask(ylens).float().to(ys.device) # DataParallel ys = ys[:, : max(ylens)] if labels is None: # NOTE: causal attention mask (logits,) = self.transformer(ys, attention_mask=attention_mask, causal=True) return logits if ylens is not None: labels = labels[:, : max(ylens)] # NOTE: causal attention mask loss, logits = self.transformer( ys, attention_mask=attention_mask, causal=True, labels=labels ) loss_dict = {"loss_total": loss} return loss, loss_dict