示例#1
0
 def make_mask(self, x, seq_len):
     ## make the mask for batch-load datasets 
     batch_size, max_len = x.size(0), x.size(1)
     mask = seq_mask(seq_len, max_len)
     mask = mask.view(batch_size, max_len)
     mask = mask.to(x).float()
     return mask
示例#2
0
 def data_forward(self, network, x):
     """
     :param network: the PyTorch model
     :param x: list of list, [batch_size, max_len]
     :return y: [batch_size, num_classes]
     """
     seq_len = [len(seq) for seq in x]
     x = torch.Tensor(x).long()
     self.batch_size = x.size(0)
     self.max_len = x.size(1)
     self.mask = seq_mask(seq_len, self.max_len)
     y = network(x)
     return y
示例#3
0
    def data_forward(self, network, inputs):
        if not isinstance(inputs, tuple):
            raise RuntimeError(
                "output_length must be true for sequence modeling. Receive {}".
                format(type(inputs[0])))
        # unpack the returned value from make_batch
        x, seq_len = inputs[0], inputs[1]

        batch_size, max_len = x.size(0), x.size(1)
        mask = utils.seq_mask(seq_len, max_len)
        mask = mask.byte().view(batch_size, max_len)

        if torch.cuda.is_available() and self.use_cuda:
            mask = mask.cuda()
        self.mask = mask

        y = network(x)
        return y
示例#4
0
 def data_forward(self, network, inputs):
     """
     This is only for sequence labeling with CRF decoder.
     :param network: a PyTorch model
     :param inputs: tuple of (x, seq_len)
                     x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch
                         after padding.
                     seq_len: list of int, the lengths of sequences before padding.
     :return prediction: Tensor of shape [batch_size, max_len]
     """
     if not isinstance(inputs[1], list) and isinstance(inputs[0], list):
         raise RuntimeError(
             "output_length must be true for sequence modeling.")
     # unpack the returned value from make_batch
     x, seq_len = inputs[0], inputs[1]
     batch_size, max_len = x.size(0), x.size(1)
     mask = utils.seq_mask(seq_len, max_len)
     mask = mask.byte().view(batch_size, max_len)
     y = network(x)
     prediction = network.prediction(y, mask)
     return torch.Tensor(prediction)
示例#5
0
 def data_forward(self, network, inputs):
     """This is only for sequence labeling with CRF decoder.
     :param network: a PyTorch model
     :param inputs: tuple of (x, seq_len)
                     x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch
                         after padding.
                     seq_len: list of int, the lengths of sequences before padding.
     :return y: Tensor of shape [batch_size, max_len]
     """
     if not isinstance(inputs, tuple):
         raise RuntimeError(
             "output_length must be true for sequence modeling.")
     # unpack the returned value from make_batch
     x, seq_len = inputs[0], inputs[1]
     batch_size, max_len = x.size(0), x.size(1)
     mask = utils.seq_mask(seq_len, max_len)
     mask = mask.byte().view(batch_size, max_len)
     if torch.cuda.is_available() and self.use_cuda:
         mask = mask.cuda()
     self.mask = mask
     self.seq_len = seq_len
     y = network(x)
     return y
示例#6
0
 def make_mask(self, x, seq_len):
     batch_size, max_len = x.size(0), x.size(1)
     mask = seq_mask(seq_len, max_len)
     mask = mask.view(batch_size, max_len)
     mask = mask.to(x).float()
     return mask
