예제 #1
0
파일: crf.py 프로젝트: jandyu/models-1
    def _trans_score(self, labels, lengths):
        batch_size, seq_len = labels.shape

        if self.with_start_stop_tag:
            # Add START and STOP on either side of the labels
            start_tensor, stop_tensor = self._get_start_stop_tensor(batch_size)
            labels_ext = paddle.concat([start_tensor, labels, stop_tensor],
                                       axis=1)
            mask = paddle.cast(
                sequence_mask(self._get_batch_seq_index(batch_size, seq_len),
                              lengths + 1), 'int32')
            pad_stop = paddle.full((batch_size, seq_len + 2),
                                   dtype='int64',
                                   fill_value=self.stop_idx)
            labels_ext = (1 - mask) * pad_stop + mask * labels_ext
        else:
            labels_ext = labels

        start_tag_indices = labels_ext[:, :-1]
        stop_tag_indices = labels_ext[:, 1:]

        # Encode the indices in a flattened representation.
        transition_indices = start_tag_indices * self.num_tags + stop_tag_indices
        flattened_transition_indices = transition_indices.reshape([-1])
        flattened_transition_params = self.transitions.reshape([-1])

        scores = paddle.gather(flattened_transition_params,
                               flattened_transition_indices).reshape(
                                   [batch_size, -1])
        mask_scores = scores * mask[:, 1:]

        # Accumulate the transition score
        score = paddle.sum(mask_scores, 1)

        return score
예제 #2
0
파일: crf.py 프로젝트: jeff41404/PaddleNLP
    def _point_score(self, inputs, labels, lengths):
        batch_size, seq_len, n_labels = inputs.shape
        # Get the true label logit value
        flattened_inputs = inputs.reshape([-1])
        offsets = paddle.unsqueeze(
            self._get_batch_index(batch_size) * seq_len * n_labels, 1)
        offsets += paddle.unsqueeze(self._get_seq_index(seq_len) * n_labels, 0)
        flattened_tag_indices = paddle.reshape(offsets + labels, [-1])

        scores = paddle.gather(flattened_inputs,
                               flattened_tag_indices).reshape(
                                   [batch_size, seq_len])

        mask = paddle.cast(
            sequence_mask(self._get_batch_seq_index(batch_size, seq_len),
                          lengths), 'float32')
        mask = mask[:, :seq_len]

        mask_scores = scores * mask
        score = paddle.sum(mask_scores, 1)
        return score