예제 #1
0
    def _wrap_to_py_func_no_return(self, node):
        func_qn = anno.getanno(node.func, anno.Basic.QN)
        args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
        wrapper_name = self.context.namer.new_symbol(func_qn.ssf(),
                                                     args_scope.referenced)
        wrapper_args = []
        for arg in node.args:
            if anno.hasanno(arg, anno.Basic.QN):
                arg_qn = anno.getanno(arg, anno.Basic.QN)
            else:
                arg_qn = qual_names.QN('arg')
            wrapper_args.append(
                self.context.namer.new_symbol(arg_qn.ssf(),
                                              args_scope.referenced))
        # TODO(mdan): Properly handle varargs, kwargs, etc.
        # TODO(mdan): This is best handled as a dynamic dispatch.
        # That way we can separate tensors from non-tensor args.
        template = """
      def wrapper(wrapper_args):
        call(wrapper_args)
        return 1
      tf.py_func(wrapper, original_args, [tf.int64])
    """
        wrapper_def, call_expr = templates.replace(template,
                                                   call=node.func,
                                                   wrapper=wrapper_name,
                                                   original_args=gast.List(
                                                       elts=node.args,
                                                       ctx=None),
                                                   wrapper_args=wrapper_args)
        anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True)

        return (wrapper_def, call_expr)
  def test_basic(self):
    a = qual_names.QN('a')
    self.assertEqual(a.qn, ('a',))
    self.assertEqual(str(a), 'a')
    self.assertEqual(a.ssf(), 'a')
    self.assertEqual(a.ast().id, 'a')
    self.assertFalse(a.is_composite())
    with self.assertRaises(ValueError):
      _ = a.parent

    a_b = qual_names.QN(a, 'b')
    self.assertEqual(a_b.qn, ('a', 'b'))
    self.assertEqual(str(a_b), 'a.b')
    self.assertEqual(a_b.ssf(), 'a_b')
    self.assertEqual(a_b.ast().value.id, 'a')
    self.assertEqual(a_b.ast().attr, 'b')
    self.assertTrue(a_b.is_composite())
    self.assertEqual(a_b.parent.qn, ('a',))

    a2 = qual_names.QN(a)
    self.assertEqual(a2.qn, ('a',))
    with self.assertRaises(ValueError):
      _ = a.parent

    a_b2 = qual_names.QN(a_b)
    self.assertEqual(a_b2.qn, ('a', 'b'))
    self.assertEqual(a_b2.parent.qn, ('a',))

    self.assertTrue(a2 == a)
    self.assertFalse(a2 is a)

    self.assertTrue(a_b.parent == a)
    self.assertTrue(a_b2.parent == a)

    self.assertTrue(a_b2 == a_b)
    self.assertFalse(a_b2 is a_b)
    self.assertFalse(a_b2 == a)

    with self.assertRaises(ValueError):
      qual_names.QN('a', 'b')
예제 #3
0
    def test_rename_symbols(self):
        node = ast.Tuple([
            ast.Name('a', ast.Load()),
            ast.Name('b', ast.Load()),
            ast.Attribute(ast.Name('b', None), 'c', ast.Store()),
            ast.Attribute(ast.Attribute(ast.Name('b', None), 'c', ast.Load()),
                          'd', None)
        ], None)
        node = qual_names.resolve(node)
        node = ast_util.rename_symbols(
            node, {
                qual_names.QN('a'): qual_names.QN('renamed_a'),
                qual_names.QN('b.c'): qual_names.QN('renamed_b_c'),
            })

        self.assertEqual(node.elts[0].id, 'renamed_a')
        self.assertTrue(isinstance(node.elts[0].ctx, ast.Load))
        self.assertEqual(node.elts[1].id, 'b')
        self.assertEqual(node.elts[2].id, 'renamed_b_c')
        self.assertTrue(isinstance(node.elts[2].ctx, ast.Store))
        self.assertEqual(node.elts[3].value.id, 'renamed_b_c')
        self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
예제 #4
0
    def visit_Expr(self, node):
        self.generic_visit(node)
        if isinstance(node.value, gast.Call):
            # Patterns of single function calls, like:
            #   opt.minimize(loss)
            # or:
            #   tf.py_func(...)

            # First, attempt to gate future evaluation of args. If that's not
            # possible, gate all remaining statements (and that may fail too, see
            # _visit_and_reindent.
            args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
            # NOTE: We can't guard object attributes because they may not be writable.
            guarded_args = tuple(s for s in args_scope.used
                                 if not s.is_composite())

            # TODO(mdan): Include all arguments which depended on guarded_args too.
            # For example, the following will still cause a race:
            #   tf.assign(a, a + 1)
            #   b = a + 1
            #   tf.assign(a, a + 1)  # Control deps here should include `b`
            #   c = b + 1
            # Or maybe we should just raise an "unsafe assign" error?

            if guarded_args:
                # The aliases may need new names to avoid incorrectly making them local.
                # TODO(mdan): This is brutal. It will even rename modules - any fix?
                need_alias = tuple(s for s in guarded_args
                                   if s not in args_scope.parent.modified)
                aliased_new_names = tuple(
                    qual_names.QN(
                        self.context.namer.new_symbol(
                            s.ssf(), args_scope.parent.referenced))
                    for s in need_alias)
                alias_map = dict(zip(need_alias, aliased_new_names))
                if len(guarded_args) == 1:
                    s, = guarded_args
                    aliased_guarded_args = alias_map.get(s, s)
                else:
                    aliased_guarded_args = gast.Tuple(
                        [alias_map.get(s, s).ast() for s in guarded_args],
                        None)

                template = """
          with py2tf_utils.control_dependency_on_returns(tf, call):
            aliased_guarded_args = py2tf_utils.alias_tensors(tf, guarded_args)
        """
                control_deps_guard = templates.replace(
                    template,
                    call=node.value,
                    aliased_guarded_args=aliased_guarded_args,
                    guarded_args=guarded_args)[-1]
            else:
                alias_map = {}

                template = """
          with py2tf_utils.control_dependency_on_returns(tf, call):
            pass
        """
                control_deps_guard = templates.replace(template,
                                                       call=node.value)[-1]
                control_deps_guard.body = []

            node = control_deps_guard
            anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER,
                         (node.body, alias_map))
        return node
  def test_hashable(self):
    d = {qual_names.QN('a'): 'a', qual_names.QN('b'): 'b'}

    self.assertEqual(d[qual_names.QN('a')], 'a')
    self.assertEqual(d[qual_names.QN('b')], 'b')
    self.assertTrue(qual_names.QN('c') not in d)