def test_convert_entity_to_ast_class_hierarchy(self):

    class TestBase(object):

      def __init__(self, x='base'):
        self.x = x

      def foo(self):
        return self.x

      def bar(self):
        return self.x

    class TestSubclass(TestBase):

      def __init__(self, y):
        super(TestSubclass, self).__init__('sub')
        self.y = y

      def foo(self):
        return self.y

      def baz(self):
        return self.y

    program_ctx = self._simple_program_ctx()
    with self.assertRaisesRegex(NotImplementedError, 'classes.*whitelisted'):
      conversion.convert_entity_to_ast(TestSubclass, program_ctx)
  def test_convert_entity_to_ast_multiple_lambdas_ambiguous_definitions(self):
    a, b = 1, 2
    f, _ = (lambda x: a * x, lambda x: b * x)

    program_ctx = self._simple_program_ctx()
    with self.assertRaises(ValueError):
      conversion.convert_entity_to_ast(f, program_ctx)
  def test_convert_entity_to_ast_multiple_lambdas_ambiguous_definitions(self):
    a, b = 1, 2
    f, _ = (lambda x: a * x, lambda x: b * x)

    program_ctx = self._simple_program_ctx()
    with self.assertRaises(ValueError):
      conversion.convert_entity_to_ast(f, program_ctx)
Beispiel #4
0
    def test_convert_entity_to_ast_class_hierarchy(self):
        class TestBase(object):
            def __init__(self, x='base'):
                self.x = x

            def foo(self):
                return self.x

            def bar(self):
                return self.x

        class TestSubclass(TestBase):
            def __init__(self, y):
                super(TestSubclass, self).__init__('sub')
                self.y = y

            def foo(self):
                return self.y

            def baz(self):
                return self.y

        program_ctx = self._simple_program_ctx()
        with self.assertRaisesRegex(NotImplementedError,
                                    'classes.*whitelisted'):
            conversion.convert_entity_to_ast(TestSubclass, program_ctx)
  def test_convert_entity_to_ast_multiple_lambdas(self):
    a, b = 1, 2
    f, _ = (lambda x: a * x, lambda y: b * y)

    program_ctx = self._simple_program_ctx()
    (fn_node,), name, entity_info = conversion.convert_entity_to_ast(
        f, program_ctx)
    self.assertIsInstance(fn_node, gast.Assign)
    self.assertIsInstance(fn_node.value, gast.Lambda)
    self.assertEqual('tf__lambda', name)
    self.assertIs(entity_info.namespace['a'], a)
Beispiel #6
0
    def test_convert_entity_to_ast_lambda(self):
        b = 2
        f = lambda x: b * x if x > 0 else -x

        program_ctx = self._simple_program_ctx()
        (fn_node, ), name, entity_info = conversion.convert_entity_to_ast(
            f, program_ctx)
        self.assertIsInstance(fn_node, gast.Assign)
        self.assertIsInstance(fn_node.value, gast.Lambda)
        self.assertEqual('tf__lambda', name)
        self.assertIs(entity_info.namespace['b'], b)
Beispiel #7
0
    def test_convert_entity_to_ast_multiple_lambdas(self):
        a, b = 1, 2
        f, _ = (lambda x: a * x, lambda y: b * y)

        program_ctx = self._simple_program_ctx()
        (fn_node, ), name, entity_info = conversion.convert_entity_to_ast(
            f, program_ctx)
        self.assertIsInstance(fn_node, gast.Assign)
        self.assertIsInstance(fn_node.value, gast.Lambda)
        self.assertEqual('tf__lambda', name)
        self.assertIs(entity_info.namespace['a'], a)
Beispiel #8
0
    def test_convert_entity_to_ast_call_tree(self):
        def g(a):
            return a

        def f(a):
            return g(a)

        program_ctx = self._simple_program_ctx()
        nodes, _, _ = conversion.convert_entity_to_ast(f, program_ctx)
        f_node, = nodes
        self.assertEqual('tf__f', f_node.name)
  def test_convert_entity_to_ast_lambda(self):
    b = 2
    f = lambda x: b * x if x > 0 else -x

    program_ctx = self._simple_program_ctx()
    (fn_node,), name, entity_info = conversion.convert_entity_to_ast(
        f, program_ctx)
    self.assertIsInstance(fn_node, gast.Assign)
    self.assertIsInstance(fn_node.value, gast.Lambda)
    self.assertEqual('tf__lambda', name)
    self.assertIs(entity_info.namespace['b'], b)
Beispiel #10
0
    def test_convert_entity_to_ast_callable(self):
        b = 2

        def f(a):
            return a + b

        program_ctx = self._simple_program_ctx()
        nodes, name, info = conversion.convert_entity_to_ast(f, program_ctx)
        fn_node, = nodes
        self.assertIsInstance(fn_node, gast.FunctionDef)
        self.assertEqual('tf__f', name)
        self.assertIs(info.namespace['b'], b)
  def test_convert_entity_to_ast_call_tree(self):

    def g(a):
      return a

    def f(a):
      return g(a)

    program_ctx = self._simple_program_ctx()
    nodes, _, _ = conversion.convert_entity_to_ast(f, program_ctx)
    f_node, = nodes
    self.assertEqual('tf__f', f_node.name)
  def test_convert_entity_to_ast_callable(self):
    b = 2

    def f(a):
      return a + b

    program_ctx = self._simple_program_ctx()
    nodes, name, info = conversion.convert_entity_to_ast(f, program_ctx)
    fn_node, = nodes
    self.assertIsInstance(fn_node, gast.FunctionDef)
    self.assertEqual('tf__f', name)
    self.assertIs(info.namespace['b'], b)