示例#7
0
    def forward(self, word_seq, pos_seq, seq_lens, gold_heads=None):
        """
        :param word_seq: [batch_size, seq_len] sequence of word's indices
        :param pos_seq: [batch_size, seq_len] sequence of word's indices
        :param seq_lens: [batch_size, seq_len] sequence of length masks
        :param gold_heads: [batch_size, seq_len] sequence of golden heads
        :return dict: parsing results
            arc_pred: [batch_size, seq_len, seq_len]
            label_pred: [batch_size, seq_len, seq_len]
            mask: [batch_size, seq_len]
            head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads
        """
        # prepare embeddings
        batch_size, seq_len = word_seq.shape
        # print('forward {} {}'.format(batch_size, seq_len))

        # get sequence mask
        mask = seq_mask(seq_lens, seq_len).long()

        word = self.word_embedding(word_seq) # [N,L] -> [N,L,C_0]
        pos = self.pos_embedding(pos_seq) # [N,L] -> [N,L,C_1]

        word, pos = self.word_fc(word), self.pos_fc(pos)
        word, pos = self.word_norm(word), self.pos_norm(pos)
        x = torch.cat([word, pos], dim=2) # -> [N,L,C]

        # encoder, extract features
        if self.encoder_name.endswith('lstm'):
            sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True)
            x = x[sort_idx]
            x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
            feat, _ = self.encoder(x) # -> [N,L,C]
            feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
            _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
            feat = feat[unsort_idx]
        else:
            seq_range = torch.arange(seq_len, dtype=torch.long, device=x.device)[None,:]
            x = x + self.position_emb(seq_range)
            feat = self.encoder(x, mask.float())

        # for arc biaffine
        # mlp, reduce dim
        feat = self.mlp(feat)
        arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size
        arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz]
        label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:]

        # biaffine arc classifier
        arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]

        # use gold or predicted arc to predict label
        if gold_heads is None or not self.training:
            # use greedy decoding in training
            if self.training or self.use_greedy_infer:
                heads = self._greedy_decoder(arc_pred, mask)
            else:
                heads = self._mst_decoder(arc_pred, mask)
            head_pred = heads
        else:
            assert self.training # must be training mode
            if gold_heads is None:
                heads = self._greedy_decoder(arc_pred, mask)
                head_pred = heads
            else:
                head_pred = None
                heads = gold_heads

        batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)
        label_head = label_head[batch_range, heads].contiguous()
        label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
        res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask}
        if head_pred is not None:
            res_dict['head_pred'] = head_pred
        return res_dict
    def forward(self, word_seq, pos_seq, seq_lens, gold_heads=None):
        """
        :param word_seq: [batch_size, seq_len] sequence of word's indices
        :param pos_seq: [batch_size, seq_len] sequence of word's indices
        :param seq_lens: [batch_size, seq_len] sequence of length masks
        :param gold_heads: [batch_size, seq_len] sequence of golden heads
        :return dict: parsing results
            arc_pred: [batch_size, seq_len, seq_len]
            label_pred: [batch_size, seq_len, seq_len]
            mask: [batch_size, seq_len]
            head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads
        """
        # prepare embeddings
        device = self.parameters().__next__().device
        word_seq = word_seq.long().to(device)
        pos_seq = pos_seq.long().to(device)
        seq_lens = seq_lens.long().to(device).view(-1)
        batch_size, seq_len = word_seq.shape
        # print('forward {} {}'.format(batch_size, seq_len))

        # get sequence mask
        mask = seq_mask(seq_lens, seq_len).long()

        word = self.normal_dropout(
            self.word_embedding(word_seq))  # [N,L] -> [N,L,C_0]
        pos = self.normal_dropout(
            self.pos_embedding(pos_seq))  # [N,L] -> [N,L,C_1]
        word, pos = self.word_fc(word), self.pos_fc(pos)
        word, pos = self.word_norm(word), self.pos_norm(pos)
        x = torch.cat([word, pos], dim=2)  # -> [N,L,C]
        del word, pos

        # lstm, extract features
        sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True)
        x = x[sort_idx]
        x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
        feat, _ = self.lstm(x)  # -> [N,L,C]
        feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
        _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
        feat = feat[unsort_idx]

        # for arc biaffine
        # mlp, reduce dim
        arc_dep = self.arc_dep_mlp(feat)
        arc_head = self.arc_head_mlp(feat)
        label_dep = self.label_dep_mlp(feat)
        label_head = self.label_head_mlp(feat)
        del feat

        # biaffine arc classifier
        arc_pred = self.arc_predictor(arc_head, arc_dep)  # [N, L, L]

        # use gold or predicted arc to predict label
        if gold_heads is None or not self.training:
            # use greedy decoding in training
            if self.training or self.use_greedy_infer:
                heads = self._greedy_decoder(arc_pred, mask)
            else:
                heads = self._mst_decoder(arc_pred, mask)
            head_pred = heads
        else:
            assert self.training  # must be training mode
            if torch.rand(1).item() < self.explore_p:
                heads = self._greedy_decoder(arc_pred, mask)
                head_pred = heads
            else:
                head_pred = None
                heads = gold_heads

        batch_range = torch.arange(start=0,
                                   end=batch_size,
                                   dtype=torch.long,
                                   device=word_seq.device).unsqueeze(1)
        label_head = label_head[batch_range, heads].contiguous()
        label_pred = self.label_predictor(label_head,
                                          label_dep)  # [N, L, num_label]
        res_dict = {
            'arc_pred': arc_pred,
            'label_pred': label_pred,
            'mask': mask
        }
        if head_pred is not None:
            res_dict['head_pred'] = head_pred
        return res_dict