Esempio n. 1
0
    def translateBatch(self, srcBatch, featsBatch, tgtBatch):
        batchSize = srcBatch[0].size(1)
        beamSize = self.opt.beam_size

        #  (1) run the encoder on the src
        encStates, context = self.model.encoder(srcBatch, featsBatch)
        srcBatch = srcBatch[0]  # drop the lengths needed for encoder

        decStates = self.model.decIniter(encStates[1])  # batch, dec_hidden

        #  (3) run the decoder to generate sentences, using beam search

        # Expand tensors for each beam.
        context = context.data.repeat(1, beamSize, 1)
        decStates = decStates.unsqueeze(0).data.repeat(1, beamSize, 1)
        att_vec = self.model.make_init_att(context)
        padMask = srcBatch.data.eq(s2s.Constants.PAD).transpose(0, 1).unsqueeze(0).repeat(beamSize, 1, 1).float()

        beam = [s2s.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]
        batchIdx = list(range(batchSize))
        remainingSents = batchSize

        for i in range(self.opt.max_dec_length):
            # Prepare decoder input.
            input = torch.stack([b.getCurrentState() for b in beam
                                 if not b.done]).transpose(0, 1).contiguous().view(1,1, -1)

            print('input shape:',input.shape)  #[1,beam_size]
            # input, hidden, context, src_pad_mask, init_att, base_flag
            _, g_predict, c_predict, copyGateOutputs, decStates, attn, att_vec, mul_head_attn, _, _, _, _= \
                self.model.decoder(input, decStates, context, padMask.view(-1, padMask.size(2)), att_vec,True)
            #sample_y, g_outputs, c_outputs, copyGateOutputs, hidden, context_attention, cur_context, mul_head_attns,
            # is_Copys, all_pos, mul_cs, mul_as

            # g_outputs: 1 x (beam*batch) x numWords
            # wordLk =  1 +
            # copyGateOutputs = copyGateOutputs.view(-1, 1)
            # g_outputs = g_outputs.squeeze(0)
            # g_out_prob = self.model.generator.forward(g_outputs) + 1e-8
            # g_predict = torch.log(g_out_prob * ((1 - copyGateOutputs).expand_as(g_out_prob)))
            # c_outputs = c_outputs.squeeze(0) + 1e-8
            # c_predict = torch.log(c_outputs * (copyGateOutputs.expand_as(c_outputs)))]
            mul_head_attn = mul_head_attn[0]
            num_head = len(mul_head_attn)
            mul_head_attn = torch.stack(mul_head_attn)
            # mul_head_attn : n_heads * (beam*batch) * src_len
            # batch x beam x numWords
            wordLk = g_predict.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous()
            copyLk = c_predict.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous()
            print('wordLk:',wordLk.shape,copyLk.shape)
            attn = attn.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous()
            #print('mul_head_attn.shape:', mul_head_attn.shape)
            mul_head_attn = mul_head_attn.view(beamSize, num_head, remainingSents, -1).transpose(1, 2).transpose(0,1).contiguous()
            # print('attn.shape:',attn.shape) # ([64, 7, 88]
            #print('mul_head_attn.shape:', mul_head_attn.shape) # [64, 7, 8, 88]
            active = []
            father_idx = []
            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx], copyLk.data[idx], attn.data[idx], mul_head_attn.data[idx]):
                    active += [b]
                    father_idx.append(beam[b].prevKs[-1])  # this is very annoying

            if not active:
                break

            # to get the real father index
            real_father_idx = []
            for kk, idx in enumerate(father_idx):
                real_father_idx.append(idx * len(father_idx) + kk)

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t, rnnSize):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, rnnSize)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return view.index_select(1, activeIdx).view(*newSize)

            decStates = updateActive(decStates, self.dec_rnn_size)
            context = updateActive(context, self.enc_rnn_size)
            att_vec = updateActive(att_vec, self.enc_rnn_size)
            padMask = padMask.index_select(1, activeIdx)

            # set correct state for beam search
            previous_index = torch.stack(real_father_idx).transpose(0, 1).contiguous()
            decStates = decStates.view(-1, decStates.size(2)).index_select(0, previous_index.view(-1)).view(
                *decStates.size())
            att_vec = att_vec.view(-1, att_vec.size(1)).index_select(0, previous_index.view(-1)).view(*att_vec.size())

            remainingSents = len(active)

        # (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        allIsCopy, allCopyPosition = [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()
            allScores += [scores[:n_best]]
            valid_attn = srcBatch.data[:, b].ne(s2s.Constants.PAD).nonzero().squeeze(1)
            hyps, isCopy, copyPosition, attn, mul_attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
            attn = [a.index_select(1, valid_attn) for a in attn]
            allHyp += [hyps]
            allAttn += [attn]
            allIsCopy += [isCopy]
            allCopyPosition += [copyPosition]
            # print('allHyp:',len(hyps),len(hyps[0]),hyps[0])
            # print('allAttn:', len(attn),len(attn[0]))
            # print('allIsCopy:', len(isCopy),len(isCopy[0]))
            # print('allCopyPosition:', len(copyPosition),len(copyPosition[0]))

        # print(mul_attn[0].shape)
        return allHyp, allScores, allIsCopy, allCopyPosition, allAttn, mul_attn, None
