コード例 #1
0
    def beam_search_decode(self, encoder_output, dot_attention_mask,
                           beam_size):
        batch_size = encoder_output.size(0)
        device = encoder_output.device
        feature_length = encoder_output.size(1)

        self.beam_steper = BeamSteper(batch_size, beam_size, self.bos_id,
                                      self.eos_id, self.vocab_size, device)
        encoder_output = encoder_output.unsqueeze(1).repeat(
            1, beam_size, 1, 1).view(batch_size * beam_size, feature_length,
                                     -1)
        dot_attention_mask = dot_attention_mask.unsqueeze(1).repeat(
            1, beam_size, 1, 1).view(batch_size * beam_size, 1, feature_length)
        with t.no_grad():
            for i in range(self.max_length):
                try:
                    length = self.beam_steper.length_container
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(
                        token_mask, token_mask)

                    token_id = self.beam_steper.get_first_step_token(
                    ) if i == 0 else self.beam_steper.token_container.view(
                        batch_size * beam_size, -1)
                    last_prob = self.beam_decode_step(token_id,
                                                      encoder_output,
                                                      token_mask,
                                                      self_attention_mask,
                                                      dot_attention_mask,
                                                      topk=beam_size,
                                                      return_last=True)
                    self.beam_steper.step(last_prob)
                except:
                    break
        return self.beam_steper.batch_best_saver
 def greedy_decode(self, encoder_output, feature_mask):
     """
     batched greedy decode
     """
     batch_size = encoder_output.size(0)
     device = encoder_output.device
     token_id = t.full((batch_size, 1), fill_value=self.bos_id, dtype=t.long, device=device)
     length = t.LongTensor([1] * batch_size).to(device)
     #probs = t.Tensor().to(device)
     with t.no_grad():
         for i in range(self.max_length):
             try:
                 token_mask = Masker.get_mask(length)
                 self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
                 self_attention_mask = Masker.get_forward_mask(self_attention_mask)
                 dot_attention_mask = Masker.get_dot_mask(token_mask, feature_mask)
                 last_prob, last_token_id = self.decode_step(
                     token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask,
                     topk=1, return_last=True)
                 token_id = t.cat([token_id, last_token_id], dim=1)
                 # print('concate, tokenid', token_id)
                 #probs = t.cat([probs, last_prob], dim=1)
                 for index, id in enumerate(last_token_id.squeeze(1)):
                     if id != self.eos_id:
                         length[index] += 1
             except:
                 #TODO: to be more consious
                 break
     return token_id
コード例 #3
0
 def _prepare_token(self, token, token_length):
     input_token, output_token, token_length = self._rebuild_target(
         token, token_length)
     token_mask = Masker.get_mask(token_length)
     token_self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
     token_self_attention_mask = Masker.get_forward_mask(
         token_self_attention_mask)
     return input_token, output_token, token_length, token_mask, token_self_attention_mask
コード例 #4
0
 def _prepare_feature(self,
                      feature,
                      feature_length,
                      restrict_left_length=None,
                      restrict_right_length=None):
     """
     do spec augment and build mask
     """
     if self.enable_spec_augment:
         feature = self.spec_augment(feature, feature_length)
     feature_mask = Masker.get_mask(feature_length)
     self_attention_mask = Masker.get_dot_mask(feature_mask, feature_mask)
     self_attention_mask = Masker.get_restricted_mask(
         self_attention_mask, restrict_left_length, restrict_right_length)
     return feature, feature_mask, self_attention_mask
コード例 #5
0
 def forward(self,
             feature,
             feature_length,
             ori_token,
             ori_token_length,
             cal_ce_loss=True):
     #
     t.cuda.empty_cache()
     feature, feature_mask, feature_self_attention_mask = self._prepare_feature(
         feature,
         feature_length,
         restrict_left_length=self.restrict_left_length,
         restrict_right_length=self.restrict_right_length)
     #
     input_token, output_token, token_length, token_mask, token_self_attention_mask, swich_target = self._prepare_token(
         ori_token, ori_token_length)
     #
     spec_feature = self.spec_encoder(feature, feature_mask,
                                      feature_self_attention_mask)
     #
     spec_output = self.encoder_linear(spec_feature)
     #
     dot_attention_mask = Masker.get_dot_mask(token_mask, feature_mask)
     output, swich = self.token_decoder(input_token, spec_feature,
                                        token_mask,
                                        token_self_attention_mask,
                                        dot_attention_mask)
     switch_loss = self.switch_loss(swich, swich_target)
     if cal_ce_loss:
         ce_loss = self.cal_ce_loss(output, output_token, type='lbce')
     else:
         ce_loss = None
     return output, output_token, spec_output, feature_length, ori_token, ori_token_length, ce_loss, switch_loss
