Пример #1
0
    def forward(self, batch, compute_bounds=True, cert_eps=1.0):
        """
    Forward pass of BOWModel.
    Args:
      batch: A batch dict from an EntailmentDataset with the following keys:
        - prem: tensor of word vector indices for premise (B, p, 1)
        - hypo: tensor of word vector indices for hypothesis (B, h, 1)
        - prem_mask: binary mask over premise words (1 for real, 0 for pad), size (B, p)
        - hypo_mask: binary mask over hypothesis words (1 for real, 0 for pad), size (B, h)
        - prem_lengths: lengths of premises, size (B,)
        - hypo_lengths: lengths of hypotheses, size (B,)
      compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values
      cert_eps: float, scaling factor for the interval bounds.
    """
        def encode(sequence, mask):
            vecs = self.embs(sequence)
            vecs = self.rotation(vecs)
            if isinstance(vecs, ibp.DiscreteChoiceTensor):
                vecs = vecs.to_interval_bounded(eps=cert_eps)
            z1 = ibp.activation(F.relu, vecs)
            z1_masked = z1 * mask.unsqueeze(-1)
            z1_pooled = ibp.sum(z1_masked, -2)
            return z1_pooled

        if not compute_bounds:
            batch['prem']['x'] = batch['prem']['x'].val
            batch['hypo']['x'] = batch['hypo']['x'].val
        prem_encoded = encode(batch['prem']['x'], batch['prem']['mask'])
        hypo_encoded = encode(batch['hypo']['x'], batch['hypo']['mask'])
        input_encoded = ibp.cat([prem_encoded, hypo_encoded], -1)
        logits = self.layers(input_encoded)
        return logits
Пример #2
0
    def query(self, dataset, device, batch_size=1, return_bounds=False):
        """Query the model on a Dataset.

    Args:
      dataset: a Dataset.
      device: torch device.
      neighbors: if provided, pass this to Dataset().
      batch_size: batch size (default=1).

    Returns: Tensor of logits & gold labels
    """
        data = dataset.get_loader(batch_size)
        output = []
        gold = []
        with torch.no_grad():
            for batch in data:
                batch = data_util.dict_batch_to_device(batch, device)
                output.append(self.forward(batch,
                                           compute_bounds=return_bounds))
                gold.append(batch['y'])
        return ibp.cat(output, dim=0), ibp.cat(gold, dim=0)
Пример #3
0
 def attend_on(self, source, target, attention):
     """
 Args:
   - source: (bXsXe)
   - target: (bXtXe)
   - attention: (bXtXs)
 """
     attention_logsoftmax = ibp.log_softmax(attention, 1)
     attention_normalized = ibp.activation(torch.exp, attention_logsoftmax)
     attended_target = ibp.matmul_nneg(attention_normalized,
                                       source)  # (bXtXe)
     return ibp.cat([target, attended_target], dim=-1)
Пример #4
0
    def forward(self, batch, compute_bounds=True, cert_eps=1.0):
        """
    Forward pass of DecompAttentionModel.
    Args:
      batch: A batch dict from an EntailmentDataset with the following keys:
        - prem: tensor of word vector indices for premise (B, p, 1)
        - hypo: tensor of word vector indices for hypothesis (B, h, 1)
        - prem_mask: binary mask over premise words (1 for real, 0 for pad), size (B, p)
        - hypo_mask: binary mask over hypothesis words (1 for real, 0 for pad), size (B, h)
        - prem_lengths: lengths of premises, size (B,)
        - hypo_lengths: lengths of hypotheses, size (B,)
      compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values
      cert_eps: float, scaling factor for the interval bounds.
    """
        def encode(sequence, mask):
            vecs = self.embs(sequence)
            if isinstance(vecs, ibp.DiscreteChoiceTensor):
                null = torch.zeros_like(vecs.val[0])
                null_choice = torch.zeros_like(vecs.choice_mat[0])
                null[0] = self.null
                null_choice[0, 0] = self.null
                vecs.val = vecs.val + null
                vecs.choice_mat = vecs.choice_mat + null_choice
            else:
                null = torch.zeros_like(vecs[0])
                null[0] = self.null
                vecs = vecs + null
            vecs = self.rotation(vecs)
            if isinstance(vecs, ibp.DiscreteChoiceTensor):
                vecs = vecs.to_interval_bounded(eps=cert_eps)
            return ibp.activation(F.relu, vecs) * mask.unsqueeze(-1)

        if not compute_bounds:
            batch['prem']['x'] = batch['prem']['x'].val
            batch['hypo']['x'] = batch['hypo']['x'].val
        prem_encoded = encode(batch['prem']['x'],
                              batch['prem']['mask'])  # (bXpXe)
        hypo_encoded = encode(batch['hypo']['x'],
                              batch['hypo']['mask'])  # (bXhXe)
        prem_weights = self.feedforward(
            prem_encoded) * batch['prem']['mask'].unsqueeze(-1)  # (bXpX1)
        hypo_weights = self.feedforward(
            hypo_encoded) * batch['hypo']['mask'].unsqueeze(-1)  # (bXhX1)
        attention = ibp.bmm(prem_weights, hypo_weights.permute(
            0, 2, 1))  # (bXpX1) X (bX1Xh) => (bXpXh)
        attention_mask = batch['prem']['mask'].unsqueeze(
            -1) * batch['hypo']['mask'].unsqueeze(1)
        attention_masked = ibp.add(attention, (1 - attention_mask) * -1e20)
        attended_prem = self.attend_on(hypo_encoded, prem_encoded,
                                       attention_masked)  # (bXpX2e)
        attended_hypo = self.attend_on(prem_encoded, hypo_encoded,
                                       attention_masked.permute(0, 2,
                                                                1))  # (bXhX2e)
        compared_prem = self.compare_ff(
            attended_prem) * batch['prem']['mask'].unsqueeze(-1)  # (bXpXhid)
        compared_hypo = self.compare_ff(
            attended_hypo) * batch['hypo']['mask'].unsqueeze(-1)  # (bXhXhid)
        prem_aggregate = ibp.pool(torch.sum, compared_prem, dim=1)  # (bXhid)
        hypo_aggregate = ibp.pool(torch.sum, compared_hypo, dim=1)  # (bXhid)
        aggregate = ibp.cat([prem_aggregate, hypo_aggregate],
                            dim=-1)  # (bX2hid)
        return self.output_layer(aggregate)  # (b)
