예제 #1
0
 def testRegression(self):
     target_column = target_column_lib.regression_target()
     with ops.Graph().as_default(), session.Session() as sess:
         prediction = constant_op.constant([[1.], [1.], [3.]])
         labels = constant_op.constant([[0.], [1.], [1.]])
         self.assertAlmostEqual(
             5. / 3, sess.run(target_column.loss(prediction, labels, {})))
예제 #2
0
 def testRegression(self):
   target_column = target_column_lib.regression_target()
   with ops.Graph().as_default(), session.Session() as sess:
     prediction = constant_op.constant([[1.], [1.], [3.]])
     labels = constant_op.constant([[0.], [1.], [1.]])
     self.assertAlmostEqual(
         5. / 3, sess.run(target_column.loss(prediction, labels, {})))
예제 #3
0
 def testRegressionWithWeights(self):
   target_column = target_column_lib.regression_target(
       weight_column_name="label_weight")
   with ops.Graph().as_default(), session.Session() as sess:
     features = {"label_weight": constant_op.constant([[2.], [5.], [0.]])}
     prediction = constant_op.constant([[1.], [1.], [3.]])
     labels = constant_op.constant([[0.], [1.], [1.]])
     self.assertAlmostEqual(
         2. / 7,
         sess.run(target_column.loss(prediction, labels, features)),
         places=3)
     self.assertAlmostEqual(
         2. / 3,
         sess.run(target_column.training_loss(prediction, labels, features)),
         places=3)
예제 #4
0
 def testRegressionWithWeights(self):
     target_column = target_column_lib.regression_target(
         weight_column_name="label_weight")
     with ops.Graph().as_default(), session.Session() as sess:
         features = {
             "label_weight": constant_op.constant([[2.], [5.], [0.]])
         }
         prediction = constant_op.constant([[1.], [1.], [3.]])
         labels = constant_op.constant([[0.], [1.], [1.]])
         self.assertAlmostEqual(
             2. / 7,
             sess.run(target_column.loss(prediction, labels, features)),
             places=3)
         self.assertAlmostEqual(2. / 3,
                                sess.run(
                                    target_column.training_loss(
                                        prediction, labels, features)),
                                places=3)