def packed_sentence_tensor(self, size=50): """Pack sentence tensors. """ sents = [ Variable(s.tensor()).type(ftype) for a in self.abstracts for s in a.sentences ] return pad_and_pack(sents, size)
def train_batch(batch, s_encoder, r_encoder, classifier): """Train the batch. """ x, reorder = batch.packed_sentence_tensor() # Encode sentences. sents = s_encoder(x, reorder) # Generate x / y pairs. examples = [] for ab in batch.unpack_sentences(sents): for i in range(len(ab) - 1): right = ab[i:] zeros = Variable(torch.zeros(ab.data.shape[1])).type(ftype) # Previous 2 sentences. minus1 = ab[i - 1] if i > 0 else zeros minus2 = ab[i - 2] if i > 1 else zeros # Shuffle right. perm = torch.randperm(len(right)).type(itype) shuffled_right = right[perm] # Raw position index, 0 <-> 1 ratio. index = Variable(torch.Tensor([i])).type(ftype) ratio = Variable(torch.Tensor([i / (len(ab) - 1)])).type(ftype) context = torch.cat([minus1, minus2, index, ratio]) first = right[0] other = random.choice(right[1:]) # Candidate + [n-1, n-2, 0-1] first = torch.cat([first, context]) other = torch.cat([other, context]) # First / not-first. examples.append((first, shuffled_right, 0)) examples.append((other, shuffled_right, 1)) sents, rights, ys = zip(*examples) # Encode rights. rights, reorder = pad_and_pack(rights, 30) rights = r_encoder(rights, reorder) # <sent, right> x = zip(sents, rights) x = list(map(torch.cat, x)) x = torch.stack(x) y = Variable(torch.LongTensor(ys)).type(itype) return classifier(x), y
def order_beam_search(ab, r_encoder, classifier, beam_size=100): """Beam search. """ beam = [((), 0)] for i in range(len(ab)): new_beam, x = [], [] for order, score in beam: right_idx = [j for j in range(len(ab)) if j not in order] # Right context. right = ab[torch.LongTensor(right_idx).type(itype)] zeros = Variable(torch.zeros(ab.data.shape[1])).type(ftype) # Previous 2 sentences. minus1 = ab[order[-1]] if i > 0 else zeros minus2 = ab[order[-2]] if i > 1 else zeros # Raw position index, 0 <-> 1 ratio. index = Variable(torch.Tensor([i])).type(ftype) ratio = Variable(torch.Tensor([i / (len(ab) - 1)])).type(ftype) # Encoded right context. right_enc, reorder = pad_and_pack([right], 30) right_enc = r_encoder(right_enc, reorder) context = torch.cat([minus1, minus2, index, ratio, right_enc[0]]) for r in right_idx: new_beam.append(((*order, r), score)) x.append(torch.cat([ab[r], context])) x = torch.stack(x) y = classifier(x) # Update scores. new_beam = [(path, score + new_score.data[0]) for (path, score), new_score in zip(new_beam, y)] # Sort by score. new_beam = sorted(new_beam, key=lambda x: x[1], reverse=True) # Keep N highest scoring paths. beam = new_beam[:beam_size] return beam[0][0]
def forward(self, x, pad_size=30): """Encode word embeddings as single sentence vector. Args: x (list of Variable): Encoded sentences for each graf. """ # Pad, pack, encode. x, reorder = pad_and_pack(x, pad_size) _, (hn, _) = self.lstm(x) # Cat forward + backward hidden layers. out = hn.transpose(0, 1).contiguous().view(hn.data.shape[1], -1) return out[reorder]
def forward(self, x, pad_size=30): """Encode sentences as a single paragraph vector, predict KT. """ # Pad, pack, encode. x, reorder = pad_and_pack(x, pad_size) _, (hn, _) = self.lstm(x) # Cat forward + backward hidden layers. y = hn.transpose(0, 1).contiguous().view(hn.data.shape[1], -1) y = y[reorder] y = F.relu(self.lin1(y)) y = F.relu(self.lin2(y)) y = F.relu(self.lin3(y)) y = F.relu(self.lin4(y)) y = F.relu(self.lin5(y)) y = self.out(y) return y.squeeze()
def order_greedy(ab, r_encoder, classifier): """Order greedy. """ order = [] while len(order) < len(ab): i = len(order) right_idx = [j for j in range(len(ab)) if j not in order] # Right context. right = ab[torch.LongTensor(right_idx).type(itype)] zeros = Variable(torch.zeros(ab.data.shape[1])).type(ftype) # Previous 2 sentences. minus1 = ab[i - 1] if i > 0 else zeros minus2 = ab[i - 2] if i > 1 else zeros # Raw position index, 0 <-> 1 ratio. index = Variable(torch.Tensor([i])).type(ftype) ratio = Variable(torch.Tensor([i / (len(ab) - 1)])).type(ftype) # Encoded right context. right_enc, reorder = pad_and_pack([right], 30) right_enc = r_encoder(right_enc, reorder) context = torch.cat([minus1, minus2, index, ratio, right_enc[0]]) # Candidate sentences. x = torch.stack([torch.cat([sent, context]) for sent in right]) preds = classifier(x).view(len(x), 2) preds = np.array(preds.data.tolist()) pred = right_idx.pop(np.argmax(preds[:, 0])) order.append(pred) return order