def forward(self, batch, compute_bounds=True, cert_eps=1.0): """ Args: batch: A batch dict from a TextClassificationDataset with the following keys: - x: tensor of word vector indices, size (B, n, 1) - mask: binary mask over words (1 for real, 0 for pad), size (B, n) - lengths: lengths of sequences, size (B,) compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values cert_eps: Scaling factor for interval bounds of the input """ if compute_bounds: x = batch['x'] else: x = batch['x'].val mask = batch['mask'] lengths = batch['lengths'] x_vecs = self.embs(x) # B, n, d if self.early_ibp and isinstance(x_vecs, ibp.DiscreteChoiceTensor): x_vecs = x_vecs.to_interval_bounded(eps=cert_eps) if not self.no_wordvec_layer: x_vecs = self.linear_input(x_vecs) # B, n, h if isinstance(x_vecs, ibp.DiscreteChoiceTensor): x_vecs = x_vecs.to_interval_bounded(eps=cert_eps) if self.no_wordvec_layer or not self.relu_wordvec: z = x_vecs else: z = ibp.activation(F.relu, x_vecs) # B, n, h z_masked = z * mask.unsqueeze(-1) # B, n, h z_cnn_in = z_masked.permute(0, 2, 1) # B, h, n c1 = ibp.activation(F.relu, self.conv1(z_cnn_in)) # B, h, n c1_masked = c1 * mask.unsqueeze(1) # B, h, n if self.pool == 'mean': fc_in = ibp.sum(c1_masked / lengths.to(dtype=torch.float).view(-1, 1, 1), 2) # B, h elif self.pool == 'attn': fc_in = attention_pool(c1_masked.permute(0, 2, 1), mask, self.attn_pool) # B, h else: # zero-masking works b/c ReLU guarantees that everything is >= 0 fc_in = ibp.pool(torch.max, c1_masked, 2) # B, h fc_in = self.dropout(fc_in) fc_hidden = ibp.activation(F.relu, self.fc_hidden(fc_in)) # B, h fc_hidden = self.dropout(fc_hidden) output = self.fc_output(fc_hidden) # B, 1 return output
def forward(self, batch, compute_bounds=True, cert_eps=1.0): """Forward pass of BOWModel. Args: batch: A batch dict from a TextClassificationDataset with the following keys: - x: tensor of word vector indices, size (B, n, 1) - mask: binary mask over words (1 for real, 0 for pad), size (B, n) - lengths: lengths of sequences, size (B,) compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values cert_eps: Scaling factor for interval bounds of the input """ if compute_bounds: x = batch['x'] else: x = batch['x'].val mask = batch['mask'] lengths = batch['lengths'] x_vecs = self.embs(x) # B, n, d if not self.no_wordvec_layer: x_vecs = self.linear_input(x_vecs) # B, n, h if isinstance(x_vecs, ibp.DiscreteChoiceTensor): x_vecs = x_vecs.to_interval_bounded(eps=cert_eps) if self.no_wordvec_layer: z1 = x_vecs else: z1 = ibp.activation(F.relu, x_vecs) z1_masked = z1 * mask.unsqueeze(-1) # B, n, h if self.pool == 'mean': z1_pooled = ibp.sum(z1_masked / lengths.to(dtype=torch.float).view(-1, 1, 1), 1) # B, h elif self.pool == 'attn': z1_pooled = attention_pool(z1_masked, mask, self.attn_pool) else: # max # zero-masking works b/c ReLU guarantees that everything is >= 0 z1_pooled = ibp.pool(torch.max, z1_masked, 1) # B, h z1_pooled = self.dropout(z1_pooled) z2 = ibp.activation(F.relu, self.linear_hidden(z1_pooled)) # B, h z2 = self.dropout(z2) output = self.linear_output(z2) # B, 1 return output
def forward(self, batch, compute_bounds=True, cert_eps=1.0): """ Forward pass of DecompAttentionModel. Args: batch: A batch dict from an EntailmentDataset with the following keys: - prem: tensor of word vector indices for premise (B, p, 1) - hypo: tensor of word vector indices for hypothesis (B, h, 1) - prem_mask: binary mask over premise words (1 for real, 0 for pad), size (B, p) - hypo_mask: binary mask over hypothesis words (1 for real, 0 for pad), size (B, h) - prem_lengths: lengths of premises, size (B,) - hypo_lengths: lengths of hypotheses, size (B,) compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values cert_eps: float, scaling factor for the interval bounds. """ def encode(sequence, mask): vecs = self.embs(sequence) if isinstance(vecs, ibp.DiscreteChoiceTensor): null = torch.zeros_like(vecs.val[0]) null_choice = torch.zeros_like(vecs.choice_mat[0]) null[0] = self.null null_choice[0, 0] = self.null vecs.val = vecs.val + null vecs.choice_mat = vecs.choice_mat + null_choice else: null = torch.zeros_like(vecs[0]) null[0] = self.null vecs = vecs + null vecs = self.rotation(vecs) if isinstance(vecs, ibp.DiscreteChoiceTensor): vecs = vecs.to_interval_bounded(eps=cert_eps) return ibp.activation(F.relu, vecs) * mask.unsqueeze(-1) if not compute_bounds: batch['prem']['x'] = batch['prem']['x'].val batch['hypo']['x'] = batch['hypo']['x'].val prem_encoded = encode(batch['prem']['x'], batch['prem']['mask']) # (bXpXe) hypo_encoded = encode(batch['hypo']['x'], batch['hypo']['mask']) # (bXhXe) prem_weights = self.feedforward( prem_encoded) * batch['prem']['mask'].unsqueeze(-1) # (bXpX1) hypo_weights = self.feedforward( hypo_encoded) * batch['hypo']['mask'].unsqueeze(-1) # (bXhX1) attention = ibp.bmm(prem_weights, hypo_weights.permute( 0, 2, 1)) # (bXpX1) X (bX1Xh) => (bXpXh) attention_mask = batch['prem']['mask'].unsqueeze( -1) * batch['hypo']['mask'].unsqueeze(1) attention_masked = ibp.add(attention, (1 - attention_mask) * -1e20) attended_prem = self.attend_on(hypo_encoded, prem_encoded, attention_masked) # (bXpX2e) attended_hypo = self.attend_on(prem_encoded, hypo_encoded, attention_masked.permute(0, 2, 1)) # (bXhX2e) compared_prem = self.compare_ff( attended_prem) * batch['prem']['mask'].unsqueeze(-1) # (bXpXhid) compared_hypo = self.compare_ff( attended_hypo) * batch['hypo']['mask'].unsqueeze(-1) # (bXhXhid) prem_aggregate = ibp.pool(torch.sum, compared_prem, dim=1) # (bXhid) hypo_aggregate = ibp.pool(torch.sum, compared_hypo, dim=1) # (bXhid) aggregate = ibp.cat([prem_aggregate, hypo_aggregate], dim=-1) # (bX2hid) return self.output_layer(aggregate) # (b)