def test_graph_to_function_v2_given_graph(self): """graph_def_to_function_v2 should accept tf.Graph""" graph = testutils.get_sample_graph(testutils.SIMPLE_MODEL_FILE_NAME) estimate = api.graph_to_function_v2(graph) x_ = 12 x = tf.constant([[x_]], dtype=tf.float32) y = as_scalar(estimate(x)) self.assertAlmostEqual(y, x_ * 5, places=1)
def test_graph_to_function_v2_given_graph_def(self): """graph_def_to_function_v2 should accept graph_def""" graph_def = testutils.get_sample_graph_def( testutils.SIMPLE_MODEL_FILE_NAME) estimate = api.graph_to_function_v2(graph_def) x_ = 20 x = tf.constant([[x_]], dtype=tf.float32) y = as_scalar(estimate(x)) self.assertAlmostEqual(y, x_ * 5, delta=0.1)