Пример #5
0
 def cal_h_cat_c(h_left, h_right, c_left, c_right):  # (n, d)
     h_cat = ibp.cat([h_left, h_right], dim=-1)
     f_cat = ibp.activation(th.sigmoid, self.U_f(h_cat))  # (n, 2 * d)
     c = f_cat[:, :self.h_size] * c_left + f_cat[:, self.h_size:] * c_right
     return self.U_iou(h_cat), c
Пример #6
0
    def reduce_func_dp(self, nodes):
        h = get(nodes.mailbox, "h")  # (n, 2, Del, Ins, Sub, d)
        c = get(nodes.mailbox, "c")
        unk_mask = nodes.mailbox["unk_mask"]  # (n, 2)
        # if both children can be deleted, then the parent can be deleted
        new_unk_mask = th.where((unk_mask[:, 0] > 0) & (unk_mask[:, 1] > 0), th.sum(unk_mask, 1),
                                th.zeros_like(unk_mask[:, 0]))

        def cal_h_cat_c(h_left, h_right, c_left, c_right):  # (n, d)
            h_cat = ibp.cat([h_left, h_right], dim=-1)
            f_cat = ibp.activation(th.sigmoid, self.U_f(h_cat))  # (n, 2 * d)
            c = f_cat[:, :self.h_size] * c_left + f_cat[:, self.h_size:] * c_right
            return self.U_iou(h_cat), c

        # (n, Del, Ins, Sub, 3 * d)
        new_iou = []
        # (n, Del, Ins, Sub, d)
        new_c = []
        for deltas in itertools.product(*self.deltas_p1_ranges):
            deltas_ranges = [range(x + 1) for x in deltas]
            piece = None
            for deltas_left in itertools.product(*deltas_ranges):
                deltas_right = [y - x for (x, y) in zip(deltas_left, deltas)]
                tmp = cal_h_cat_c(h[:, 0, deltas_left[0], deltas_left[1], deltas_left[2], :],
                                  h[:, 1, deltas_right[0], deltas_right[1], deltas_right[2], :],
                                  c[:, 0, deltas_left[0], deltas_left[1], deltas_left[2], :],
                                  c[:, 1, deltas_right[0], deltas_right[1], deltas_right[2], :])
                piece = ibp.merge(piece, tmp)
            new_iou.append(piece[0].unsqueeze(1))
            new_c.append(piece[1].unsqueeze(1))

        auxh = []
        auxc = []
        for delta0 in range(self.deltas_p1[0]):
            inds = th.arange(unk_mask.shape[0])
            clip_unk_mask0 = th.clamp(unk_mask[:, 0], 1, delta0).long()
            clip_unk_mask1 = th.clamp(unk_mask[:, 1], 1, delta0).long()
            aux_list = []
            for x in [h, c]:
                aux_l = ibp.where(((unk_mask[:, 0] > 0) & (unk_mask[:, 0] <= delta0)).view(-1, 1, 1, 1),
                                  # n, Ins, Sub, d
                                  x[inds, 1, delta0 - clip_unk_mask0],
                                  ibp.IntervalBoundedTensor.bottom(
                                      (h.shape[0], self.deltas_p1[1], self.deltas_p1[2], self.h_size),
                                      device=self.device))
                aux_r = ibp.where(((unk_mask[:, 1] > 0) & (unk_mask[:, 1] <= delta0)).view(-1, 1, 1, 1),
                                  # n, Ins, Sub, d
                                  x[inds, 0, delta0 - clip_unk_mask1],
                                  ibp.IntervalBoundedTensor.bottom(
                                      (h.shape[0], self.deltas_p1[1], self.deltas_p1[2], self.h_size),
                                      device=self.device))
                aux_list.append(aux_l.merge(aux_r))

            auxh.append(aux_list[0].unsqueeze(dim=1))
            auxc.append(aux_list[1].unsqueeze(dim=1))

        auxh = ibp.cat(auxh, dim=1)
        auxc = ibp.cat(auxc, dim=1)
        new_iou = ibp.cat(new_iou, dim=1).view(-1, self.deltas_p1[0], self.deltas_p1[1], self.deltas_p1[2],
                                               self.h_size * 3)
        new_c = ibp.cat(new_c, dim=1).view(-1, self.deltas_p1[0], self.deltas_p1[1], self.deltas_p1[2], self.h_size)
        ret = {}
        add(ret, "iou", new_iou)
        add(ret, "c", new_c)
        add(ret, "auxh", auxh)
        add(ret, "auxc", auxc)
        ret["unk_mask"] = new_unk_mask
        return ret