Esempio n. 2
0
    def translateBatch(self, srcBatch, bioBatch, featsBatch, tgtBatch,
                       guideData):
        batchSize = srcBatch[0].size(1)
        beamSize = self.opt.beam_size
        #  (1) run the encoder on the src
        encStates, context, backward_hids, wordEmb = self.model.encoder(
            srcBatch, bioBatch, featsBatch, guideData)
        srcLength = srcBatch[1]
        srcBatch = srcBatch[0]  # drop the lengths needed for encoder
        entsBatch = srcBatch
        entsBatch = entsBatch.transpose(1, 0)

        decStates = self.model.decIniter(encStates[1])  # batch, dec_hidden

        #  (3) run the decoder to generate sentences, using beam search

        # Expand tensors for each beam.
        context = context.data.repeat(1, beamSize, 1)
        decStates = decStates.unsqueeze(0).data.repeat(1, beamSize, 1)
        att_vec = self.model.make_init_att(context)
        padMask = srcBatch.data.eq(s2s.Constants.PAD).transpose(
            0, 1).unsqueeze(0).repeat(beamSize, 1, 1).float()

        beam = [
            s2s.Beam(beamSize, srcLength[0][k], self.opt.cuda)
            for k in range(batchSize)
        ]
        batchIdx = list(range(batchSize))
        remainingSents = batchSize
        swaps = [0 for j in range(batchSize)]
        srcCur = [0 for j in range(batchSize)]

        for i in range(self.opt.max_sent_length):
            # Prepare decoder input.
            input = torch.stack([
                b.getCurrentState() for b in beam if not b.done
            ]).transpose(0, 1).contiguous().view(1, -1)
            cur = 0
            for j in range(batchSize):
                if not beam[j].done:
                    if input[0][cur].item() == s2s.Constants.SS:
                        swaps[j] = 1
                        srcCur[j] += 1
                    cur += 1
            cur = 0
            for j in range(batchSize):
                if not beam[j].done:
                    if swaps[j] == 1:
                        input[0][cur] = entsBatch[j][srcCur[j] - 1]
                    cur += 1
            swaps = [0 for j in range(batchSize)]
            g_outputs, c_outputs, copyGateOutputs, decStates, attn, att_vec, hiddens = \
                self.model.decoder(srcBatch, input, decStates, context, padMask.view(-1, padMask.size(2)), att_vec, True, False, backward_hids, wordEmb, srcCur)

            # g_outputs: 1 x (beam*batch) x numWords
            copyGateOutputs = copyGateOutputs.view(-1, 1)
            g_outputs = g_outputs.squeeze(0)
            g_out_prob = self.model.generator.forward(g_outputs) + 1e-8
            g_predict = torch.log(
                g_out_prob * ((1 - copyGateOutputs).expand_as(g_out_prob)))
            c_outputs = c_outputs.squeeze(0) + 1e-8
            c_predict = torch.log(c_outputs *
                                  (copyGateOutputs.expand_as(c_outputs)))

            # batch x beam x numWords
            wordLk = g_predict.view(beamSize, remainingSents,
                                    -1).transpose(0, 1).contiguous()
            copyLk = c_predict.view(beamSize, remainingSents,
                                    -1).transpose(0, 1).contiguous()
            attn = attn.view(beamSize, remainingSents,
                             -1).transpose(0, 1).contiguous()

            active = []
            father_idx = []
            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx], copyLk.data[idx],
                                       attn.data[idx]):
                    active += [b]
                    father_idx.append(
                        beam[b].prevKs[-1])  # this is very annoying

            if not active:
                break

            # to get the real father index
            real_father_idx = []
            for kk, idx in enumerate(father_idx):
                real_father_idx.append(idx * len(father_idx) + kk)

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t, rnnSize):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, rnnSize)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return view.index_select(1, activeIdx).view(*newSize)

            decStates = updateActive(decStates, self.dec_rnn_size)
            context = updateActive(context, self.enc_rnn_size)
            att_vec = updateActive(att_vec, self.enc_rnn_size)
            padMask = padMask.index_select(1, activeIdx)

            # set correct state for beam search
            previous_index = torch.stack(real_father_idx).transpose(
                0, 1).contiguous()
            decStates = decStates.view(-1, decStates.size(2)).index_select(
                0, previous_index.view(-1)).view(*decStates.size())
            att_vec = att_vec.view(-1, att_vec.size(1)).index_select(
                0, previous_index.view(-1)).view(*att_vec.size())

            remainingSents = len(active)

        # (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        allIsCopy, allCopyPosition = [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            valid_attn = srcBatch.data[:, b].ne(
                s2s.Constants.PAD).nonzero().squeeze(1)
            hyps, isCopy, copyPosition, attn = zip(
                *[beam[b].getHyp(k) for k in ks[:n_best]])
            attn = [a.index_select(1, valid_attn) for a in attn]
            allHyp += [hyps]
            allAttn += [attn]
            allIsCopy += [isCopy]
            allCopyPosition += [copyPosition]

        return allHyp, allScores, allIsCopy, allCopyPosition, allAttn, None
