Пример #1
0
    def forward(self, eouts, elens, ys, forced_align=False):
        """Compute CTC loss.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (List): length `B`
            ys (List): length `B`, each of which contains a list of size `[L]`
        Returns:
            loss (FloatTensor): `[1]`
            trigger_points (IntTensor): `[B, L]`

        """
        # Concatenate all elements in ys for warpctc_pytorch
        ylens = np2tensor(np.fromiter([len(y) for y in ys], dtype=np.int32))
        ys_ctc = torch.cat([np2tensor(np.fromiter(y[::-1] if self.bwd else y, dtype=np.int32))
                            for y in ys], dim=0)
        # NOTE: do not copy to GPUs here

        # Compute CTC loss
        logits = self.output(eouts)
        loss = self.loss_fn(logits.transpose(1, 0), ys_ctc, elens, ylens)

        # Label smoothing for CTC
        if self.lsm_prob > 0:
            loss = loss * (1 - self.lsm_prob) + kldiv_lsm_ctc(logits, elens) * self.lsm_prob

        trigger_points = self.forced_align(logits, elens, ys, ylens) if forced_align else None

        if not self.training:
            self.data_dict['elens'] = tensor2np(elens)
            self.prob_dict['probs'] = tensor2np(torch.softmax(logits, dim=-1))

        return loss, trigger_points
Пример #2
0
 def ctc_probs_topk(self, eouts, temperature, topk):
     probs = F.softmax(self.ctc.output(eouts) / temperature, dim=-1)
     if topk is None:
         topk = probs.size(-1)
     _, topk_ids = torch.topk(probs.sum(1),
                              k=topk,
                              dim=-1,
                              largest=True,
                              sorted=True)
     return tensor2np(probs), tensor2np(topk_ids)
Пример #3
0
 def ctc_posteriors(self, eouts, x_lens, temperature, topk):
     # Path through the softmax layer
     logits_ctc = self.output_ctc(eouts)
     ctc_probs = F.softmax(logits_ctc / temperature, dim=-1)
     if topk is None:
         topk = ctc_probs.size(-1)
     _, indices_topk = torch.topk(ctc_probs.sum(1),
                                  k=topk,
                                  dim=-1,
                                  largest=True,
                                  sorted=True)
     return tensor2np(ctc_probs), tensor2np(indices_topk)
Пример #4
0
    def add_ctc_score(self,
                      hyp,
                      topk_ids,
                      ctc_state,
                      total_scores_topk,
                      ctc_prefix_scorer,
                      new_chunk=False,
                      backward=False):
        beam_width = self.beam_width_bwd if backward else self.beam_width
        if ctc_prefix_scorer is None:
            return None, topk_ids.new_zeros(beam_width), total_scores_topk

        ctc_scores, new_ctc_states = ctc_prefix_scorer(hyp,
                                                       tensor2np(topk_ids[0]),
                                                       ctc_state,
                                                       new_chunk=new_chunk)
        total_scores_ctc = torch.from_numpy(ctc_scores).to(self.device)
        total_scores_topk += total_scores_ctc * self.ctc_weight
        # Sort again
        total_scores_topk, joint_ids_topk = torch.topk(total_scores_topk,
                                                       k=beam_width,
                                                       dim=1,
                                                       largest=True,
                                                       sorted=True)
        topk_ids = topk_ids[:, joint_ids_topk[0]]
        new_ctc_states = new_ctc_states[joint_ids_topk[0].cpu().numpy()]
        return new_ctc_states, total_scores_ctc, total_scores_topk
Пример #5
0
    def decode(self, ys, state=None, mems=None, cache=None, incremental=False):
        """Decode function.

        Args:
            ys (LongTensor): `[B, L]`
            state (list): dummy interfance for RNNLM
            mems (list): length `n_layers`, each of which contains a FloatTensor `[B, mlen, d_model]`
            cache (list): length `L`, each of which contains a FloatTensor `[B, L-1, d_model]`
            incremental (bool): ASR decoding mode
        Returns:
            logits (FloatTensor): `[B, L, vocab]`
            out (FloatTensor): `[B, L, d_model]`
            new_cache (list): length `n_layers`, each of which contains a FloatTensor `[B, L, d_model]`

        """
        # for ASR decoding
        if cache is None:
            cache = [None] * self.n_layers  # 1-th to L-th layer

        if mems is None:
            mems = self.init_memory()

        # Create the self-attention mask
        bs, ylen = ys.size()[:2]
        if incremental and cache[0] is not None:
            ylen = cache[0].size(1) + 1
        causal_mask = ys.new_ones(ylen, ylen).byte()
        causal_mask = torch.tril(causal_mask, diagonal=0,
                                 out=causal_mask).unsqueeze(0)
        causal_mask = causal_mask.repeat([bs, 1, 1])

        out = self.pos_enc(self.embed(ys.long()))

        new_mems = [None] * self.n_layers
        new_cache = [None] * self.n_layers
        hidden_states = [out]
        for lth, (mem, layer) in enumerate(zip(mems, self.layers)):
            out = layer(out, causal_mask, cache=cache[lth], memory=mem)
            if incremental:
                new_cache[lth] = out
            elif lth < self.n_layers - 1:
                hidden_states.append(out)
                # NOTE: outputs from the last layer is not used for memory
            if not self.training and layer.yy_aws is not None:
                setattr(self, 'yy_aws_layer%d' % lth, tensor2np(layer.yy_aws))
        out = self.norm_out(out)
        if self.adaptive_softmax is None:
            logits = self.output(out)
        else:
            logits = out

        if incremental:
            # NOTE: do not update memory here during ASR decoding
            return logits, out, new_cache
        elif self.mem_len > 0:
            # Update memory
            new_mems = self.update_memory(mems, hidden_states)
            return logits, out, new_mems
        else:
            return logits, out, mems
Пример #6
0
    def ctc_forced_align(self, xs, ys, task='ys'):
        """CTC-based forced alignment.

        Args:
            xs (FloatTensor): `[B, T, idim]`
            ys (List): length `B`, each of which contains a list of size `[L]`
        Returns:
            trigger_points (np.ndarray): `[B, L]`

        """
        from neural_sp.models.seq2seq.decoders.ctc import CTCForcedAligner
        forced_aligner = CTCForcedAligner()

        self.eval()
        with torch.no_grad():
            eout_dict = self.encode(xs, 'ys')
            # NOTE: support the main task only
            ctc = getattr(self, 'dec_fwd').ctc
            logits = ctc.output(eout_dict[task]['xs'])
            ylens = np2tensor(np.fromiter([len(y) for y in ys],
                                          dtype=np.int32))
            trigger_points = forced_aligner(logits, eout_dict[task]['xlens'],
                                            ys, ylens)

        return tensor2np(trigger_points)
Пример #7
0
    def get_ctc_probs(self, xs, task='ys', temperature=1, topk=None):
        self.eval()
        with torch.no_grad():
            eout_dict = self.encode(xs, task)
            dir = 'fwd' if self.fwd_weight >= self.bwd_weight else 'bwd'
            if task == 'ys_sub1':
                dir += '_sub1'
            elif task == 'ys_sub2':
                dir += '_sub2'

            if task == 'ys':
                assert self.ctc_weight > 0
            elif task == 'ys_sub1':
                assert self.ctc_weight_sub1 > 0
            elif task == 'ys_sub2':
                assert self.ctc_weight_sub2 > 0
            ctc_probs, indices_topk = getattr(self, 'dec_' + dir).ctc_probs_topk(
                eout_dict[task]['xs'], temperature, topk)
            return tensor2np(ctc_probs), tensor2np(indices_topk), eout_dict[task]['xlens']
Пример #8
0
    def forward(self, xs, xlens, task):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (list): `[B]`
            task (str): not supported now
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (list): `[B]`

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks before RNN layers
            xs, xlens = self.conv(xs, xlens)

        # Create the self-attention mask
        bs, xmax = xs.size()[:2]
        xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(1).expand(
            bs, xmax, xmax)
        xx_mask = xx_mask.unsqueeze(1).expand(bs, self.attn_n_heads, xmax,
                                              xmax)

        xs = self.pos_enc(xs)
        for l in range(self.n_layers):
            xs, xx_aws = self.layers[l](xs, xx_mask)
            if not self.training:
                setattr(self, 'xx_aws_layer%d' % l, tensor2np(xx_aws))
        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        eouts['ys']['xs'] = xs
        eouts['ys']['xlens'] = xlens
        return eouts
Пример #9
0
    def decode(self, ys, state=None, mems=None, cache=None, incremental=False):
        """Decode function.

        Args:
            ys (LongTensor): `[B, L]`
            state (List): dummy interfance for RNNLM
            mems (List): length `n_layers` (inter-utterance),
                each of which contains a FloatTensor of size `[B, mlen, d_model]`
            cache (List): length `n_layers` (intra-utterance),
                each of which contains a FloatTensor of size `[B, L-1, d_model]`
            incremental (bool): ASR decoding mode
        Returns:
            logits (FloatTensor): `[B, L, vocab]`
            out (FloatTensor): `[B, L, d_model]`
            new_cache (List): length `n_layers`,
                each of which contains a FloatTensor of size `[B, L, d_model]`

        """
        # for ASR decoding
        if cache is None:
            cache = [None] * self.n_layers  # 1-th to L-th layer

        bs, ylen = ys.size()[:2]
        n_hist = 0
        if incremental and cache[0] is not None:
            n_hist = cache[0].size(1)
            ylen += n_hist

        # Create the self-attention mask
        causal_mask = ys.new_ones(ylen, ylen).byte()
        causal_mask = torch.tril(causal_mask).unsqueeze(0)
        causal_mask = causal_mask.repeat([bs, 1, 1])  # `[B, L, L]`

        out = self.pos_enc(self.embed_token_id(ys),
                           scale=True,
                           offset=max(0, n_hist))  # scaled + dropout

        new_cache = [None] * self.n_layers
        hidden_states = [out]
        for lth, layer in enumerate(self.layers):
            out = layer(out, causal_mask, cache=cache[lth])
            if incremental:
                new_cache[lth] = out
            elif lth < self.n_layers - 1:
                hidden_states.append(out)
                # NOTE: outputs from the last layer is not used for cache
            if not self.training and layer.yy_aws is not None:
                setattr(self, 'yy_aws_layer%d' % lth, tensor2np(layer.yy_aws))
        out = self.norm_out(out)
        if self.adaptive_softmax is None:
            logits = self.output(out)
        else:
            logits = out

        return logits, out, new_cache
Пример #10
0
    def get_ctc_probs(self, xs, task='ys', temperature=1, topk=None):
        """Get CTC top-K probabilities.

        Args:
            xs (FloatTensor): `[B, T, idim]`
            task (str): task to evaluate
            temperature (float): softmax temperature
            topk (int): top-K classes to sample
        Returns:
            probs (np.ndarray): `[B, T, vocab]`
            topk_ids (np.ndarray): `[B, T, topk]`
            elens (IntTensor): `[B]`

        """
        self.eval()
        with torch.no_grad():
            eout_dict = self.encode(xs, task)
            dir = 'fwd' if self.fwd_weight >= self.bwd_weight else 'bwd'
            if task == 'ys_sub1':
                dir += '_sub1'
            elif task == 'ys_sub2':
                dir += '_sub2'

            if task == 'ys':
                assert self.ctc_weight > 0
            elif task == 'ys_sub1':
                assert self.ctc_weight_sub1 > 0
            elif task == 'ys_sub2':
                assert self.ctc_weight_sub2 > 0

            probs = getattr(self,
                            'dec_' + dir).ctc.probs(eout_dict[task]['xs'])
            if topk is None:
                topk = probs.size(-1)  # return all classes
            _, topk_ids = torch.topk(probs,
                                     k=topk,
                                     dim=-1,
                                     largest=True,
                                     sorted=True)

            return tensor2np(probs), tensor2np(
                topk_ids), eout_dict[task]['xlens']
Пример #11
0
 def sub_module(self, xs, xx_mask, lth, pos_embs=None, module='sub1'):
     if self.task_specific_layer:
         xs_sub = getattr(self, 'layer_' + module)(xs, xx_mask, pos_embs=pos_embs)
     else:
         xs_sub = xs.clone()
     xs_sub = getattr(self, 'norm_out_' + module)(xs_sub)
     if getattr(self, 'bridge_' + module) is not None:
         xs_sub = getattr(self, 'bridge_' + module)(xs_sub)
     if not self.training:
         self.aws_dict['xx_aws_%s_layer%d' % (module, lth)] = tensor2np(getattr(self, 'layer_' + module).xx_aws)
     return xs_sub
Пример #12
0
    def ctc_forced_align(self, xs, ys, task='ys'):
        """CTC-based forced alignment.

        Args:
            xs (FloatTensor): `[B, T, idim]`
            ys (List): length `B`, each of which contains a list of size `[L]`
        Returns:
            trigger_points (np.ndarray): `[B, L]`

        """
        self.eval()
        with torch.no_grad():
            eout_dict = self.encode(xs, 'ys')
            # NOTE: support the main task only
            trigger_points = getattr(self, 'dec_fwd').ctc_forced_align(
                eout_dict[task]['xs'], eout_dict[task]['xlens'], ys)
        return tensor2np(trigger_points)
