Exemple #1
0
 def _compare_values(self, x, y=None):
   y = np.rint(x) if y is None else np.asarray(y)
   with self.cached_session() as sess:
     tf_rint = math_ops.rint(x)
     np_rint = sess.run(tf_rint)
   self.assertAllEqual(y, np_rint)
   self.assertShapeEqual(y, tf_rint)
  def _compare_values(self, x, y=None):
    y = np.rint(x) if y is None else np.asarray(y)

    tf_rint = math_ops.rint(x)
    np_rint = self.evaluate(tf_rint)

    self.assertAllEqual(y, np_rint)
    self.assertShapeEqual(y, tf_rint)
 def testRint(self):
     x = np.arange(-5.0, 5.0, .25)
     for dtype in [np.float32, np.double, np.int32]:
         x_np = np.array(x, dtype=dtype)
         x_tf = constant_op.constant(x_np, shape=x_np.shape)
         y_tf = math_ops.rint(x_tf)
         y_tf_np = self.evaluate(y_tf)
         y_np = np.rint(x_np)
         self.assertAllClose(y_tf_np, y_np, atol=1e-2)