Exemplo n.º 1
0
def attention_pool(x, mask, layer):
    """Attention pooling

  Args:
    x: batch of inputs, shape (B, n, h)
    mask: binary mask, shape (B, n)
    layer: Linear layer mapping h -> 1
  Returns:
    pooled version of x, shape (B, h)
  """
    attn_raw = layer(x).squeeze(2)  # B, n, 1 -> B, n
    attn_raw = ibp.add(attn_raw, (1 - mask) * -1e20)
    attn_logsoftmax = ibp.log_softmax(attn_raw, 1)
    attn_probs = ibp.activation(torch.exp, attn_logsoftmax)  # B, n
    return ibp.bmm(attn_probs.unsqueeze(1),
                   x).squeeze(1)  # B, 1, n x B, n, h -> B, h
Exemplo n.º 2
0
    def forward(self, batch, compute_bounds=True, cert_eps=1.0):
        """
    Forward pass of DecompAttentionModel.
    Args:
      batch: A batch dict from an EntailmentDataset with the following keys:
        - prem: tensor of word vector indices for premise (B, p, 1)
        - hypo: tensor of word vector indices for hypothesis (B, h, 1)
        - prem_mask: binary mask over premise words (1 for real, 0 for pad), size (B, p)
        - hypo_mask: binary mask over hypothesis words (1 for real, 0 for pad), size (B, h)
        - prem_lengths: lengths of premises, size (B,)
        - hypo_lengths: lengths of hypotheses, size (B,)
      compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values
      cert_eps: float, scaling factor for the interval bounds.
    """
        def encode(sequence, mask):
            vecs = self.embs(sequence)
            if isinstance(vecs, ibp.DiscreteChoiceTensor):
                null = torch.zeros_like(vecs.val[0])
                null_choice = torch.zeros_like(vecs.choice_mat[0])
                null[0] = self.null
                null_choice[0, 0] = self.null
                vecs.val = vecs.val + null
                vecs.choice_mat = vecs.choice_mat + null_choice
            else:
                null = torch.zeros_like(vecs[0])
                null[0] = self.null
                vecs = vecs + null
            vecs = self.rotation(vecs)
            if isinstance(vecs, ibp.DiscreteChoiceTensor):
                vecs = vecs.to_interval_bounded(eps=cert_eps)
            return ibp.activation(F.relu, vecs) * mask.unsqueeze(-1)

        if not compute_bounds:
            batch['prem']['x'] = batch['prem']['x'].val
            batch['hypo']['x'] = batch['hypo']['x'].val
        prem_encoded = encode(batch['prem']['x'],
                              batch['prem']['mask'])  # (bXpXe)
        hypo_encoded = encode(batch['hypo']['x'],
                              batch['hypo']['mask'])  # (bXhXe)
        prem_weights = self.feedforward(
            prem_encoded) * batch['prem']['mask'].unsqueeze(-1)  # (bXpX1)
        hypo_weights = self.feedforward(
            hypo_encoded) * batch['hypo']['mask'].unsqueeze(-1)  # (bXhX1)
        attention = ibp.bmm(prem_weights, hypo_weights.permute(
            0, 2, 1))  # (bXpX1) X (bX1Xh) => (bXpXh)
        attention_mask = batch['prem']['mask'].unsqueeze(
            -1) * batch['hypo']['mask'].unsqueeze(1)
        attention_masked = ibp.add(attention, (1 - attention_mask) * -1e20)
        attended_prem = self.attend_on(hypo_encoded, prem_encoded,
                                       attention_masked)  # (bXpX2e)
        attended_hypo = self.attend_on(prem_encoded, hypo_encoded,
                                       attention_masked.permute(0, 2,
                                                                1))  # (bXhX2e)
        compared_prem = self.compare_ff(
            attended_prem) * batch['prem']['mask'].unsqueeze(-1)  # (bXpXhid)
        compared_hypo = self.compare_ff(
            attended_hypo) * batch['hypo']['mask'].unsqueeze(-1)  # (bXhXhid)
        prem_aggregate = ibp.pool(torch.sum, compared_prem, dim=1)  # (bXhid)
        hypo_aggregate = ibp.pool(torch.sum, compared_hypo, dim=1)  # (bXhid)
        aggregate = ibp.cat([prem_aggregate, hypo_aggregate],
                            dim=-1)  # (bX2hid)
        return self.output_layer(aggregate)  # (b)