예제 #1
0
 def test_accuracy(self):
     input_signature = (ShapeDtype((29, 4, 4, 20)), ShapeDtype((29, 4, 4)))
     result_shape = base.check_shape_agreement(metrics._Accuracy(),
                                               input_signature)
     self.assertEqual(result_shape, (29, 4, 4))
예제 #2
0
 def test_accuracy(self):
     layer = metrics._Accuracy()
     xs = [np.ones((9, 4, 4, 20)), np.ones((9, 4, 4))]
     y = layer(xs)
     self.assertEqual(y.shape, (9, 4, 4))