Beispiel #13
0
    def test_convert_entity_to_ast_lambda_code_with_garbage(self):
        # pylint:disable=g-long-lambda
        f = (  # intentional wrap
            lambda x: (
                x  # intentional wrap
                + 1), )[0]
        # pylint:enable=g-long-lambda

        program_ctx = self._simple_program_ctx()
        (fn_node, ), name, _ = conversion.convert_entity_to_ast(f, program_ctx)
        self.assertIsInstance(fn_node, gast.Assign)
        self.assertIsInstance(fn_node.value, gast.Lambda)
        self.assertEqual('tf__lambda', name)
  def test_convert_entity_to_ast_lambda_code_with_garbage(self):
    # pylint:disable=g-long-lambda
    f = (  # intentional wrap
        lambda x: (
            x  # intentional wrap
            + 1),)[0]
    # pylint:enable=g-long-lambda

    program_ctx = self._simple_program_ctx()
    (fn_node,), name, _ = conversion.convert_entity_to_ast(f, program_ctx)
    self.assertIsInstance(fn_node, gast.Assign)
    self.assertIsInstance(fn_node.value, gast.Lambda)
    self.assertEqual('tf__lambda', name)
Beispiel #15
0
    def test_convert_entity_to_ast_function_with_defaults(self):
        b = 2
        c = 1

        def f(a, d=c + 1):
            return a + b + d

        program_ctx = self._simple_program_ctx()
        nodes, name, _ = conversion.convert_entity_to_ast(f, program_ctx)
        fn_node, = nodes
        self.assertIsInstance(fn_node, gast.FunctionDef)
        self.assertEqual('tf__f', name)
        self.assertEqual(
            compiler.ast_to_source(fn_node.args.defaults[0]).strip(), 'None')
  def test_convert_entity_to_ast_function_with_defaults(self):
    b = 2
    c = 1

    def f(a, d=c + 1):
      return a + b + d

    program_ctx = self._simple_program_ctx()
    nodes, name, _ = conversion.convert_entity_to_ast(f, program_ctx)
    fn_node, = nodes
    self.assertIsInstance(fn_node, gast.FunctionDef)
    self.assertEqual('tf__f', name)
    self.assertEqual(
        compiler.ast_to_source(fn_node.args.defaults[0]).strip(), 'None')
Beispiel #17
0
    def test_convert_entity_to_ast_function_with_defaults(self):
        b = 2
        c = 1

        def f(a, d=c + 1):
            return a + b + d

        program_ctx = self._simple_program_ctx()
        nodes, name, _ = conversion.convert_entity_to_ast(f, program_ctx)
        fn_node, = nodes
        self.assertIsInstance(fn_node, gast.FunctionDef)
        self.assertEqual('tf__f', name)
        self.assertEqual(
            parser.unparse(fn_node.args.defaults[0],
                           include_encoding_marker=False).strip(), 'None')
Beispiel #18
0
    def test_convert_entity_to_ast_class_hierarchy_whitelisted(self):
        class TestSubclass(training.Model):
            def __init__(self, y):
                super(TestSubclass, self).__init__()
                self.built = False

            def call(self, x):
                return 3 * x

        program_ctx = self._simple_program_ctx()
        (import_node, class_node), name, _ = conversion.convert_entity_to_ast(
            TestSubclass, program_ctx)
        self.assertEqual(import_node.names[0].name, 'Model')
        self.assertEqual(name, 'TfTestSubclass')
        self.assertEqual(class_node.name, 'TfTestSubclass')
Beispiel #19
0
    def test_convert_entity_to_ast_nested_functions(self):
        b = 2

        def f(x):
            def g(x):
                return b * x

            return g(x)

        program_ctx = self._simple_program_ctx()
        (fn_node, ), name, entity_info = conversion.convert_entity_to_ast(
            f, program_ctx)
        self.assertIsInstance(fn_node, gast.FunctionDef)
        self.assertEqual(fn_node.name, 'tf__f')
        self.assertEqual('tf__f', name)
        self.assertIs(entity_info.namespace['b'], b)
  def test_convert_entity_to_ast_nested_functions(self):
    b = 2

    def f(x):

      def g(x):
        return b * x

      return g(x)

    program_ctx = self._simple_program_ctx()
    (fn_node,), name, entity_info = conversion.convert_entity_to_ast(
        f, program_ctx)
    self.assertIsInstance(fn_node, gast.FunctionDef)
    self.assertEqual(fn_node.name, 'tf__f')
    self.assertEqual('tf__f', name)
    self.assertIs(entity_info.namespace['b'], b)
  def test_convert_entity_to_ast_class_hierarchy_whitelisted(self):

    class TestSubclass(training.Model):

      def __init__(self, y):
        super(TestSubclass, self).__init__()
        self.built = False

      def call(self, x):
        return 3 * x

    program_ctx = self._simple_program_ctx()
    (import_node, class_node), name, _ = conversion.convert_entity_to_ast(
        TestSubclass, program_ctx)
    self.assertEqual(import_node.names[0].name, 'Model')
    self.assertEqual(name, 'TfTestSubclass')
    self.assertEqual(class_node.name, 'TfTestSubclass')
 def test_convert_entity_to_ast_unsupported_types(self):
   with self.assertRaises(NotImplementedError):
     program_ctx = self._simple_program_ctx()
     conversion.convert_entity_to_ast('dummy', program_ctx)
Beispiel #23
0
 def test_convert_entity_to_ast_unsupported_types(self):
     with self.assertRaises(NotImplementedError):
         program_ctx = self._simple_program_ctx()
         conversion.convert_entity_to_ast('dummy', program_ctx)