Ejemplo n.º 1
0
    def test_disable_op_2(self, invalid_op_list):
        # This test is disabled for grappler because grappler fails silently and
        # TF continues to run with the unoptimized graph
        # Note, tried setting fail_on_optimizer_errors, but grappler still failed silently
        # TODO: enable this test for grappler as well.
        if (not ngraph_bridge.is_grappler_enabled()):
            ngraph_bridge.set_disabled_ops(invalid_op_list)
            a = tf.placeholder(tf.int32, shape=(5, ))
            b = tf.constant(np.ones((5, )), dtype=tf.int32)
            c = a + b

            def run_test(sess):
                return sess.run(c, feed_dict={a: np.ones((5, ))})

            assert (self.without_ngraph(run_test) == np.ones(5, ) * 2).all()
            #import pdb; pdb.set_trace()
            try:
                # This test is expected to fail,
                # since all the strings passed to set_disabled_ops have invalid ops in them
                res = self.with_ngraph(run_test)
            except:
                # Clean up
                ngraph_bridge.set_disabled_ops('')
                return
            assert False, 'Had expected test to raise error'
Ejemplo n.º 2
0
 def test_disable_op_1(self, op_list):
     ngraph_bridge.set_disabled_ops(op_list)
     assert ngraph_bridge.get_disabled_ops() == op_list.encode("utf-8")
     # Running get_disabled_ops twice to see nothing has changed between 2 consecutive calls
     assert ngraph_bridge.get_disabled_ops() == op_list.encode("utf-8")
     # Clean up
     ngraph_bridge.set_disabled_ops('')
Ejemplo n.º 3
0
    def test_disable_3(self):
        old_backend = ngraph_bridge.get_backend()
        ngraph_bridge.set_backend('CPU')
        N = 1
        C = 4
        H = 10
        W = 10
        FW = 3
        FH = 3
        O = 6
        inp = tf.compat.v1.placeholder(tf.float32, shape=(N, C, H, W))
        filt = tf.constant(np.ones((FH, FW, C, O)), dtype=tf.float32)
        conv = nn_ops.conv2d(inp,
                             filt,
                             strides=[1, 1, 1, 2],
                             padding="SAME",
                             data_format='NCHW')

        def run_test(sess):
            return sess.run(conv, feed_dict={inp: np.ones((N, C, H, W))})

        # Ensure that NCHW does not run on TF CPU backend natively
        test_passed = True
        try:
            self.without_ngraph(run_test)
        except:
            test_passed = False
        if (test_passed):
            ngraph_bridge.set_backend(old_backend)
            assert False, 'Had expected test to raise error, since NCHW in conv2d is not supported'

        # We have asserted the above network would not run on TF natively.
        # Now ensure it runs with ngraph
        test_passed = True
        try:
            self.with_ngraph(run_test)
        except:
            test_passed = False
        if (not test_passed):
            ngraph_bridge.set_backend(old_backend)
            assert False, 'Had expected test to pass, since NCHW in conv2d is supported through ngraph'

        # Now disabling Conv2D. Expecting the test to fail, even when run through ngraph
        ngraph_bridge.set_disabled_ops('Conv2D')
        test_passed = True
        try:
            self.with_ngraph(run_test)
        except:
            test_passed = False
        if (test_passed):
            ngraph_bridge.set_backend(old_backend)
            assert False, 'Had expected test to raise error, since conv2D is disabled in ngraph'

        # Clean up
        ngraph_bridge.set_backend(old_backend)
        ngraph_bridge.set_disabled_ops('')
Ejemplo n.º 4
0
    def test_disable_op_env(self):
        op_list = 'Select,Where'
        ngraph_bridge.set_disabled_ops(op_list)
        assert ngraph_bridge.get_disabled_ops() == op_list.encode("utf-8")

        env_map = self.store_env_variables('NGRAPH_TF_DISABLED_OPS')
        env_list = 'Squeeze'
        self.set_env_variable('NGRAPH_TF_DISABLED_OPS', env_list)
        assert ngraph_bridge.get_disabled_ops() == env_list.encode("utf-8")
        self.unset_env_variable('NGRAPH_TF_DISABLED_OPS')
        self.restore_env_variables(env_map)

        # Clean up
        ngraph_bridge.set_disabled_ops('')