示例#1
0
    def beam_sample(self, batch, use_cuda, beam_size=1):
        src, src_len, src_mask = batch.src, batch.src_len, batch.src_mask
        # if self.use_title:
        #     src, src_len, src_mask = batch.title, batch.title_len, batch.title_mask
        # else:
        #     src, src_len, src_mask = batch.ori_content, batch.ori_content_len, batch.ori_content_mask
        if use_cuda:
            src, src_len, src_mask = src.cuda(), src_len.cuda(), src_mask.cuda(
            )
        # beam_size = self.config.beam_size
        batch_size = src.size(0)

        # (1) Run the encoder on the src. Done!!!!
        contexts, encState = self.encoder(src, src_len)

        #  (1b) Initialize for the decoder.
        def rvar(a):
            return a.repeat(1, beam_size, 1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        # Repeat everything beam_size times.
        # (batch, seq, nh) -> (beam*batch, seq, nh)
        contexts = contexts.repeat(beam_size, 1, 1)
        # (batch, seq) -> (beam*batch, seq)
        src_mask = src_mask.repeat(beam_size, 1)
        assert contexts.size(0) == src_mask.size(0), (contexts.size(),
                                                      src_mask.size())
        assert contexts.size(1) == src_mask.size(1), (contexts.size(),
                                                      src_mask.size())
        decState = (rvar(encState[0]), rvar(encState[1])
                    )  # layer, beam*batch, nh
        # decState.repeat_beam_size_times(beam_size)
        beam = [
            models.Beam(beam_size, n_best=1, cuda=use_cuda)
            for _ in range(batch_size)
        ]

        # (2) run the decoder to generate sentences, using beam search.

        for i in range(self.config.max_tgt_len):

            if all((b.done() for b in beam)):
                break

            # Construct beam*batch  nxt words.
            # Get all the pending current beam words and arrange for forward.
            # beam is batch_sized, so stack on dimension 1 not 0
            inp = torch.stack([b.getCurrentState() for b in beam],
                              1).contiguous().view(-1)
            if use_cuda:
                inp = inp.cuda()

            # Run one step.
            output, decState, attn = self.decoder.sample_one(
                inp, decState, contexts)  #, src_mask)  ok ?
            # decOut: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.

            #output = unbottle(self.log_softmax(output, -1))   ok?
            output = F.softmax(output, -1)
            output = unbottle(torch.log(output))

            attn = unbottle(attn)
            # beam x tgt_vocab

            # (c) Advance each beam.
            # update state
            for j, b in enumerate(
                    beam
            ):  # there are batch size beams!!! so here enumerate over batch
                b.advance(output.data[:, j],
                          attn.data[:, j])  # output is beam first
                b.beam_update(decState, j)

        # (3) Package everything up.
        allHyps, allScores, allAttn = [], [], []

        for j in range(batch_size):
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.getHyp(times, k)
                hyps.append(hyp)
                attn.append(att.max(1)[1])
            allHyps.append(hyps[0])
            allScores.append(scores[0])
            allAttn.append(attn[0])

        # print(allHyps)
        # print(allAttn)
        return allHyps, allAttn
示例#2
0
    def beam_sample(self, src, src_len, beam_size=1):

        # (1) Run the encoder on the src.

        lengths, indices = torch.sort(src_len, dim=0, descending=True)
        _, ind = torch.sort(indices)
        src = torch.index_select(src, dim=0, index=indices)
        src = src.t()
        batch_size = src.size(1)
        contexts, encState, embeds = self.encoder(src, lengths.data.tolist())

        #  (1b) Initialize for the decoder.
        def var(a):
            return Variable(a, volatile=True)

        def rvar(a):
            return var(a.repeat(1, beam_size, 1))

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        # Repeat everything beam_size times.
        contexts = rvar(contexts.data)
        embeds = rvar(embeds.data)

        if self.config.cell == 'lstm':
            decState = (rvar(encState[0].data), rvar(encState[1].data))
            memory = decState[0][-1]
        else:
            decState = rvar(encState.data)
            memory = rvar(encState[-1].data)
        #print(decState[0].size(), memory.size())
        #decState.repeat_beam_size_times(beam_size)
        beam = [
            models.Beam(beam_size,
                        n_best=1,
                        cuda=self.use_cuda,
                        length_norm=self.config.length_norm)
            for __ in range(batch_size)
        ]
        if self.decoder.attention is not None:
            self.decoder.attention.init_context(contexts)

        # (2) run the decoder to generate sentences, using beam search.

        for i in range(self.config.max_time_step):

            if all((b.done() for b in beam)):
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = var(
                torch.stack([b.getCurrentState()
                             for b in beam]).t().contiguous().view(-1))

            # Run one step.
            output, decState, attn, memory = self.decoder(
                inp, decState, embeds, memory)
            # decOut: beam x rnn_size
            #print(decState.size(), memory.size())
            # (b) Compute a vector of batch*beam word scores.
            output = unbottle(self.log_softmax(output))
            attn = unbottle(attn)
            # beam x tgt_vocab

            # (c) Advance each beam.
            # update state
            for j, b in enumerate(beam):
                b.advance(output.data[:, j], attn.data[:, j])
                b.beam_update(decState, j)
                b.beam_update_memory(memory, j)

        # (3) Package everything up.
        allHyps, allScores, allAttn = [], [], []

        for j in ind.data:
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.getHyp(times, k)
                hyps.append(hyp)
                attn.append(att.max(1)[1])
            allHyps.append(hyps[0])
            allScores.append(scores[0])
            allAttn.append(attn[0])

        return allHyps, allAttn
示例#3
0
文件: seq2seq.py 项目: Frances255/SGM
    def beam_sample(self, src, src_len, beam_size = 1):

        #beam_size = self.config.beam_size
        batch_size = src.size(1)

        # (1) Run the encoder on the src. Done!!!!
        if self.use_cuda:
            src = src.cuda()
            src_len = src_len.cuda()

        lengths, indices = torch.sort(src_len, dim=0, descending=True)
        _, ind = torch.sort(indices)
        src = Variable(torch.index_select(src, dim=1, index=indices), volatile=True)
        contexts, encState = self.encoder(src, lengths.tolist())

        #  (1b) Initialize for the decoder.
        def var(a):
            return Variable(a, volatile=True)

        def rvar(a):
            return var(a.repeat(1, beam_size, 1))

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        # Repeat everything beam_size times.
        contexts = rvar(contexts.data).transpose(0, 1)
        decState = (rvar(encState[0].data), rvar(encState[1].data))
        #decState.repeat_beam_size_times(beam_size)
        beam = [models.Beam(beam_size, n_best=1,
                          cuda=self.use_cuda)
                for __ in range(batch_size)]

        # (2) run the decoder to generate sentences, using beam search.
        
        mask = None
        soft_score = None
        for i in range(self.config.max_tgt_len):

            if all((b.done() for b in beam)):
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = var(torch.stack([b.getCurrentState() for b in beam])
                      .t().contiguous().view(-1))

            # Run one step.
            output, decState, attn = self.decoder.sample_one(inp, soft_score, decState, contexts, mask)
            soft_score = F.softmax(output)
            predicted = output.max(1)[1]
            if self.config.mask:
                if mask is None:
                    mask = predicted.unsqueeze(1).long()
                else:
                    mask = torch.cat((mask, predicted.unsqueeze(1)), 1)
            # decOut: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.
            output = unbottle(self.log_softmax(output))
            attn = unbottle(attn)
                # beam x tgt_vocab

            # (c) Advance each beam.
            # update state
            for j, b in enumerate(beam):
                b.advance(output.data[:, j], attn.data[:, j])
                b.beam_update(decState, j)

        # (3) Package everything up.
        allHyps, allScores, allAttn = [], [], []

        for j in ind:
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.getHyp(times, k)
                hyps.append(hyp)
                attn.append(att.max(1)[1])
            allHyps.append(hyps[0])
            allScores.append(scores[0])
            allAttn.append(attn[0])

        #print(allHyps)
        #print(allAttn)
        return allHyps, allAttn
示例#4
0
    def beam_sample(self, src, src_len, beam_size=1, eval_=False):
        """
        谢谢小可爱!
        评测时生成摘要,使用beam_search,
        """
        # (1) Run the encoder on the src.
        # sort 返回排序结果和排序结果在原先 tensor 的索引
        print("小可爱棒棒哒!")
        lengths, indices = torch.sort(src_len, dim=0, descending=True)
        # 返回原先 tensor 在排序后的索引
        _, ind = torch.sort(indices)
        src = torch.index_select(src, dim=0, index=indices)
        src = src.t()
        batch_size = src.size(1)
        contexts, encState = self.encoder(src, lengths.tolist())

        #  (1b) Initialize for the decoder.
        # 复制数据,且无需求导
        def var(a):
            return torch.tensor(a, requires_grad=False)

        def rvar(a):
            return var(a.repeat(1, beam_size, 1))

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        # Repeat everything beam_size times.
        # contexts = rvar(contexts.data)
        contexts = rvar(contexts)

        if self.config.cell == 'lstm':
            decState = (rvar(encState[0]), rvar(encState[1]))
        else:
            decState = rvar(encState)

        # 自定义 Beam, length_norm 数据归一化
        beam = [
            models.Beam(beam_size,
                        n_best=1,
                        cuda=self.use_cuda,
                        length_norm=self.config.length_norm)
            for __ in range(batch_size)
        ]
        if self.decoder.attention is not None:
            self.decoder.attention.init_context(contexts)

        # (2) run the decoder to generate sentences, using beam search.
        for i in range(self.config.max_time_step):
            if all((b.done() for b in beam)):
                break
            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = var(
                torch.stack([b.getCurrentState()
                             for b in beam]).t().contiguous().view(-1))
            # Run one step.
            output, decState, attn, p = self.decoder(inp, decState)
            # decOut: beam x rnn_size
            # (b) Compute a vector of batch*beam word scores.
            output = unbottle(self.log_softmax(output))
            attn = unbottle(attn)
            # beam x tgt_vocab
            # (c) Advance each beam.
            # update state
            for j, b in enumerate(beam):
                b.advance(output[:, j], attn[:, j])
                if self.config.cell == 'lstm':
                    b.beam_update(decState, j)
                else:
                    b.beam_update_gru(decState, j)

        # (3) Package everything up.
        allHyps, allScores, allAttn = [], [], []
        if eval_:
            allWeight = []

        # for j in ind.data:
        for j in ind:
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn = [], []
            if eval_:
                weight = []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.getHyp(times, k)
                hyps.append(hyp)
                attn.append(att.max(1)[1])
                if eval_:
                    weight.append(att)
            allHyps.append(hyps[0])
            allScores.append(scores[0])
            allAttn.append(attn[0])
            if eval_:
                allWeight.append(weight[0])

        if eval_:
            return allHyps, allAttn, allWeight

        return allHyps, allAttn
示例#5
0
    def beam_sample(self, batch, use_cuda, beam_size=1):
        src, adjs, concept, concept_mask = batch.src, batch.adj, batch.concept, batch.concept_mask
        src_mask = batch.src_mask
        concept_vocab = batch.concept_vocab
        title_index = batch.title_index
        if use_cuda:
            src = [s.cuda() for s in src]
            src_mask = [s.cuda() for s in src_mask]
            adjs = [adj.cuda() for adj in adjs]
            concept = [c.cuda() for c in concept]
            concept_mask = concept_mask.cuda()
            title_index = title_index.cuda()
        # beam_size = self.config.beam_size
        batch_size = len(src)

        # (1) Run the encoder on the src. Done!!!!
        contexts, state = self.encode(src, src_mask, concept, concept_mask,
                                      title_index, adjs)
        c0, h0 = self.build_init_state(state, self.config.num_layers)

        def rvar(a):
            return a.repeat(1, beam_size, 1)

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        # Repeat everything beam_size times.
        contexts = contexts.repeat(beam_size, 1, 1)
        concept_mask = concept_mask.repeat(beam_size, 1)
        concept_vocab = concept_vocab.repeat(beam_size, 1)
        title_index = title_index.repeat(beam_size)
        decState = (c0.repeat(1, beam_size, 1), h0.repeat(1, beam_size, 1))
        beam = [
            models.Beam(beam_size, n_best=1, cuda=use_cuda)
            for _ in range(batch_size)
        ]

        # (2) run the decoder to generate sentences, using beam search.

        for i in range(self.config.max_tgt_len):

            if all((b.done() for b in beam)):
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = torch.stack([b.getCurrentState()
                               for b in beam]).t().contiguous().view(-1)

            # Run one step.
            if self.use_copy:
                output, decState, attn, p_gen = self.decoder.sample_one(
                    inp,
                    decState,
                    contexts,
                    concept_mask,
                    title_index,
                    max_oov=0,
                    extend_vocab=concept_vocab)
            else:
                output, decState, attn = self.decoder.sample_one(
                    inp, decState, contexts)
                output = F.softmax(output, -1)
            # decOut: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.
            output = unbottle(torch.log(output))
            attn = unbottle(attn)
            # beam x tgt_vocab

            # (c) Advance each beam.
            # update state
            for j, b in enumerate(beam):
                b.advance(output.data[:, j], attn.data[:, j])
                b.beam_update(decState, j)

        # (3) Package everything up.
        allHyps, allScores, allAttn = [], [], []

        for j in range(batch_size):
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.getHyp(times, k)
                hyps.append(hyp)
                attn.append(att.max(1)[1])
            allAttn.append(attn[0])
            allHyps.append(hyps[0])

        # print(allHyps)
        # print(allAttn)
        return allHyps, allAttn
示例#6
0
    def beam_sample(self, src, src_len, dict_spk2idx, tgt, beam_size=1, src_original=None, mix_wav=None):

        if src_original is None:
            src_original=src
        src_original=src_original.transpose(0,1) # 确保要bs在第二维

        src = src.transpose(0, 1)
        # beam_size = self.config.beam_size
        batch_size = src.size(0)
        if mix_wav is not None:
            mix_wav=mix_wav.transpose(0,1)

        # (1) Run the encoder on the src. Done!!!!
        if self.use_cuda:
            src = src.cuda()
            src_len = src_len.cuda()

        lengths, indices = torch.sort(src_len, dim=0, descending=True)
        # _, ind = torch.sort(indices)
        # src = Variable(torch.index_select(src, dim=1, index=indices), volatile=True)
        contexts, *_ = self.encoder(src, lengths.data.cpu().numpy()[0])
        best_hyps_dict=self.decoder.recognize_beam_greddy(contexts, list(dict_spk2idx.keys()),None)[0]
        print('hyps:',best_hyps_dict['yseq'])

        if self.config.use_emb:
            ss_embs = best_hyps_dict['dec_embs_input'][:,1:]  # to [ bs, decLen(3),dim]
        else:
            ss_embs = best_hyps_dict['dec_hiddens'][:,:-1]  # to [ bs, decLen(3),dim]


        query=ss_embs
        if self.config.use_tas:
            predicted_maps = self.ss_model(mix_wav, query)
        else:
            predicted_maps = self.ss_model(src_original, query, tgt[1:-1], dict_spk2idx)
        return best_hyps_dict['yseq'][1:], predicted_maps.transpose(0,1)

        #  (1b) Initialize for the decoder.
        def var(a):
            return Variable(a, volatile=True)

        def rvar(a):
            return var(a.repeat(1, beam_size, 1))

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        # Repeat everything beam_size times.
        contexts = rvar(contexts.data).transpose(0, 1)
        # decState = (rvar(encState[0].data), rvar(encState[1].data))
        # decState.repeat_beam_size_times(beam_size)
        beam = [models.Beam(beam_size, dict_spk2idx, n_best=1,
                            cuda=self.use_cuda)
                for __ in range(batch_size)]

        # (2) run the decoder to generate sentences, using beam search.

        mask = None
        soft_score = None
        tmp_hiddens = []
        tmp_soft_score = []
        for i in range(self.config.max_tgt_len):

            if all((b.done() for b in beam)):
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = var(torch.stack([b.getCurrentState() for b in beam])
                      .t().contiguous().view(-1))
            if self.config.schmidt and i > 0:
                assert len(beam[0].sch_hiddens[-1]) == i
                tmp_hiddens = []
                for xxx in range(i):  # 每一个sch之前的列表
                    one_len = []
                    for bm_idx in range(beam_size):
                        for bs_idx in range(batch_size):
                            one_len.append(beam[bs_idx].sch_hiddens[-1][xxx][bm_idx, :])
                    tmp_hiddens.append(var(torch.stack(one_len)))

            # Run one step.
            output, decState, attn, hidden, emb = self.decoder.sample_one(inp, soft_score, tmp_hiddens,
                                                                          contexts.transpose(0,1), mask)
            # print "sample after decState:",decState[0].data.cpu().numpy().mean()
            if self.config.schmidt:
                tmp_hiddens += [hidden]
            if self.config.ct_recu:
                contexts = (1 - (attn > 0.03).float()).unsqueeze(-1) * contexts
            soft_score = F.softmax(output)
            if self.config.tmp_score:
                tmp_soft_score += [soft_score]
                if i == 1:
                    kl_loss = np.array([])
                    for kk in range(self.config.beam_size):
                        kl_loss = np.append(kl_loss, F.kl_div(soft_score[kk], tmp_soft_score[0][kk].data).data[0])
                    kl_loss = Variable(torch.from_numpy(kl_loss).float().cuda().unsqueeze(-1))
            predicted = output.max(1)[1]
            if self.config.mask:
                if mask is None:
                    mask = predicted.unsqueeze(1).long()
                else:
                    mask = torch.cat((mask, predicted.unsqueeze(1)), 1)
            # decOut: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.
            if self.config.tmp_score and i == 1:
                output = unbottle(self.log_softmax(output) + self.config.tmp_score * kl_loss)
            else:
                output = unbottle(self.log_softmax(output))
            attn = unbottle(attn)
            hidden = unbottle(hidden)
            emb = unbottle(emb)
            # beam x tgt_vocab

            # (c) Advance each beam.
            # update state
            for j, b in enumerate(beam):
                b.advance(output.data[:, j], attn.data[:, j], hidden.data[:, j], emb.data[:, j])
                # b.beam_update(decState, j)  # 这个函数更新了原来的decState,只不过不是用return,是直接赋值!
                # print('pre root',b.prevKs)
                # print('next root',b.nextYs)
                # print('score',b.scores)
                if self.config.ct_recu:
                    b.beam_update_context(contexts, j)  # 这个函数更新了原来的decState,只不过不是用return,是直接赋值!
            # print "beam after decState:",decState[0].data.cpu().numpy().mean()

        # (3) Package everything up.
        allHyps, allScores, allAttn, allHiddens, allEmbs = [], [], [], [], []

        ind = list(range(batch_size))
        for j in ind:
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn, hiddens, embs = [], [], [], []
            pred = []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att, hidden, emb = b.getHyp(times, k)
                # if self.config.relitu:
                #     relitu_line(626, 1, att[0].cpu().numpy())
                #     relitu_line(626, 1, att[1].cpu().numpy())
                hyps.append(hyp)
                attn.append(att.max(1)[1])
                hiddens.append(hidden)
                embs.append(emb)
            allHyps.append(hyps[0])
            allScores.append(scores[0])
            allAttn.append(attn[0])
            allHiddens.append(hiddens[0])
            allEmbs.append(embs[0])
        print('allHyps:\n',allHyps)

        # from sklearn.metrics.pairwise import euclidean_distances,cosine_distances
        # print(cosine_distances(allEmbs[0].data.cpu().numpy()))
        # print(cosine_distances(allHiddens[0].data.cpu().numpy()))

        if not self.config.global_emb:
            outputs = Variable(torch.stack(allHiddens, 0).transpose(0, 1))  # to [decLen, bs, dim]
            if not self.config.hidden_mix:
                predicted_maps = self.ss_model(src, outputs[:-1, :], tgt[1:-1])
            else:
                ss_embs = Variable(torch.stack(allEmbs, 0).transpose(0, 1))  # to [decLen, bs, dim]
                mix = torch.cat((outputs[:-1, :], ss_embs[1:]), dim=2)
                predicted_maps = self.ss_model(src, mix, tgt[1:-1])
            if self.config.top1:
                predicted_maps = predicted_maps[:, 0].unsqueeze(1)
        else:
            # allEmbs=[j[1:self.config.MAX_MIX] for j in allEmbs]
            ss_embs = Variable(torch.stack(allEmbs, 0).transpose(0, 1))  # to [decLen, bs, dim]
            if self.config.use_tas:
                if self.config.global_hidden:
                    ss_hidden = Variable(torch.stack(allHiddens, 0).transpose(0, 1))  # to [decLen, bs, dim]
                    # predicted_maps = self.ss_model(mix_wav,ss_hidden[:-1, :]) #正确
                    predicted_maps = self.ss_model(mix_wav,ss_hidden[1:, :].transpose(0,1)) #错位
                else:
                    predicted_maps = self.ss_model(mix_wav,ss_embs[1:].transpose(0,1))
            elif self.config.global_hidden:
                ss_hidden = Variable(torch.stack(allHiddens, 0).transpose(0, 1))  # to [decLen, bs, dim]
                print((ss_hidden.shape))

                predicted_maps = self.ss_model(src,ss_hidden[1:, :], tgt[1:-1], dict_spk2idx)
            elif self.config.top1:
                predicted_maps = self.ss_model(src, ss_embs[1:2], tgt[1:2])
            else:
                predicted_maps = self.ss_model(src, ss_embs[1:, :], tgt[1:-1], dict_spk2idx)
                # predicted_maps = self.ss_model(src, ss_embs, tgt[1:2])
        if self.config.top1:
            predicted_maps = predicted_maps[:,0:1] #bs,1,len
        return allHyps, allAttn, allHiddens, predicted_maps  # .transpose(0,1)
示例#7
0
    def beam_sample(self,
                    src,
                    src_len,
                    knowledge=None,
                    knowledge_len=None,
                    beam_size=1,
                    eval_=False):
        """
        beam search
        :param src: source input
        :param src_len: source length
        :param beam_size: beam size
        :param eval_: evaluation or not
        :return: prediction, attention max score and attention weights
        """
        lengths, indices = torch.sort(src_len, dim=0,
                                      descending=True)  # [batch]
        _, ind = torch.sort(indices)
        src = torch.index_select(src, dim=0, index=indices)  # [batch, len]
        src = src.t()  # [len, batch]
        batch_size = src.size(1)

        if self.config.knowledge:
            knowledge = knowledge.t()

        if self.config.positional:
            if self.config.knowledge:
                mask = (knowledge.t() != 0).float()
                knowledge_contexts = self.knowledge_encoder(
                    knowledge, is_knowledge=True).transpose(0, 1)
            contexts = self.encoder(src,
                                    src_len.tolist())  # [len, batch, size]
            if self.config.knowledge:
                contexts = contexts.transpose(0, 1)
                contexts = self.encoder.condition_context_attn(
                    contexts, knowledge_contexts, mask)
                contexts = self.encoder.bi_attn_transform(contexts)
                contexts = contexts.transpose(0, 1)
        else:
            contexts, state = self.encoder(
                src, lengths.tolist())  # [len, batch, size]

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(batch_size, beam_size, -1)

        beam = [
            models.Beam(beam_size,
                        n_best=1,
                        cuda=self.use_cuda,
                        length_norm=self.config.length_norm)
            for __ in range(batch_size)
        ]  # [batch, beam]

        contexts = tile(contexts, beam_size, 1)  # [len, batch*beam, size]
        src = tile(src, beam_size, 1)  # [len, batch*beam]

        if not self.config.positional:
            h = tile(state[0], beam_size, 0)
            c = tile(state[1], beam_size, 0)
            state = (h, c)  # [len, batch*beam, size]

        # self.decoder.init_state(src, contexts)
        models.transformer.init_state(self.decoder, src, contexts,
                                      self.decoder.num_layers)

        # sequential generation
        for i in range(self.config.max_time_step):
            # finish beam search
            if all((b.done() for b in beam)):
                break

            inp = torch.stack([b.getCurrentState() for b in beam])
            inp = inp.view(1, -1)  # [1, batch*beam]

            if self.config.positional:
                output, attn = self.decoder(inp, contexts,
                                            step=i)  # [len, batch*beam, size]
                state = None
            else:
                output, attn, state = self.decoder(
                    inp, contexts, state, step=i)  # [len, batch*beam, size]
            output = self.compute_score(output.transpose(0, 1)).squeeze(
                1)  # [batch*beam, size]

            output = unbottle(self.log_softmax(output))  # [batch, beam, size]
            attn = unbottle(attn.squeeze(0))  # [batch, beam, k_len]

            select_indices_array = []
            # scan beams in a batch
            for j, b in enumerate(beam):
                # update each beam
                b.advance(output[j, :], attn[j, :])  # batch index
                select_indices_array.append(b.getCurrentOrigin() +
                                            j * beam_size)
            select_indices = torch.cat(select_indices_array)
            self.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))
            if state is not None:
                state = (state[0].index_select(0, select_indices),
                         state[1].index_select(0, select_indices))

        allHyps, allScores, allAttn = [], [], []
        if eval_:
            allWeight = []

        for j in ind:
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn = [], []
            if eval_:
                weight = []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.getHyp(times, k)
                hyps.append(hyp)
                attn.append(att.max(1)[1])
                if eval_:
                    weight.append(att)
            allHyps.append(hyps[0])
            allScores.append(scores[0])
            allAttn.append(attn[0])
            if eval_:
                allWeight.append(weight[0])
        if eval_:
            return allHyps, allAttn, allWeight

        return allHyps, allAttn