Пример #13
0
    def decode(self, ys, state=None, is_asr=False):
        """Decode function.

        Args:
            ys (FloatTensor): `[B, L]`
            state: previous tokens
            is_asr (bool):
        Returns:
            ys_emb (FloatTensor): `[B, L, n_units]`
            state: previous tokens

        """
        # Concatenate previous tokens
        if is_asr and state is not None:
            ys = torch.cat([state, ys], dim=1)
            # NOTE: this is used for ASR decoding

        ys_emb = self.embed(ys.long())

        # Create the self-attention mask
        bs, ymax = ys_emb.size()[:2]
        ylens = torch.IntTensor([ymax] * bs)
        yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).expand(
            bs, ymax, ymax)
        yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, ymax,
                                              ymax)
        subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(),
                                     diagonal=0)
        subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand(
            bs, self.attn_n_heads, ymax, ymax)
        yy_mask = yy_mask & subsequent_mask

        ys_emb = self.pos_enc(ys_emb)
        for l in range(self.n_layers):
            ys_emb, yy_aws, _ = self.layers[l](ys_emb, yy_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
        ys_emb = self.norm_out(ys_emb)

        if is_asr:
            state = ys

        return ys_emb, state
Пример #14
0
    def decode(self, ys, state=None, mems=None, cache=None, incremental=False):
        """Decode function.

        Args:
            ys (LongTensor): `[B, L]`
            state (list): dummy interfance for RNNLM
            mems (list): dummy interface for TransformerXL
            cache (list): length `L`, each of which contains a FloatTensor `[B, L-1, d_model]`
            incremental (bool): ASR decoding mode
        Returns:
            logits (FloatTensor): `[B, L, vocab]`
            out (FloatTensor): `[B, L, d_model]`
            new_cache (list): length `n_layers`, each of which contains a FloatTensor `[B, L, d_model]`
            new_mems: dummy interfance for TransformerXL

        """
        # for ASR decoding
        if cache is None:
            cache = [None] * self.n_layers

        # Create the self-attention mask
        bs, ylen = ys.size()[:2]
        if incremental and cache[0] is not None:
            ylen = cache[0].size(1) + 1
        causal_mask = ys.new_ones(ylen, ylen).byte()
        causal_mask = torch.tril(causal_mask, diagonal=0, out=causal_mask).unsqueeze(0)
        causal_mask = causal_mask.repeat([bs, 1, 1])

        new_cache = [None] * self.n_layers
        out = self.pos_enc(self.embed(ys.long()))
        for l, layer in enumerate(self.layers):
            out, yy_aws = layer(out, causal_mask, cache=cache[l])[:2]
            if incremental:
                new_cache[l] = out
            if not self.training and yy_aws is not None:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
        out = self.norm_out(out)
        if self.adaptive_softmax is None:
            logits = self.output(out)
        else:
            logits = out

        return logits, out, new_cache
Пример #15
0
    def decode_ctc(self, eouts, x_lens, beam_width=1, rnnlm=None):
        """Decoding by the CTC layer in the inference stage.

            This is only used for Joint CTC-Attention model.
        Args:
            eouts (FloatTensor): `[B, T, enc_units]`
            beam_width (int): the size of beam
            rnnlm ():
        Returns:
            best_hyps (list): A list of length `[B]`, which contains arrays of size `[L]`

        """
        logits_ctc = self.output_ctc(eouts)
        if beam_width == 1:
            best_hyps = self.decode_ctc_greedy(tensor2np(logits_ctc), x_lens)
        else:
            best_hyps = self.decode_ctc_beam(F.log_softmax(logits_ctc, dim=-1),
                                             x_lens, beam_width, rnnlm)
            # TODO(hirofumi): decoding paramters

        return best_hyps
Пример #16
0
    def decode(self, ys, ys_prev=None, cache=False):
        """Decode function.

        Args:
            ys (LongTensor): `[B, L]`
            ys_prev (LongTensor): previous tokens
            cahce (bool): concatenate previous tokens
        Returns:
            logits (FloatTensor): `[B, L, vocab]`
            ys_emb (FloatTensor): `[B, L, d_model]` (for ys_prev)
            ys_prev (LongTensor): previous tokens

        """
        # Concatenate previous tokens
        if cache and ys_prev is not None:
            ys = torch.cat([ys_prev, ys], dim=1)
            # NOTE: this is used for ASR decoding

        # Create the self-attention mask
        bs, ymax = ys.size()[:2]
        ylens = torch.IntTensor([ymax] * bs)
        tgt_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).repeat(
            [1, ymax, 1])
        subsequent_mask = tgt_mask.new_ones(ymax, ymax).byte()
        subsequent_mask = torch.tril(subsequent_mask,
                                     out=subsequent_mask).unsqueeze(0)
        tgt_mask = tgt_mask & subsequent_mask

        out = self.pos_enc(self.embed(ys.long()))
        for l in range(self.n_layers):
            out, yy_aws, _ = self.layers[l](out, tgt_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
        out = self.norm_out(out)
        if self.adaptive_softmax is None:
            logits = self.output(out)
        else:
            logits = out

        return logits, out, ys
Пример #17
0
    def decode(self, ys, state=None, cache=False):
        """Decode function.

        Args:
            ys (LongTensor): `[B, L]`
            state (LongTensor): `[B, L]`
            cahce (bool): concatenate previous tokens
        Returns:
            logits (FloatTensor): `[B, L, vocab]`
            out (FloatTensor): `[B, L, d_model]`
            new_state (LongTensor): previous tokens

        """
        # Concatenate previous tokens
        if cache and state is not None:
            ys = torch.cat([state, ys], dim=1)
            # NOTE: this is used for ASR decoding

        # Create the self-attention mask
        bs, ylen = ys.size()[:2]
        causal_mask = ys.new_ones(ylen, ylen).byte()
        causal_mask = torch.tril(causal_mask, diagonal=0, out=causal_mask).unsqueeze(0)
        causal_mask = causal_mask.repeat([bs, 1, 1])

        out = self.pos_enc(self.embed(ys.long()))
        for l, layer in enumerate(self.layers):
            out, yy_aws = layer(out, causal_mask)[:2]
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
        out = self.norm_out(out)
        if self.adaptive_softmax is None:
            logits = self.output(out)
        else:
            logits = out

        return logits, out, ys
Пример #18
0
    def beam_search(self, eouts, elens, params, idx2token=None,
                    lm=None, lm_second=None, lm_second_bwd=None, ctc_log_probs=None,
                    nbest=1, exclude_eos=False,
                    refs_id=None, utt_ids=None, speakers=None,
                    ensmbl_eouts=[], ensmbl_elens=[], ensmbl_decs=[], cache_states=True):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            params (dict): decoding hyperparameters
            idx2token (): converter from index to token
            lm (torch.nn.module): firsh-pass LM
            lm_second (torch.nn.module): second-pass LM
            lm_second_bwd (torch.nn.module): secoding-pass backward LM
            ctc_log_probs (FloatTensor):
            nbest (int): number of N-best list
            exclude_eos (bool): exclude <eos> from hypothesis
            refs_id (List): reference list
            utt_ids (List): utterance id list
            speakers (List): speaker list
            ensmbl_eouts (List[FloatTensor]): encoder outputs for ensemble models
            ensmbl_elens (List[IntTensor]) encoder outputs for ensemble models
            ensmbl_decs (List[torch.nn.Module): decoders for ensemble models
            cache_states (bool): cache decoder states for fast decoding
        Returns:
            nbest_hyps_idx (List): length `[B]`, each of which contains list of N hypotheses
            aws (List): length `[B]`, each of which contains arrays of size `[H, L, T]`
            scores (List):

        """
        bs, xmax, _ = eouts.size()
        n_models = len(ensmbl_decs) + 1

        beam_width = params.get('recog_beam_width')
        assert 1 <= nbest <= beam_width
        ctc_weight = params.get('recog_ctc_weight')
        max_len_ratio = params.get('recog_max_len_ratio')
        min_len_ratio = params.get('recog_min_len_ratio')
        lp_weight = params.get('recog_length_penalty')
        length_norm = params.get('recog_length_norm')
        cache_emb = params.get('recog_cache_embedding')
        lm_weight = params.get('recog_lm_weight')
        lm_weight_second = params.get('recog_lm_second_weight')
        lm_weight_second_bwd = params.get('recog_lm_bwd_weight')
        eos_threshold = params.get('recog_eos_threshold')
        lm_state_carry_over = params.get('recog_lm_state_carry_over')
        softmax_smoothing = params.get('recog_softmax_smoothing')
        eps_wait = params.get('recog_mma_delay_threshold')

        helper = BeamSearch(beam_width, self.eos, ctc_weight, lm_weight, self.device)
        lm = helper.verify_lm_eval_mode(lm, lm_weight, cache_emb)
        lm_second = helper.verify_lm_eval_mode(lm_second, lm_weight_second, cache_emb)
        lm_second_bwd = helper.verify_lm_eval_mode(lm_second_bwd, lm_weight_second_bwd, cache_emb)

        # cache token embeddings
        if cache_emb:
            self.cache_embedding(eouts.device)

        if ctc_log_probs is not None:
            assert ctc_weight > 0
            ctc_log_probs = tensor2np(ctc_log_probs)

        nbest_hyps_idx, aws, scores = [], [], []
        eos_flags = []
        for b in range(bs):
            # Initialization per utterance
            lmstate = None
            ys = eouts.new_zeros((1, 1), dtype=torch.int64).fill_(self.eos)
            # print(ys.shape)
            for layer in self.layers:
                layer.reset()

            # For joint CTC-Attention decoding
            ctc_prefix_scorer = None
            if ctc_log_probs is not None:
                if self.bwd:
                    ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b][::-1], self.blank, self.eos)
                else:
                    ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b], self.blank, self.eos)

            if speakers is not None:
                if speakers[b] == self.prev_spk:
                    if lm_state_carry_over and isinstance(lm, RNNLM):
                        lmstate = self.lmstate_final
                self.prev_spk = speakers[b]

            end_hyps = []
            hyps = [{'hyp': [self.eos],
                     'ys': ys,
                     'cache': None,
                     'score': 0.,
                     'score_att': 0.,
                     'score_ctc': 0.,
                     'score_lm': 0.,
                     'aws': [None],
                     'lmstate': lmstate,
                     'ensmbl_cache': [[None] * dec.n_layers for dec in ensmbl_decs] if n_models > 1 else None,
                     'ctc_state': ctc_prefix_scorer.initial_state() if ctc_prefix_scorer is not None else None,
                     'quantity_rate': 1.,
                     'streamable': True,
                     'streaming_failed_point': 1000}]
            streamable_global = True
            ymax = math.ceil(elens[b] * max_len_ratio)
            for i in range(ymax):
                # batchfy all hypotheses for batch decoding
                cache = [None] * self.n_layers
                if cache_states and i > 0:
                    for lth in range(self.n_layers): # 
                        cache[lth] = torch.cat([beam['cache'][lth] for beam in hyps], dim=0)
                ys = eouts.new_zeros((len(hyps), i + 1), dtype=torch.int64)
                for j, beam in enumerate(hyps):
                    ys[j, :] = beam['ys']
                if i > 0:
                    xy_aws_prev = torch.cat([beam['aws'][-1] for beam in hyps], dim=0)  # `[B, n_layers, H_ma, 1, klen]`
                else:
                    xy_aws_prev = None

                # Update LM states for shallow fusion
                y_lm = ys[:, -1:].clone()  # NOTE: this is important
                _, lmstate, scores_lm = helper.update_rnnlm_state_batch(lm, hyps, y_lm)

                # for the main model
                # print(i)
                causal_mask = eouts.new_ones(i + 1, i + 1, dtype=torch.uint8)
                causal_mask = torch.tril(causal_mask).unsqueeze(0).repeat([ys.size(0), 1, 1])
                # print(causal_mask.shape)
                out = self.pos_enc(self.embed_token_id(ys), scale=True)  # scaled + dropout
                # print(out.shape)
                # assert False, 'vv'
                n_heads_total = 0
                eouts_b = eouts[b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1]) # [Beam, T, dim]
                new_cache = [None] * self.n_layers
                xy_aws_layers = []
                xy_aws = None
                lth_s = self.mma_first_layer - 1
                # 自回归解码
                for lth, layer in enumerate(self.layers):
                    out = layer(
                        out, causal_mask, eouts_b, None,
                        cache=cache[lth],
                        xy_aws_prev=xy_aws_prev[:, lth - lth_s] if lth >= lth_s and i > 0 else None,
                        eps_wait=eps_wait)
                    xy_aws = layer.xy_aws

                    new_cache[lth] = out
                    if xy_aws is not None:
                        xy_aws_layers.append(xy_aws)
                logits = self.output(self.norm_out(out[:, -1])) # 取当前时刻概率输出
                probs = torch.softmax(logits * softmax_smoothing, dim=1)
                xy_aws_layers = torch.stack(xy_aws_layers, dim=1)  # `[B, H, n_layers, L, T]`

                # Ensemble initialization
                ensmbl_cache = [[None] * dec.n_layers for dec in ensmbl_decs]
                if n_models > 1 and cache_states and i > 0:
                    for i_e, dec in enumerate(ensmbl_decs):
                        for lth in range(dec.n_layers):
                            ensmbl_cache[i_e][lth] = torch.cat([beam['ensmbl_cache'][i_e][lth]
                                                                for beam in hyps], dim=0)

                # for the ensemble
                ensmbl_new_cache = [[None] * dec.n_layers for dec in ensmbl_decs]
                for i_e, dec in enumerate(ensmbl_decs):
                    out_e = dec.pos_enc(dec.embed(ys))  # scaled + dropout
                    eouts_e = ensmbl_eouts[i_e][b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1])
                    for lth in range(dec.n_layers):
                        out_e = dec.layers[lth](out_e, causal_mask, eouts_e, None,
                                                cache=ensmbl_cache[i_e][lth])
                        ensmbl_new_cache[i_e][lth] = out_e
                    logits_e = dec.output(dec.norm_out(out_e[:, -1]))
                    probs += torch.softmax(logits_e * softmax_smoothing, dim=1)
                    # NOTE: sum in the probability scale (not log-scale)

                # Ensemble 多个模型融合
                scores_att = torch.log(probs / n_models) # [1, vocab]
                # print(scores_att.shape)
                # assert False, 'vv'
                new_hyps = []
                for j, beam in enumerate(hyps): # hyps [,] # 每个beam生成beam
                    # Attention scores
                    total_scores_att = beam['score_att'] + scores_att[j:j + 1] # current time T # [[vocab]]
                    total_scores = total_scores_att * (1 - ctc_weight)

                    # Add LM score <before> top-K selection
                    if lm is not None:
                        total_scores_lm = beam['score_lm'] + scores_lm[j:j + 1, -1]
                        total_scores += total_scores_lm * lm_weight
                    else:
                        total_scores_lm = eouts.new_zeros(1, self.vocab)

                    # topk_ids 
                    total_scores_topk, topk_ids = torch.topk(
                        total_scores, k=beam_width, dim=1, largest=True, sorted=True)

                    # Add length penalty
                    if lp_weight > 0:
                        total_scores_topk += (len(beam['hyp'][1:]) + 1) * lp_weight

                    # Add CTC score
                    new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score(
                        beam['hyp'], topk_ids, beam['ctc_state'],
                        total_scores_topk, ctc_prefix_scorer)

                    new_aws = beam['aws'] + [xy_aws_layers[j:j + 1, :, :, -1:]]
                    aws_j = torch.cat(new_aws[1:], dim=3)  # `[1, H, n_layers, L, T]`

                    # forward direction
                    for k in range(beam_width):
                        idx = topk_ids[0, k].item() # k-beam 的索引
                        length_norm_factor = len(beam['hyp'][1:]) + 1 if length_norm else 1
                        total_score = total_scores_topk[0, k].item() / length_norm_factor # 当前长度

                        if idx == self.eos:
                            # Exclude short hypotheses
                            # remove 短句 中间的静默信号
                            if len(beam['hyp'][1:]) < elens[b] * min_len_ratio:
                                continue
                            # EOS threshold
                            # 找到不是EOS的最大得分idx
                            max_score_no_eos = scores_att[j, :idx].max(0)[0].item()
                            max_score_no_eos = max(max_score_no_eos, scores_att[j, idx + 1:].max(0)[0].item())
                            if scores_att[j, idx].item() <= eos_threshold * max_score_no_eos:
                                # 继续识别 跳过当前帧
                                continue

                        streaming_failed_point = beam['streaming_failed_point']
                        quantity_rate = 1.
                        # 流式相关的
                        if self.attn_type == 'mocha':
                            n_tokens_hyp_k = i + 1
                            n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item()
                            quantity_diff = n_tokens_hyp_k * n_heads_total - n_quantity_k

                            if quantity_diff != 0:
                                if idx == self.eos:
                                    n_tokens_hyp_k -= 1  # NOTE: do not count <eos> for streamability
                                    n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item()
                                else:
                                    streamable_global = False
                                if n_tokens_hyp_k * n_heads_total == 0:
                                    quantity_rate = 0
                                else:
                                    quantity_rate = n_quantity_k / (n_tokens_hyp_k * n_heads_total)

                            if beam['streamable'] and not streamable_global:
                                streaming_failed_point = i

                        new_hyps.append(
                            {'hyp': beam['hyp'] + [idx],
                             'ys': torch.cat([beam['ys'], eouts.new_zeros((1, 1), dtype=torch.int64).fill_(idx)], dim=-1),
                             'cache': [new_cache_l[j:j + 1] for new_cache_l in new_cache] if cache_states else cache,
                             'score': total_score,
                             'score_att': total_scores_att[0, idx].item(),
                             'score_ctc': total_scores_ctc[k].item(),
                             'score_lm': total_scores_lm[0, idx].item(),
                             'aws': new_aws,
                             'lmstate': {'hxs': lmstate['hxs'][:, j:j + 1],
                                         'cxs': lmstate['cxs'][:, j:j + 1]} if lmstate is not None else None,
                             'ctc_state': new_ctc_states[k] if ctc_prefix_scorer is not None else None,
                             'ensmbl_cache': [[new_cache_e_l[j:j + 1] for new_cache_e_l in new_cache_e]
                                              for new_cache_e in ensmbl_new_cache] if cache_states else None,
                             'streamable': streamable_global,
                             'streaming_failed_point': streaming_failed_point,
                             'quantity_rate': quantity_rate})

                # Local pruning 
                # new_hyps[beamsize,hyps]
                new_hyps_sorted = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width]

                # Remove complete hypotheses
                # 剪枝 结果beamwidth大小的列表
                new_hyps, end_hyps, is_finish = helper.remove_complete_hyp(
                    new_hyps_sorted, end_hyps, prune=True)
                hyps = new_hyps[:]
                if is_finish:
                    break

            # Global pruning # 一句识别结束
            if len(end_hyps) == 0:
                end_hyps = hyps[:]
            elif len(end_hyps) < nbest and nbest > 1:
                end_hyps.extend(hyps[:nbest - len(end_hyps)])

            # forward/backward second-pass LM rescoring
            end_hyps = helper.lm_rescoring(end_hyps, lm_second, lm_weight_second,
                                           length_norm=length_norm, tag='second')
            end_hyps = helper.lm_rescoring(end_hyps, lm_second_bwd, lm_weight_second_bwd,
                                           length_norm=length_norm, tag='second_bwd')

            # Sort by score
            end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True)

            # TODO: 
            for j in range(len(end_hyps[0]['aws'][1:])):
                tmp = end_hyps[0]['aws'][j + 1]
                end_hyps[0]['aws'][j + 1] = tmp.view(1, -1, tmp.size(-2), tmp.size(-1))

            # metrics for streaming infernece
            self.streamable = end_hyps[0]['streamable']
            self.quantity_rate = end_hyps[0]['quantity_rate']
            self.last_success_frame_ratio = None

            if idx2token is not None:
                if utt_ids is not None:
                    logger.info('Utt-id: %s' % utt_ids[b])
                assert self.vocab == idx2token.vocab
                logger.info('=' * 200)
                for k in range(len(end_hyps)):
                    if refs_id is not None:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(
                        end_hyps[k]['hyp'][1:][::-1] if self.bwd else end_hyps[k]['hyp'][1:]))
                    logger.info('num tokens (hyp): %d' % len(end_hyps[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % end_hyps[k]['score'])
                    logger.info('log prob (hyp, att): %.7f' %
                                (end_hyps[k]['score_att'] * (1 - ctc_weight)))
                    if ctc_prefix_scorer is not None:
                        logger.info('log prob (hyp, ctc): %.7f' %
                                    (end_hyps[k]['score_ctc'] * ctc_weight))
                    if lm is not None:
                        logger.info('log prob (hyp, first-pass lm): %.7f' %
                                    (end_hyps[k]['score_lm'] * lm_weight))
                    if lm_second is not None:
                        logger.info('log prob (hyp, second-pass lm): %.7f' %
                                    (end_hyps[k]['score_lm_second'] * lm_weight_second))
                    if lm_second_bwd is not None:
                        logger.info('log prob (hyp, second-pass lm, reverse): %.7f' %
                                    (end_hyps[k]['score_lm_second_bwd'] * lm_weight_second_bwd))
                    if self.attn_type == 'mocha':
                        logger.info('streamable: %s' % end_hyps[k]['streamable'])
                        logger.info('streaming failed point: %d' %
                                    (end_hyps[k]['streaming_failed_point'] + 1))
                        logger.info('quantity rate [%%]: %.2f' %
                                    (end_hyps[k]['quantity_rate'] * 100))
                    logger.info('-' * 50)

                if self.attn_type == 'mocha' and end_hyps[0]['streaming_failed_point'] < 1000:
                    assert not self.streamable
                    aws_last_success = end_hyps[0]['aws'][1:][end_hyps[0]['streaming_failed_point'] - 1]
                    rightmost_frame = max(0, aws_last_success[0, :, 0].nonzero()[:, -1].max().item()) + 1
                    frame_ratio = rightmost_frame * 100 / xmax
                    self.last_success_frame_ratio = frame_ratio
                    logger.info('streaming last success frame ratio: %.2f' % frame_ratio)

            # N-best list
            if self.bwd:
                # Reverse the order
                nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:][::-1]) for n in range(nbest)]]
                aws += [[tensor2np(torch.cat(end_hyps[n]['aws'][1:][::-1], dim=2).squeeze(0)) for n in range(nbest)]]
            else:
                nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)]]
                aws += [[tensor2np(torch.cat(end_hyps[n]['aws'][1:], dim=2).squeeze(0)) for n in range(nbest)]]
            scores += [[end_hyps[n]['score_att'] for n in range(nbest)]]

            # Check <eos>
            eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos) for n in range(nbest)])

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.bwd:
                nbest_hyps_idx = [[nbest_hyps_idx[b][n][1:] if eos_flags[b][n]
                                   else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)]
                aws = [[aws[b][n][:, 1:] if eos_flags[b][n] else aws[b][n] for n in range(nbest)] for b in range(bs)]
            else:
                nbest_hyps_idx = [[nbest_hyps_idx[b][n][:-1] if eos_flags[b][n]
                                   else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)]
                aws = [[aws[b][n][:, :-1] if eos_flags[b][n] else aws[b][n] for n in range(nbest)] for b in range(bs)]

        # Store ASR/LM state
        if bs == 1:
            self.lmstate_final = end_hyps[0]['lmstate']

        return nbest_hyps_idx, aws, scores
Пример #19
0
    def greedy(self, eouts, elens, max_len_ratio, idx2token,
               exclude_eos=False, refs_id=None, utt_ids=None, speakers=None,
               cache_states=True):
        """Greedy decoding.

        Args:
            eouts (FloatTensor): `[B, T, enc_units]`
            elens (IntTensor): `[B]`
            max_len_ratio (int): maximum sequence length of tokens
            idx2token (): converter from index to token
            exclude_eos (bool): exclude <eos> from hypothesis
            refs_id (List): reference list
            utt_ids (List): utterance id list
            speakers (List): speaker list
            cache_states (bool): cache decoder states for fast decoding
        Returns:
            hyps (List): length `[B]`, each of which contains arrays of size `[L]`
            aws (List): length `[B]`, each of which contains arrays of size `[H * n_layers, L, T]`

        """
        bs, xmax = eouts.size()[:2]
        ys = eouts.new_zeros((bs, 1), dtype=torch.int64).fill_(self.eos)
        # print(ys)
        for layer in self.layers:
            layer.reset()

        cache = [None] * self.n_layers

        hyps_batch = []
        ylens = torch.zeros(bs).int()
        eos_flags = [False] * bs
        xy_aws_layers_steps = []
        ymax = math.ceil(xmax * max_len_ratio)
        for i in range(ymax):  # 最长句子
            # 下三角mask 频闭未来的信息
            causal_mask = eouts.new_ones(i + 1, i + 1, dtype=torch.uint8)
            causal_mask = torch.tril(causal_mask).unsqueeze(0).repeat([bs, 1, 1])

            new_cache = [None] * self.n_layers
            xy_aws_layers = []
            out = self.pos_enc(self.embed_token_id(ys), scale=True)  # scaled + dropout
            for lth, layer in enumerate(self.layers): # decoder layer
                out = layer(out, causal_mask, eouts, None, cache=cache[lth])
                new_cache[lth] = out
                if layer.xy_aws is not None:
                    xy_aws_layers.append(layer.xy_aws[:, :, -1:])

            if cache_states:
                cache = new_cache[:]

            # Pick up 1-best
            y = self.output(self.norm_out(out))[:, -1:].argmax(-1) # 
            hyps_batch += [y]
            xy_aws_layers = torch.stack(xy_aws_layers, dim=2)  # `[B, H, n_layers, 1, T]`
            xy_aws_layers_steps.append(xy_aws_layers)

            # Count lengths of hypotheses
            for b in range(bs):
                if not eos_flags[b]:
                    if y[b].item() == self.eos:
                        eos_flags[b] = True
                    ylens[b] += 1  # include <eos>

            # Break if <eos> is outputed in all mini-batch
            if sum(eos_flags) == bs:
                break
            if i == ymax - 1:
                break

            ys = torch.cat([ys, y], dim=-1)

        # Concatenate in L dimension
        hyps_batch = tensor2np(torch.cat(hyps_batch, dim=1))
        xy_aws_layers_steps = torch.cat(xy_aws_layers_steps, dim=-2)  # `[B, H, n_layers, L, T]`
        xy_aws_layers_steps = xy_aws_layers_steps.reshape(bs, self.n_heads * self.n_layers, ys.size(1), xmax)
        xy_aws = tensor2np(xy_aws_layers_steps)

        # Truncate by the first <eos> (<sos> in case of the backward decoder)
        if self.bwd:
            # Reverse the order
            hyps = [hyps_batch[b, :ylens[b]][::-1] for b in range(bs)]
            aws = [xy_aws[b, :, :ylens[b], :][:, ::-1] for b in range(bs)]
        else:
            hyps = [hyps_batch[b, :ylens[b]] for b in range(bs)]
            aws = [xy_aws[b, :, :ylens[b], :] for b in range(bs)]

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.bwd:
                hyps = [hyps[b][1:] if eos_flags[b] else hyps[b] for b in range(bs)]
                aws = [aws[b][:, 1:] if eos_flags[b] else aws[b] for b in range(bs)]
            else:
                hyps = [hyps[b][:-1] if eos_flags[b] else hyps[b] for b in range(bs)]
                aws = [aws[b][:, :-1] if eos_flags[b] else aws[b] for b in range(bs)]

        if idx2token is not None: # idx -> token
            for b in range(bs):
                if utt_ids is not None:
                    logger.debug('Utt-id: %s' % utt_ids[b])
                if refs_id is not None and self.vocab == idx2token.vocab:
                    logger.debug('Ref: %s' % idx2token(refs_id[b]))
                if self.bwd:
                    logger.debug('Hyp: %s' % idx2token(hyps[b][::-1]))
                else:
                    logger.debug('Hyp: %s' % idx2token(hyps[b]))
                logger.info('=' * 200)
                # NOTE: do not show with logger.info here

        return hyps, aws
Пример #20
0
    def forward_att(self, eouts, elens, ys, trigger_points=None):
        """Compute XE loss for the Transformer decoder.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (List): length `[B]`, each of which contains a list of size `[L]`
            trigger_points (IntTensor): `[B, L]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity
            losses_auxiliary (dict):

        """
        losses_auxiliary = {}

        # Append <sos> and <eos>
        ys_in, ys_out, ylens = append_sos_eos(ys, self.eos, self.eos, self.pad, self.device, self.bwd)
        if not self.training:
            self.data_dict['elens'] = tensor2np(elens)
            self.data_dict['ylens'] = tensor2np(ylens)
            self.data_dict['ys'] = tensor2np(ys_out)

        # Create target self-attention mask
        bs, ymax = ys_in.size()[:2]
        tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1])
        causal_mask = tgt_mask.new_ones(ymax, ymax, dtype=tgt_mask.dtype)
        causal_mask = torch.tril(causal_mask).unsqueeze(0)
        tgt_mask = tgt_mask & causal_mask  # `[B, L (query), L (key)]`

        # Create source-target mask
        src_mask = make_pad_mask(elens.to(self.device)).unsqueeze(1).repeat([1, ymax, 1])  # `[B, L, T]`

        # Create attention padding mask for quantity loss
        if self.attn_type == 'mocha':
            attn_mask = (ys_out != self.pad).unsqueeze(1).unsqueeze(3)  # `[B, 1, L, 1]`
        else:
            attn_mask = None

        # external LM integration
        lmout = None
        if self.lm is not None:
            self.lm.eval()
            with torch.no_grad():
                lmout, lmstate, _ = self.lm.predict(ys_in, None)
            lmout = self.lm_output_proj(lmout)

        out = self.pos_enc(self.embed_token_id(ys_in), scale=True)  # scaled + dropout

        xy_aws_layers = []
        xy_aws = None
        for lth, layer in enumerate(self.layers):
            out = layer(out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout)
            # Attention padding
            xy_aws = layer.xy_aws
            if xy_aws is not None and self.attn_type == 'mocha':
                xy_aws_masked = xy_aws.masked_fill_(attn_mask.expand_as(xy_aws) == 0, 0)
                # NOTE: attention padding is quite effective for quantity loss
                xy_aws_layers.append(xy_aws_masked.clone())
            if not self.training:
                self.aws_dict['yy_aws_layer%d' % lth] = tensor2np(layer.yy_aws)
                self.aws_dict['xy_aws_layer%d' % lth] = tensor2np(layer.xy_aws)
                self.aws_dict['xy_aws_beta_layer%d' % lth] = tensor2np(layer.xy_aws_beta)
                self.aws_dict['xy_aws_p_choose%d' % lth] = tensor2np(layer.xy_aws_p_choose)
                self.aws_dict['yy_aws_lm_layer%d' % lth] = tensor2np(layer.yy_aws_lm)
        logits = self.output(self.norm_out(out))

        # Compute XE loss (+ label smoothing)
        loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training)

        # Quantity loss
        losses_auxiliary['loss_quantity'] = 0.
        if self.attn_type == 'mocha':
            # Average over all heads across all layers
            n_tokens_ref = tgt_mask[:, -1, :].sum(1).float()  # `[B]`
            # NOTE: count <eos> tokens
            n_tokens_pred = sum([torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1))
                                 for aws in xy_aws_layers])  # `[B]`
            n_tokens_pred /= len(xy_aws_layers)
            losses_auxiliary['loss_quantity'] = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref))

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out, self.pad)

        return loss, acc, ppl, losses_auxiliary
Пример #21
0
    def beam_search(self, eouts, elens, params, idx2token=None,
                    lm=None, lm_second=None, lm_bwd=None, ctc_log_probs=None,
                    nbest=1, exclude_eos=False,
                    refs_id=None, utt_ids=None, speakers=None,
                    ensmbl_eouts=None, ensmbl_elens=None, ensmbl_decs=[], cache_states=True):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            params (dict): hyperparameters for decoding
            idx2token (): converter from index to token
            lm: firsh path LM
            lm_second: second path LM
            lm_bwd: first/secoding path backward LM
            ctc_log_probs (FloatTensor):
            nbest (int):
            exclude_eos (bool): exclude <eos> from hypothesis
            refs_id (list): reference list
            utt_ids (list): utterance id list
            speakers (list): speaker list
            ensmbl_eouts (list): list of FloatTensor
            ensmbl_elens (list) list of list
            ensmbl_decs (list): list of torch.nn.Module
            cache_states (bool): cache decoder states for fast decoding
        Returns:
            nbest_hyps_idx (list): length `B`, each of which contains list of N hypotheses
            aws (list): length `B`, each of which contains arrays of size `[H, L, T]`
            scores (list):

        """
        bs, xmax, _ = eouts.size()
        n_models = len(ensmbl_decs) + 1

        beam_width = params['recog_beam_width']
        assert 1 <= nbest <= beam_width
        ctc_weight = params['recog_ctc_weight']
        max_len_ratio = params['recog_max_len_ratio']
        min_len_ratio = params['recog_min_len_ratio']
        lp_weight = params['recog_length_penalty']
        length_norm = params['recog_length_norm']
        lm_weight = params['recog_lm_weight']
        lm_weight_second = params['recog_lm_second_weight']
        lm_weight_bwd = params['recog_lm_bwd_weight']
        eos_threshold = params['recog_eos_threshold']
        lm_state_carry_over = params['recog_lm_state_carry_over']
        softmax_smoothing = params['recog_softmax_smoothing']
        eps_wait = params['recog_mma_delay_threshold']

        if lm is not None:
            assert lm_weight > 0
            lm.eval()
        if lm_second is not None:
            assert lm_weight_second > 0
            lm_second.eval()
        if lm_bwd is not None:
            assert lm_weight_bwd > 0
            lm_bwd.eval()

        if ctc_log_probs is not None:
            assert ctc_weight > 0
            ctc_log_probs = tensor2np(ctc_log_probs)

        nbest_hyps_idx, aws, scores = [], [], []
        eos_flags = []
        for b in range(bs):
            # Initialization per utterance
            lmstate = None
            ys = eouts.new_zeros(1, 1).fill_(self.eos).long()

            # For joint CTC-Attention decoding
            ctc_prefix_scorer = None
            if ctc_log_probs is not None:
                if self.bwd:
                    ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b][::-1], self.blank, self.eos)
                else:
                    ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b], self.blank, self.eos)

            if speakers is not None:
                if speakers[b] == self.prev_spk:
                    if lm_state_carry_over and isinstance(lm, RNNLM):
                        lmstate = self.lmstate_final
                self.prev_spk = speakers[b]

            helper = BeamSearch(beam_width, self.eos, ctc_weight, self.device_id)

            end_hyps = []
            ymax = int(math.floor(elens[b] * max_len_ratio)) + 1
            hyps = [{'hyp': [self.eos],
                     'ys': ys,
                     'cache': None,
                     'score': 0.,
                     'score_attn': 0.,
                     'score_ctc': 0.,
                     'score_lm': 0.,
                     'aws': [None],
                     'lmstate': lmstate,
                     'ensmbl_aws':[[None]] * (n_models - 1),
                     'ctc_state': ctc_prefix_scorer.initial_state() if ctc_prefix_scorer is not None else None,
                     'streamable': True,
                     'streaming_failed_point': 1000}]
            streamable_global = True
            for t in range(ymax):
                # batchfy all hypotheses for batch decoding
                cache = [None] * self.n_layers
                if cache_states and t > 0:
                    for lth in range(self.n_layers):
                        cache[lth] = torch.cat([beam['cache'][lth] for beam in hyps], dim=0)
                ys = eouts.new_zeros(len(hyps), t + 1).long()
                for j, beam in enumerate(hyps):
                    ys[j, :] = beam['ys']
                if t > 0:
                    xy_aws_prev = torch.cat([beam['aws'][-1] for beam in hyps], dim=0)  # `[B, n_layers, H_ma, 1, klen]`
                else:
                    xy_aws_prev = None

                # Update LM states for shallow fusion
                lmstate, scores_lm = None, None
                if lm is not None:
                    if hyps[0]['lmstate'] is not None:
                        lm_hxs = torch.cat([beam['lmstate']['hxs'] for beam in hyps], dim=1)
                        lm_cxs = torch.cat([beam['lmstate']['cxs'] for beam in hyps], dim=1)
                        lmstate = {'hxs': lm_hxs, 'cxs': lm_cxs}
                    y = ys[:, -1:].clone()  # NOTE: this is important
                    _, lmstate, scores_lm = lm.predict(y, lmstate)

                # for the main model
                causal_mask = eouts.new_ones(t + 1, t + 1).byte()
                causal_mask = torch.tril(causal_mask, out=causal_mask).unsqueeze(0).repeat([ys.size(0), 1, 1])

                out = self.pos_enc(self.embed(ys))  # scaled

                mlen = 0  # TODO: fix later
                if self.memory_transformer:
                    # NOTE: TransformerXL does not use positional encoding in the token embedding
                    mems = self.init_memory()
                    # adopt zero-centered offset
                    pos_idxs = torch.arange(mlen - 1, -(t + 1) - 1, -1.0, dtype=torch.float)
                    pos_embs = self.pos_emb(pos_idxs, self.device_id)
                    out = self.dropout_emb(out)
                    hidden_states = [out]

                n_heads_total = 0
                eouts_b = eouts[b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1])
                new_cache = [None] * self.n_layers
                xy_aws_all_layers = []
                lth_s = self.mocha_first_layer - 1
                for lth, layer in enumerate(self.layers):
                    if self.memory_transformer:
                        out = layer(
                            out, causal_mask, eouts_b, None,
                            cache=cache[lth],
                            pos_embs=pos_embs, memory=mems[lth], u=self.u, v=self.v)
                        hidden_states.append(out)
                    else:
                        out = layer(
                            out, causal_mask, eouts_b, None,
                            cache=cache[lth],
                            xy_aws_prev=xy_aws_prev[:, lth - lth_s] if lth >= lth_s and t > 0 else None,
                            eps_wait=eps_wait)

                    new_cache[lth] = out
                    if layer.xy_aws is not None:
                        xy_aws_all_layers.append(layer.xy_aws)
                logits = self.output(self.norm_out(out))
                probs = torch.softmax(logits[:, -1] * softmax_smoothing, dim=1)
                xy_aws_all_layers = torch.stack(xy_aws_all_layers, dim=1)  # `[B, H, n_layers, L, T]`

                # for the ensemble
                ensmbl_new_cache = []
                if n_models > 1:
                    # Ensemble initialization
                    # ensmbl_cache = []
                    # cache_e = [None] * self.n_layers
                    # if cache_states and t > 0:
                    #     for lth in range(self.n_layers):
                    #         cache_e[lth] = torch.cat([beam['ensmbl_cache'][lth] for beam in hyps], dim=0)
                    for i_e, dec in enumerate(ensmbl_decs):
                        out_e = dec.pos_enc(dec.embed(ys))  # scaled
                        eouts_e = ensmbl_eouts[i_e][b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1])
                        new_cache_e = [None] * dec.n_layers
                        for lth in range(dec.n_layers):
                            out_e, _, xy_aws_e, _, _ = dec.layers[lth](out_e, causal_mask, eouts_e, None,
                                                                       cache=cache[lth])
                            new_cache_e[lth] = out_e
                        ensmbl_new_cache.append(new_cache_e)
                        logits_e = dec.output(dec.norm_out(out_e))
                        probs += torch.softmax(logits_e[:, -1] * softmax_smoothing, dim=1)
                        # NOTE: sum in the probability scale (not log-scale)

                # Ensemble in log-scale
                scores_attn = torch.log(probs) / n_models

                new_hyps = []
                for j, beam in enumerate(hyps):
                    # Attention scores
                    total_scores_attn = beam['score_attn'] + scores_attn[j:j + 1]
                    total_scores = total_scores_attn * (1 - ctc_weight)

                    # Add LM score <before> top-K selection
                    if lm is not None:
                        total_scores_lm = beam['score_lm'] + scores_lm[j:j + 1, -1]
                        total_scores += total_scores_lm * lm_weight
                    else:
                        total_scores_lm = eouts.new_zeros(1, self.vocab)

                    total_scores_topk, topk_ids = torch.topk(
                        total_scores, k=beam_width, dim=1, largest=True, sorted=True)

                    # Add length penalty
                    if lp_weight > 0:
                        total_scores_topk += (len(beam['hyp'][1:]) + 1) * lp_weight

                    # Add CTC score
                    new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score(
                        beam['hyp'], topk_ids, beam['ctc_state'],
                        total_scores_topk, ctc_prefix_scorer)

                    new_aws = beam['aws'] + [xy_aws_all_layers[j:j + 1, :, :, -1:]]
                    aws_j = torch.cat(new_aws[1:], dim=3)  # `[1, H, n_layers, L, T]`
                    streaming_failed_point = beam['streaming_failed_point']

                    # forward direction
                    for k in range(beam_width):
                        idx = topk_ids[0, k].item()
                        length_norm_factor = len(beam['hyp'][1:]) + 1 if length_norm else 1
                        total_scores_topk /= length_norm_factor

                        if idx == self.eos:
                            # Exclude short hypotheses
                            if len(beam['hyp']) - 1 < elens[b] * min_len_ratio:
                                continue
                            # EOS threshold
                            max_score_no_eos = scores_attn[j, :idx].max(0)[0].item()
                            max_score_no_eos = max(max_score_no_eos, scores_attn[j, idx + 1:].max(0)[0].item())
                            if scores_attn[j, idx].item() <= eos_threshold * max_score_no_eos:
                                continue

                        quantity_rate = 1.
                        if 'mocha' in self.attn_type:
                            n_tokens_hyp_k = t + 1
                            n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item()
                            quantity_diff = n_tokens_hyp_k * n_heads_total - n_quantity_k

                            if quantity_diff != 0:
                                if idx == self.eos:
                                    n_tokens_hyp_k -= 1  # NOTE: do not count <eos> for streamability
                                    n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item()
                                else:
                                    streamable_global = False
                                if n_tokens_hyp_k * n_heads_total == 0:
                                    quantity_rate = 0
                                else:
                                    quantity_rate = n_quantity_k / (n_tokens_hyp_k * n_heads_total)

                            if beam['streamable'] and not streamable_global:
                                streaming_failed_point = t

                        new_hyps.append(
                            {'hyp': beam['hyp'] + [idx],
                             'ys': torch.cat([beam['ys'], eouts.new_zeros(1, 1).fill_(idx).long()], dim=-1),
                             'cache': [new_cache_l[j:j + 1] for new_cache_l in new_cache] if cache_states else cache,
                             'score': total_scores_topk[0, k].item(),
                             'score_attn': total_scores_attn[0, idx].item(),
                             'score_ctc': total_scores_ctc[k].item(),
                             'score_lm': total_scores_lm[0, idx].item(),
                             'aws': new_aws,
                             'lmstate': {'hxs': lmstate['hxs'][:, j:j + 1],
                                         'cxs': lmstate['cxs'][:, j:j + 1]} if lmstate is not None else None,
                             'ctc_state': new_ctc_states[k] if ctc_prefix_scorer is not None else None,
                             'ensmbl_cache': ensmbl_new_cache,
                             'streamable': streamable_global,
                             'streaming_failed_point': streaming_failed_point,
                             'quantity_rate': quantity_rate})

                # Local pruning
                new_hyps_sorted = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width]

                # Remove complete hypotheses
                new_hyps, end_hyps, is_finish = helper.remove_complete_hyp(
                    new_hyps_sorted, end_hyps, prune=True)
                hyps = new_hyps[:]
                if is_finish:
                    break

            # Global pruning
            if len(end_hyps) == 0:
                end_hyps = hyps[:]
            elif len(end_hyps) < nbest and nbest > 1:
                end_hyps.extend(hyps[:nbest - len(end_hyps)])

            # forward second path LM rescoring
            if lm_second is not None:
                self.lm_rescoring(end_hyps, lm_second, lm_weight_second, tag='second')

            # backward secodn path LM rescoring
            if lm_bwd is not None and lm_weight_bwd > 0:
                self.lm_rescoring(end_hyps, lm_bwd, lm_weight_bwd, tag='second_bwd')

            # Sort by score
            end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True)

            for j in range(len(end_hyps[0]['aws'][1:])):
                tmp = end_hyps[0]['aws'][j + 1]
                end_hyps[0]['aws'][j + 1] = tmp.view(1, -1, tmp.size(-2), tmp.size(-1))

            # metrics for streaming infernece
            self.streamable = end_hyps[0]['streamable']
            self.quantity_rate = end_hyps[0]['quantity_rate']
            self.last_success_frame_ratio = None

            if idx2token is not None:
                if utt_ids is not None:
                    logger.info('Utt-id: %s' % utt_ids[b])
                assert self.vocab == idx2token.vocab
                logger.info('=' * 200)
                for k in range(len(end_hyps)):
                    if refs_id is not None:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(
                        end_hyps[k]['hyp'][1:][::-1] if self.bwd else end_hyps[k]['hyp'][1:]))
                    logger.info('num tokens (hyp): %d' % len(end_hyps[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % end_hyps[k]['score'])
                    logger.info('log prob (hyp, att): %.7f' % (end_hyps[k]['score_attn'] * (1 - ctc_weight)))
                    if ctc_prefix_scorer is not None:
                        logger.info('log prob (hyp, ctc): %.7f' % (end_hyps[k]['score_ctc'] * ctc_weight))
                    if lm is not None:
                        logger.info('log prob (hyp, first-path lm): %.7f' % (end_hyps[k]['score_lm'] * lm_weight))
                    if lm_second is not None:
                        logger.info('log prob (hyp, second-path lm): %.7f' %
                                    (end_hyps[k]['score_lm_second'] * lm_weight_second))
                    if lm_bwd is not None:
                        logger.info('log prob (hyp, second-path lm-bwd): %.7f' %
                                    (end_hyps[k]['score_lm_second_bwd'] * lm_weight_bwd))
                    if 'mocha' in self.attn_type:
                        logger.info('streamable: %s' % end_hyps[k]['streamable'])
                        logger.info('streaming failed point: %d' % (end_hyps[k]['streaming_failed_point'] + 1))
                        logger.info('quantity rate [%%]: %.2f' % (end_hyps[k]['quantity_rate'] * 100))
                    logger.info('-' * 50)

                if 'mocha' in self.attn_type and end_hyps[0]['streaming_failed_point'] < 1000:
                    assert not self.streamable
                    aws_last_success = end_hyps[0]['aws'][1:][end_hyps[0]['streaming_failed_point'] - 1]
                    rightmost_frame = max(0, aws_last_success[0, :, 0].nonzero()[:, -1].max().item()) + 1
                    frame_ratio = rightmost_frame * 100 / xmax
                    self.last_success_frame_ratio = frame_ratio
                    logger.info('streaming last success frame ratio: %.2f' % frame_ratio)

            # N-best list
            if self.bwd:
                # Reverse the order
                nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:][::-1]) for n in range(nbest)]]
                aws += [tensor2np(torch.cat(end_hyps[0]['aws'][1:][::-1], dim=2).squeeze(0))]
            else:
                nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)]]
                aws += [tensor2np(torch.cat(end_hyps[0]['aws'][1:], dim=2).squeeze(0))]
            scores += [[end_hyps[n]['score_attn'] for n in range(nbest)]]

            # Check <eos>
            eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos) for n in range(nbest)])

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.bwd:
                nbest_hyps_idx = [[nbest_hyps_idx[b][n][1:] if eos_flags[b][n]
                                   else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)]
            else:
                nbest_hyps_idx = [[nbest_hyps_idx[b][n][:-1] if eos_flags[b][n]
                                   else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)]

        # Store ASR/LM state
        if len(end_hyps) > 0:
            self.lmstate_final = end_hyps[0]['lmstate']

        return nbest_hyps_idx, aws, scores
Пример #22
0
    def greedy(self, eouts, elens, max_len_ratio, idx2token,
               exclude_eos=False, refs_id=None, utt_ids=None, speakers=None,
               cache_states=True):
        """Greedy decoding.

        Args:
            eouts (FloatTensor): `[B, T, enc_units]`
            elens (IntTensor): `[B]`
            max_len_ratio (int): maximum sequence length of tokens
            idx2token (): converter from index to token
            exclude_eos (bool): exclude <eos> from hypothesis
            refs_id (list): reference list
            utt_ids (list): utterance id list
            speakers (list): speaker list
            cache_states (bool):
        Returns:
            hyps (list): length `B`, each of which contains arrays of size `[L]`
            aw (list): length `B`, each of which contains arrays of size `[L, T]`

        """
        bs, xtime = eouts.size()[:2]
        ys = eouts.new_zeros(bs, 1).fill_(self.eos).long()

        cache = [None] * self.n_layers

        hyps_batch = []
        ylens = torch.zeros(bs).int()
        eos_flags = [False] * bs
        ymax = int(math.floor(xtime * max_len_ratio)) + 1
        for t in range(ymax):
            causal_mask = eouts.new_ones(t + 1, t + 1).byte()
            causal_mask = torch.tril(causal_mask, out=causal_mask).unsqueeze(0)

            new_cache = [None] * self.n_layers
            out = self.pos_enc(self.embed(ys))  # scaled
            for lth, layer in enumerate(self.layers):
                out = layer(out, causal_mask, eouts, None, cache=cache[lth])
                new_cache[lth] = out

            if cache_states:
                cache = new_cache[:]

            # Pick up 1-best
            y = self.output(self.norm_out(out))[:, -1:].argmax(-1)
            hyps_batch += [y]

            # Count lengths of hypotheses
            for b in range(bs):
                if not eos_flags[b]:
                    if y[b].item() == self.eos:
                        eos_flags[b] = True
                    ylens[b] += 1  # include <eos>

            # Break if <eos> is outputed in all mini-batch
            if sum(eos_flags) == bs:
                break
            if t == ymax - 1:
                break

            ys = torch.cat([ys, y], dim=-1)

        # Concatenate in L dimension
        hyps_batch = tensor2np(torch.cat(hyps_batch, dim=1))
        xy_aws = tensor2np(layer.xy_aws.transpose(1, 2).transpose(2, 3))

        # Truncate by the first <eos> (<sos> in case of the backward decoder)
        if self.bwd:
            # Reverse the order
            hyps = [hyps_batch[b, :ylens[b]][::-1] for b in range(bs)]
            aws = [xy_aws[b, :, :ylens[b]][::-1] for b in range(bs)]
        else:
            hyps = [hyps_batch[b, :ylens[b]] for b in range(bs)]
            aws = [xy_aws[b, :, :ylens[b]] for b in range(bs)]

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.bwd:
                hyps = [hyps[b][1:] if eos_flags[b] else hyps[b] for b in range(bs)]
            else:
                hyps = [hyps[b][:-1] if eos_flags[b] else hyps[b] for b in range(bs)]

        for b in range(bs):
            if utt_ids is not None:
                logger.debug('Utt-id: %s' % utt_ids[b])
            if refs_id is not None and self.vocab == idx2token.vocab:
                logger.debug('Ref: %s' % idx2token(refs_id[b]))
            if self.bwd:
                logger.debug('Hyp: %s' % idx2token(hyps[b][::-1]))
            else:
                logger.debug('Hyp: %s' % idx2token(hyps[b]))

        return hyps, aws
Пример #23
0
    def forward_att(self, eouts, elens, ys,
                    return_logits=False, teacher_logits=None, trigger_points=None):
        """Compute XE loss for the Transformer decoder.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): length `B`, each of which contains a list of size `[L]`
            return_logits (bool): return logits for knowledge distillation
            teacher_logits (FloatTensor): `[B, L, vocab]`
            trigger_points (IntTensor): `[B, T]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity
            loss_quantity (FloatTensor): `[1]`
            loss_headdiv (FloatTensor): `[1]`
            loss_latency (FloatTensor): `[1]`

        """
        # Append <sos> and <eos>
        ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.eos, self.pad, self.bwd)
        if not self.training:
            self.data_dict['elens'] = tensor2np(elens)
            self.data_dict['ylens'] = tensor2np(ylens)
            self.data_dict['ys'] = tensor2np(ys_out)

        # Create target self-attention mask
        xmax = eouts.size(1)
        bs, ymax = ys_in.size()[:2]
        mlen = 0
        tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1])
        causal_mask = tgt_mask.new_ones(ymax, ymax).byte()
        causal_mask = torch.tril(causal_mask, diagonal=0 + mlen, out=causal_mask).unsqueeze(0)
        tgt_mask = tgt_mask & causal_mask  # `[B, L (query), L (key)]`

        # Create source-target mask
        src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat([1, ymax, 1])  # `[B, L, T]`

        # external LM integration
        lmout = None
        if self.lm is not None:
            self.lm.eval()
            with torch.no_grad():
                lmout, lmstate, _ = self.lm.predict(ys_in, None)
            lmout = self.lm_output_proj(lmout)

        out = self.pos_enc(self.embed(ys_in))  # scaled

        mems = self.init_memory()
        pos_embs = None
        if self.memory_transformer:
            out = self.dropout_emb(out)
            # NOTE: TransformerXL does not use positional encoding in the token embedding
            # adopt zero-centered offset
            pos_idxs = torch.arange(mlen - 1, -ymax - 1, -1.0, dtype=torch.float)
            pos_embs = self.pos_emb(pos_idxs, self.device_id)

        hidden_states = [out]
        xy_aws_layers = []
        for lth, (mem, layer) in enumerate(zip(mems, self.layers)):
            out = layer(out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout,
                        pos_embs=pos_embs, memory=mem, u=self.u, v=self.v)
            if lth < self.n_layers - 1:
                hidden_states.append(out)
                # NOTE: outputs from the last layer is not used for momory
            # Attention padding
            xy_aws = layer.xy_aws
            if xy_aws is not None and 'mocha' in self.attn_type:
                tgt_mask_v2 = (ys_out != self.pad).unsqueeze(1).unsqueeze(3)  # `[B, 1, L, 1]`
                xy_aws = xy_aws.masked_fill_(tgt_mask_v2.repeat([1, xy_aws.size(1), 1, xmax]) == 0, 0)
                # NOTE: attention padding is quite effective for quantity loss
                xy_aws_layers.append(xy_aws.clone())
            if not self.training:
                if layer.yy_aws is not None:
                    self.aws_dict['yy_aws_layer%d' % lth] = tensor2np(layer.yy_aws)
                if layer.xy_aws is not None:
                    self.aws_dict['xy_aws_layer%d' % lth] = tensor2np(layer.xy_aws)
                if layer.xy_aws_beta is not None:
                    self.aws_dict['xy_aws_beta_layer%d' % lth] = tensor2np(layer.xy_aws_beta)
                if layer.xy_aws_p_choose is not None:
                    self.aws_dict['xy_aws_p_choose%d' % lth] = tensor2np(layer.xy_aws_p_choose)
                if layer.yy_aws_lm is not None:
                    self.aws_dict['yy_aws_lm_layer%d' % lth] = tensor2np(layer.yy_aws_lm)
        logits = self.output(self.norm_out(out))

        # for knowledge distillation
        if return_logits:
            return logits

        # Compute XE loss (+ label smoothing)
        loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training)
        losses_auxiliary = {}

        # Quantity loss
        losses_auxiliary['loss_quantity'] = 0.
        if 'mocha' in self.attn_type:
            # Average over all heads across all layers
            n_tokens_ref = tgt_mask[:, -1, :].sum(1).float()  # `[B]`
            # NOTE: count <eos> tokens
            n_tokens_pred = sum([torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1))
                                 for aws in xy_aws_layers])  # `[B]`
            n_tokens_pred /= len(xy_aws_layers)
            losses_auxiliary['loss_quantity'] = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref))

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out, self.pad)

        return loss, acc, ppl, losses_auxiliary
Пример #24
0
    def greedy(self, eouts, elens, max_len_ratio, exclude_eos=False):
        """Greedy decoding in the inference stage.

        Args:
            eouts (FloatTensor): `[B, T, enc_units]`
            elens (list): A list of length `[B]`
            max_len_ratio (int): maximum sequence length of tokens
            exclude_eos (bool):
        Returns:
            best_hyps (list): A list of length `[B]`, which contains arrays of size `[L]`
            aw (list): A list of length `[B]`, which contains arrays of size `[L, T]`

        """
        bs, max_xlen, d_model = eouts.size()

        # Start from <sos> (<eos> in case of the backward decoder)
        ys = eouts.new_zeros(bs, 1).fill_(self.eos).long()

        yy_mask = None

        best_hyps_tmp = []
        ylens = np.zeros((bs, ), dtype=np.int32)
        yy_aws_tmp = [None] * bs
        xy_aws_tmp = [None] * bs
        eos_flags = [False] * bs
        for t in range(int(np.floor(max_xlen * max_len_ratio)) + 1):
            # Make source-target attention mask
            yx_mask = eouts.new_ones(bs, t + 1, max_xlen)
            for b in range(bs):
                if elens[b] < max_xlen:
                    yx_mask[b, :, elens[b]:] = 0

            # Add positional embedding
            out = self.embed(ys) * (self.d_model**0.5)
            if self.pe_type:
                out = self.pos_emb_out(out)

            for l in range(self.n_layers):
                out, yy_aw, xy_aw = self.layers[l](eouts, out, yx_mask,
                                                   yy_mask)
                # xy_aw: `[B, head, T, L]`
            out = self.layer_norm_top(out)
            logits_t = self.output(out)

            # Pick up 1-best
            y = logits_t.detach().argmax(-1)[:, -1:]
            best_hyps_tmp += [y]

            # Count lengths of hypotheses
            for b in range(bs):
                if not eos_flags[b]:
                    if y[b].item() == self.eos:
                        eos_flags[b] = True
                        yy_aws_tmp[b] = yy_aw[b:b + 1]  # TODO: fix this
                        xy_aws_tmp[b] = xy_aw[b:b + 1]
                    ylens[b] += 1
                    # NOTE: include <eos>

            # Break if <eos> is outputed in all mini-bs
            if sum(eos_flags) == bs:
                break

            ys = torch.cat([ys, y], dim=-1)

        # Concatenate in L dimension
        best_hyps_tmp = torch.cat(best_hyps_tmp, dim=1)
        # xy_aws_tmp = torch.stack(xy_aws_tmp, dim=0)

        # Convert to numpy
        best_hyps_tmp = tensor2np(best_hyps_tmp)
        # xy_aws_tmp = tensor2np(xy_aws_tmp)

        # if self.score.attn_nheads > 1:
        #     xy_aws_tmp = xy_aws_tmp[:, :, :, 0]
        #     # TODO(hirofumi): fix for MHA

        # Truncate by the first <eos> (<sos> in case of the backward decoder)
        if self.backward:
            # Reverse the order
            best_hyps = [best_hyps_tmp[b, :ylens[b]][::-1] for b in range(bs)]
            # aws = [xy_aws_tmp[b, :ylens[b]][::-1] for b in range(bs)]
        else:
            best_hyps = [best_hyps_tmp[b, :ylens[b]] for b in range(bs)]
            # aws = [xy_aws_tmp[b, :ylens[b]] for b in range(bs)]

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.backward:
                best_hyps = [
                    best_hyps[b][1:] if eos_flags[b] else best_hyps[b]
                    for b in range(bs)
                ]
            else:
                best_hyps = [
                    best_hyps[b][:-1] if eos_flags[b] else best_hyps[b]
                    for b in range(bs)
                ]

        # return best_hyps, aws
        return best_hyps, None
Пример #25
0
    def forward(self,
                xs,
                xlens,
                task,
                streaming=False,
                lookback=False,
                lookahead=False):
        """Forward pass.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (InteTensor): `[B]` (on CPU)
            task (str): ys/ys_sub1/ys_sub2
            streaming (bool): streaming encoding
            lookback (bool): truncate leftmost frames for lookback in CNN context
            lookahead (bool): truncate rightmost frames for lookahead in CNN context
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (InteTensor): `[B]` (on CPU)

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        bs, xmax = xs.size()[:2]
        n_chunks = 0
        unidir = self.unidir
        lc_bidir = self.lc_bidir
        N_l, N_c, N_r = self.chunk_size_left, self.chunk_size_current, self.chunk_size_right

        if streaming and self.streaming_type == 'mask':
            assert xmax <= N_c
        elif streaming and self.streaming_type == 'reshape':
            assert xmax <= (N_l + N_c + N_r)

        if lc_bidir:
            if self.streaming_type == 'mask' and not streaming:
                xs = chunkwise(xs, 0, N_c, 0,
                               padding=True)  # `[B * n_chunks, N_c, idim]`
                # NOTE: CNN consumes inputs in the current chunk to avoid extra lookahead latency
                # That is, CNN outputs are independent on chunk boundary
            elif self.streaming_type == 'reshape':
                xs = chunkwise(xs, N_l, N_c, N_r, padding=not streaming
                               )  # `[B * n_chunks, N_l+N_c+N_r, idim]`
            n_chunks = xs.size(0) // bs
            assert bs * n_chunks == xs.size(0)
            if streaming:
                assert n_chunks == 1, xs.size()

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks
            xs, xlens = self.conv(xs,
                                  xlens,
                                  lookback=False if lc_bidir else lookback,
                                  lookahead=False if lc_bidir else lookahead)
            # NOTE: CNN lookahead surpassing a chunk is not allowed in chunkwise processing
            N_l = max(0, N_l // self.conv.subsampling_factor)
            N_c = N_c // self.conv.subsampling_factor
            N_r = N_r // self.conv.subsampling_factor

        if lc_bidir:
            # Do nothing in the streaming mode
            if self.streaming_type == 'mask' and not streaming:
                # back to the original shape (during training only)
                xs = xs.contiguous().view(
                    bs, -1,
                    xs.size(2))[:, :xlens.max()]  # `[B, emax, d_model]`
        elif streaming:
            xs = xs[:, :xlens.max()]  # for unidirectional

        if self.enc_type == 'conv':
            eouts['ys']['xs'] = xs
            eouts['ys']['xlens'] = xlens
            return eouts

        if not streaming:
            self.reset_cache()
        n_hist = self.cache[0]['input_san'].size(
            1) if streaming and self.cache[0] is not None else 0

        # positional encoding
        if self.pe_type in ['relative', 'relative_xl']:
            xs = xs * self.scale  # NOTE: first layer only
            rel_pos_embs = self.pos_emb(xs, mlen=n_hist)
        else:
            xs = self.pos_enc(xs, scale=True, offset=max(0, n_hist))
            rel_pos_embs = None

        new_cache = [None] * self.n_layers
        if lc_bidir:
            # chunkwise streaming encoder
            if self.streaming_type == 'reshape':
                xx_mask = None  # NOTE: no mask to avoid masking all frames in a chunk
            elif self.streaming_type == 'mask':
                if streaming:
                    n_chunks = math.ceil((xlens.max().item() + n_hist) / N_c)
                xx_mask = make_chunkwise_san_mask(xs, xlens + n_hist, N_l, N_c,
                                                  n_chunks)

            for lth, layer in enumerate(self.layers):
                xs, cache = layer(xs,
                                  xx_mask,
                                  cache=self.cache[lth],
                                  pos_embs=rel_pos_embs,
                                  u_bias=self.u_bias,
                                  v_bias=self.v_bias)
                if self.streaming_type == 'mask':
                    new_cache[lth] = cache
                if not self.training and not streaming:
                    if self.streaming_type == 'reshape':
                        n_heads = layer.xx_aws.size(1)
                        xx_aws = layer.xx_aws[:, :, N_l:N_l + N_c,
                                              N_l:N_l + N_c]
                        xx_aws = xx_aws.view(bs, n_chunks, n_heads, N_c, N_c)
                        emax = xlens.max().item()
                        xx_aws_center = xx_aws.new_zeros(
                            bs, n_heads, emax, emax)
                        for chunk_idx in range(n_chunks):
                            offset = chunk_idx * N_c
                            emax_chunk = xx_aws_center[:, :, offset:offset +
                                                       N_c].size(2)
                            xx_aws_chunk = xx_aws[:, chunk_idx, :, :
                                                  emax_chunk, :emax_chunk]
                            xx_aws_center[:, :, offset:offset + N_c,
                                          offset:offset + N_c] = xx_aws_chunk
                        self.aws_dict['xx_aws_layer%d' %
                                      lth] = tensor2np(xx_aws_center)
                    elif self.streaming_type == 'mask':
                        self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                            layer.xx_aws)
                    self.data_dict['elens%d' % lth] = tensor2np(xlens)

                if self.subsample is not None:
                    xs, xlens = self.subsample[lth](xs, xlens)
                    N_l = max(0, N_l // self.subsample[lth].factor)
                    N_c = N_c // self.subsample[lth].factor
                    N_r = N_r // self.subsample[lth].factor
                    if self.pe_type in ['relative', 'relative_xl']:
                        rel_pos_embs = self.pos_emb(xs)
                    if self.streaming_type == 'mask':
                        xx_mask = make_chunkwise_san_mask(
                            xs, xlens, N_l, N_c, n_chunks)

            # Extract the center region
            if self.streaming_type == 'reshape':
                xs = xs[:, N_l:N_l + N_c]  # `[B * n_chunks, N_c, d_model]`
                xs = xs.contiguous().view(bs, -1, xs.size(2))
                xs = xs[:, :xlens.max()]

        else:
            xx_mask = make_san_mask(xs, xlens + n_hist, unidir,
                                    self.lookaheads[0])
            for lth, layer in enumerate(self.layers):
                xs, cache = layer(xs,
                                  xx_mask,
                                  cache=self.cache[lth],
                                  pos_embs=rel_pos_embs,
                                  u_bias=self.u_bias,
                                  v_bias=self.v_bias)
                new_cache[lth] = cache
                if not self.training and not streaming:
                    self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                        layer.xx_aws)
                    self.data_dict['elens%d' % lth] = tensor2np(xlens)

                # Pick up outputs in the sub task before the projection layer
                if lth == self.n_layers_sub1 - 1:
                    xs_sub1 = self.sub_module(xs, xx_mask, lth, rel_pos_embs,
                                              'sub1')
                    xlens_sub1 = xlens.clone()
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens_sub1
                        return eouts
                if lth == self.n_layers_sub2 - 1:
                    xs_sub2 = self.sub_module(xs, xx_mask, lth, rel_pos_embs,
                                              'sub2')
                    xlens_sub2 = xlens.clone()
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens_sub2
                        return eouts

                if lth < len(self.layers) - 1:
                    if self.subsample is not None and self.subsample[
                            lth].factor > 1:
                        xs, xlens = self.subsample[lth](xs, xlens)
                        n_hist = self.cache[lth + 1]['input_san'].size(
                            1) if streaming and self.cache[
                                lth + 1] is not None else 0
                        if self.pe_type in ['relative', 'relative_xl']:
                            rel_pos_embs = self.pos_emb(xs, mlen=n_hist)
                        xx_mask = make_san_mask(xs, xlens + n_hist, unidir,
                                                self.lookaheads[lth + 1])
                    elif self.lookaheads[lth] != self.lookaheads[lth + 1]:
                        xx_mask = make_san_mask(xs, xlens + n_hist, unidir,
                                                self.lookaheads[lth + 1])

        xs = self.norm_out(xs)

        if streaming:
            self.cache = new_cache

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1'][
                'xlens'] = xs_sub1, xlens_sub1
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2'][
                'xlens'] = xs_sub2, xlens_sub2
        return eouts
