コード例 #1
0
 def testAccuracy1DString(self):
     with self.cached_session() as session:
         pred = array_ops.placeholder(dtypes.string, shape=[None])
         labels = array_ops.placeholder(dtypes.string, shape=[None])
         acc = classification.accuracy(pred, labels)
         result = session.run(acc,
                              feed_dict={
                                  pred: ['a', 'b', 'a', 'c'],
                                  labels: ['a', 'c', 'b', 'c']
                              })
         self.assertEqual(result, 0.5)
コード例 #2
0
 def testAccuracy1DInt64(self):
     with self.cached_session() as session:
         pred = array_ops.placeholder(dtypes.int64, shape=[None])
         labels = array_ops.placeholder(dtypes.int64, shape=[None])
         acc = classification.accuracy(pred, labels)
         result = session.run(acc,
                              feed_dict={
                                  pred: [1, 0, 1, 0],
                                  labels: [1, 1, 0, 0]
                              })
         self.assertEqual(result, 0.5)
コード例 #3
0
 def testAccuracy1DWeighted(self):
     with self.cached_session() as session:
         pred = array_ops.placeholder(dtypes.int32, shape=[None])
         labels = array_ops.placeholder(dtypes.int32, shape=[None])
         weights = array_ops.placeholder(dtypes.float32, shape=[None])
         acc = classification.accuracy(pred, labels)
         result = session.run(acc,
                              feed_dict={
                                  pred: [1, 0, 1, 1],
                                  labels: [1, 1, 0, 1],
                                  weights: [3.0, 1.0, 2.0, 0.0]
                              })
         self.assertEqual(result, 0.5)
コード例 #4
0
 def testAccuracyFloatLabels(self):
     with self.assertRaises(ValueError):
         pred = array_ops.placeholder(dtypes.int32, shape=[None])
         labels = array_ops.placeholder(dtypes.float32, shape=[None])
         classification.accuracy(pred, labels)