Beispiel #1
0
    def decode(self, h, mask):  # Viterbi decoding
        # initialize backpointers and viterbi variables in log space
        batch_size = h.shape[0]
        bptr = LongTensor()
        score = maybe_cuda(torch.full((batch_size, self.num_tags), -10000.))
        score[:, self.sos_idx] = 0.

        for t in range(h.size(1)):  # recursion through the sequence
            mask_t = mask[:, t].unsqueeze(1)
            score_t = score.unsqueeze(1) + self.trans  # [B, 1, C] -> [B, C, C]
            score_t, bptr_t = score_t.max(2)  # best previous scores and tags
            score_t += h[:, t]  # plus emission scores
            bptr = torch.cat((bptr, bptr_t.unsqueeze(1)), 1)
            score = score_t * mask_t + score * (1 - mask_t)
        score += self.trans[self.eos_idx]
        best_score, best_tag = torch.max(score, 1)

        # back-tracking
        bptr = bptr.tolist()
        best_path = [[i] for i in best_tag.tolist()]
        for b in range(batch_size):
            x = best_tag[b]  # best tag
            y = int(mask[b].sum().item())
            for bptr_t in reversed(bptr[b][:y]):
                x = bptr_t[x]
                best_path[b].append(x)
            best_path[b].pop()
            best_path[b].reverse()

        return best_path
Beispiel #2
0
    def forward(self, encoder_outputs, forward_step_fn, initial_decoder_hidden, max_length):
        batch_size = encoder_outputs.size(0)
        decoder_input = maybe_cuda(torch.full((batch_size, 1), self.sos_token_id, dtype=torch.int64))
        decoder_outputs = []
        sequence_symbols = []
        lengths = np.array([max_length] * batch_size)
        attentions = []
        for step in range(max_length):
            decoder_output, step_attention = forward_step_fn(
                input_var=decoder_input,
                encoder_hidden=initial_decoder_hidden,
                encoder_outputs=encoder_outputs)
            step_output = decoder_output.squeeze(1)
            decoder_outputs.append(step_output)
            if step_attention is not None:
                attentions.append(step_attention)
            symbols = step_output.topk(1)[1]
            decoder_input = symbols
            eos_batches = symbols.data.eq(self.eos_token_id)
            if eos_batches.dim() > 0:
                eos_batches = eos_batches.cpu().view(-1).numpy()
                update_idx = ((lengths > step) & eos_batches) != 0
                lengths[update_idx] = len(sequence_symbols)
            sequence_symbols.append(symbols)

        #print("greedy")
        #for i in range(batch_size):
        #    for j in range(max_length):
        #        print(sequence_symbols[j][i].item(), end=' ')
        #    print()
        return sequence_symbols, lengths, decoder_outputs, attentions
Beispiel #3
0
 def score(self, h, y, mask):  # calculate the score of a given sequence
     batch_size = h.shape[0]
     score = maybe_cuda(torch.zeros(batch_size))
     h = h.unsqueeze(3)
     trans = self.trans.unsqueeze(2)
     for t in range(h.size(1)):  # recursion through the sequence
         mask_t = mask[:, t]
         emit_t = torch.cat([h[t, y[t + 1]] for h, y in zip(h, y)])
         trans_t = torch.cat([trans[y[t + 1], y[t]] for y in y])
         score += (emit_t + trans_t) * mask_t
     last_tag = y.gather(1, mask.sum(1).long().unsqueeze(1)).squeeze(1)
     score += self.trans[self.eos_idx, last_tag]
     return score
Beispiel #4
0
 def forward(self, h, mask):  # forward algorithm
     # initialize forward variables in log space
     batch_size = h.shape[0]
     score = maybe_cuda(torch.full((batch_size, self.num_tags),
                                   -10000.))  # [B, C]
     score[:, self.sos_idx] = 0.
     trans = self.trans.unsqueeze(0)  # [1, C, C]
     for t in range(h.size(1)):  # recursion through the sequence
         mask_t = mask[:, t].unsqueeze(1)
         emit_t = h[:, t].unsqueeze(2)  # [B, C, 1]
         score_t = score.unsqueeze(
             1) + emit_t + trans  # [B, 1, C] -> [B, C, C]
         score_t = log_sum_exp(score_t)  # [B, C, C] -> [B, C]
         score = score_t * mask_t + score * (1 - mask_t)
     score = log_sum_exp(score + self.trans[self.eos_idx])
     return score  # partition function