Пример #26
0
    def beam_search(self,
                    eouts,
                    elens,
                    params,
                    idx2token,
                    lm=None,
                    lm_second=None,
                    lm_second_bwd=None,
                    ctc_log_probs=None,
                    nbest=1,
                    exclude_eos=False,
                    refs_id=None,
                    utt_ids=None,
                    speakers=None,
                    ensmbl_eouts=None,
                    ensmbl_elens=None,
                    ensmbl_decs=[]):
        """Beam search decoding.

        Args:
            eouts (FloatTensor): `[B, T, enc_n_units]`
            elens (IntTensor): `[B]`
            params (dict):
                recog_beam_width (int): size of beam
                recog_max_len_ratio (int): maximum sequence length of tokens
                recog_min_len_ratio (float): minimum sequence length of tokens
                recog_length_penalty (float): length penalty
                recog_coverage_penalty (float): coverage penalty
                recog_coverage_threshold (float): threshold for coverage penalty
                recog_lm_weight (float): weight of LM score
            idx2token (): converter from index to token
            lm: firsh path LM
            lm_second: second path LM
            lm_second_bwd: secoding path backward LM
            ctc_log_probs (FloatTensor):
            nbest (int):
            exclude_eos (bool): exclude <eos> from hypothesis
            refs_id (list): reference list
            utt_ids (list): utterance id list
            speakers (list): speaker list
            ensmbl_eouts (list): list of FloatTensor
            ensmbl_elens (list) list of list
            ensmbl_decs (list): list of torch.nn.Module
        Returns:
            nbest_hyps_idx (list): A list of length `[B]`, which contains list of N hypotheses
            aws: dummy
            scores: dummy

        """
        bs = eouts.size(0)

        beam_width = params['recog_beam_width']
        ctc_weight = params['recog_ctc_weight']
        lm_weight = params['recog_lm_weight']
        lm_weight_second = params['recog_lm_second_weight']
        lm_weight_second_bwd = params['recog_lm_bwd_weight']
        asr_state_carry_over = params['recog_asr_state_carry_over']
        lm_state_carry_over = params['recog_lm_state_carry_over']

        if lm is not None:
            assert lm_weight > 0
            lm.eval()
        if lm_second is not None:
            assert lm_weight_second > 0
            lm_second.eval()
        if lm_second_bwd is not None:
            assert lm_weight_second_bwd > 0
            lm_second_bwd.eval()

        if ctc_log_probs is not None:
            assert ctc_weight > 0
            ctc_log_probs = tensor2np(ctc_log_probs)

        nbest_hyps_idx = []
        eos_flags = []
        for b in range(bs):
            # Initialization per utterance
            y = eouts.new_zeros(bs, 1).fill_(self.eos).long()
            y_emb = self.dropout_emb(self.embed(y))
            dout, dstate = self.recurrency(y_emb, None)
            lmstate = None

            # For joint CTC-Attention decoding
            ctc_prefix_scorer = None
            if ctc_log_probs is not None:
                ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b],
                                                   self.blank, self.eos)

            if speakers is not None:
                if speakers[b] == self.prev_spk:
                    if lm_state_carry_over and isinstance(lm, RNNLM):
                        lmstate = self.lmstate_final
                self.prev_spk = speakers[b]

            helper = BeamSearch(beam_width, self.eos, ctc_weight,
                                self.device_id)

            end_hyps = []
            hyps = [{
                'hyp': [self.eos],
                'ref_id': [self.eos],
                'score':
                0.,
                'score_rnnt':
                0.,
                'score_lm':
                0.,
                'score_ctc':
                0.,
                'dout':
                dout,
                'dstate':
                dstate,
                'lmstate':
                lmstate,
                'ctc_state':
                ctc_prefix_scorer.initial_state()
                if ctc_prefix_scorer is not None else None
            }]
            for t in range(elens[b]):
                # preprocess for batch decoding
                douts = torch.cat([beam['dout'] for beam in hyps], dim=0)
                outs = self.joint(
                    eouts[b:b + 1, t:t + 1].repeat([douts.size(0), 1, 1]),
                    douts)
                scores_rnnt = torch.log_softmax(outs.squeeze(2).squeeze(1),
                                                dim=-1)

                # Update LM states for shallow fusion
                y = eouts.new_zeros(len(hyps), 1).long()
                for j, beam in enumerate(hyps):
                    y[j, 0] = beam['hyp'][-1]
                lmstate, scores_lm = None, None
                if lm is not None:
                    if hyps[0]['lmstate'] is not None:
                        lm_hxs = torch.cat(
                            [beam['lmstate']['hxs'] for beam in hyps], dim=1)
                        lm_cxs = torch.cat(
                            [beam['lmstate']['cxs'] for beam in hyps], dim=1)
                        lmstate = {'hxs': lm_hxs, 'cxs': lm_cxs}
                    lmout, lmstate, scores_lm = lm.predict(y, lmstate)

                new_hyps = []
                for j, beam in enumerate(hyps):
                    dout = douts[j:j + 1]
                    dstate = beam['dstate']
                    lmstate = beam['lmstate']

                    # Attention scores
                    total_scores_rnnt = beam['score_rnnt'] + scores_rnnt[j:j +
                                                                         1]
                    total_scores = total_scores_rnnt * (1 - ctc_weight)

                    # Add LM score <after> top-K selection
                    total_scores_topk, topk_ids = torch.topk(total_scores,
                                                             k=beam_width,
                                                             dim=-1,
                                                             largest=True,
                                                             sorted=True)
                    if lm is not None:
                        total_scores_lm = beam['score_lm'] + scores_lm[
                            j, -1, topk_ids[0]]
                        total_scores_topk += total_scores_lm * lm_weight
                    else:
                        total_scores_lm = eouts.new_zeros(beam_width)

                    # Add CTC score
                    new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score(
                        beam['hyp'], topk_ids, beam['ctc_state'],
                        total_scores_topk, ctc_prefix_scorer)

                    for k in range(beam_width):
                        idx = topk_ids[0, k].item()

                        if idx == self.blank:
                            beam['score'] = total_scores_topk[0, k].item()
                            beam['score_rnnt'] = total_scores_topk[0, k].item()
                            new_hyps.append(beam.copy())
                            continue

                        # skip blank-dominant frames
                        # if total_scores_topk[0, self.blank].item() > 0.7:
                        #     continue

                        # Update prediction network only when predicting non-blank labels
                        hyp_id = beam['hyp'] + [idx]
                        hyp_str = ' '.join(list(map(str, hyp_id)))
                        # if hyp_str in self.state_cache.keys():
                        #     # from cache
                        #     dout = self.state_cache[hyp_str]['dout']
                        #     new_dstate = self.state_cache[hyp_str]['dstate']
                        #     lmstate = self.state_cache[hyp_str]['lmstate']
                        # else:
                        y = eouts.new_zeros(1, 1).fill_(idx).long()
                        y_emb = self.dropout_emb(self.embed(y))
                        dout, new_dstate = self.recurrency(y_emb, dstate)

                        # store in cache
                        self.state_cache[hyp_str] = {
                            'dout': dout,
                            'dstate': new_dstate,
                            'lmstate': {
                                'hxs': lmstate['hxs'][:, j:j + 1],
                                'cxs': lmstate['cxs'][:, j:j + 1]
                            } if lmstate is not None else None,
                        }

                        new_hyps.append({
                            'hyp':
                            hyp_id,
                            'score':
                            total_scores_topk[0, k].item(),
                            'score_rnnt':
                            total_scores_rnnt[0, idx].item(),
                            'score_ctc':
                            total_scores_ctc[k].item(),
                            'score_lm':
                            total_scores_lm[k].item(),
                            'dout':
                            dout,
                            'dstate':
                            new_dstate,
                            'lmstate': {
                                'hxs': lmstate['hxs'][:, j:j + 1],
                                'cxs': lmstate['cxs'][:, j:j + 1]
                            } if lmstate is not None else None,
                            'ctc_state':
                            new_ctc_states[k]
                            if ctc_prefix_scorer is not None else None
                        })

                # Merge hypotheses having the same token sequences
                new_hyps_merged = {}
                for beam in new_hyps:
                    hyp_str = ' '.join(list(map(str, beam['hyp'])))
                    if hyp_str not in new_hyps_merged.keys():
                        new_hyps_merged[hyp_str] = beam
                    elif hyp_str in new_hyps_merged.keys():
                        if beam['score'] > new_hyps_merged[hyp_str]['score']:
                            new_hyps_merged[hyp_str] = beam
                new_hyps = [v for v in new_hyps_merged.values()]

                # Local pruning
                new_hyps_tmp = sorted(new_hyps,
                                      key=lambda x: x['score'],
                                      reverse=True)[:beam_width]

                # Remove complete hypotheses
                new_hyps = []
                for hyp in new_hyps_tmp:
                    new_hyps += [hyp]
                if len(end_hyps) >= beam_width:
                    end_hyps = end_hyps[:beam_width]
                    break
                hyps = new_hyps[:]

            # Global pruning
            if len(end_hyps) == 0:
                end_hyps = hyps[:]
            elif len(end_hyps) < nbest and nbest > 1:
                end_hyps.extend(hyps[:nbest - len(end_hyps)])

            # forward second path LM rescoring
            if lm_second is not None:
                self.lm_rescoring(end_hyps,
                                  lm_second,
                                  lm_weight_second,
                                  tag='second')

            # backward secodn path LM rescoring
            if lm_second_bwd is not None:
                self.lm_rescoring(end_hyps,
                                  lm_second_bwd,
                                  lm_weight_second_bwd,
                                  tag='second_rev')

            end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True)

            # Reset state cache
            self.state_cache = OrderedDict()

            if utt_ids is not None:
                logger.info('Utt-id: %s' % utt_ids[b])
            if idx2token is not None:
                logger.info('=' * 200)
                for k in range(len(end_hyps)):
                    if refs_id is not None and self.vocab == idx2token.vocab:
                        logger.info('Ref: %s' % idx2token(refs_id[b]))
                    logger.info('Hyp: %s' % idx2token(end_hyps[k]['hyp'][1:]))
                    logger.info('log prob (hyp): %.7f' % end_hyps[k]['score'])
                    if ctc_log_probs is not None:
                        logger.info('log prob (hyp, ctc): %.7f' %
                                    (end_hyps[k]['score_ctc']))
                    if lm is not None:
                        logger.info('log prob (hyp, lm): %.7f' %
                                    (end_hyps[k]['score_lm']))
                    logger.info('-' * 50)

            # N-best list
            nbest_hyps_idx += [[
                np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)
            ]]

            # Check <eos>
            eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos)
                              for n in range(nbest)])

        return nbest_hyps_idx, None, None
