def forward(self, eouts, elens, ys, task='all',
                teacher_logits=None, recog_params={}, idx2token=None):
        """Forward pass.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): length `B`, each of which contains a list of size `[L]`
            task (str): all/ys*/ys_sub*
            teacher_logits (FloatTensor): `[B, L, vocab]`
            recog_params (dict): parameters for MBR training
            idx2token ():
        Returns:
            loss (FloatTensor): `[1]`
            observation (dict):

        """
        observation = {'loss': None, 'loss_att': None, 'loss_ctc': None, 'loss_mbr': None,
                       'acc_att': None, 'ppl_att': None}
        loss = eouts.new_zeros(1)

        # CTC loss
        trigger_points = None
        if self.ctc_weight > 0 and (task == 'all' or 'ctc' in task):
            forced_align = (self.ctc_trigger and self.training) or self.attn_type == 'triggered_attention'
            loss_ctc, trigger_points = self.ctc(eouts, elens, ys, forced_align=forced_align)
            observation['loss_ctc'] = tensor2scalar(loss_ctc)
            if self.mtl_per_batch:
                loss += loss_ctc
            else:
                loss += loss_ctc * self.ctc_weight

        # XE loss
        if self.att_weight > 0 and (task == 'all' or 'ctc' not in task):
            loss_att, acc_att, ppl_att, losses_auxiliary = self.forward_att(
                eouts, elens, ys, trigger_points=trigger_points)
            observation['loss_att'] = loss_att.item()
            observation['acc_att'] = acc_att
            observation['ppl_att'] = ppl_att
            if self.attn_type == 'mocha':
                if self._quantity_loss_weight > 0:
                    loss_att += losses_auxiliary['loss_quantity'] * self._quantity_loss_weight
                observation['loss_quantity'] = tensor2scalar(losses_auxiliary['loss_quantity'])
            if self.headdiv_loss_weight > 0:
                loss_att += losses_auxiliary['loss_headdiv'] * self.headdiv_loss_weight
                observation['loss_headdiv'] = tensor2scalar(losses_auxiliary['loss_headdiv'])
            if self.latency_metric:
                observation['loss_latency'] = tensor2scalar(losses_auxiliary['loss_latency']) if self.training else 0
                if self.latency_metric != 'decot' and self.latency_loss_weight > 0:
                    loss_att += losses_auxiliary['loss_latency'] * self.latency_loss_weight
            if self.mtl_per_batch:
                loss += loss_att
            else:
                loss += loss_att * self.att_weight

        observation['loss'] = tensor2scalar(loss)
        return loss, observation
Exemple #2
0
    def forward(self,
                eouts,
                elens,
                ys,
                task='all',
                teacher_logits=None,
                recog_params={},
                idx2token=None,
                trigger_points=None):
        """Forward pass.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (IntTensor): `[B]`
            ys (list): length `B`, each of which contains a list of size `[L]`
            task (str): all/ys*/ys_sub*
            teacher_logits (FloatTensor): `[B, L, vocab]`
            recog_params (dict): parameters for MBR training
            idx2token ():
            trigger_points (np.ndarray): `[B, L]`
        Returns:
            loss (FloatTensor): `[1]`
            observation (dict):

        """
        observation = {
            'loss': None,
            'loss_transducer': None,
            'loss_ctc': None,
            'loss_mbr': None
        }
        loss = eouts.new_zeros((1, ))

        # CTC loss
        if self.ctc_weight > 0 and (task == 'all' or 'ctc' in task):
            loss_ctc, _ = self.ctc(eouts, elens, ys)
            observation['loss_ctc'] = tensor2scalar(loss_ctc)
            if self.mtl_per_batch:
                loss += loss_ctc
            else:
                loss += loss_ctc * self.ctc_weight

        # RNN-T loss
        if self.rnnt_weight > 0 and (task == 'all' or 'ctc' not in task):
            loss_transducer = self.forward_transducer(eouts, elens, ys)
            observation['loss_transducer'] = tensor2scalar(loss_transducer)
            if self.mtl_per_batch:
                loss += loss_transducer
            else:
                loss += loss_transducer * self.rnnt_weight

        observation['loss'] = tensor2scalar(loss)
        return loss, observation