def forward(self, batch, compute_bounds=True, cert_eps=1.0): """ Forward pass of BOWModel. Args: batch: A batch dict from an EntailmentDataset with the following keys: - prem: tensor of word vector indices for premise (B, p, 1) - hypo: tensor of word vector indices for hypothesis (B, h, 1) - prem_mask: binary mask over premise words (1 for real, 0 for pad), size (B, p) - hypo_mask: binary mask over hypothesis words (1 for real, 0 for pad), size (B, h) - prem_lengths: lengths of premises, size (B,) - hypo_lengths: lengths of hypotheses, size (B,) compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values cert_eps: float, scaling factor for the interval bounds. """ def encode(sequence, mask): vecs = self.embs(sequence) vecs = self.rotation(vecs) if isinstance(vecs, ibp.DiscreteChoiceTensor): vecs = vecs.to_interval_bounded(eps=cert_eps) z1 = ibp.activation(F.relu, vecs) z1_masked = z1 * mask.unsqueeze(-1) z1_pooled = ibp.sum(z1_masked, -2) return z1_pooled if not compute_bounds: batch['prem']['x'] = batch['prem']['x'].val batch['hypo']['x'] = batch['hypo']['x'].val prem_encoded = encode(batch['prem']['x'], batch['prem']['mask']) hypo_encoded = encode(batch['hypo']['x'], batch['hypo']['mask']) input_encoded = ibp.cat([prem_encoded, hypo_encoded], -1) logits = self.layers(input_encoded) return logits
def query(self, dataset, device, batch_size=1, return_bounds=False): """Query the model on a Dataset. Args: dataset: a Dataset. device: torch device. neighbors: if provided, pass this to Dataset(). batch_size: batch size (default=1). Returns: Tensor of logits & gold labels """ data = dataset.get_loader(batch_size) output = [] gold = [] with torch.no_grad(): for batch in data: batch = data_util.dict_batch_to_device(batch, device) output.append(self.forward(batch, compute_bounds=return_bounds)) gold.append(batch['y']) return ibp.cat(output, dim=0), ibp.cat(gold, dim=0)
def attend_on(self, source, target, attention): """ Args: - source: (bXsXe) - target: (bXtXe) - attention: (bXtXs) """ attention_logsoftmax = ibp.log_softmax(attention, 1) attention_normalized = ibp.activation(torch.exp, attention_logsoftmax) attended_target = ibp.matmul_nneg(attention_normalized, source) # (bXtXe) return ibp.cat([target, attended_target], dim=-1)
def forward(self, batch, compute_bounds=True, cert_eps=1.0): """ Forward pass of DecompAttentionModel. Args: batch: A batch dict from an EntailmentDataset with the following keys: - prem: tensor of word vector indices for premise (B, p, 1) - hypo: tensor of word vector indices for hypothesis (B, h, 1) - prem_mask: binary mask over premise words (1 for real, 0 for pad), size (B, p) - hypo_mask: binary mask over hypothesis words (1 for real, 0 for pad), size (B, h) - prem_lengths: lengths of premises, size (B,) - hypo_lengths: lengths of hypotheses, size (B,) compute_bounds: If True compute the interval bounds and reutrn an IntervalBoundedTensor as logits. Otherwise just use the values cert_eps: float, scaling factor for the interval bounds. """ def encode(sequence, mask): vecs = self.embs(sequence) if isinstance(vecs, ibp.DiscreteChoiceTensor): null = torch.zeros_like(vecs.val[0]) null_choice = torch.zeros_like(vecs.choice_mat[0]) null[0] = self.null null_choice[0, 0] = self.null vecs.val = vecs.val + null vecs.choice_mat = vecs.choice_mat + null_choice else: null = torch.zeros_like(vecs[0]) null[0] = self.null vecs = vecs + null vecs = self.rotation(vecs) if isinstance(vecs, ibp.DiscreteChoiceTensor): vecs = vecs.to_interval_bounded(eps=cert_eps) return ibp.activation(F.relu, vecs) * mask.unsqueeze(-1) if not compute_bounds: batch['prem']['x'] = batch['prem']['x'].val batch['hypo']['x'] = batch['hypo']['x'].val prem_encoded = encode(batch['prem']['x'], batch['prem']['mask']) # (bXpXe) hypo_encoded = encode(batch['hypo']['x'], batch['hypo']['mask']) # (bXhXe) prem_weights = self.feedforward( prem_encoded) * batch['prem']['mask'].unsqueeze(-1) # (bXpX1) hypo_weights = self.feedforward( hypo_encoded) * batch['hypo']['mask'].unsqueeze(-1) # (bXhX1) attention = ibp.bmm(prem_weights, hypo_weights.permute( 0, 2, 1)) # (bXpX1) X (bX1Xh) => (bXpXh) attention_mask = batch['prem']['mask'].unsqueeze( -1) * batch['hypo']['mask'].unsqueeze(1) attention_masked = ibp.add(attention, (1 - attention_mask) * -1e20) attended_prem = self.attend_on(hypo_encoded, prem_encoded, attention_masked) # (bXpX2e) attended_hypo = self.attend_on(prem_encoded, hypo_encoded, attention_masked.permute(0, 2, 1)) # (bXhX2e) compared_prem = self.compare_ff( attended_prem) * batch['prem']['mask'].unsqueeze(-1) # (bXpXhid) compared_hypo = self.compare_ff( attended_hypo) * batch['hypo']['mask'].unsqueeze(-1) # (bXhXhid) prem_aggregate = ibp.pool(torch.sum, compared_prem, dim=1) # (bXhid) hypo_aggregate = ibp.pool(torch.sum, compared_hypo, dim=1) # (bXhid) aggregate = ibp.cat([prem_aggregate, hypo_aggregate], dim=-1) # (bX2hid) return self.output_layer(aggregate) # (b)
def cal_h_cat_c(h_left, h_right, c_left, c_right): # (n, d) h_cat = ibp.cat([h_left, h_right], dim=-1) f_cat = ibp.activation(th.sigmoid, self.U_f(h_cat)) # (n, 2 * d) c = f_cat[:, :self.h_size] * c_left + f_cat[:, self.h_size:] * c_right return self.U_iou(h_cat), c
def reduce_func_dp(self, nodes): h = get(nodes.mailbox, "h") # (n, 2, Del, Ins, Sub, d) c = get(nodes.mailbox, "c") unk_mask = nodes.mailbox["unk_mask"] # (n, 2) # if both children can be deleted, then the parent can be deleted new_unk_mask = th.where((unk_mask[:, 0] > 0) & (unk_mask[:, 1] > 0), th.sum(unk_mask, 1), th.zeros_like(unk_mask[:, 0])) def cal_h_cat_c(h_left, h_right, c_left, c_right): # (n, d) h_cat = ibp.cat([h_left, h_right], dim=-1) f_cat = ibp.activation(th.sigmoid, self.U_f(h_cat)) # (n, 2 * d) c = f_cat[:, :self.h_size] * c_left + f_cat[:, self.h_size:] * c_right return self.U_iou(h_cat), c # (n, Del, Ins, Sub, 3 * d) new_iou = [] # (n, Del, Ins, Sub, d) new_c = [] for deltas in itertools.product(*self.deltas_p1_ranges): deltas_ranges = [range(x + 1) for x in deltas] piece = None for deltas_left in itertools.product(*deltas_ranges): deltas_right = [y - x for (x, y) in zip(deltas_left, deltas)] tmp = cal_h_cat_c(h[:, 0, deltas_left[0], deltas_left[1], deltas_left[2], :], h[:, 1, deltas_right[0], deltas_right[1], deltas_right[2], :], c[:, 0, deltas_left[0], deltas_left[1], deltas_left[2], :], c[:, 1, deltas_right[0], deltas_right[1], deltas_right[2], :]) piece = ibp.merge(piece, tmp) new_iou.append(piece[0].unsqueeze(1)) new_c.append(piece[1].unsqueeze(1)) auxh = [] auxc = [] for delta0 in range(self.deltas_p1[0]): inds = th.arange(unk_mask.shape[0]) clip_unk_mask0 = th.clamp(unk_mask[:, 0], 1, delta0).long() clip_unk_mask1 = th.clamp(unk_mask[:, 1], 1, delta0).long() aux_list = [] for x in [h, c]: aux_l = ibp.where(((unk_mask[:, 0] > 0) & (unk_mask[:, 0] <= delta0)).view(-1, 1, 1, 1), # n, Ins, Sub, d x[inds, 1, delta0 - clip_unk_mask0], ibp.IntervalBoundedTensor.bottom( (h.shape[0], self.deltas_p1[1], self.deltas_p1[2], self.h_size), device=self.device)) aux_r = ibp.where(((unk_mask[:, 1] > 0) & (unk_mask[:, 1] <= delta0)).view(-1, 1, 1, 1), # n, Ins, Sub, d x[inds, 0, delta0 - clip_unk_mask1], ibp.IntervalBoundedTensor.bottom( (h.shape[0], self.deltas_p1[1], self.deltas_p1[2], self.h_size), device=self.device)) aux_list.append(aux_l.merge(aux_r)) auxh.append(aux_list[0].unsqueeze(dim=1)) auxc.append(aux_list[1].unsqueeze(dim=1)) auxh = ibp.cat(auxh, dim=1) auxc = ibp.cat(auxc, dim=1) new_iou = ibp.cat(new_iou, dim=1).view(-1, self.deltas_p1[0], self.deltas_p1[1], self.deltas_p1[2], self.h_size * 3) new_c = ibp.cat(new_c, dim=1).view(-1, self.deltas_p1[0], self.deltas_p1[1], self.deltas_p1[2], self.h_size) ret = {} add(ret, "iou", new_iou) add(ret, "c", new_c) add(ret, "auxh", auxh) add(ret, "auxc", auxc) ret["unk_mask"] = new_unk_mask return ret