Пример #27
0
    def greedy(self,
               eouts,
               elens,
               max_len_ratio,
               exclude_eos=False,
               idx2token=None,
               refs_id=None,
               speakers=None,
               oracle=False):
        """Greedy decoding in the inference stage (used only for evaluation during training).

        Args:
            eouts (FloatTensor): `[B, T, enc_units]`
            elens (IntTensor): `[B]`
            max_len_ratio (int): maximum sequence length of tokens
            exclude_eos (bool):
            idx2token ():
            refs_id (list):
            speakers (list):
            oracle (bool):
        Returns:
            best_hyps (list): A list of length `[B]`, which contains arrays of size `[L]`
            aw (list): A list of length `[B]`, which contains arrays of size `[L, T]`

        """
        bs, xmax = eouts.size()[:2]

        # Start from <sos> (<eos> in case of the backward decoder)
        ys_all = eouts.new_zeros(bs, 1).fill_(self.eos).long()

        # TODO(hirofumi): Create the source-target mask for batch decoding

        best_hyps_batch = []
        ylens = torch.zeros(bs).int()
        yy_aws_tmp = [None] * bs
        xy_aws_tmp = [None] * bs
        eos_flags = [False] * bs
        for t in range(int(np.floor(xmax * max_len_ratio)) + 1):
            # Create the self-attention mask
            yy_mask = make_pad_mask(ylens + 1,
                                    self.device_id).unsqueeze(1).expand(
                                        bs, t + 1, t + 1)
            yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, t + 1,
                                                  t + 1)
            subsequent_mask = torch.tril(yy_mask.new_ones(
                (t + 1, t + 1)).byte(),
                                         diagonal=0)
            subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand(
                bs, self.attn_n_heads, t + 1, t + 1)
            yy_mask = yy_mask & subsequent_mask

            # Create the source-target mask
            xmax = eouts.size(1)
            x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand(
                bs, t + 1, xmax)
            y_mask = make_pad_mask(ylens + 1,
                                   self.device_id).unsqueeze(2).expand(
                                       bs, t + 1, xmax)
            xy_mask = (x_mask * y_mask).unsqueeze(1).expand(
                bs, self.attn_n_heads, t + 1, xmax)

            out = self.pos_enc(self.embed(ys_all))
            for l in range(self.n_layers):
                out, yy_aws, xy_aws = self.layers[l](out, yy_mask, eouts,
                                                     xy_mask)
            out = self.norm_out(out)

            # Pick up 1-best
            y = self.output(out).argmax(-1)[:, -1:]
            best_hyps_batch += [y]

            # Count lengths of hypotheses
            for b in range(bs):
                if not eos_flags[b]:
                    if y[b].item() == self.eos:
                        eos_flags[b] = True
                        yy_aws_tmp[b] = yy_aws[b:b + 1]  # TODO: fix this
                        xy_aws_tmp[b] = xy_aws[b:b + 1]
                    ylens[b] += 1
                    # NOTE: include <eos>

            # Break if <eos> is outputed in all mini-bs
            if sum(eos_flags) == bs:
                break

            ys_all = torch.cat([ys_all, y], dim=-1)

        # Concatenate in L dimension
        best_hyps_batch = torch.cat(best_hyps_batch, dim=1)
        # xy_aws_tmp = torch.stack(xy_aws_tmp, dim=0)

        # Convert to numpy
        best_hyps_batch = tensor2np(best_hyps_batch)
        # xy_aws_tmp = tensor2np(xy_aws_tmp)

        # if self.score.attn_n_heads > 1:
        #     xy_aws_tmp = xy_aws_tmp[:, :, :, 0]
        #     # TODO(hirofumi): fix for MHA

        # Truncate by the first <eos> (<sos> in case of the backward decoder)
        if self.bwd:
            # Reverse the order
            best_hyps = [
                best_hyps_batch[b, :ylens[b]][::-1] for b in range(bs)
            ]
            # aws = [xy_aws_tmp[b, :ylens[b]][::-1] for b in range(bs)]
        else:
            best_hyps = [best_hyps_batch[b, :ylens[b]] for b in range(bs)]
            # aws = [xy_aws_tmp[b, :ylens[b]] for b in range(bs)]

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.bwd:
                best_hyps = [
                    best_hyps[b][1:] if eos_flags[b] else best_hyps[b]
                    for b in range(bs)
                ]
            else:
                best_hyps = [
                    best_hyps[b][:-1] if eos_flags[b] else best_hyps[b]
                    for b in range(bs)
                ]

        # return best_hyps, aws
        return best_hyps, None
