예제 #1
0
    def testWhileLoop(self):
        with self.cached_session():
            r_ = rate.Rate()

            def body(value, denom, i, ret_rate):
                i += 1
                ret_rate = r_(value, denom)
                with ops.control_dependencies([ret_rate]):
                    value = math_ops.add(value, 2)
                    denom = math_ops.add(denom, 1)
                return [value, denom, i, ret_rate]

            def condition(v, d, i, r):
                del v, d, r  # unused vars by condition
                return math_ops.less(i, 100)

            i = constant_op.constant(0)
            value = constant_op.constant([1], dtype=dtypes.float64)
            denom = constant_op.constant([1], dtype=dtypes.float64)
            ret_rate = r_(value, denom)
            self.evaluate(variables.global_variables_initializer())
            self.evaluate(variables.local_variables_initializer())
            loop = control_flow_ops.while_loop(condition, body,
                                               [value, denom, i, ret_rate])
            self.assertEqual([[2]], self.evaluate(loop[3]))
예제 #2
0
 def testBuildRate(self):
     m = rate.Rate()
     m.build(constant_op.constant([1], dtype=dtypes.float32),
             constant_op.constant([2], dtype=dtypes.float32))
     old_numer = m.numer
     m(constant_op.constant([2], dtype=dtypes.float32),
       constant_op.constant([2], dtype=dtypes.float32))
     self.assertTrue(old_numer is m.numer)
예제 #3
0
 def testBasic(self):
     with self.cached_session():
         r_ = rate.Rate()
         a = r_(array_ops.ones([1]), denominator=array_ops.ones([1]))
         self.evaluate(variables.global_variables_initializer())
         self.evaluate(variables.local_variables_initializer())
         self.assertEqual([[1]], self.evaluate(a))
         b = r_(constant_op.constant([2]),
                denominator=constant_op.constant([2]))
         self.assertEqual([[1]], self.evaluate(b))
         c = r_(constant_op.constant([4]),
                denominator=constant_op.constant([3]))
         self.assertEqual([[2]], self.evaluate(c))
         d = r_(constant_op.constant([16]),
                denominator=constant_op.constant([3]))
         self.assertEqual([[0]], self.evaluate(d))  # divide by 0
예제 #4
0
 def testNamesWithSpaces(self):
     m1 = rate.Rate(name="has space")
     m1(array_ops.ones([1]), array_ops.ones([1]))
     self.assertEqual(m1.name, "has space")
     self.assertEqual(m1.prev_values.name, "has_space_1/prev_values:0")