예제 #1
0
 def __init__(self,
              word_dim,
              num_words,
              char_dim,
              num_chars,
              pos_dim,
              num_pos,
              rnn_mode,
              hidden_size,
              num_layers,
              num_labels,
              arc_space,
              type_space,
              embedd_word=None,
              embedd_char=None,
              embedd_pos=None,
              p_in=0.33,
              p_out=0.33,
              p_rnn=(0.33, 0.33),
              pos=True,
              activation='elu'):
     super(NeuroMST, self).__init__(word_dim,
                                    num_words,
                                    char_dim,
                                    num_chars,
                                    pos_dim,
                                    num_pos,
                                    rnn_mode,
                                    hidden_size,
                                    num_layers,
                                    num_labels,
                                    arc_space,
                                    type_space,
                                    embedd_word=embedd_word,
                                    embedd_char=embedd_char,
                                    embedd_pos=embedd_pos,
                                    p_in=p_in,
                                    p_out=p_out,
                                    p_rnn=p_rnn,
                                    pos=pos,
                                    activation=activation)
     self.biaffine = None
     self.treecrf = TreeCRF(arc_space)
예제 #2
0
class NeuroMST(DeepBiAffine):
    def __init__(self,
                 word_dim,
                 num_words,
                 char_dim,
                 num_chars,
                 pos_dim,
                 num_pos,
                 rnn_mode,
                 hidden_size,
                 num_layers,
                 num_labels,
                 arc_space,
                 type_space,
                 embedd_word=None,
                 embedd_char=None,
                 embedd_pos=None,
                 p_in=0.33,
                 p_out=0.33,
                 p_rnn=(0.33, 0.33),
                 pos=True,
                 activation='elu'):
        super(NeuroMST, self).__init__(word_dim,
                                       num_words,
                                       char_dim,
                                       num_chars,
                                       pos_dim,
                                       num_pos,
                                       rnn_mode,
                                       hidden_size,
                                       num_layers,
                                       num_labels,
                                       arc_space,
                                       type_space,
                                       embedd_word=embedd_word,
                                       embedd_char=embedd_char,
                                       embedd_pos=embedd_pos,
                                       p_in=p_in,
                                       p_out=p_out,
                                       p_rnn=p_rnn,
                                       pos=pos,
                                       activation=activation)
        self.biaffine = None
        self.treecrf = TreeCRF(arc_space)

    def forward(self, input_word, input_char, input_pos, mask=None):
        # output from rnn [batch, length, dim]
        arc, type = self._get_rnn_output(input_word,
                                         input_char,
                                         input_pos,
                                         mask=mask)
        # [batch, length_head, length_child]
        out_arc = self.treecrf(arc[0], arc[1], mask=mask)
        return out_arc, type

    # @overrides
    def loss(self, input_word, input_char, input_pos, heads, types, mask=None):
        # output from rnn [batch, length, dim]
        arc, out_type = self._get_rnn_output(input_word,
                                             input_char,
                                             input_pos,
                                             mask=mask)
        # [batch]
        loss_arc = self.treecrf.loss(arc[0], arc[1], heads, mask=mask)
        # out_type shape [batch, length, type_space]
        type_h, type_c = out_type

        # get vector for heads [batch, length, type_space],
        type_h = type_h.gather(dim=1,
                               index=heads.unsqueeze(2).expand(type_h.size()))
        # compute output for type [batch, length, num_labels]
        out_type = self.bilinear(type_h, type_c)
        loss_type = self.criterion(out_type.transpose(1, 2), types)

        # mask invalid position to 0 for sum loss
        if mask is not None:
            loss_type = loss_type * mask

        return loss_arc, loss_type[:, 1:].sum(dim=1)

    # @overrides
    def decode(self,
               input_word,
               input_char,
               input_pos,
               mask=None,
               leading_symbolic=0):
        """
        Args:
            input_word: Tensor
                the word input tensor with shape = [batch, length]
            input_char: Tensor
                the character input tensor with shape = [batch, length, char_length]
            input_pos: Tensor
                the pos input tensor with shape = [batch, length]
            mask: Tensor or None
                the mask tensor with shape = [batch, length]
            length: Tensor or None
                the length tensor with shape = [batch]
            hx: Tensor or None
                the initial states of RNN
            leading_symbolic: int
                number of symbolic labels leading in type alphabets (set it to 0 if you are not sure)

        Returns: (Tensor, Tensor)
                predicted heads and types.

        """
        # out_arc shape [batch, length_h, length_c]
        energy, out_type = self(input_word, input_char, input_pos, mask=mask)
        # compute lengths
        length = mask.sum(dim=1).long()
        heads, _ = parser.decode_MST(energy.cpu().numpy(),
                                     length.cpu().numpy(),
                                     leading_symbolic=leading_symbolic,
                                     labeled=False)
        types = self._decode_types(out_type,
                                   torch.from_numpy(heads).type_as(length),
                                   leading_symbolic)
        return heads, types.cpu().numpy()