Пример #28
0
    def forward_att(self, eouts, elens, ys, return_logits=False):
        """Compute XE loss for the sequence-to-sequence model.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
            return_logits (bool): return logits for knowledge distillation
        Returns:
            loss (FloatTensor): `[1]`
            acc (float):
            ppl (float):

        """
        bs = eouts.size(0)

        # Append <sos> and <eos>
        eos = eouts.new_zeros(1).fill_(self.eos).long()
        ys = [
            np2tensor(np.fromiter(y[::-1] if self.bwd else y, dtype=np.int64),
                      self.device_id) for y in ys
        ]
        ylens = np2tensor(
            np.fromiter([y.size(0) + 1 for y in ys],
                        dtype=np.int32))  # +1 for <eos>
        ys_in_pad = pad_list([torch.cat([eos, y], dim=0) for y in ys],
                             self.pad)
        ys_out_pad = pad_list([torch.cat([y, eos], dim=0) for y in ys],
                              self.pad)

        # Create the self-attention mask
        bs, ymax = ys_in_pad.size()[:2]
        yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).expand(
            bs, ymax, ymax)
        yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, ymax,
                                              ymax)
        subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(),
                                     diagonal=0)
        subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand(
            bs, self.attn_n_heads, ymax, ymax)
        yy_mask = yy_mask & subsequent_mask

        # Create the source-target mask
        xmax = eouts.size(1)
        x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand(
            bs, ymax, xmax)
        y_mask = make_pad_mask(ylens, self.device_id).unsqueeze(2).expand(
            bs, ymax, xmax)
        xy_mask = (x_mask * y_mask).unsqueeze(1).expand(
            bs, self.attn_n_heads, ymax, xmax)

        ys_emb = self.pos_enc(self.embed(ys_in_pad))
        for l in range(self.n_layers):
            ys_emb, yy_aws, xy_aws = self.layers[l](ys_emb, yy_mask, eouts,
                                                    xy_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
                setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws))
        logits = self.norm_out(ys_emb)
        if self.adaptive_softmax is None:
            logits = self.output(logits)
        if return_logits:
            return logits

        # Compute XE sequence loss
        if self.adaptive_softmax is None:
            if self.lsm_prob > 0 and self.training:
                # Label smoothing
                loss = cross_entropy_lsm(logits.view((-1, logits.size(2))),
                                         ys_out_pad.view(-1), self.lsm_prob,
                                         self.pad)
            else:
                loss = F.cross_entropy(logits.view((-1, logits.size(2))),
                                       ys_out_pad.view(-1),
                                       ignore_index=self.pad,
                                       size_average=True)

            # Focal loss
            if self.focal_loss_weight > 0:
                fl = focal_loss(logits,
                                ys_out_pad,
                                ylens,
                                alpha=self.focal_loss_weight,
                                gamma=self.focal_loss_gamma)
                loss = loss * (
                    1 - self.focal_loss_weight) + fl * self.focal_loss_weight
        else:
            loss = self.adaptive_softmax(logits.view((-1, logits.size(2))),
                                         ys_out_pad.view(-1)).loss

        # Compute token-level accuracy in teacher-forcing
        if self.adaptive_softmax is None:
            acc = compute_accuracy(logits, ys_out_pad, self.pad)
        else:
            acc = compute_accuracy(
                self.adaptive_softmax.log_prob(
                    logits.view((-1, logits.size(2)))), ys_out_pad, self.pad)
        ppl = min(np.exp(loss.item()), np.inf)

        # scale loss for CTC
        loss *= ylens.float().mean()

        return loss, acc, ppl