示例#8
0
    def beam_sample(self, src, src_len, beam_size=1):
        '''
        :param src: [maxlen, batch]
        :param src_len: [batch]
        :param beam_size: from config
        :return:
        1.句长排序,输入encoder,返回contexts,encState
            contexts [maxlen, batch, hidden*num_dirc]
            state -> (h, c) -> [num_layer, batch, dir * hidden]
        2.
        '''
        # beam_size = self.config.beam_size
        batch_size = src.size(1)

        # (1) Run the encoder on the src. Done!!!!
        if self.use_cuda:
            src = src.cuda()
            src_len = src_len.cuda()

        lengths, indices = torch.sort(src_len, dim=0, descending=True)
        _, ind = torch.sort(indices)
        with torch.no_grad():
            src = torch.index_select(src, dim=1, index=indices)
        # contexts [maxlen, batch, hidden*num_dirc]
        # state -> (h, c) -> [num_layer, batch, dir * hidden]
        contexts, encState = self.encoder(src, lengths.tolist())

        #  (1b) Initialize for the decoder.
        def var(a):
            with torch.no_grad():
                return a  # 相当于设置为inference模式

        def rvar(a):
            return var(a.repeat(1, beam_size, 1))

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        # Repeat everything beam_size times.
        contexts = rvar(contexts.data).transpose(
            0, 1)  # [len, batch, hid*dirc] -> [batch * beam_sz, len, hid*dirc]
        decState = (rvar(encState[0].data), rvar(encState[1].data)
                    )  # [layer, batch*bean_sz, dir*hidden]
        # decState.repeat_beam_size_times(beam_size)
        beam = [
            models.Beam(beam_size, n_best=1, cuda=self.use_cuda)
            for __ in range(batch_size)
        ]  # batch_sz个Beam对象

        # (2) run the decoder to generate sentences, using beam search.

        mask = None
        soft_score = None
        for i in range(self.config.max_tgt_len):

            if all((b.done() for b in beam)):
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            # the nextYs in each b is initially [BOS] + [EOS] * beam_sz - 1
            inp = var(
                torch.stack([
                    b.getCurrentState() for b in beam
                ])  # b.getCurrentState return a longTensor of shape [beam_sz]
                # then stack it to [batch, beam_sz], t() -> [beam_sz, batch] -> [beam*batch]
                # 符合decoder的输入
                .t().contiguous().view(-1))

            # Run one step. # inp[80], soft_score -> initially None  decState->[layer, batch*bean_sz, dir*hidden]
            # contexts -> [batch * beam_sz, len, hid*dirc], mask initially None
            # HINT: output就是score

            output, decState, attn = self.decoder.sample_one(
                inp, soft_score, decState, contexts, mask)
            soft_score = F.softmax(output, dim=1)
            '''
            import time
            print(soft_score[:3])
            print(output.max(1)[0])
            time.sleep(20)
            '''

            if self.config.view_score:
                import time
                print(output[:10])
                time.sleep(20)

            if self.config.cheat:
                output[:, -5:] += self.config.cheat
            predicted = output.max(1)[1]
            # predicted = torch.multinomial(soft_score, 1).squeeze()

            # print(output.shape, predicted.shape)
            # print(predicted[:8])
            # time.sleep(20)
            '''
            the_index = (torch.zeros(predicted.shape).float().cuda() + 0.1).cuda().gt(output.max(1)[0])
            # the_index = output.max(1)[0].gt(torch.zeros(predicted.shape).cuda() + 0.1).cuda()
            # print(the_index, predicted)
            predicted = predicted.masked_fill_(the_index, 3)
            '''
            '''
            import time
            print(predicted)
            time.sleep(20)
            '''

            if self.config.mask:
                if mask is None:
                    mask = predicted.unsqueeze(1).long()
                else:
                    mask = torch.cat((mask, predicted.unsqueeze(1)), 1)
            # decOut: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.
            output = unbottle(self.log_softmax(output))
            attn = unbottle(attn)
            # beam x tgt_vocab

            # (c) Advance each beam.
            # update state
            for j, b in enumerate(beam):
                b.advance(output.data[:, j], attn.data[:, j])
                b.beam_update(decState, j)

        # (3) Package everything up.
        allHyps, allScores, allAttn = [], [], []

        for j in ind:
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.getHyp(times, k)
                hyps.append(hyp)
                attn.append(att.max(1)[1])  # 最关注的位置
            allHyps.append(hyps[0])
            allScores.append(scores[0])
            allAttn.append(attn[0])

        # print(allHyps)
        # print(allAttn)
        return allHyps, allAttn
    def beam_sample(self, fc_feats, att_feats, att_masks, opt={}):
        # (1) Run the encoder on the src.
        eval_ = False
        beam_size = opt.get('beam_size', 10)
        batch_size = fc_feats.size(0)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(
            fc_feats, att_feats, att_masks)

        #contexts, encState = self.encoder(src, lengths.tolist())
        ind = range(batch_size)
        decState = self.init_hidden(beam_size * batch_size)
        # tmp_fc_feats = p_fc_feats.unsqueeze(1).expand(batch_size, beam_size, p_fc_feats.size(1)).reshape(-1, p_fc_feats.size(1))
        # tmp_att_feats = p_att_feats.unsqueeze(1).expand((batch_size, beam_size, p_att_feats.size(1), p_att_feats.size(2))).reshape(
        #     -1, p_att_feats.size(1), p_att_feats.size(2)).contiguous()
        # tmp_p_att_feats = pp_att_feats.unsqueeze(1).expand(batch_size, beam_size, pp_att_feats.size(1), pp_att_feats.size(2))\
        #     .reshape(-1,pp_att_feats.size(1), pp_att_feats.size(2)) .contiguous()
        # tmp_att_masks = p_att_masks.unsqueeze(1).expand(batch_size, beam_size, p_att_masks.size(1)).reshape(
        #     -1, p_att_masks.size(1)).contiguous() if att_masks is not None else None
        tmp_fc_feats = p_fc_feats.repeat(beam_size, 1)
        tmp_att_feats = p_att_feats.repeat(beam_size, 1, 1)
        tmp_p_att_feats = pp_att_feats.repeat(beam_size, 1, 1)
        tmp_att_masks = p_att_masks.repeat(
            beam_size, 1) if att_masks is not None else None

        #  (1b) Initialize for the decoder.
        def var(a):
            return torch.tensor(a, requires_grad=False)

        def rvar(a):
            return var(a.repeat(1, beam_size, 1))

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        # Repeat everything beam_size times.
        # contexts = rvar(contexts.data)
        # contexts = rvar(contexts)

        # if self.config.cell == 'lstm':
        #     decState = (rvar(encState[0]), rvar(encState[1]))
        # else:
        #     decState = rvar(encState)

        beam = [
            models.Beam(beam_size, n_best=1, cuda=1, length_norm=0)
            for __ in range(batch_size)
        ]
        # if self.decoder.attention is not None:
        #     self.decoder.attention.init_context(contexts)

        # (2) run the decoder to generate sentences, using beam search.

        for i in range(self.seq_length):

            if all((b.done() for b in beam)):
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = var(
                torch.stack([b.getCurrentState()
                             for b in beam]).t().contiguous().view(-1))

            # Run one step.
            #output, decState, attn = self.decoder(inp, decState)
            #print(inp.size(), tmp_fc_feats.size())
            output, decState = self.get_logprobs_state(inp, tmp_fc_feats,
                                                       tmp_att_feats,
                                                       tmp_p_att_feats,
                                                       tmp_att_masks, decState)

            # decOut: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.
            #output = unbottle(self.log_softmax(output))
            #print(output.size())
            output = unbottle(output)
            # UNK
            output[:, :, output.size(2) -
                   1] = output[:, :, output.size(2) - 1] - 1000
            # attn = unbottle(attn)
            # beam x tgt_vocab

            # (c) Advance each beam.
            # update state
            tmp_decState = (unbottle(decState[0]), unbottle(decState[1]))
            t_tmp_fc_feats = unbottle(tmp_fc_feats)
            t_tmp_att_feats = tmp_att_feats.view(beam_size, batch_size,
                                                 tmp_att_feats.size(1),
                                                 tmp_att_feats.size(2))
            t_tmp_p_att_feats = tmp_att_feats.view(beam_size, batch_size,
                                                   tmp_p_att_feats.size(1),
                                                   tmp_p_att_feats.size(2))
            tmp_att_masks = unbottle(tmp_att_masks)

            for j, b in enumerate(beam):
                tmp = [
                    t_tmp_fc_feats[:, j, :], t_tmp_att_feats[:, j, :, :],
                    t_tmp_p_att_feats[:, j, :, :], tmp_att_masks[:, j, :]
                ]
                b.advance(output[:, j])
                b.beam_update(decState, j)

        # (3) Package everything up.
        allHyps, allScores, allAttn = [], [], []
        if eval_:
            allWeight = []

        # for j in ind.data:
        #print(len(beam))
        for j in ind:
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn = [], []
            if eval_:
                weight = []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.getHyp(times, k)
                hyps.append(hyp)
                if eval_:
                    weight.append(att)
            allHyps.append(hyps[0])
            allScores.append(scores[0])
            if eval_:
                allWeight.append(weight[0])

        if eval_:
            return allHyps, allAttn, allWeight
        seq = torch.zeros(batch_size, self.seq_length).long()
        for i in range(batch_size):
            for j, w in enumerate(allHyps[i]):
                seq[i, j] = w
        # print(seq)
        # exit(0)
        return seq, allAttn