Beispiel #5
0
    def recognize_beam_batch(
            self, states,
            configs,
            rnn_language_model=None,
            normalize_score=True, strm_idx=0):
        """
        :param encoder_outputs:
        :type states: DecodingStates
        :type configs: BeamSearchConfigs
        :param rnn_language_model:
        :param normalize_score:
        :param strm_idx:
        :return:
        """
        att_idx = min(strm_idx, len(self._attention) - 1)
        encoder_outputs = mask_by_length(encoder_outputs, encoder_output_lens, 0.0)

        # search params
        batch_size = len(encoder_output_lens)

        n_bb = batch_size * configs.beam_size
        n_bo = configs.beam_size * self._output_size
        n_bbo = n_bb * self._output_size
        pad_b = maybe_cuda(torch.LongTensor([i * configs.beam_size for i in range(batch_size)]).view(-1, 1))
        pad_bo = maybe_cuda(torch.LongTensor([i * n_bo for i in range(batch_size)]).view(-1, 1))
        pad_o = maybe_cuda(torch.LongTensor([i * self._output_size for i in range(n_bb)]).view(-1, 1))

        max_encoder_output_len = int(max(encoder_output_lens))
        max_len = max_encoder_output_len if configs.max_len_ratio == 0 \
            else max(1, int(self._max_len_ratio * max_encoder_output_len))
        min_len = int(configs.min_len_ratio * max_encoder_output_len)

        # initialization
        c_prev = [maybe_cuda(torch.zeros(n_bb, self._hidden_size)) for _ in range(self._num_layers)]
        z_prev = [maybe_cuda(torch.zeros(n_bb, self._hidden_size)) for _ in range(self._num_layers)]
        c_list = [maybe_cuda(torch.zeros(n_bb, self._hidden_size)) for _ in range(self._num_layers)]
        z_list = [maybe_cuda(torch.zeros(n_bb, self._hidden_size)) for _ in range(self._num_layers)]
        vscores = maybe_cuda(torch.zeros(batch_size, configs.beam_size))

        a_prev = None
        rnn_language_model_prev = None

        self._attention[att_idx].reset()  # reset pre-computation of h

        y_seq = [[self._sos_id] for _ in range(n_bb)]
        stop_search = [False for _ in range(batch_size)]
        ended_hypotheses = [[] for _ in range(batch_size)]

        exp_encoder_output_lens = encoder_output_lens.repeat(configs.beam_size).view(configs.beam_size, batch_size).transpose(0, 1).contiguous()
        exp_encoder_output_lens = exp_encoder_output_lens.view(-1).tolist()
        exp_h = encoder_outputs.unsqueeze(1).repeat(1, configs.beam_size, 1, 1).contiguous()
        exp_h = exp_h.view(n_bb, encoder_outputs.size()[1], encoder_outputs.size()[2])

        for i in range(max_len):
            vy = maybe_cuda(torch.LongTensor(get_last_y_seq(y_seq)))
            ey = self.dropout_emb(self.embed(vy))
            att_c, att_w = self._attention[att_idx](exp_h, exp_encoder_output_lens, self._dropout_decoder[0](z_prev[0]), a_prev)
            ey = torch.cat((ey, att_c), dim=1)

            # attention decoder
            z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev)
            local_scores = F.log_softmax(self._output(self._dropout_decoder[-1](z_list[-1])), dim=1)

            # rnn_language_model
            if rnn_language_model:
                rnn_language_model_state, local_lm_scores = rnn_language_model.buff_predict(rnn_language_model_prev, vy, n_bb)
                local_scores = local_scores + self._language_model_weight * local_lm_scores
            local_scores = local_scores.view(batch_size, configs.beam_size, self._output_size)

            if i == 0:
                local_scores[:, 1:, :] = self._logzero
            local_best_scores, local_best_odims = torch.topk(
                local_scores.view(batch_size, configs.beam_size, self._output_size),
                configs.beam_size, 2)

            # local pruning (via xp)
            local_scores = np.full((n_bbo,), self._logzero)
            _best_odims = local_best_odims.view(n_bb, configs.beam_size) + pad_o
            _best_odims = _best_odims.view(-1).cpu().numpy()
            _best_score = local_best_scores.view(-1).cpu().detach().numpy()
            local_scores[_best_odims] = _best_score
            local_scores = maybe_cuda(torch.from_numpy(local_scores).float()).view(batch_size, configs.beam_size, self._output_size)

            # (or indexing)
            # local_scores = to_cuda(self, torch.full((batch, beam, self._output_size), self._logzero))
            # _best_odims = local_best_odims
            # _best_score = local_best_scores
            # for si in range(batch):
            # for bj in range(beam):
            # for bk in range(beam):
            # local_scores[si, bj, _best_odims[si, bj, bk]] = _best_score[si, bj, bk]

            eos_vscores = local_scores[:, :, self._eos_id] + vscores
            vscores = vscores.view(batch_size, configs.beam_size, 1).repeat(1, 1, self._output_size)
            vscores[:, :, self._eos_id] = self._logzero
            vscores = (vscores + local_scores).view(batch_size, n_bo)

            # global pruning
            accum_best_scores, accum_best_ids = torch.topk(vscores, configs.beam_size, 1)
            accum_odim_ids = torch.fmod(accum_best_ids, self._output_size).view(-1).data.cpu().tolist()
            accum_padded_beam_ids = (torch.div(accum_best_ids, self._output_size) + pad_b).view(-1).data.cpu().tolist()

            y_prev = y_seq[:][:]
            y_seq = index_select_list(y_seq, accum_padded_beam_ids)
            y_seq = append_ids(y_seq, accum_odim_ids)
            vscores = accum_best_scores
            vidx = maybe_cuda(torch.LongTensor(accum_padded_beam_ids))

            if isinstance(att_w, torch.Tensor):
                a_prev = torch.index_select(att_w.view(n_bb, *att_w.shape[1:]), 0, vidx)
            elif isinstance(att_w, list):  # multi-head attention
                a_prev = [torch.index_select(att_w_one.view(n_bb, -1), 0, vidx) for att_w_one in att_w]
            else:
                # handle the case of location_recurrent when return is a tuple
                a_prev_ = torch.index_select(att_w[0].view(n_bb, -1), 0, vidx)
                h_prev_ = torch.index_select(att_w[1][0].view(n_bb, -1), 0, vidx)
                c_prev_ = torch.index_select(att_w[1][1].view(n_bb, -1), 0, vidx)
                a_prev = (a_prev_, (h_prev_, c_prev_))
            z_prev = [torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self._num_layers)]
            c_prev = [torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self._num_layers)]

            if rnn_language_model:
                rnn_language_model_prev = index_select_lm_state(rnn_language_model_state, 0, vidx)

            # pick ended hypotheses
            if i > min_len:
                k = 0
                penalty_i = (i + 1) * configs.penalty
                thr = accum_best_scores[:, -1]
                for samp_i in range(batch_size):
                    if stop_search[samp_i]:
                        k = k + configs.beam_size
                        continue
                    for beam_j in range(configs.beam_size):
                        if eos_vscores[samp_i, beam_j] > thr[samp_i]:
                            yk = y_prev[k][:]
                            yk.append(self._eos_id)
                            if len(yk) < 1000: # encoder_output_lens[samp_i]:
                                _vscore = eos_vscores[samp_i][beam_j] + penalty_i
                                if normalize_score:
                                    _vscore = _vscore / len(yk)
                                _score = _vscore.data.cpu().numpy()
                                ended_hypotheses[samp_i].append({'y_seq': yk, 'vscore': _vscore, 'score': _score})
                        k = k + 1

            # end detection
            stop_search = [stop_search[samp_i] or end_detect(ended_hypotheses[samp_i], i)
                           for samp_i in range(batch_size)]
            stop_search_summary = list(set(stop_search))
            if len(stop_search_summary) == 1 and stop_search_summary[0]:
                break

            torch.cuda.empty_cache()

        dummy_hypotheses = [{'y_seq': [self._sos_id, self._eos_id], 'score': np.array([-float('inf')])}]
        ended_hypotheses = [
            ended_hypotheses[samp_i] if len(ended_hypotheses[samp_i]) != 0
            else dummy_hypotheses for samp_i in range(batch_size)]
        n_best_hypotheses = [
            sorted(ended_hypotheses[samp_i], key=lambda x: x['score'], reverse=True)
            [:min(len(ended_hypotheses[samp_i]), configs.n_best)] for samp_i in range(batch_size)]

        return n_best_hypotheses
