def test_runtime_error_rewriting(self): def g(x, s): while tf.reduce_sum(x) > s: x //= 0 return x def test_fn(x): return g(x, 10) compiled_fn = ag.to_graph(test_fn) with self.assertRaises(ag.TfRuntimeError) as error: with self.cached_session() as sess: x = compiled_fn(tf.constant([4, 8])) with ag.improved_errors(compiled_fn): sess.run(x) expected = error.exception custom_traceback = expected.custom_traceback found_correct_filename = False num_test_fn_frames = 0 num_g_frames = 0 for frame in custom_traceback: filename, _, fn_name, source_code = frame self.assertFalse('/tmp/' in filename) self.assertFalse('control_flow.py' in filename) self.assertFalse('ag__.' in fn_name) found_correct_filename |= __file__ in filename num_test_fn_frames += int('test_fn' == fn_name and 'return g(x, 10)' in source_code) num_g_frames += int('g' == fn_name and 'x //= 0' in source_code) self.assertTrue(found_correct_filename) self.assertEqual(num_test_fn_frames, 1) self.assertEqual(num_g_frames, 1)
def test_graph_construction_error_rewriting_call_tree(self): def test_fn(): return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) def inner_caller(): return test_fn() def caller(): return inner_caller() with self.assertRaises(ag.GraphConstructionError) as error: graph = ag.to_graph(caller) graph() expected = error.exception custom_traceback = expected.custom_traceback found_correct_filename = False num_test_fn_names = 0 num_inner_caller_names = 0 num_caller_names = 0 for frame in custom_traceback: filename, _, fn_name, _ = frame self.assertFalse('/tmp/' in filename) found_correct_filename |= __file__ in filename self.assertNotEqual('tf__test_fn', fn_name) num_test_fn_names += int('test_fn' == fn_name) self.assertNotEqual('tf__inner_caller', fn_name) num_inner_caller_names += int('inner_caller' == fn_name) self.assertNotEqual('tf__caller', fn_name) num_caller_names += int('caller' == fn_name) self.assertTrue(found_correct_filename) self.assertEqual(num_test_fn_names, 1) self.assertEqual(num_inner_caller_names, 1) self.assertEqual(num_caller_names, 1)
def test_runtime_error_rewriting(self): def g(x, s): while tf.reduce_sum(x) > s: x //= 0 return x def test_fn(x): return g(x, 10) compiled_fn = ag.to_graph(test_fn) with self.assertRaises(ag.TfRuntimeError) as error: with self.cached_session() as sess: x = compiled_fn(tf.constant([4, 8])) with ag.improved_errors(compiled_fn): sess.run(x) expected = error.exception custom_traceback = expected.custom_traceback found_correct_filename = False num_test_fn_frames = 0 num_g_frames = 0 for frame in custom_traceback: filename, _, fn_name, source_code = frame self.assertFalse('/tmp/' in filename) self.assertFalse('control_flow.py' in filename) self.assertFalse('ag__.' in fn_name) found_correct_filename |= __file__ in filename num_test_fn_frames += int('test_fn' == fn_name and 'return g(x, 10)' in source_code) num_g_frames += int('g' == fn_name and 'x //= 0' in source_code) self.assertTrue(found_correct_filename) self.assertEqual(num_test_fn_frames, 1) self.assertEqual(num_g_frames, 1)
def test_runtime_error_rewriting_nested(self): def test_fn(x): def g(y): return y**2 // 0 s = 0 for xi in x: s += g(xi) return s compiled_fn = ag.to_graph(test_fn) # TODO(b/111408261): Nested functions currently do not rewrite correctly, # when they do we should change this test to check for the same traceback # properties as the other tests. This should throw a runtime error with a # frame with "g" as the function name but because we don't yet add # try/except blocks to inner functions the name is "tf__g". with self.assertRaises(ag.TfRuntimeError) as error: with self.cached_session() as sess: x = compiled_fn(tf.constant([4, 8])) with ag.improved_errors(compiled_fn): sess.run(x) expected = error.exception custom_traceback = expected.custom_traceback num_tf_g_frames = 0 for frame in custom_traceback: _, _, fn_name, _ = frame self.assertNotEqual('g', fn_name) num_tf_g_frames += int('tf__g' == fn_name) self.assertEqual(num_tf_g_frames, 1)
def test_graph_construction_error_rewriting_call_tree(self): def test_fn(): return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) def inner_caller(): return test_fn() def caller(): return inner_caller() with self.assertRaises(ag.GraphConstructionError) as error: graph = ag.to_graph(caller) graph() expected = error.exception custom_traceback = expected.custom_traceback found_correct_filename = False num_test_fn_names = 0 num_inner_caller_names = 0 num_caller_names = 0 for frame in custom_traceback: filename, _, fn_name, _ = frame self.assertFalse('/tmp/' in filename) found_correct_filename |= __file__ in filename self.assertNotEqual('tf__test_fn', fn_name) num_test_fn_names += int('test_fn' == fn_name) self.assertNotEqual('tf__inner_caller', fn_name) num_inner_caller_names += int('inner_caller' == fn_name) self.assertNotEqual('tf__caller', fn_name) num_caller_names += int('caller' == fn_name) self.assertTrue(found_correct_filename) self.assertEqual(num_test_fn_names, 1) self.assertEqual(num_inner_caller_names, 1) self.assertEqual(num_caller_names, 1)
def test_runtime_error_rewriting_nested(self): def test_fn(x): def g(y): return y**2 // 0 s = 0 for xi in x: s += g(xi) return s compiled_fn = ag.to_graph(test_fn) # TODO(b/111408261): Nested functions currently do not rewrite correctly, # when they do we should change this test to check for the same traceback # properties as the other tests. This should throw a runtime error with a # frame with "g" as the function name but because we don't yet add # try/except blocks to inner functions the name is "tf__g". with self.assertRaises(ag.TfRuntimeError) as error: with self.cached_session() as sess: x = compiled_fn(tf.constant([4, 8])) with ag.improved_errors(compiled_fn): sess.run(x) expected = error.exception custom_traceback = expected.custom_traceback num_tf_g_frames = 0 for frame in custom_traceback: _, _, fn_name, _ = frame self.assertNotEqual('g', fn_name) num_tf_g_frames += int('tf__g' == fn_name) self.assertEqual(num_tf_g_frames, 1)
def test_graph_construction_error_rewriting_class(self): class TestClass(object): def test_fn(self): return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) def inner_caller(self): return self.test_fn() def caller(self): return self.inner_caller() # Note we expect a TypeError here because the traceback will not be # rewritten for classes. with self.assertRaises(TypeError): graph = ag.to_graph(TestClass) graph().caller()
def test_graph_construction_error_rewriting_class(self): class TestClass(object): def test_fn(self): return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) def inner_caller(self): return self.test_fn() def caller(self): return self.inner_caller() # Note we expect a TypeError here because the traceback will not be # rewritten for classes. with self.assertRaises(TypeError): graph = ag.to_graph(TestClass) graph().caller()
def test_basic(self): converted = ag.to_graph(list_used_as_tuple) result = converted() with self.cached_session() as sess: self.assertAllEqual(self.evaluate(result), [1, 2, 3])
def _createEqnModel(self, inputVar, outputVar, mdlName, eqMdl): """ Build a K-CHAIN model using input and output variables from the KG and the physics equation. Arguments: inputVar (JSON array): array of JSON variable objects with name, type, value, and unit fields outputVar (JSON array): array of JSON variable objects with name, type, value, and unit fields mdlName (string): Name to assign to the final model (E.g.: 'Newtons2ndLaw') eqMdl (string): Equation relating inputs to output (E.g.: "c = a * b") Returns: (TensorFlow Graph, string): * TensorFlow Graph: Computational graph of the physics equation * metagraphLoc: string of location on disk where computational model was stored """ in_dims = len(inputVar) inStr = inputVar[0]['name'] + ' = inArg[0]' for ii in range(1, in_dims): inStr = inStr + '\n ' + inputVar[ii][ 'name'] + ' = inArg[' + str(ii) + ']' #4 spaces is ideal for indentation #construct the python function around the python snippet stringfun = 'import tensorflow as tf'\ +'\ndef '+mdlName+'(inArg):'\ +'\n '+ inStr\ +'\n '+ eqMdl\ +'\n return '+ outputVar[0]['name'] + '\n\n' print(stringfun) #write the python code into a file where AutoGraph can read it self._makePyFile(stringfun) #reload the eqnModels package as a method was newly added imp.reload(eqnModels) #get the method created by the code tmp_method = getattr(eqnModels, mdlName) if self.debug: print(tmp_method) metagraphLoc = "../models/" + mdlName tf.reset_default_graph() mdl = tf.Graph() with mdl.as_default(): invars = [] for ii in range(in_dims): #create list on input variables tfType = self._getVarType(inputVar[ii]['type']) invars.append(tf.placeholder(tfType, name=inputVar[ii]['name'])) tfType = self._getVarType(outputVar[0]['type']) #create TensorFlow graph from python code tf_model = ag.to_graph(tmp_method) #obtain TensorFlow model of input and outputs output = tf_model(invars) tf.add_to_collection("output", output) for node in invars: tf.add_to_collection("input", node) #save model locally as a MetaGraph tf.train.export_meta_graph(filename=metagraphLoc + '.meta', graph=mdl) return mdl, metagraphLoc
def test_basic(self): converted = ag.to_graph(list_used_as_tuple) result = converted() with self.cached_session() as sess: self.assertAllEqual(self.evaluate(result), [1, 2, 3])