def test_bmm(): m1 = torch.tensor([[-1, 2], [3, 2], [-3, 1]], dtype=torch.float).view(1, 3, 2) m2 = torch.tensor([[4, 5], [-4, -5]], dtype=torch.float).view(1, 2, 2) z = ibp.bmm(ibp.IntervalBoundedTensor(m1, m1, m1), ibp.IntervalBoundedTensor(m2, m2, m2)) m1_bounded = ibp.IntervalBoundedTensor(m1, m1 - torch.tensor(0.1), m1 + torch.tensor(0.1)) m2_bounded = ibp.IntervalBoundedTensor(m2, m2 - torch.tensor(0.1), m2 + torch.tensor(0.1)) print('ibp.bmm, exact:', z.val, z.lb, z.ub) z2 = ibp.bmm(m1_bounded, m2_bounded) print('ibp.bmm, bound both:', z2.val, z2.lb, z2.ub) z3 = ibp.bmm(m1_bounded, m2) print('ibp.bmm, bound first:', z3.val, z3.lb, z3.ub) z4 = ibp.bmm(m1, m2_bounded) print('ibp.bmm, bound second:', z4.val, z4.lb, z4.ub)
def attention_pool(x, mask, layer): """Attention pooling Args: x: batch of inputs, shape (B, n, h) mask: binary mask, shape (B, n) layer: Linear layer mapping h -> 1 Returns: pooled version of x, shape (B, h) """ attn_raw = layer(x).squeeze(2) # B, n, 1 -> B, n attn_raw = ibp.add(attn_raw, (1 - mask) * -1e20) attn_logsoftmax = ibp.log_softmax(attn_raw, 1) attn_probs = ibp.activation(torch.exp, attn_logsoftmax) # B, n return ibp.bmm(attn_probs.unsqueeze(1), x).squeeze(1) # B, 1, n x B, n, h -> B, h
def forward(self, batch, compute_bounds=True, cert_eps=1.0): """ Forward pass of DecompAttentionModel. Args: batch: A batch dict from an EntailmentDataset with the following keys: - prem: tensor of word vector indices for premise (B, p, 1) - hypo: tensor of word vector indices for hypothesis (B, h, 1) - prem_mask: binary mask over premise words (1 for real, 0 for pad), size (B, p) - hypo_mask: binary mask over hypothesis words (1 for real, 0 for pad), size (B, h) - prem_lengths: lengths of premises, size (B,) - hypo_lengths: lengths of hypotheses, size (B,) compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values cert_eps: float, scaling factor for the interval bounds. """ def encode(sequence, mask): vecs = self.embs(sequence) if isinstance(vecs, ibp.DiscreteChoiceTensor): null = torch.zeros_like(vecs.val[0]) null_choice = torch.zeros_like(vecs.choice_mat[0]) null[0] = self.null null_choice[0, 0] = self.null vecs.val = vecs.val + null vecs.choice_mat = vecs.choice_mat + null_choice else: null = torch.zeros_like(vecs[0]) null[0] = self.null vecs = vecs + null vecs = self.rotation(vecs) if isinstance(vecs, ibp.DiscreteChoiceTensor): vecs = vecs.to_interval_bounded(eps=cert_eps) return ibp.activation(F.relu, vecs) * mask.unsqueeze(-1) if not compute_bounds: batch['prem']['x'] = batch['prem']['x'].val batch['hypo']['x'] = batch['hypo']['x'].val prem_encoded = encode(batch['prem']['x'], batch['prem']['mask']) # (bXpXe) hypo_encoded = encode(batch['hypo']['x'], batch['hypo']['mask']) # (bXhXe) prem_weights = self.feedforward( prem_encoded) * batch['prem']['mask'].unsqueeze(-1) # (bXpX1) hypo_weights = self.feedforward( hypo_encoded) * batch['hypo']['mask'].unsqueeze(-1) # (bXhX1) attention = ibp.bmm(prem_weights, hypo_weights.permute( 0, 2, 1)) # (bXpX1) X (bX1Xh) => (bXpXh) attention_mask = batch['prem']['mask'].unsqueeze( -1) * batch['hypo']['mask'].unsqueeze(1) attention_masked = ibp.add(attention, (1 - attention_mask) * -1e20) attended_prem = self.attend_on(hypo_encoded, prem_encoded, attention_masked) # (bXpX2e) attended_hypo = self.attend_on(prem_encoded, hypo_encoded, attention_masked.permute(0, 2, 1)) # (bXhX2e) compared_prem = self.compare_ff( attended_prem) * batch['prem']['mask'].unsqueeze(-1) # (bXpXhid) compared_hypo = self.compare_ff( attended_hypo) * batch['hypo']['mask'].unsqueeze(-1) # (bXhXhid) prem_aggregate = ibp.pool(torch.sum, compared_prem, dim=1) # (bXhid) hypo_aggregate = ibp.pool(torch.sum, compared_hypo, dim=1) # (bXhid) aggregate = ibp.cat([prem_aggregate, hypo_aggregate], dim=-1) # (bX2hid) return self.output_layer(aggregate) # (b)