Esempio n. 3
0
    def translateBatch(self, srcBatch, srcInsBatch, srcDelBatch, tgtBatch):
        batchSize = srcBatch[0].size(1)
        beamSize = self.opt.beam_size

        #  (1) run the encoder on the src
        encStates, context = self.model.encoder(srcBatch)
        srcBatch = srcBatch[0]  # drop the lengths needed for encoder
        # enc_ins_hidden = self.model.editEncoder(srcInsBatch).data.repeat(beamSize, 1)
        # enc_del_hidden = self.model.editEncoder(srcDelBatch).data.repeat(beamSize, 1)

        enc_ins_hidden = srcInsBatch[0].data.repeat(1, beamSize)
        enc_del_hidden = srcDelBatch[0].data.repeat(1, beamSize)

        decStates = self.model.decIniter(encStates[1])  # batch, dec_hidden

        #  (3) run the decoder to generate sentences, using beam search

        # Expand tensors for each beam.
        context = context.data.repeat(1, beamSize, 1)
        decStates = decStates.unsqueeze(0).data.repeat(1, beamSize, 1)
        att_vec = self.model.make_init_att(context)
        padMask = srcBatch.data.eq(s2s.Constants.PAD).transpose(
            0, 1).unsqueeze(0).repeat(beamSize, 1, 1).float()
        insMask = srcInsBatch[0].data.repeat(1, beamSize).eq(
            s2s.Constants.PAD).transpose(0, 1).float()
        delMask = srcDelBatch[0].data.repeat(1, beamSize).eq(
            s2s.Constants.PAD).transpose(0, 1).float()

        beam = [s2s.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]
        batchIdx = list(range(batchSize))
        remainingSents = batchSize

        for i in range(self.opt.max_sent_length):
            # Prepare decoder input.
            input = torch.stack([
                b.getCurrentState() for b in beam if not b.done
            ]).transpose(0, 1).contiguous().view(1, -1)
            #print(enc_ins_hidden.shape,input.shape,insMask.shape,padMask.shape)
            g_outputs, decStates, attn, att_vec = self.model.decoder(
                input, decStates, enc_ins_hidden, enc_del_hidden, context,
                padMask.view(-1, padMask.size(2)), att_vec, insMask, delMask)

            # g_outputs: 1 x (beam*batch) x numWords
            g_outputs = g_outputs.squeeze(0)
            g_out_prob = self.model.generator.forward(g_outputs)

            # batch x beam x numWords
            wordLk = g_out_prob.view(beamSize, remainingSents,
                                     -1).transpose(0, 1).contiguous()
            attn = attn.view(beamSize, remainingSents,
                             -1).transpose(0, 1).contiguous()

            active = []
            father_idx = []
            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
                    active += [b]
                    father_idx.append(
                        beam[b].prevKs[-1])  # this is very annoying

            if not active:
                break

            # to get the real father index
            real_father_idx = []
            for kk, idx in enumerate(father_idx):
                real_father_idx.append(idx * len(father_idx) + kk)

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActiveIns(t, rnnSize):
                # select only the remaining active sentences
                view = t.data.view(rnnSize, -1, remainingSents)
                newSize = list(t.size())
                newSize[-1] = newSize[-1] * len(activeIdx) // remainingSents
                return view.index_select(2, activeIdx).view(*newSize)

            def updateActive(t, rnnSize):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, rnnSize)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return view.index_select(1, activeIdx).view(*newSize)

            # def updateActiveIns(t, rnnSize):
            #     # select only the remaining active sentences
            #     view = t.data.view(-1, remainingSents, rnnSize)
            #     newSize = list(t.size())
            #     newSize[-1] = newSize[-1] * len(activeIdx) // remainingSents
            #     return view.index_select(1, activeIdx).view(*newSize)

            decStates = updateActive(decStates, self.dec_rnn_size)
            context = updateActive(context, self.enc_rnn_size)
            att_vec = updateActive(att_vec, self.enc_rnn_size)
            enc_ins_hidden = updateActiveIns(enc_ins_hidden,
                                             enc_ins_hidden.size()[0])
            enc_del_hidden = updateActiveIns(enc_del_hidden,
                                             enc_del_hidden.size()[0])
            padMask = padMask.index_select(1, activeIdx)
            insMask = enc_ins_hidden.eq(s2s.Constants.PAD).transpose(
                0, 1).float()
            delMask = enc_del_hidden.eq(s2s.Constants.PAD).transpose(
                0, 1).float()
            #print(insMask.shape)

            # set correct state for beam search
            previous_index = torch.stack(real_father_idx).transpose(
                0, 1).contiguous()
            decStates = decStates.view(-1, decStates.size(2)).index_select(
                0, previous_index.view(-1)).view(*decStates.size())
            att_vec = att_vec.view(-1, att_vec.size(1)).index_select(
                0, previous_index.view(-1)).view(*att_vec.size())

            remainingSents = len(active)

        # (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            valid_attn = srcBatch.data[:, b].ne(
                s2s.Constants.PAD).nonzero().squeeze(1)
            hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
            attn = [a.index_select(1, valid_attn) for a in attn]
            allHyp += [hyps]
            allAttn += [attn]

        return allHyp, allScores, allAttn, None