def _trans_score(self, labels, lengths): batch_size, seq_len = labels.shape if self.with_start_stop_tag: # Add START and STOP on either side of the labels start_tensor, stop_tensor = self._get_start_stop_tensor(batch_size) labels_ext = paddle.concat([start_tensor, labels, stop_tensor], axis=1) mask = paddle.cast( sequence_mask(self._get_batch_seq_index(batch_size, seq_len), lengths + 1), 'int32') pad_stop = paddle.full((batch_size, seq_len + 2), dtype='int64', fill_value=self.stop_idx) labels_ext = (1 - mask) * pad_stop + mask * labels_ext else: labels_ext = labels start_tag_indices = labels_ext[:, :-1] stop_tag_indices = labels_ext[:, 1:] # Encode the indices in a flattened representation. transition_indices = start_tag_indices * self.num_tags + stop_tag_indices flattened_transition_indices = transition_indices.reshape([-1]) flattened_transition_params = self.transitions.reshape([-1]) scores = paddle.gather(flattened_transition_params, flattened_transition_indices).reshape( [batch_size, -1]) mask_scores = scores * mask[:, 1:] # Accumulate the transition score score = paddle.sum(mask_scores, 1) return score
def _point_score(self, inputs, labels, lengths): batch_size, seq_len, n_labels = inputs.shape # Get the true label logit value flattened_inputs = inputs.reshape([-1]) offsets = paddle.unsqueeze( self._get_batch_index(batch_size) * seq_len * n_labels, 1) offsets += paddle.unsqueeze(self._get_seq_index(seq_len) * n_labels, 0) flattened_tag_indices = paddle.reshape(offsets + labels, [-1]) scores = paddle.gather(flattened_inputs, flattened_tag_indices).reshape( [batch_size, seq_len]) mask = paddle.cast( sequence_mask(self._get_batch_seq_index(batch_size, seq_len), lengths), 'float32') mask = mask[:, :seq_len] mask_scores = scores * mask score = paddle.sum(mask_scores, 1) return score