def attention_pool(x, mask, layer): """Attention pooling Args: x: batch of inputs, shape (B, n, h) mask: binary mask, shape (B, n) layer: Linear layer mapping h -> 1 Returns: pooled version of x, shape (B, h) """ attn_raw = layer(x).squeeze(2) # B, n, 1 -> B, n attn_raw = ibp.add(attn_raw, (1 - mask) * -1e20) attn_logsoftmax = ibp.log_softmax(attn_raw, 1) attn_probs = ibp.activation(torch.exp, attn_logsoftmax) # B, n return ibp.bmm(attn_probs.unsqueeze(1), x).squeeze(1) # B, 1, n x B, n, h -> B, h
def forward(self, batch, compute_bounds=True, cert_eps=1.0): """ Forward pass of DecompAttentionModel. Args: batch: A batch dict from an EntailmentDataset with the following keys: - prem: tensor of word vector indices for premise (B, p, 1) - hypo: tensor of word vector indices for hypothesis (B, h, 1) - prem_mask: binary mask over premise words (1 for real, 0 for pad), size (B, p) - hypo_mask: binary mask over hypothesis words (1 for real, 0 for pad), size (B, h) - prem_lengths: lengths of premises, size (B,) - hypo_lengths: lengths of hypotheses, size (B,) compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values cert_eps: float, scaling factor for the interval bounds. """ def encode(sequence, mask): vecs = self.embs(sequence) if isinstance(vecs, ibp.DiscreteChoiceTensor): null = torch.zeros_like(vecs.val[0]) null_choice = torch.zeros_like(vecs.choice_mat[0]) null[0] = self.null null_choice[0, 0] = self.null vecs.val = vecs.val + null vecs.choice_mat = vecs.choice_mat + null_choice else: null = torch.zeros_like(vecs[0]) null[0] = self.null vecs = vecs + null vecs = self.rotation(vecs) if isinstance(vecs, ibp.DiscreteChoiceTensor): vecs = vecs.to_interval_bounded(eps=cert_eps) return ibp.activation(F.relu, vecs) * mask.unsqueeze(-1) if not compute_bounds: batch['prem']['x'] = batch['prem']['x'].val batch['hypo']['x'] = batch['hypo']['x'].val prem_encoded = encode(batch['prem']['x'], batch['prem']['mask']) # (bXpXe) hypo_encoded = encode(batch['hypo']['x'], batch['hypo']['mask']) # (bXhXe) prem_weights = self.feedforward( prem_encoded) * batch['prem']['mask'].unsqueeze(-1) # (bXpX1) hypo_weights = self.feedforward( hypo_encoded) * batch['hypo']['mask'].unsqueeze(-1) # (bXhX1) attention = ibp.bmm(prem_weights, hypo_weights.permute( 0, 2, 1)) # (bXpX1) X (bX1Xh) => (bXpXh) attention_mask = batch['prem']['mask'].unsqueeze( -1) * batch['hypo']['mask'].unsqueeze(1) attention_masked = ibp.add(attention, (1 - attention_mask) * -1e20) attended_prem = self.attend_on(hypo_encoded, prem_encoded, attention_masked) # (bXpX2e) attended_hypo = self.attend_on(prem_encoded, hypo_encoded, attention_masked.permute(0, 2, 1)) # (bXhX2e) compared_prem = self.compare_ff( attended_prem) * batch['prem']['mask'].unsqueeze(-1) # (bXpXhid) compared_hypo = self.compare_ff( attended_hypo) * batch['hypo']['mask'].unsqueeze(-1) # (bXhXhid) prem_aggregate = ibp.pool(torch.sum, compared_prem, dim=1) # (bXhid) hypo_aggregate = ibp.pool(torch.sum, compared_hypo, dim=1) # (bXhid) aggregate = ibp.cat([prem_aggregate, hypo_aggregate], dim=-1) # (bX2hid) return self.output_layer(aggregate) # (b)