def forward(self, eouts, elens, ys, task='all', teacher_logits=None, recog_params={}, idx2token=None): """Forward pass. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): length `B`, each of which contains a list of size `[L]` task (str): all/ys*/ys_sub* teacher_logits (FloatTensor): `[B, L, vocab]` recog_params (dict): parameters for MBR training idx2token (): Returns: loss (FloatTensor): `[1]` observation (dict): """ observation = {'loss': None, 'loss_att': None, 'loss_ctc': None, 'loss_mbr': None, 'acc_att': None, 'ppl_att': None} loss = eouts.new_zeros(1) # CTC loss trigger_points = None if self.ctc_weight > 0 and (task == 'all' or 'ctc' in task): forced_align = (self.ctc_trigger and self.training) or self.attn_type == 'triggered_attention' loss_ctc, trigger_points = self.ctc(eouts, elens, ys, forced_align=forced_align) observation['loss_ctc'] = tensor2scalar(loss_ctc) if self.mtl_per_batch: loss += loss_ctc else: loss += loss_ctc * self.ctc_weight # XE loss if self.att_weight > 0 and (task == 'all' or 'ctc' not in task): loss_att, acc_att, ppl_att, losses_auxiliary = self.forward_att( eouts, elens, ys, trigger_points=trigger_points) observation['loss_att'] = loss_att.item() observation['acc_att'] = acc_att observation['ppl_att'] = ppl_att if self.attn_type == 'mocha': if self._quantity_loss_weight > 0: loss_att += losses_auxiliary['loss_quantity'] * self._quantity_loss_weight observation['loss_quantity'] = tensor2scalar(losses_auxiliary['loss_quantity']) if self.headdiv_loss_weight > 0: loss_att += losses_auxiliary['loss_headdiv'] * self.headdiv_loss_weight observation['loss_headdiv'] = tensor2scalar(losses_auxiliary['loss_headdiv']) if self.latency_metric: observation['loss_latency'] = tensor2scalar(losses_auxiliary['loss_latency']) if self.training else 0 if self.latency_metric != 'decot' and self.latency_loss_weight > 0: loss_att += losses_auxiliary['loss_latency'] * self.latency_loss_weight if self.mtl_per_batch: loss += loss_att else: loss += loss_att * self.att_weight observation['loss'] = tensor2scalar(loss) return loss, observation
def forward(self, eouts, elens, ys, task='all', teacher_logits=None, recog_params={}, idx2token=None, trigger_points=None): """Forward pass. Args: eouts (FloatTensor): `[B, T, enc_n_units]` elens (IntTensor): `[B]` ys (list): length `B`, each of which contains a list of size `[L]` task (str): all/ys*/ys_sub* teacher_logits (FloatTensor): `[B, L, vocab]` recog_params (dict): parameters for MBR training idx2token (): trigger_points (np.ndarray): `[B, L]` Returns: loss (FloatTensor): `[1]` observation (dict): """ observation = { 'loss': None, 'loss_transducer': None, 'loss_ctc': None, 'loss_mbr': None } loss = eouts.new_zeros((1, )) # CTC loss if self.ctc_weight > 0 and (task == 'all' or 'ctc' in task): loss_ctc, _ = self.ctc(eouts, elens, ys) observation['loss_ctc'] = tensor2scalar(loss_ctc) if self.mtl_per_batch: loss += loss_ctc else: loss += loss_ctc * self.ctc_weight # RNN-T loss if self.rnnt_weight > 0 and (task == 'all' or 'ctc' not in task): loss_transducer = self.forward_transducer(eouts, elens, ys) observation['loss_transducer'] = tensor2scalar(loss_transducer) if self.mtl_per_batch: loss += loss_transducer else: loss += loss_transducer * self.rnnt_weight observation['loss'] = tensor2scalar(loss) return loss, observation