コード例 #6
0
    def _prepare_token(self, token, token_length):
        """
        build target and mask
        """
        input_token, output_token, token_length = self._rebuild_target(
            token, token_length)
        token_mask = Masker.get_mask(token_length)
        token_self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
        token_self_attention_mask = Masker.get_forward_mask(
            token_self_attention_mask)

        switch = t.ones_like(output_token,
                             device=token.device).long()  # eng = 1
        switch.masked_fill_(output_token.eq(0), 0)  # pad=0
        switch.masked_fill_((output_token.ge(12) & output_token.le(4211)),
                            2)  # ch = 2
        switch.masked_fill_((output_token.ge(1) & output_token.le(10)),
                            3)  # other = 3
        return input_token, output_token, token_length, token_mask, token_self_attention_mask, switch
    def beam_search_decode(self, encoder_output, feature_mask, beam_size, best_k=5, lp_eps=0.0):
        batch_size, feature_length, _ = encoder_output.size()
        device = encoder_output.device
        self.beam_steper = BeamSteper(
            batch_size=batch_size, beam_size=beam_size, bos_id=self.bos_id, eos_id=self.eos_id,
            vocab_size=self.vocab_size, device=device, k_best=best_k, lp_eps=lp_eps
        )

        beam_feature_mask = feature_mask.unsqueeze(1).repeat(1, beam_size, 1).view(batch_size*beam_size, -1)
        beam_encoder_output = encoder_output.unsqueeze(1).repeat(1, beam_size, 1, 1).view(batch_size*beam_size, feature_length, -1)
        with t.no_grad():
            for i in range(self.max_length):
                if i == 0:
                    token_id = self.beam_steper.get_first_step_token()
                    length = self.beam_steper.get_first_step_length()
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
                    self_attention_mask = Masker.get_forward_mask(self_attention_mask)
                    dot_attention_mask = Masker.get_dot_mask(token_mask, feature_mask)
                    last_prob = self.beam_decode_step(
                        token_id, encoder_output, token_mask, self_attention_mask, dot_attention_mask)
                    if_continue = self.beam_steper.first_step(last_prob)
                    if not if_continue:
                        break
                else:
                    token_id = self.beam_steper.token_container
                    token_id = token_id.view(batch_size * beam_size, -1)
                    length = self.beam_steper.length_container
                    length = length.view(batch_size * beam_size)
                    token_mask = Masker.get_mask(length)
                    self_attention_mask = Masker.get_dot_mask(token_mask, token_mask)
                    self_attention_mask = Masker.get_forward_mask(self_attention_mask)
                    dot_attention_mask = Masker.get_dot_mask(token_mask, beam_feature_mask)
                    last_prob = self.beam_decode_step(
                        token_id, beam_encoder_output, token_mask, self_attention_mask, dot_attention_mask
                    )
                    if_continue = self.beam_steper.step(last_prob.view(batch_size, beam_size, -1))
                    if not if_continue:
                        break
        output_token = self.beam_steper.batch_best_saver.batch
        return output_token