示例#10
0
    def beam_sample(self, src, dict_spk2idx, dict_dir2idx, beam_size=1):

        mix_wav = src.transpose(0, 1)  # [batch, sample]
        mix = self.TasNetEncoder(mix_wav)  # [batch, featuremap, timeStep]

        mix_infer = mix.transpose(1, 2)  # [batch, timeStep, featuremap]
        batch_size, lengths, _ = mix_infer.size()
        lengths = Variable(
            torch.LongTensor(self.config.batch_size).zero_() +
            lengths).unsqueeze(0).cuda()
        lengths, indices = torch.sort(lengths.squeeze(0),
                                      dim=0,
                                      descending=True)

        contexts, encState = self.encoder(mix_infer, lengths.data.tolist(
        ))  # context [max_len,batch_size,hidden_size×2]

        #  (1b) Initialize for the decoder.
        def var(a):
            return Variable(a, volatile=True)

        def rvar(a):
            return var(a.repeat(1, beam_size, 1))

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        # Repeat everything beam_size times.
        contexts = rvar(contexts.data).transpose(0, 1)
        decState = (rvar(encState[0].data), rvar(encState[1].data))
        decState_dir = (rvar(encState[0].data), rvar(encState[1].data))
        # decState.repeat_beam_size_times(beam_size)
        beam = [
            models.Beam(beam_size, dict_spk2idx, n_best=1, cuda=self.use_cuda)
            for __ in range(batch_size)
        ]

        beam_dir = [
            models.Beam(beam_size, dict_dir2idx, n_best=1, cuda=self.use_cuda)
            for __ in range(batch_size)
        ]
        # (2) run the decoder to generate sentences, using beam search.

        mask = None
        mask_dir = None
        soft_score = None
        tmp_hiddens = []
        tmp_soft_score = []

        soft_score_dir = None
        tmp_hiddens_dir = []
        tmp_soft_score_dir = []
        output_list = []
        output_dir_list = []
        predicted_list = []
        predicted_dir_list = []
        output_bk_list = []
        output_bk_dir_list = []
        hidden_list = []
        hidden_dir_list = []
        emb_list = []
        emb_dir_list = []
        for i in range(self.config.max_tgt_len):

            if all((b.done() for b in beam)):
                break
            if all((b_dir.done() for b_dir in beam_dir)):
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = var(
                torch.stack([b.getCurrentState()
                             for b in beam]).t().contiguous().view(-1))
            inp_dir = var(
                torch.stack([b_dir.getCurrentState()
                             for b_dir in beam_dir]).t().contiguous().view(-1))

            # Run one step.
            output, output_dir, decState, decState_dir, attn_weights, attn_weights_dir, hidden, hidden_dir, emb, emb_dir, output_bk, output_bk_dir = self.decoder.sample_one(
                inp, inp_dir, soft_score, soft_score_dir, decState,
                decState_dir, tmp_hiddens, tmp_hiddens_dir, contexts, mask,
                mask_dir)
            soft_score = F.softmax(output)
            soft_score_dir = F.softmax(output_dir)

            predicted = output.max(1)[1]
            predicted_dir = output_dir.max(1)[1]
            if self.config.mask:
                if mask is None:
                    mask = predicted.unsqueeze(1).long()
                    mask_dir = predicted_dir.unsqueeze(1).long()
                else:
                    mask = torch.cat((mask, predicted.unsqueeze(1)), 1)
                    mask_dir = torch.cat(
                        (mask_dir, predicted_dir.unsqueeze(1)), 1)
            # decOut: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.

            output_list.append(output[0])
            output_dir_list.append(output_dir[0])

            output = unbottle(self.log_softmax(output))
            output_dir = unbottle(F.sigmoid(output_dir))

            attn = unbottle(attn_weights)
            hidden = unbottle(hidden)
            emb = unbottle(emb)
            attn_dir = unbottle(attn_weights_dir)
            hidden_dir = unbottle(hidden_dir)
            emb_dir = unbottle(emb_dir)
            # beam x tgt_vocab

            output_bk_list.append(output_bk[0])
            output_bk_dir_list.append(output_bk_dir[0])
            hidden_list.append(hidden[0])
            hidden_dir_list.append(hidden_dir[0])
            emb_list.append(emb[0])
            emb_dir_list.append(emb_dir[0])

            predicted_list.append(predicted)
            predicted_dir_list.append(predicted_dir)

            # (c) Advance each beam.
            # update state

            for j, b in enumerate(beam):
                b.advance(output.data[:, j], attn.data[:, j],
                          hidden.data[:, j], emb.data[:, j])
                b.beam_update(decState,
                              j)  # 这个函数更新了原来的decState,只不过不是用return,是直接赋值!
                if self.config.ct_recu:
                    b.beam_update_context(
                        contexts, j)  # 这个函数更新了原来的decState,只不过不是用return,是直接赋值!
            for i, a in enumerate(beam_dir):
                a.advance(output_dir.data[:, i], attn_dir.data[:, i],
                          hidden_dir.data[:, i], emb_dir.data[:, i])
                a.beam_update(decState_dir,
                              i)  # 这个函数更新了原来的decState,只不过不是用return,是直接赋值!
                if self.config.ct_recu:
                    a.beam_update_context(
                        contexts, i)  # 这个函数更新了原来的decState,只不过不是用return,是直接赋值!
            # print "beam after decState:",decState[0].data.cpu().numpy().mean()

        # (3) Package everything up.
        allHyps,allHyps_dir, allScores, allAttn, allHiddens, allEmbs = [],[], [], [], [], []

        ind = range(batch_size)
        for j in ind:
            b = beam[j]
            c = beam_dir[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, hyps_dir, attn, hiddens, embs = [], [], [], [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att, hidden, emb = b.getHyp(times, k)
                hyp_dir, att_dir, hidden_dir, emb_dir = c.getHyp(times, k)
                if self.config.relitu:
                    relitu_line(626, 1, att[0].cpu().numpy())
                    relitu_line(626, 1, att[1].cpu().numpy())
                hyps.append(hyp)
                attn.append(att.max(1)[1])
                hiddens.append(hidden + hidden_dir)
                embs.append(emb + emb_dir)
                hyps_dir.append(hyp_dir)
            allHyps.append(hyps[0])
            allHyps_dir.append(hyps_dir[0])
            allScores.append(scores[0])
            allAttn.append(attn[0])
            allHiddens.append(hiddens[0])
            allEmbs.append(embs[0])

        ss_embs = Variable(torch.stack(allEmbs,
                                       0).transpose(0,
                                                    1))  # to [decLen, bs, dim]
        if self.config.use_tas:
            predicted_maps = self.ss_model(mix, ss_embs[1:].transpose(0, 1))

        predicted_signal = self.TasNetDecoder(
            mix, predicted_maps)  # [batch, spkN, timeStep]
        return allHyps, allHyps_dir, allAttn, allHiddens, predicted_signal.transpose(
            0, 1
        ), output_list, output_dir_list, output_bk_list, output_bk_dir_list, hidden_list, hidden_dir_list, emb_list, emb_dir_list
示例#11
0
    def beam_sample(self, src, src_len, dict_spk2idx, tgt, beam_size=1):

        src = src.transpose(0, 1)
        #beam_size = self.config.beam_size
        batch_size = src.size(0)

        # (1) Run the encoder on the src. Done!!!!
        if self.use_cuda:
            src = src.cuda()
            src_len = src_len.cuda()

        lengths, indices = torch.sort(src_len, dim=0, descending=True)
        # _, ind = torch.sort(indices)
        # src = Variable(torch.index_select(src, dim=1, index=indices), volatile=True)
        contexts, encState = self.encoder(src, lengths.data.cpu().numpy()[0])

        #  (1b) Initialize for the decoder.
        def var(a):
            return Variable(a, volatile=True)

        def rvar(a):
            return var(a.repeat(1, beam_size, 1))

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        # Repeat everything beam_size times.
        contexts = rvar(contexts.data).transpose(0, 1)
        decState = (rvar(encState[0].data), rvar(encState[1].data))
        #decState.repeat_beam_size_times(beam_size)
        beam = [
            models.Beam(beam_size, dict_spk2idx, n_best=1, cuda=self.use_cuda)
            for __ in range(batch_size)
        ]

        # (2) run the decoder to generate sentences, using beam search.

        mask = None
        soft_score = None
        tmp_hiddens = []
        for i in range(self.config.max_tgt_len):

            if all((b.done() for b in beam)):
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = var(
                torch.stack([b.getCurrentState()
                             for b in beam]).t().contiguous().view(-1))
            if self.config.schmidt and i > 0:
                assert len(beam[0].sch_hiddens[-1]) == i
                tmp_hiddens = []
                for xxx in range(i):  #每一个sch之前的列表
                    one_len = []
                    for bm_idx in range(beam_size):
                        for bs_idx in range(batch_size):
                            one_len.append(
                                beam[bs_idx].sch_hiddens[-1][xxx][bm_idx, :])
                    tmp_hiddens.append(var(torch.stack(one_len)))

            # Run one step.
            output, decState, attn, hidden, emb = self.decoder.sample_one(
                inp, soft_score, decState, tmp_hiddens, contexts, mask)
            # print "sample after decState:",decState[0].data.cpu().numpy().mean()
            if self.config.schmidt:
                tmp_hiddens += [hidden]
            if self.config.ct_recu:
                contexts = (1 -
                            (attn > 0.003).float()).unsqueeze(-1) * contexts
            soft_score = F.softmax(output)
            predicted = output.max(1)[1]
            if self.config.mask:
                if mask is None:
                    mask = predicted.unsqueeze(1).long()
                else:
                    mask = torch.cat((mask, predicted.unsqueeze(1)), 1)
            # decOut: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.
            output = unbottle(self.log_softmax(output))
            attn = unbottle(attn)
            hidden = unbottle(hidden)
            emb = unbottle(emb)
            # beam x tgt_vocab

            # (c) Advance each beam.
            # update state
            for j, b in enumerate(beam):
                b.advance(output.data[:, j], attn.data[:, j],
                          hidden.data[:, j], emb.data[:, j])
                b.beam_update(decState,
                              j)  #这个函数更新了原来的decState,只不过不是用return,是直接赋值!
                if self.config.ct_recu:
                    b.beam_update_context(
                        contexts, j)  #这个函数更新了原来的decState,只不过不是用return,是直接赋值!
            # print "beam after decState:",decState[0].data.cpu().numpy().mean()

        # (3) Package everything up.
        allHyps, allScores, allAttn, allHiddens, allEmbs = [], [], [], [], []

        ind = range(batch_size)
        for j in ind:
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn, hiddens, embs = [], [], [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att, hidden, emb = b.getHyp(times, k)
                if self.config.relitu:
                    relitu_line(626, 1, att[0].cpu().numpy())
                    relitu_line(626, 1, att[1].cpu().numpy())
                hyps.append(hyp)
                attn.append(att.max(1)[1])
                hiddens.append(hidden)
                embs.append(emb)
            allHyps.append(hyps[0])
            allScores.append(scores[0])
            allAttn.append(attn[0])
            allHiddens.append(hiddens[0])
            allEmbs.append(embs[0])
        print allHyps

        if not self.config.global_emb:
            outputs = Variable(torch.stack(allHiddens, 0).transpose(
                0, 1))  # to [decLen, bs, dim]
            if not self.config.hidden_mix:
                predicted_maps = self.ss_model(src, outputs[:-1, :], tgt[1:-1])
            else:
                ss_embs = Variable(torch.stack(allEmbs, 0).transpose(
                    0, 1))  # to [decLen, bs, dim]
                mix = torch.cat((outputs[:-1, :], ss_embs[1:]), dim=2)
                predicted_maps = self.ss_model(src, mix, tgt[1:-1])
            if self.config.top1:
                predicted_maps = predicted_maps[:, 0].unsqueeze(1)
        else:
            ss_embs = Variable(torch.stack(allEmbs, 0).transpose(
                0, 1))  # to [decLen, bs, dim]
            if not self.config.top1:
                # src=src[:,:self.config.buffer_shift]
                predicted_maps = self.ss_model(src, ss_embs[1:, :], tgt[1:-1])
            else:
                predicted_maps = self.ss_model(src, ss_embs[1:2], tgt[1:2])
        return allHyps, allAttn, allHiddens, predicted_maps  #.transpose(0,1)