Exemplo n.º 1
0
  def test_focal_loss(self):
    with self.cached_session():
      logits = np.array([[22, 23, 24]], dtype=np.float32)
      labels = np.array([2], dtype=np.int32)

      ce_loss = loss_utils.cross_entropy(
          logits=tf.constant(logits),
          input_length=None,
          labels=tf.constant(labels),
          label_length=None)

      fl_loss0 = loss_utils.focal_loss(
          logits=tf.constant(logits),
          labels=tf.constant(labels),
          gamma = 0)

      self.assertAllClose(fl_loss0.eval(), 0.407606, rtol=1e-06, atol=1e-6)
      self.assertAllClose(fl_loss0.eval(), ce_loss.eval(), rtol=1e-07, atol=1e-7)

      fl_loss2 = loss_utils.focal_loss(
          logits=tf.constant(logits),
          labels=tf.constant(labels),
          gamma = 2)

      fl_loss5 = loss_utils.focal_loss(
          logits=tf.constant(logits),
          labels=tf.constant(labels),
          gamma = 5)

      self.assertAllClose(fl_loss2.eval(), 0.045677, rtol=1e-06, atol=1e-6)
      self.assertAllClose(fl_loss5.eval(), 0.001713, rtol=1e-06, atol=1e-6)
Exemplo n.º 2
0
  def call(self,
           logits=None,
           input_length=None,
           labels=None,
           label_length=None,
           **kwargs):

    del input_length
    del label_length

    return focal_loss(
        logits=logits, labels=labels, gamma=self.gamma, name='focal_loss')