def predictByGreedySearch(self, inputSeq, maxAnswerLength=32, showAttention=False, figsize=(12, 6)): inputSeq = filter_sent(inputSeq) inputSeq = [ w for w in jieba.lcut(inputSeq) if w in self.word2id.keys() ] X = seq2id(self.word2id, inputSeq) XLens = torch.tensor([len(X) + 1], dtype=torch.int, device=self.device) X = X + [eosToken] X = torch.tensor([X], dtype=torch.long, device=self.device) d = int(self.encoderRNN.bidirectional) + 1 hidden = torch.zeros( (d * self.encoderRNN.numLayers, 1, self.hiddenSize), dtype=torch.float32, device=self.device) encoderOutput, hidden = self.encoderRNN(X, XLens, hidden) hidden = hidden[-d * self.decoderRNN.numLayers::2].contiguous() attentionArrs = [] Y = [] decoderInput = torch.tensor([[sosToken]], dtype=torch.long, device=self.device) while decoderInput.item() != eosToken and len(Y) < maxAnswerLength: decoderOutput, hidden, decoderAttentionWeight = self.decoderRNN( decoderInput, hidden, encoderOutput) topv, topi = decoderOutput.topk(1) decoderInput = topi[:, :, 0] attentionArrs.append( decoderAttentionWeight.data.cpu().numpy().reshape(1, XLens)) Y.append(decoderInput.item()) outputSeq = id2seq(self.id2word, Y) if showAttention: attentionArrs = np.vstack(attentionArrs) fig = plt.figure(figsize=figsize) ax = fig.add_subplot('111') cax = ax.matshow(attentionArrs, cmap='bone') fig.colorbar(cax) ax.set_xticklabels(['', '<SOS>'] + inputSeq) ax.set_yticklabels([''] + outputSeq) ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.show() return ''.join(outputSeq[:-1])
def predictByBeamSearch(self, inputSeq, beamWidth=10, maxAnswerLength=32, alpha=0.7, isRandomChoose=False, allRandomChoose=False, improve=True, showInfo=False): outputSize = len(self.id2word) inputSeq = filter_sent(inputSeq) inputSeq = [ w for w in jieba.lcut(inputSeq) if w in self.word2id.keys() ] X = seq2id(self.word2id, inputSeq) XLens = torch.tensor([len(X) + 1], dtype=torch.int, device=self.device) X = X + [eosToken] X = torch.tensor([X], dtype=torch.long, device=self.device) d = int(self.encoderRNN.bidirectional) + 1 hidden = torch.zeros( (d * self.encoderRNN.numLayers, 1, self.hiddenSize), dtype=torch.float32, device=self.device) encoderOutput, hidden = self.encoderRNN(X, XLens, hidden) hidden = hidden[-d * self.decoderRNN.numLayers::2].contiguous() Y = np.ones([beamWidth, maxAnswerLength], dtype='int32') * eosToken # prob: beamWidth × 1 prob = np.zeros([beamWidth, 1], dtype='float32') decoderInput = torch.tensor([[sosToken]], dtype=torch.long, device=self.device) # decoderOutput: 1 × 1 × outputSize; hidden: numLayers × 1 × hiddenSize decoderOutput, hidden, decoderAttentionWeight = self.decoderRNN( decoderInput, hidden, encoderOutput) # topv: 1 × 1 × beamWidth; topi: 1 × 1 × beamWidth topv, topi = decoderOutput.topk(beamWidth) # decoderInput: beamWidth × 1 decoderInput = topi.view(beamWidth, 1) for i in range(beamWidth): Y[i, 0] = decoderInput[i].item() Y_ = Y.copy() prob += topv.view(beamWidth, 1).data.cpu().numpy() prob_ = prob.copy() # hidden: numLayers × beamWidth × hiddenSize hidden = hidden.expand(-1, beamWidth, -1).contiguous() localRestId = np.array([i for i in range(beamWidth)], dtype='int32') encoderOutput = encoderOutput.expand( beamWidth, -1, -1) # => beamWidth × 1 × hiddenSize for i in range(1, maxAnswerLength): # decoderOutput: beamWidth × 1 × outputSize; hidden: numLayers × beamWidth × hiddenSize; decoderAttentionWeight: beamWidth × 1 × XSeqLen decoderOutput, hidden, decoderAttentionWeight = self.decoderRNN( decoderInput, hidden, encoderOutput) # topv: beamWidth × 1; topi: beamWidth × 1 if improve: decoderOutput = decoderOutput.view(-1, 1) if allRandomChoose: topv, topi = self._random_pick_k_by_prob(decoderOutput, k=beamWidth) else: topv, topi = decoderOutput.topk(beamWidth, dim=0) else: topv, topi = (torch.tensor(prob[localRestId], dtype=torch.float32, device=self.device).unsqueeze(2) + decoderOutput).view(-1, 1).topk(beamWidth, dim=0) # decoderInput: beamWidth × 1 decoderInput = topi % outputSize idFrom = topi.cpu().view(-1).numpy() // outputSize Y[localRestId, :i + 1] = np.hstack( [Y[localRestId[idFrom], :i], decoderInput.cpu().numpy()]) prob[localRestId] = prob[ localRestId[idFrom]] + topv.data.cpu().numpy() hidden = hidden[:, idFrom, :] restId = (decoderInput != eosToken).cpu().view(-1) localRestId = localRestId[restId.numpy().astype('bool')] decoderInput = decoderInput[restId] hidden = hidden[:, restId, :] encoderOutput = encoderOutput[restId] beamWidth = len(localRestId) if beamWidth < 1: break lens = [ i.index(eosToken) if eosToken in i else maxAnswerLength for i in Y.tolist() ] ans = [''.join(id2seq(self.id2word, i[:l])) for i, l in zip(Y, lens)] prob = [prob[i, 0] / np.power(lens[i], alpha) for i in range(len(ans))] if isRandomChoose or allRandomChoose: prob = [np.exp(p) for p in prob] prob = [p / sum(prob) for p in prob] if showInfo: for i in range(len(ans)): print((ans[i], prob[i])) return random_pick(ans, prob) else: ansAndProb = list(zip(ans, prob)) ansAndProb.sort(key=lambda x: x[1], reverse=True) if showInfo: for i in ansAndProb: print(i) return ansAndProb[0][0]