コード例 #8
0
    def recognize(self,
                  feature,
                  feature_length,
                  beam=5,
                  penalty=0,
                  ctc_weight=0.3,
                  maxlenratio=0,
                  minlenratio=0,
                  char_list=None,
                  rnnlm=None,
                  lm_weight=0.1,
                  nbest=1):
        """Recognize input speech.

        :param ndnarray x: input acoustic feature (B, T, D) ,length (B)
        :param Namespace recog_args: argment Namespace contraining options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        assert feature.size(0) == 1
        enc_output, feature_mask = self._encode(feature, feature_length)
        if ctc_weight > 0.0:
            lpz = t.nn.functional.log_softmax(self.encoder_linear(enc_output),
                                              -1).squeeze(0)
        else:
            lpz = None

        h = enc_output.squeeze(0)
        # print('input lengths: ' + str(h.size(0)))
        # preprare sos
        y = self.vocab.bos_id
        vy = h.new_zeros(1).long()

        if maxlenratio == 0:
            maxlen = h.shape[0]
        else:
            # maxlen >= 1
            maxlen = max(1, int(maxlenratio * h.size(0)))
        minlen = int(minlenratio * h.size(0))
        # print('max output length: ' + str(maxlen))
        # print('min output length: ' + str(minlen))

        # initialize hypothesis
        if rnnlm:
            hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None}
        else:
            hyp = {'score': 0.0, 'yseq': [y]}
        if lpz is not None:

            ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(),
                                              self.vocab.blank_id,
                                              self.vocab.eos_id, np)
            hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
            hyp['ctc_score_prev'] = 0.0
            if ctc_weight != 1.0:
                # pre-pruning based on attention scores
                ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
            else:
                ctc_beam = lpz.shape[-1]
        hyps = [hyp]
        ended_hyps = []

        import six
        for i in six.moves.range(maxlen):
            # print('position ' + str(i))

            hyps_best_kept = []
            for hyp in hyps:
                vy[0] = hyp['yseq'][i]

                # get nbest local scores and their ids
                ys_mask = Masker.get_mask(t.LongTensor([i + 1]))
                ys_self_attention_mask = Masker.get_dot_mask(ys_mask, ys_mask)
                ys_self_attention_mask = Masker.get_forward_mask(
                    ys_self_attention_mask)
                dot_attention_mask = Masker.get_dot_mask(ys_mask, feature_mask)
                ys = t.tensor(hyp['yseq']).unsqueeze(0)
                local_att_scores = self.token_decoder.forward_one_step(
                    ys, enc_output, ys_mask, ys_self_attention_mask,
                    dot_attention_mask)[0]
                local_att_scores = t.nn.functional.log_softmax(
                    local_att_scores, -1)
                if rnnlm:
                    rnnlm_state, local_lm_scores = rnnlm.predict(
                        hyp['rnnlm_prev'], vy)
                    local_scores = local_att_scores + lm_weight * local_lm_scores
                else:
                    local_scores = local_att_scores

                if lpz is not None:
                    local_best_scores, local_best_ids = t.topk(
                        local_att_scores, ctc_beam, dim=1)
                    ctc_scores, ctc_states = ctc_prefix_score(
                        hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'])
                    local_scores = \
                        (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \
                        + ctc_weight * t.from_numpy(ctc_scores - hyp['ctc_score_prev'])
                    if rnnlm:
                        local_scores += lm_weight * local_lm_scores[:,
                                                                    local_best_ids[
                                                                        0]]
                    local_best_scores, joint_best_ids = t.topk(local_scores,
                                                               beam,
                                                               dim=1)
                    local_best_ids = local_best_ids[:, joint_best_ids[0]]
                else:
                    local_best_scores, local_best_ids = t.topk(local_scores,
                                                               beam,
                                                               dim=1)

                for j in six.moves.range(beam):
                    new_hyp = {}
                    new_hyp['score'] = hyp['score'] + float(
                        local_best_scores[0, j])
                    new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
                    new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
                    new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0,
                                                                           j])
                    if rnnlm:
                        new_hyp['rnnlm_prev'] = rnnlm_state
                    if lpz is not None:
                        new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[
                            0, j]]
                        new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[
                            0, j]]
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x['score'],
                                        reverse=True)[:beam]

            # sort and get nbest
            hyps = hyps_best_kept
            # print('number of pruned hypothes: ' + str(len(hyps)))
            # if char_list is not None:
            #     print(
            #         'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))

            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                # print('adding <eos> in the last postion in the loop')
                for hyp in hyps:
                    hyp['yseq'].append(self.vocab.eos_id)

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp['yseq'][-1] == self.vocab.eos_id:
                    # only store the sequence that has more than minlen outputs
                    # also add penalty
                    if len(hyp['yseq']) > minlen:
                        hyp['score'] += (i + 1) * penalty
                        if rnnlm:  # Word LM needs to add final <eos> score
                            hyp['score'] += lm_weight * rnnlm.final(
                                hyp['rnnlm_prev'])
                        ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            # end detection
            # if end_detect(ended_hyps, i) and maxlenratio == 0.0:
            #     print('end detected at %d', i)
            #     # break

            hyps = remained_hyps
            # if len(hyps) > 0:
            #     # pass
            #     print('remeined hypothes: ' + str(len(hyps)))
            # else:
            #     print('no hypothesis. Finish decoding.')
            #     # break

            # if char_list is not None:
            #     for hyp in hyps:
            #         print(
            #             'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))
            #
            # print('number of ended hypothes: ' + str(len(ended_hyps)))

        nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'],
                            reverse=True)[:min(len(ended_hyps), nbest)]

        # check number of hypotheis
        if len(nbest_hyps) == 0:
            # print('there is no N-best results, perform recognition again with smaller minlenratio.')
            # should copy becasuse Namespace will be overwritten globally
            return None
            # recog_args = Namespace(**vars(recog_args))
            # recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
            # return self.recognize(x, recog_args, char_list, rnnlm)

        # print('total log probability: ' + str(nbest_hyps[0]['score']))
        # print('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))
        return nbest_hyps
コード例 #9
0
 def _prepare_feature(self, feature, feature_length):
     if self.enable_spec_augment:
         feature = self.spec_augment(feature, feature_length)
     feature_mask = Masker.get_mask(feature_length)
     self_attention_mask = Masker.get_dot_mask(feature_mask, feature_mask)
     return feature, feature_mask, self_attention_mask