def decode(self, input_word, input_char, input_pos, mask=None, leading_symbolic=0): """ Args: input_word: Tensor the word input tensor with shape = [batch, length] input_char: Tensor the character input tensor with shape = [batch, length, char_length] input_pos: Tensor the pos input tensor with shape = [batch, length] mask: Tensor or None the mask tensor with shape = [batch, length] length: Tensor or None the length tensor with shape = [batch] hx: Tensor or None the initial states of RNN leading_symbolic: int number of symbolic labels leading in type alphabets (set it to 0 if you are not sure) Returns: (Tensor, Tensor) predicted heads and types. """ # out_arc shape [batch, length_h, length_c] energy, out_type = self(input_word, input_char, input_pos, mask=mask) # compute lengths length = mask.sum(dim=1).long() heads, _ = parser.decode_MST(energy.cpu().numpy(), length.cpu().numpy(), leading_symbolic=leading_symbolic, labeled=False) types = self._decode_types(out_type, torch.from_numpy(heads).type_as(length), leading_symbolic) return heads, types.cpu().numpy()
def decode(self, input_word, input_char, input_pos, mask=None, leading_symbolic=0): """ Args: input_word: Tensor the word input tensor with shape = [batch, length] input_char: Tensor the character input tensor with shape = [batch, length, char_length] input_pos: Tensor the pos input tensor with shape = [batch, length] mask: Tensor or None the mask tensor with shape = [batch, length] length: Tensor or None the length tensor with shape = [batch] hx: Tensor or None the initial states of RNN leading_symbolic: int number of symbolic labels leading in type alphabets (set it to 0 if you are not sure) Returns: (Tensor, Tensor) predicted heads and types. """ # out_arc shape [batch, length_h, length_c] out_arc, out_type = self(input_word, input_char, input_pos, mask=mask) # out_type shape [batch, length, type_space] type_h, type_c = out_type batch, max_len, type_space = type_h.size() type_h = type_h.unsqueeze(2).expand(batch, max_len, max_len, type_space).contiguous() type_c = type_c.unsqueeze(1).expand(batch, max_len, max_len, type_space).contiguous() # compute output for type [batch, length_h, length_c, num_labels] out_type = self.bilinear(type_h, type_c) if mask is not None: minus_mask = mask.eq(0).unsqueeze(2) out_arc.masked_fill_(minus_mask, float('-inf')) # loss_arc shape [batch, length_h, length_c] loss_arc = F.log_softmax(out_arc, dim=1) # loss_type shape [batch, length_h, length_c, num_labels] loss_type = F.log_softmax(out_type, dim=3).permute(0, 3, 1, 2) # [batch, num_labels, length_h, length_c] energy = loss_arc.unsqueeze(1) + loss_type # compute lengths length = mask.sum(dim=1).long().cpu().numpy() return parser.decode_MST(energy.cpu().numpy(), length, leading_symbolic=leading_symbolic, labeled=True)
def decode_mst(self, input_word, input_char, input_pos, mask=None, length=None, hx=None, leading_symbolic=0): ''' Args: input_word: Tensor the word input tensor with shape = [batch, length] input_char: Tensor the character input tensor with shape = [batch, length, char_length] input_pos: Tensor the pos input tensor with shape = [batch, length] mask: Tensor or None the mask tensor with shape = [batch, length] length: Tensor or None the length tensor with shape = [batch] hx: Tensor or None the initial states of RNN leading_symbolic: int number of symbolic labels leading in type alphabets (set it to 0 if you are not sure) Returns: (Tensor, Tensor) predicted heads and types. ''' # out_arc shape [batch, length, length] out_arc, out_type, mask, length = self.forward(input_word, input_char, input_pos, mask=mask, length=length, hx=hx) # out_type shape [batch, length, type_space] type_h, type_c = out_type batch, max_len, type_space = type_h.size() # compute lengths if length is None: if mask is None: length = [max_len for _ in range(batch)] else: length = mask.data.sum(dim=1).long().cpu().numpy() type_h = type_h.unsqueeze(2).expand(batch, max_len, max_len, type_space).contiguous() type_c = type_c.unsqueeze(1).expand(batch, max_len, max_len, type_space).contiguous() # compute output for type [batch, length, length, num_labels] out_type = self.bilinear(type_h, type_c) # mask invalid position to -inf for log_softmax if mask is not None: minus_inf = -1e8 minus_mask = (1 - mask) * minus_inf out_arc = out_arc + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) # loss_arc shape [batch, length, length] loss_arc = F.log_softmax(out_arc, dim=1) # loss_type shape [batch, length, length, num_labels] loss_type = F.log_softmax(out_type, dim=3).permute(0, 3, 1, 2) # [batch, num_labels, length, length] energy = torch.exp(loss_arc.unsqueeze(1) + loss_type) return parser.decode_MST(energy.data.cpu().numpy(), length, leading_symbolic=leading_symbolic, labeled=True)