예제 #1
0
 def my_loss(y_true, y_pred):
     entropy = weighted_cross_entropy_with_logits(
         y_true, y_pred, K.constant(loss_weights, dtype='float32'))
     mask, count_unmasked = mask_from(y_true, mask_val)
     masked_entropy = entropy * mask
     loss = K.sum(masked_entropy) / count_unmasked
     return loss
예제 #2
0
 def testConstructionNamed(self):
     with self.test_session():
         logits, targets, pos_weight, _ = self._Inputs()
         loss = nn_impl.weighted_cross_entropy_with_logits(targets,
                                                           logits,
                                                           pos_weight,
                                                           name="mybce")
     self.assertEqual("mybce", loss.op.name)
예제 #3
0
 def testGradient(self):
   sizes = [4, 2]
   with self.cached_session():
     logits, targets, pos_weight, _ = self._Inputs(sizes=sizes)
     loss = nn_impl.weighted_cross_entropy_with_logits(
         targets=targets, logits=logits, pos_weight=pos_weight)
     err = gradient_checker.compute_gradient_error(logits, sizes, loss, sizes)
   print("logistic loss gradient err = ", err)
   self.assertLess(err, 1e-7)
예제 #4
0
 def testOutput(self):
   for use_gpu in [True, False]:
     with self.test_session(use_gpu=use_gpu):
       logits, targets, pos_weight, losses = self._Inputs(dtype=dtypes.float32)
       loss = nn_impl.weighted_cross_entropy_with_logits(
           targets=targets, logits=logits, pos_weight=pos_weight)
       np_loss = np.array(losses).astype(np.float32)
       tf_loss = loss.eval()
     self.assertAllClose(np_loss, tf_loss, atol=0.001)
예제 #5
0
 def testShapeError(self):
   with self.assertRaisesRegexp(ValueError, "must have the same shape"):
     nn_impl.weighted_cross_entropy_with_logits(
         targets=[1, 2, 3], logits=[[2, 1]], pos_weight=2.0)
예제 #6
0
 def testConstructionNamed(self):
   with self.cached_session():
     logits, targets, pos_weight, _ = self._Inputs()
     loss = nn_impl.weighted_cross_entropy_with_logits(
         targets=targets, logits=logits, pos_weight=pos_weight, name="mybce")
   self.assertEqual("mybce", loss.op.name)