コード例 #1
0
    def test_value(self):
        def fn1(): return 'fn1'

        def fn2(): return 'fn2'

        def expected(v): return 'fn1' if v else 'fn2'
        for v in [True, False, 1, 0]:
            o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
            self.assertEqual(o, expected(v))
コード例 #2
0
    def test_tensors(self):
        def fn1(): return constant_op.constant(0) - constant_op.constant(1)

        def fn2(): return constant_op.constant(0) - constant_op.constant(2)

        def expected(v): return -1 if v else -2
        for v in [True, False, 1, 0]:
            o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
            with self.cached_session():
                self.assertEqual(o.eval(), expected(v))
コード例 #3
0
    def test_constant(self):
        def fn1(): return constant_op.constant('fn1')

        def fn2(): return constant_op.constant('fn2')

        def expected(v): return b'fn1' if v else b'fn2'
        for v in [True, False, 1, 0]:
            o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
            with self.cached_session():
                self.assertEqual(o.eval(), expected(v))
コード例 #4
0
    def test_tensors(self):
        def fn1(): return constant_op.constant(0) - constant_op.constant(1)

        def fn2(): return constant_op.constant(0) - constant_op.constant(2)

        def expected(v): return -1 if v else -2
        p = array_ops.placeholder(dtypes.bool, [])
        for v in [True, False, 1, 0]:
            o = utils.smart_cond(p, fn1, fn2)
            with self.cached_session():
                self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
コード例 #5
0
    def test_value(self):
        def fn1(): return ops.convert_to_tensor('fn1')

        def fn2(): return ops.convert_to_tensor('fn2')

        def expected(v): return b'fn1' if v else b'fn2'
        p = array_ops.placeholder(dtypes.bool, [])
        for v in [True, False, 1, 0]:
            o = utils.smart_cond(p, fn1, fn2)
            with self.cached_session():
                self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
コード例 #6
0
    def test_variable(self):
        with tf.Graph().as_default():
            def fn1(): return variables.Variable('fn1')

            def fn2(): return variables.Variable('fn2')

            def expected(v): return b'fn1' if v else b'fn2'
            p = array_ops.placeholder(dtypes.bool, [])
            for v in [True, False, 1, 0]:
                o = utils.smart_cond(p, fn1, fn2)
                with self.cached_session() as sess:
                    sess.run(variables.global_variables_initializer())
                    self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
コード例 #7
0
    def test_variable(self):
        with tf.Graph().as_default():
            def fn1(): return variables.Variable('fn1')

            def fn2(): return variables.Variable('fn2')

            def expected(v): return b'fn1' if v else b'fn2'
            for v in [True, False, 1, 0]:
                o = utils.smart_cond(constant_op.constant(v), fn1, fn2)

                with self.cached_session() as sess:
                    sess.run(variables.global_variables_initializer())
                    self.assertEqual(o.eval(), expected(v))