Пример #29
0
    def beam_search(self,
                    eouts,
                    elens,
                    params,
                    rnnlm,
                    nbest=1,
                    exclude_eos=False,
                    id2token=None,
                    refs=None):
        """Beam search decoding in the inference stage.

        Args:
            eouts (FloatTensor): `[B, T, dec_units]`
            elens (list): A list of length `[B]`
            params (dict):
                beam_width (int): the size of beam
                max_len_ratio (int): the maximum sequence length of tokens
                min_len_ratio (float): the minimum sequence length of tokens
                length_penalty (float): length penalty
                coverage_penalty (float): coverage penalty
                coverage_threshold (float): threshold for coverage penalty
                rnnlm_weight (float): the weight of RNNLM score
            rnnlm (torch.nn.Module):
            nbest (int):
            exclude_eos (bool):
            id2token (): converter from index to token
            refs ():
        Returns:
            nbest_hyps (list): A list of length `[B]`, which contains list of n hypotheses
            aws (list): A list of length `[B]`, which contains arrays of size `[L, T]`
            scores (list):

        """
        bs, _, enc_nunits = eouts.size()
        device_id = eouts.get_device()

        # For cold fusion
        if params['rnnlm_weight'] > 0 and not self.cold_fusion:
            assert self.rnnlm_cf
            self.rnnlm_cf.eval()

        # For shallow fusion
        if rnnlm is not None:
            rnnlm.eval()

        if self.backward:
            sos, eos = self.eos, self.sos
        else:
            sos, eos = self.sos, self.eos

        nbest_hyps, aws, scores = [], [], []
        eos_flags = []
        for b in range(bs):
            # Initialization per utterance
            dout, (hx_list,
                   cx_list) = self.init_dec_state(1, self.nlayers, device_id,
                                                  eouts[b:b + 1],
                                                  elens[b:b + 1])
            _dout, _dstate = self.init_dec_state(1, 1, device_id,
                                                 eouts[b:b + 1],
                                                 elens[b:b + 1])
            context = eouts.new_zeros(1, 1, enc_nunits)
            self.score.reset()

            complete = []
            beam = [{
                'hyp': [sos],
                'score': 0,
                'scores': [0],
                'score_raw': 0,
                'dout': dout,
                'hx_list': hx_list,
                'cx_list': cx_list,
                'context': context,
                'aws': [None],
                'rnnlm_hx_list': None,
                'rnnlm_cx_list': None,
                'prev_cov': 0,
                '_dout': _dout,
                '_dstate': _dstate
            }]
            for t in range(
                    int(math.floor(elens[b] * params['max_len_ratio'])) + 1):
                new_beam = []
                for i_beam in range(len(beam)):
                    # Recurrency
                    y = eouts.new_zeros(1, 1).fill_(
                        beam[i_beam]['hyp'][-1]).long()
                    y_emb = self.embed(y)
                    dout, (hx_list, cx_list), _dout, _dstate = self.recurrency(
                        y_emb, beam[i_beam]['context'],
                        (beam[i_beam]['hx_list'], beam[i_beam]['cx_list']),
                        beam[i_beam]['_dstate'])

                    # Score
                    context, aw = self.score(eouts[b:b + 1, :elens[b]],
                                             elens[b:b + 1], dout,
                                             beam[i_beam]['aws'][-1])

                    if self.rnnlm_cf:
                        # Update RNNLM states for cold fusion
                        y_lm = eouts.new_zeros(1, 1).fill_(
                            beam[i_beam]['hyp'][-1]).long()
                        y_lm_emb = self.rnnlm_cf.embed(y_lm).squeeze(1)
                        logits_lm_t, lm_out, rnnlm_state = self.rnnlm_cf.predict(
                            y_lm_emb, (beam[i_beam]['rnnlm_hx_list'],
                                       beam[i_beam]['rnnlm_cx_list']))
                    elif rnnlm is not None:
                        # Update RNNLM states for shallow fusion
                        y_lm = eouts.new_zeros(1, 1).fill_(
                            beam[i_beam]['hyp'][-1]).long()
                        y_lm_emb = rnnlm.embed(y_lm).squeeze(1)
                        logits_lm_t, lm_out, rnnlm_state = rnnlm.predict(
                            y_lm_emb, (beam[i_beam]['rnnlm_hx_list'],
                                       beam[i_beam]['rnnlm_cx_list']))
                    else:
                        logits_lm_t, lm_out, rnnlm_state = None, None, None

                    # Generate
                    attentional_t = self.generate(context, dout, logits_lm_t,
                                                  lm_out)
                    if self.rnnlm_init and self.internal_lm:
                        # Residual connection
                        attentional_t += _dout
                    logits_t = self.output(attentional_t)

                    # Path through the softmax layer & convert to log-scale
                    log_probs = F.log_softmax(logits_t.squeeze(1),
                                              dim=1)  # log-prob-level
                    # log_probs = logits_t.squeeze(1)  # logits-level
                    # NOTE: `[1 (B), 1, vocab]` -> `[1 (B), vocab]`

                    # Pick up the top-k scores
                    log_probs_topk, indices_topk = torch.topk(
                        log_probs,
                        k=params['beam_width'],
                        dim=1,
                        largest=True,
                        sorted=True)

                    for k in range(params['beam_width']):
                        # Exclude short hypotheses
                        if indices_topk[0, k].item() == eos and len(
                                beam[i_beam]
                            ['hyp']) < elens[b] * params['min_len_ratio']:
                            continue

                        # Add length penalty
                        score_raw = beam[i_beam]['score_raw'] + log_probs_topk[
                            0, k].item()
                        score = score_raw + params['length_penalty']

                        # Add coverage penalty
                        if params['coverage_penalty'] > 0:
                            # Recompute converage penalty in each step
                            score -= beam[i_beam]['prev_cov'] * params[
                                'coverage_penalty']
                            aw_stack = torch.stack(beam[i_beam]['aws'][1:] +
                                                   [aw],
                                                   dim=-1)
                            cov_sum = aw_stack.detach().cpu().numpy()
                            if params['coverage_threshold'] == 0:
                                cov_sum = np.sum(cov_sum) / self.score.nheads
                            else:
                                cov_sum = np.sum(cov_sum[np.where(
                                    cov_sum > params['coverage_threshold'])[0]]
                                                 ) / self.score.nheads
                            score += cov_sum * params['coverage_penalty']
                        else:
                            cov_sum = 0

                        # Add RNNLM score
                        if params['rnnlm_weight'] > 0:
                            lm_log_probs = F.log_softmax(
                                logits_lm_t.squeeze(1), dim=1)
                            assert log_probs.size() == lm_log_probs.size()
                            score += lm_log_probs[0, indices_topk[
                                0, k].item()].item() * params['rnnlm_weight']

                        new_beam.append({
                            'hyp':
                            beam[i_beam]['hyp'] + [indices_topk[0, k].item()],
                            'score':
                            score,
                            'scores':
                            beam[i_beam]['scores'] + [score],
                            'score_raw':
                            score_raw,
                            'score_lm':
                            0,  # TODO(hirofumi):
                            'score_lp':
                            0,  # TODO(hirofumi):
                            'score_cp':
                            0,  # TODO(hirofumi):
                            'hx_list':
                            hx_list[:],
                            'cx_list':
                            cx_list[:] if cx_list is not None else None,
                            'dout':
                            dout,
                            'context':
                            context,
                            'aws':
                            beam[i_beam]['aws'] + [aw],
                            'rnnlm_hx_list':
                            rnnlm_state[0][:]
                            if rnnlm_state is not None else None,
                            'rnnlm_cx_list':
                            rnnlm_state[1][:]
                            if rnnlm_state is not None else None,
                            'prev_cov':
                            cov_sum,
                            '_dout':
                            _dout,
                            '_dstate':
                            _dstate[:]
                        })

                new_beam = sorted(new_beam,
                                  key=lambda x: x['score'],
                                  reverse=True)

                # Remove complete hypotheses
                not_complete = []
                for cand in new_beam[:params['beam_width']]:
                    if cand['hyp'][-1] == eos:
                        complete += [cand]
                    else:
                        not_complete += [cand]

                if len(complete) >= params['beam_width']:
                    complete = complete[:params['beam_width']]
                    break

                beam = not_complete[:params['beam_width']]

            # Sort by score
            if len(complete) == 0:
                complete = beam
            elif len(complete) < nbest and nbest > 1:
                complete.extend(beam[:nbest - len(complete)])
            complete = sorted(complete, key=lambda x: x['score'], reverse=True)

            # N-best list
            if self.backward:
                # Reverse the order
                nbest_hyps += [[
                    np.array(complete[n]['hyp'][1:][::-1])
                    for n in range(nbest)
                ]]
                if self.score.nheads > 1:
                    aws += [[
                        complete[n]['aws'][0, 1:][::-1] for n in range(nbest)
                    ]]
                else:
                    aws += [[
                        complete[n]['aws'][1:][::-1] for n in range(nbest)
                    ]]
                scores += [[
                    complete[n]['scores'][1:][::-1] for n in range(nbest)
                ]]
            else:
                nbest_hyps += [[
                    np.array(complete[n]['hyp'][1:]) for n in range(nbest)
                ]]
                if self.score.nheads > 1:
                    aws += [[complete[n]['aws'][0, 1:] for n in range(nbest)]]
                else:
                    aws += [[complete[n]['aws'][1:] for n in range(nbest)]]
                scores += [[complete[n]['scores'][1:] for n in range(nbest)]]
            # scores += [[complete[n]['score_raw'] for n in range(nbest)]]

            # Check <eos>
            eos_flag = [
                True if complete[n]['hyp'][-1] == eos else False
                for n in range(nbest)
            ]
            eos_flags.append(eos_flag)

            if id2token is not None:
                if refs is not None:
                    logger.info('Ref: %s' % refs[b].lower())
                for n in range(nbest):
                    logger.info('Hyp: %s' % id2token(nbest_hyps[0][n]))
            if refs is not None:
                logger.info('log prob (ref): ')
            for n in range(nbest):
                logger.info('log prob (hyp): %.3f' % complete[n]['score'])
                logger.info('log prob (hyp, raw): %.3f' %
                            complete[n]['score_raw'])

        # Concatenate in L dimension
        for b in range(len(aws)):
            for n in range(nbest):
                aws[b][n] = tensor2np(torch.stack(aws[b][n], dim=1).squeeze(0))

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.backward:
                nbest_hyps = [[
                    nbest_hyps[b][n][1:]
                    if eos_flags[b][n] else nbest_hyps[b][n]
                    for n in range(nbest)
                ] for b in range(bs)]
            else:
                nbest_hyps = [[
                    nbest_hyps[b][n][:-1]
                    if eos_flags[b][n] else nbest_hyps[b][n]
                    for n in range(nbest)
                ] for b in range(bs)]

        return nbest_hyps, aws, scores
