Exemplo n.º 1
0
  def test_runtime_error_rewriting(self):

    def g(x, s):
      while tf.reduce_sum(x) > s:
        x //= 0
      return x

    def test_fn(x):
      return g(x, 10)

    compiled_fn = ag.to_graph(test_fn)

    with self.assertRaises(ag.TfRuntimeError) as error:
      with self.cached_session() as sess:
        x = compiled_fn(tf.constant([4, 8]))
        with ag.improved_errors(compiled_fn):
          sess.run(x)
    expected = error.exception
    custom_traceback = expected.custom_traceback
    found_correct_filename = False
    num_test_fn_frames = 0
    num_g_frames = 0
    for frame in custom_traceback:
      filename, _, fn_name, source_code = frame
      self.assertFalse('/tmp/' in filename)
      self.assertFalse('control_flow.py' in filename)
      self.assertFalse('ag__.' in fn_name)
      found_correct_filename |= __file__ in filename
      num_test_fn_frames += int('test_fn' == fn_name and
                                'return g(x, 10)' in source_code)
      num_g_frames += int('g' == fn_name and 'x //= 0' in source_code)
    self.assertTrue(found_correct_filename)
    self.assertEqual(num_test_fn_frames, 1)
    self.assertEqual(num_g_frames, 1)
Exemplo n.º 2
0
  def test_graph_construction_error_rewriting_call_tree(self):

    def test_fn():
      return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)

    def inner_caller():
      return test_fn()

    def caller():
      return inner_caller()

    with self.assertRaises(ag.GraphConstructionError) as error:
      graph = ag.to_graph(caller)
      graph()
    expected = error.exception
    custom_traceback = expected.custom_traceback
    found_correct_filename = False
    num_test_fn_names = 0
    num_inner_caller_names = 0
    num_caller_names = 0
    for frame in custom_traceback:
      filename, _, fn_name, _ = frame
      self.assertFalse('/tmp/' in filename)
      found_correct_filename |= __file__ in filename
      self.assertNotEqual('tf__test_fn', fn_name)
      num_test_fn_names += int('test_fn' == fn_name)
      self.assertNotEqual('tf__inner_caller', fn_name)
      num_inner_caller_names += int('inner_caller' == fn_name)
      self.assertNotEqual('tf__caller', fn_name)
      num_caller_names += int('caller' == fn_name)
    self.assertTrue(found_correct_filename)
    self.assertEqual(num_test_fn_names, 1)
    self.assertEqual(num_inner_caller_names, 1)
    self.assertEqual(num_caller_names, 1)
Exemplo n.º 3
0
    def test_runtime_error_rewriting(self):
        def g(x, s):
            while tf.reduce_sum(x) > s:
                x //= 0
            return x

        def test_fn(x):
            return g(x, 10)

        compiled_fn = ag.to_graph(test_fn)

        with self.assertRaises(ag.TfRuntimeError) as error:
            with self.cached_session() as sess:
                x = compiled_fn(tf.constant([4, 8]))
                with ag.improved_errors(compiled_fn):
                    sess.run(x)
        expected = error.exception
        custom_traceback = expected.custom_traceback
        found_correct_filename = False
        num_test_fn_frames = 0
        num_g_frames = 0
        for frame in custom_traceback:
            filename, _, fn_name, source_code = frame
            self.assertFalse('/tmp/' in filename)
            self.assertFalse('control_flow.py' in filename)
            self.assertFalse('ag__.' in fn_name)
            found_correct_filename |= __file__ in filename
            num_test_fn_frames += int('test_fn' == fn_name
                                      and 'return g(x, 10)' in source_code)
            num_g_frames += int('g' == fn_name and 'x //= 0' in source_code)
        self.assertTrue(found_correct_filename)
        self.assertEqual(num_test_fn_frames, 1)
        self.assertEqual(num_g_frames, 1)
Exemplo n.º 4
0
  def test_runtime_error_rewriting_nested(self):

    def test_fn(x):

      def g(y):
        return y**2 // 0

      s = 0
      for xi in x:
        s += g(xi)
      return s

    compiled_fn = ag.to_graph(test_fn)

    # TODO(b/111408261): Nested functions currently do not rewrite correctly,
    # when they do we should change this test to check for the same traceback
    # properties as the other tests.  This should throw a runtime error with a
    # frame with "g" as the function name but because we don't yet add
    # try/except blocks to inner functions the name is "tf__g".
    with self.assertRaises(ag.TfRuntimeError) as error:
      with self.cached_session() as sess:
        x = compiled_fn(tf.constant([4, 8]))
        with ag.improved_errors(compiled_fn):
          sess.run(x)
    expected = error.exception
    custom_traceback = expected.custom_traceback
    num_tf_g_frames = 0
    for frame in custom_traceback:
      _, _, fn_name, _ = frame
      self.assertNotEqual('g', fn_name)
      num_tf_g_frames += int('tf__g' == fn_name)
    self.assertEqual(num_tf_g_frames, 1)
Exemplo n.º 5
0
    def test_graph_construction_error_rewriting_call_tree(self):
        def test_fn():
            return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)

        def inner_caller():
            return test_fn()

        def caller():
            return inner_caller()

        with self.assertRaises(ag.GraphConstructionError) as error:
            graph = ag.to_graph(caller)
            graph()
        expected = error.exception
        custom_traceback = expected.custom_traceback
        found_correct_filename = False
        num_test_fn_names = 0
        num_inner_caller_names = 0
        num_caller_names = 0
        for frame in custom_traceback:
            filename, _, fn_name, _ = frame
            self.assertFalse('/tmp/' in filename)
            found_correct_filename |= __file__ in filename
            self.assertNotEqual('tf__test_fn', fn_name)
            num_test_fn_names += int('test_fn' == fn_name)
            self.assertNotEqual('tf__inner_caller', fn_name)
            num_inner_caller_names += int('inner_caller' == fn_name)
            self.assertNotEqual('tf__caller', fn_name)
            num_caller_names += int('caller' == fn_name)
        self.assertTrue(found_correct_filename)
        self.assertEqual(num_test_fn_names, 1)
        self.assertEqual(num_inner_caller_names, 1)
        self.assertEqual(num_caller_names, 1)
