def decode_mst(self, out_arc, out_arc_tag, mask, length, leading_symbolic): loss_arc, loss_arc_tag = self.pre_loss(out_arc, out_arc_tag, heads=None, arc_tags=None, mask=mask, length=length, use_log=True, temperature=1.0) batch_size, max_len, _ = loss_arc.size() # compute lengths if length is None: if mask is None: length = [max_len for _ in range(batch_size)] else: length = mask.data.sum(dim=1).long().cpu().numpy() # energy shape [batch_size, num_arcs, length, length] energy = torch.exp(loss_arc.unsqueeze(1) + loss_arc_tag) heads, arc_tags = parse.decode_MST(energy.data.cpu().numpy(), length, leading_symbolic=leading_symbolic, labeled=True) heads = from_numpy(heads) arc_tags = from_numpy(arc_tags) # compute the average score for each tree batch_size, max_len = heads.size() scores = torch.zeros_like(heads, dtype=energy.dtype, device=energy.device) for b_idx in range(batch_size): for len_idx in range(max_len): scores[b_idx, len_idx] = energy[b_idx, arc_tags[b_idx, len_idx], heads[b_idx, len_idx], len_idx] if mask is not None: scores = scores.sum(1) / mask.sum(1) else: scores = scores.sum(1) / max_len return heads, arc_tags, scores
def unconstrained_decode_mst(self, model_path, input_word, input_pos, input_lemma, out_arc, out_arc_tag, mask, length, leading_symbolic): loss_arc, loss_arc_tag = self.pre_loss(out_arc, out_arc_tag, heads=None, arc_tags=None, mask=mask, length=length, use_log=True, temperature=1.0) batch_size, max_len, _ = loss_arc.size() # compute lengths if length is None: if mask is None: length = [max_len for _ in range(batch_size)] else: length = mask.data.sum(dim=1).long().cpu().numpy() # energy shape [batch_size, num_arcs, length, length] raw_energy = loss_arc.unsqueeze(1) + loss_arc_tag # pdb.set_trace() energy = torch.exp(raw_energy) # with open('/home/jivnesh/Documents/DCST/energy.npy', 'wb') as f: # np.save(f, energy.data.cpu().numpy()) constrained_energy = energy.data.cpu().numpy() # heads, arc_tags, = parse.decode_MST(constrained_energy, length, leading_symbolic=leading_symbolic, labeled=True) heads = from_numpy(heads) arc_tags = from_numpy(arc_tags) # compute the average score for each tree batch_size, max_len = heads.size() scores = torch.zeros_like(heads, dtype=energy.dtype, device=energy.device) for b_idx in range(batch_size): for len_idx in range(max_len): scores[b_idx, len_idx] = energy[b_idx, arc_tags[b_idx, len_idx], heads[b_idx, len_idx], len_idx] if mask is not None: scores = scores.sum(1) / mask.sum(1) else: scores = scores.sum(1) / max_len return heads, arc_tags, scores