Ejemplo n.º 1
0
    def compute_hyps(self,
                     current_hyps,
                     curren_frame,
                     total_frame,
                     enc_output,
                     hat_att,
                     enc_mask,
                     chunk=True):
        for length, hyps_t in current_hyps.items():
            ys_mask = subsequent_mask(length).unsqueeze(0).cuda()
            ys_mask4use = ys_mask.repeat(len(hyps_t), 1, 1)

            # print(ys_mask4use.shape)
            l_id = [hyp_t['yseq'] for hyp_t in hyps_t]
            ys4use = torch.tensor(l_id).cuda()
            enc_output4use = enc_output.repeat(len(hyps_t), 1, 1)
            if hyps_t[0]["cache"] is None:
                cache4use = None
            else:
                cache4use = []
                for decode_num in range(len(hyps_t[0]["cache"])):
                    current_cache = []
                    for hyp_t in hyps_t:
                        current_cache.append(
                            hyp_t["cache"][decode_num].squeeze(0))
                    # print( torch.stack(current_cache).shape)

                    current_cache = torch.stack(current_cache)
                    cache4use.append(current_cache)

            partial_mask4use = []
            for hyp_t in hyps_t:
                #partial_mask4use.append(torch.ones([1, len(hyp_t['last_time'])+1, enc_mask.shape[1]]).byte())
                align = [0] * length
                align[:length - 1] = hyp_t['last_time'][:]
                align[-1] = curren_frame
                align_tensor = torch.tensor(align).unsqueeze(0)
                if chunk:
                    partial_mask = enc_mask[0][align_tensor]
                else:
                    right_window = self.right_window
                    partial_mask = trigger_mask(1, total_frame, align_tensor,
                                                self.left_window, right_window)
                partial_mask4use.append(partial_mask)

            partial_mask4use = torch.stack(partial_mask4use).cuda().squeeze(1)
            local_att_scores_b, new_cache_b = self.decoder.forward_one_step(
                ys4use, ys_mask4use, enc_output4use, partial_mask4use,
                cache4use)
            for idx, hyp_t in enumerate(hyps_t):
                hyp_t['tmp_cache'] = [
                    new_cache_b[decode_num][idx].unsqueeze(0)
                    for decode_num in range(len(new_cache_b))
                ]
                hyp_t['tmp_att'] = local_att_scores_b[idx].unsqueeze(0)
                hat_att[hyp_t['seq']] = {}
                hat_att[hyp_t['seq']]['cache'] = hyp_t['tmp_cache']
                hat_att[hyp_t['seq']]['att_scores'] = hyp_t['tmp_att']
Ejemplo n.º 2
0
    def compute_hyps_ctc(self,
                         hyps_ctc_cluster,
                         total_frame,
                         enc_output,
                         hat_att,
                         enc_mask,
                         chunk=True):
        for length, hyps_t in hyps_ctc_cluster.items():
            ys_mask = subsequent_mask(length - 1).unsqueeze(0).cuda()
            ys_mask4use = ys_mask.repeat(len(hyps_t), 1, 1)
            l_id = [hyp_t['yseq'][:-1] for hyp_t in hyps_t]
            ys4use = torch.tensor(l_id).cuda()
            enc_output4use = enc_output.repeat(len(hyps_t), 1, 1)
            if "precache" not in hyps_t[0] or hyps_t[0]["precache"] is None:
                cache4use = None
            else:
                cache4use = []
                for decode_num in range(len(hyps_t[0]["precache"])):
                    current_cache = []
                    for hyp_t in hyps_t:
                        # print(length, hyp_t["yseq"], hyp_t["cache"][0].shape,
                        #       hyp_t["cache"][2].shape, hyp_t["cache"][4].shape)
                        current_cache.append(
                            hyp_t["precache"][decode_num].squeeze(0))
                    current_cache = torch.stack(current_cache)
                    cache4use.append(current_cache)
            partial_mask4use = []
            for hyp_t in hyps_t:
                #partial_mask4use.append(torch.ones([1, len(hyp_t['last_time']), enc_mask.shape[1]]).byte())
                align = hyp_t['last_time']
                align_tensor = torch.tensor(align).unsqueeze(0)
                if chunk:
                    partial_mask = enc_mask[0][align_tensor]
                else:
                    right_window = self.right_window
                    partial_mask = trigger_mask(1, total_frame, align_tensor,
                                                self.left_window, right_window)
                partial_mask4use.append(partial_mask)

            partial_mask4use = torch.stack(partial_mask4use).cuda().squeeze(1)

            local_att_scores_b, new_cache_b = \
                self.decoder.forward_one_step(ys4use, ys_mask4use,
                                              enc_output4use, partial_mask4use, cache4use)
            for idx, hyp_t in enumerate(hyps_t):
                hyp_t['tmp_cur_new_cache'] = [
                    new_cache_b[decode_num][idx].unsqueeze(0)
                    for decode_num in range(len(new_cache_b))
                ]
                hyp_t['tmp_cur_att_scores'] = local_att_scores_b[
                    idx].unsqueeze(0)
                l_minus = ' '.join(hyp_t['seq'].split()[:-1])
                hat_att[l_minus] = {}
                hat_att[l_minus]['att_scores'] = hyp_t['tmp_cur_att_scores']
                hat_att[l_minus]['cache'] = hyp_t['tmp_cur_new_cache']