Exemplo n.º 6
0
    def test_runtime_error_rewriting_nested(self):
        def test_fn(x):
            def g(y):
                return y**2 // 0

            s = 0
            for xi in x:
                s += g(xi)
            return s

        compiled_fn = ag.to_graph(test_fn)

        # TODO(b/111408261): Nested functions currently do not rewrite correctly,
        # when they do we should change this test to check for the same traceback
        # properties as the other tests.  This should throw a runtime error with a
        # frame with "g" as the function name but because we don't yet add
        # try/except blocks to inner functions the name is "tf__g".
        with self.assertRaises(ag.TfRuntimeError) as error:
            with self.cached_session() as sess:
                x = compiled_fn(tf.constant([4, 8]))
                with ag.improved_errors(compiled_fn):
                    sess.run(x)
        expected = error.exception
        custom_traceback = expected.custom_traceback
        num_tf_g_frames = 0
        for frame in custom_traceback:
            _, _, fn_name, _ = frame
            self.assertNotEqual('g', fn_name)
            num_tf_g_frames += int('tf__g' == fn_name)
        self.assertEqual(num_tf_g_frames, 1)
Exemplo n.º 7
0
    def test_graph_construction_error_rewriting_class(self):
        class TestClass(object):
            def test_fn(self):
                return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)

            def inner_caller(self):
                return self.test_fn()

            def caller(self):
                return self.inner_caller()

        # Note we expect a TypeError here because the traceback will not be
        # rewritten for classes.
        with self.assertRaises(TypeError):
            graph = ag.to_graph(TestClass)
            graph().caller()
Exemplo n.º 8
0
  def test_graph_construction_error_rewriting_class(self):

    class TestClass(object):

      def test_fn(self):
        return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)

      def inner_caller(self):
        return self.test_fn()

      def caller(self):
        return self.inner_caller()

    # Note we expect a TypeError here because the traceback will not be
    # rewritten for classes.
    with self.assertRaises(TypeError):
      graph = ag.to_graph(TestClass)
      graph().caller()
  def test_basic(self):
    converted = ag.to_graph(list_used_as_tuple)
    result = converted()

    with self.cached_session() as sess:
      self.assertAllEqual(self.evaluate(result), [1, 2, 3])
Exemplo n.º 10
0
    def _createEqnModel(self, inputVar, outputVar, mdlName, eqMdl):
        """
        Build a K-CHAIN model using input and output variables from the KG and 
        the physics equation.

        Arguments:
            inputVar (JSON array):
                array of JSON variable objects with name, type, value, and unit fields
            outputVar (JSON array):
                array of JSON variable objects with name, type, value, and unit fields
            mdlName (string):
                Name to assign to the final model (E.g.: 'Newtons2ndLaw')
            eqMdl (string):
                Equation relating inputs to output (E.g.: "c = a * b")
        
        Returns:
            (TensorFlow Graph, string):                
                * TensorFlow Graph: Computational graph of the physics equation
                * metagraphLoc: string of location on disk where computational model was stored
            
        """
        in_dims = len(inputVar)

        inStr = inputVar[0]['name'] + ' = inArg[0]'
        for ii in range(1, in_dims):
            inStr = inStr + '\n    ' + inputVar[ii][
                'name'] + ' = inArg[' + str(ii) + ']'

        #4 spaces is ideal for indentation
        #construct the python function around the python snippet
        stringfun = 'import tensorflow as tf'\
        +'\ndef '+mdlName+'(inArg):'\
        +'\n    '+ inStr\
        +'\n    '+ eqMdl\
        +'\n    return '+ outputVar[0]['name'] + '\n\n'

        print(stringfun)

        #write the python code into a file where AutoGraph can read it
        self._makePyFile(stringfun)

        #reload the eqnModels package as a method was newly added
        imp.reload(eqnModels)

        #get the method created by the code
        tmp_method = getattr(eqnModels, mdlName)

        if self.debug:
            print(tmp_method)

        metagraphLoc = "../models/" + mdlName

        tf.reset_default_graph()
        mdl = tf.Graph()

        with mdl.as_default():
            invars = []
            for ii in range(in_dims):
                #create list on input variables
                tfType = self._getVarType(inputVar[ii]['type'])
                invars.append(tf.placeholder(tfType,
                                             name=inputVar[ii]['name']))

            tfType = self._getVarType(outputVar[0]['type'])

            #create TensorFlow graph from python code
            tf_model = ag.to_graph(tmp_method)

            #obtain TensorFlow model of input and outputs
            output = tf_model(invars)

            tf.add_to_collection("output", output)

            for node in invars:
                tf.add_to_collection("input", node)

            #save model locally as a MetaGraph
            tf.train.export_meta_graph(filename=metagraphLoc + '.meta',
                                       graph=mdl)

        return mdl, metagraphLoc
Exemplo n.º 11
0
    def test_basic(self):
        converted = ag.to_graph(list_used_as_tuple)
        result = converted()

        with self.cached_session() as sess:
            self.assertAllEqual(self.evaluate(result), [1, 2, 3])