Пример #30
0
    def greedy(self, eouts, elens, max_len_ratio, exclude_eos=False):
        """Greedy decoding in the inference stage.

        Args:
            eouts (FloatTensor): `[B, T, enc_units]`
            elens (list): A list of length `[B]`
            max_len_ratio (int): the maximum sequence length of tokens
            exclude_eos (bool):
        Returns:
            best_hyps (list): A list of length `[B]`, which contains arrays of size `[L]`
            aw (list): A list of length `[B]`, which contains arrays of size `[L, T]`

        """
        bs, enc_time, enc_nunits = eouts.size()
        device_id = eouts.get_device()

        # Initialization
        dout, dstate = self.init_dec_state(bs, self.nlayers, device_id, eouts,
                                           elens)
        _dout, _dstate = self.init_dec_state(bs, 1, device_id, eouts, elens)
        context = eouts.new_zeros(bs, 1, enc_nunits)
        self.score.reset()
        aw = None
        rnnlm_state = None

        if self.backward:
            sos, eos = self.eos, self.sos
        else:
            sos, eos = self.sos, self.eos

        # Start from <sos> (<eos> in case of the backward decoder)
        y = eouts.new_zeros(bs, 1).fill_(sos).long()

        best_hyps_tmp, aws_tmp = [], []
        y_lens = np.zeros((bs, ), dtype=np.int32)
        eos_flags = [False] * bs
        for t in range(int(math.floor(enc_time * max_len_ratio)) + 1):
            # Recurrency
            y_emb = self.embed(y)
            dout, dstate, _dout, _dstate = self.recurrency(
                y_emb, context, dstate, _dstate)

            # Update RNNLM states for cold fusion
            if self.rnnlm_cf:
                y_lm = self.rnnlm_cf.embed(y)
                logits_lm_t, lm_out, rnnlm_state = self.rnnlm_cf.predict(
                    y_lm, rnnlm_state)
            else:
                logits_lm_t, lm_out = None, None

            # Score
            context, aw = self.score(eouts, elens, dout, aw)

            # Generate
            attentional_t = self.generate(context, dout, logits_lm_t, lm_out)
            if self.rnnlm_init and self.internal_lm:
                # Residual connection
                attentional_t += _dout
            logits_t = self.output(attentional_t)

            # Pick up 1-best
            device_id = logits_t.get_device()
            y = np.argmax(logits_t.squeeze(1).detach(),
                          axis=1).cuda(device_id).unsqueeze(1)
            best_hyps_tmp += [y]
            if self.score.nheads > 1:
                aws_tmp += [aw[0]]
            else:
                aws_tmp += [aw]

            # Count lengths of hypotheses
            for b in range(bs):
                if not eos_flags[b]:
                    if y[b].item() == eos:
                        eos_flags[b] = True
                    y_lens[b] += 1
                    # NOTE: include <eos>

            # Break if <eos> is outputed in all mini-bs
            if sum(eos_flags) == bs:
                break

        # Concatenate in L dimension
        best_hyps_tmp = torch.cat(best_hyps_tmp, dim=1)
        aws_tmp = torch.stack(aws_tmp, dim=1)

        # Convert to numpy
        best_hyps_tmp = tensor2np(best_hyps_tmp)
        aws_tmp = tensor2np(aws_tmp)

        # Truncate by the first <eos> (<sos> in case of the backward decoder)
        if self.backward:
            # Reverse the order
            best_hyps = [best_hyps_tmp[b, :y_lens[b]][::-1] for b in range(bs)]
            aws = [aws_tmp[b, :y_lens[b]][::-1] for b in range(bs)]
        else:
            best_hyps = [best_hyps_tmp[b, :y_lens[b]] for b in range(bs)]
            aws = [aws_tmp[b, :y_lens[b]] for b in range(bs)]

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.backward:
                best_hyps = [
                    best_hyps[b][1:] if eos_flags[b] else best_hyps[b]
                    for b in range(bs)
                ]
            else:
                best_hyps = [
                    best_hyps[b][:-1] if eos_flags[b] else best_hyps[b]
                    for b in range(bs)
                ]

        return best_hyps, aws