Esempio n. 1
0
    def call(self,
             logits=None,
             input_length=None,
             labels=None,
             label_length=None,
             soft_labels=None):

        tags_scores = tf.reshape(logits,
                                 [-1, self.max_seq_len, self.num_classes],
                                 name="scores")
        loss, _ = crf_log_likelihood(tags_scores, labels, input_length,
                                     self.transitions)

        return loss
Esempio n. 2
0
  def call(self,
           logits=None,
           input_length=None,
           labels=None,
           label_length=None,
           **kwargs):
    assert "model" in kwargs
    model = kwargs["model"]
    tags_scores = tf.reshape(
        logits, [-1, model.max_len, model.seq_num_classes], name="scores")
    loss, _ = crf_log_likelihood(tags_scores, labels, input_length,
                                 model.transitions)

    return loss
Esempio n. 3
0
  def test_crf_loss(self):
    ''' test crf loss '''
    with self.cached_session():
      loss_true = np.float32(5.5096426)
      logits = np.asarray([[[0.3, 0.4, 0.3], [0.1, 0.9, 0.0], [0.2, 0.7, 0.1],
                            [0.3, 0.2, 0.5], [0.6, 0.2, 0.2]]],
                          dtype=np.float32)  # [1,5,3]
      trans_params = tf.fill([3, 3], 0.5, name='trans_params')
      labels = np.asarray([[0, 1, 2, 0, 1]])  # shape=[1,5]
      sequence_lengths = np.asarray([5])  # shape=[1,]
      loss, _ = loss_utils.crf_log_likelihood(
          tf.constant(logits), tf.constant(labels),
          tf.constant(sequence_lengths), trans_params)

      self.assertEqual(loss.eval(), loss_true)