def call(self, logits=None, input_length=None, labels=None, label_length=None, soft_labels=None): tags_scores = tf.reshape(logits, [-1, self.max_seq_len, self.num_classes], name="scores") loss, _ = crf_log_likelihood(tags_scores, labels, input_length, self.transitions) return loss
def call(self, logits=None, input_length=None, labels=None, label_length=None, **kwargs): assert "model" in kwargs model = kwargs["model"] tags_scores = tf.reshape( logits, [-1, model.max_len, model.seq_num_classes], name="scores") loss, _ = crf_log_likelihood(tags_scores, labels, input_length, model.transitions) return loss
def test_crf_loss(self): ''' test crf loss ''' with self.cached_session(): loss_true = np.float32(5.5096426) logits = np.asarray([[[0.3, 0.4, 0.3], [0.1, 0.9, 0.0], [0.2, 0.7, 0.1], [0.3, 0.2, 0.5], [0.6, 0.2, 0.2]]], dtype=np.float32) # [1,5,3] trans_params = tf.fill([3, 3], 0.5, name='trans_params') labels = np.asarray([[0, 1, 2, 0, 1]]) # shape=[1,5] sequence_lengths = np.asarray([5]) # shape=[1,] loss, _ = loss_utils.crf_log_likelihood( tf.constant(logits), tf.constant(labels), tf.constant(sequence_lengths), trans_params) self.assertEqual(loss.eval(), loss_true)