Beispiel #6
0
    def forward(self, states, strm_idx=0, use_teacher_forcing=True):
        """Decoder forward

        :type states: DecodingStates
        :param torch.Tensor encoder_outputs: batch of padded hidden state sequences (B, Tmax, D)
        :param torch.Tensor encoder_output_lens: batch of lengths of hidden state sequences (B)
        :param torch.Tensor decoder_inputs: batch of padded character id sequence tensor (B, Lmax)
        :param int strm_idx: stream index indicates the index of decoding stream.
        :return: decoder outputs
        :rtype: torch.Tensor
        :return: attentions
        :rtype: list[torch.Tensor]
        :return: sequences
        :rtype: torch.Tensor
        :return: lengths
        :rtype: np.ndarray
        """
        # TODO(kan-bayashi): need to make more smart way
        batch_size = states.encoder_outputs.size(0)
        max_length = states.decoder_inputs.size(1) if states.decoder_inputs is not None else self._max_length
        # ys = [y[y != self.ignore_id] for y in decoder_inputs]  # parse padded ys
        # attention index for the attention module
        # in SPA (speaker parallel attention), att_idx is used to select attention module. In other cases, it is 0.
        att_idx = min(strm_idx, len(self._attention) - 1)

        # initialization
        c_list = [self.zero_state(states.encoder_outputs)]
        z_list = [self.zero_state(states.encoder_outputs)]
        for _ in range(1, self._num_layers):
            c_list.append(self.zero_state(states.encoder_outputs))
            z_list.append(self.zero_state(states.encoder_outputs))
        att_w = None
        z_all = []
        self._attention[att_idx].reset()  # reset pre-computation of h

        if use_teacher_forcing:
            decoder_embedded_inputs = self.dropout_emb(self.embed(states.decoder_inputs[:, :-1]))  # utt x olen x zdim

            # loop for an output sequence
            for i in range(states.decoder_inputs.size(1) - 1):
                att_c, att_w = self._attention[att_idx](
                    states.encoder_outputs,
                    states.encoder_output_lens,
                    self._dropout_decoder[0](z_list[0]),
                    att_w)
                ey = torch.cat((decoder_embedded_inputs[:, i, :], att_c), dim=1)  # utt x (zdim + hdim)
                z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
                z_all.append(self._dropout_decoder[-1](z_list[-1]))
        else:
            for step in range(max_length):
                att_c, att_w = self._attention[att_idx](
                    states.encoder_outputs,
                    states.encoder_output_lens,
                    self._dropout_decoder[0](z_list[0]),
                    att_w)
                if step == 0:
                    embedded_sos = self.dropout_emb(self.embed(
                        maybe_cuda(torch.full((batch_size, ), self._sos_id, dtype=torch.int64))))
                    ey = torch.cat((embedded_sos, att_c), dim=1)
                else:
                    z_out = self._output(z_all[-1])
                    _, z_out = torch.max(z_out.detach(), dim=1)
                    z_out = self.dropout_emb(self.embed(z_out.cuda()))
                    ey = torch.cat((z_out, att_c), dim=1)  # utt x (zdim + hdim)

                z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
                z_all.append(self._dropout_decoder[-1](z_list[-1]))

        z_all = torch.stack(z_all, dim=1).view(batch_size, -1, self._hidden_size)
        decoder_outputs = self._output(z_all)
        _, sequences = decoder_outputs.max(-1)

        lengths = np.array([max_length] * batch_size)
        for step in range(max_length - 1):
            eos_batches = sequences[:, step].eq(self._eos_id)
            if eos_batches.dim() > 0:
                eos_batches = eos_batches.cpu().view(-1).numpy()
                update_idx = ((lengths > step) & eos_batches) != 0
                lengths[update_idx] = step

        # acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id)
        # logger.info('att loss:' + ''.join(str(self.loss.item()).split('\n')))
        # attentions = []
        # att_ws = self.calculate_all_attentions(encoder_outputs, encoder_output_lens, decoder_inputs, strm_idx)
        # print(att_ws[0].shape)
        # attentions.append(att_c)
        # print(len(attentions))
        return decoder_outputs, None, sequences, lengths