Beispiel #1
0
def reverse_ad(node, wrt, preserve_result):
  """Perform reverse-mode AD on an AST.

  This function analyses the AST to determine which variables are active and
  proceeds by taking the naive derivative. Before returning the primal and
  adjoint it annotates push and pop statements as such.

  Args:
    node: A `FunctionDef` AST node.
    wrt: A tuple of argument indices with respect to which we take the
        derivative.
    preserve_result: A boolean indicating whether the generated
        derivative function should also return the original return value.

  Returns:
    mod: A `Module` node containing the naive primal and adjoint of the
        function which can be fed to the `split` and `joint` functions.
    required: A list of tuples of functions and argument indices. These
        functions were called by the function but did not have an adjoint.
  """
  if not isinstance(node, gast.FunctionDef):
    raise TypeError
  # Activity analysis
  cfg.forward(node, cfg.Active(wrt))

  ad = ReverseAD(wrt, preserve_result)
  pri, adj = ad.visit(node)
  mod = gast.Module(body=[pri, adj])
  mod = annotate.find_stacks(mod)
  return mod, ad.required, ad.stack
Beispiel #2
0
    def prepare(self, node):
        assert isinstance(node, ast.Module)
        self.env = {
            'builtins': __import__('builtins'),
        }

        for module_name in MODULES:
            # __dispatch__ is the only fake top-level module
            if module_name != '__dispatch__':
                import_name = module_name

                alias_module_name = mangle(module_name)
                self.env[alias_module_name] = __import__(import_name)

                # handle functions conflicting with c++ keywords
                for fun in MODULES[module_name]:
                    if fun in ("__theitemgetter__", "pythran"):
                        # these ones do not exist in Python
                        continue

        # we need to parse the whole code to be able to apply user-defined pure
        # function but import are resolved before so we remove them to avoid
        # ImportError (for operator_ for example)
        dummy_module = ast.Module([s for s in node.body
                                   if not isinstance(s, ast.Import)],
                                  [])
        eval(compile(ast.gast_to_ast(dummy_module),
                     '<constant_folding>', 'exec'),
             self.env)

        super(ConstantFolding, self).prepare(node)
Beispiel #3
0
def to_graph(o, arg_value_hints=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    o: A Python function or class.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap()
  _, name = conversion.object_to_graph(o, conversion_map, arg_value_hints)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.append(parser.parse_str(import_line))
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(o):
    compiled_node.__dict__.update(six.get_function_globals(o))

  compiled_fn = getattr(compiled_node, name)
  return compiled_fn
Beispiel #4
0
def filter_code_typevars(module, duc, ancestors):
    """Create a filtered code with what is needed to create the annotations"""

    module_filtered = ast.Module()
    kept = module_filtered.body = []
    module_filtered.type_ignores = []
    suppressed = set()

    def fill_suppressed(def_):
        for user in def_.users():
            parent_in_body = ancestors.parents(user.node)[1]
            suppressed.add(parent_in_body)
            fill_suppressed(user)

    for node in module.body:
        if node in suppressed:
            continue

        if isinstance(node, ast.Import):
            if node.names[0].name in ["transonic", "numpy"]:
                kept.append(node)
            else:
                def_ = duc.chains[node.names[0]]
                fill_suppressed(def_)
            #     suppressed.add()
        elif isinstance(node, ast.ImportFrom):
            if node.module in ["transonic", "numpy"]:
                kept.append(node)

        elif isinstance(node, (ast.Assign, ast.AugAssign)):
            kept.append(node)

    return extast.unparse(module_filtered)
Beispiel #5
0
    def prepare(self, node):
        assert isinstance(node, ast.Module)
        self.env = {
            'builtins': __import__('builtins'),
        }

        for module_name in MODULES:
            # __dispatch__ is the only fake top-level module
            if module_name != '__dispatch__':
                alias_module_name = mangle(module_name)
                try:
                    self.env[alias_module_name] = __import__(module_name)
                except ImportError:
                    pass

        # we need to parse the whole code to be able to apply user-defined pure
        # function but import are resolved before so we remove them to avoid
        # ImportError (for operator_ for example)
        dummy_module = ast.Module([s for s in node.body
                                   if not isinstance(s, ast.Import)],
                                  [])
        eval(compile(ast.gast_to_ast(dummy_module),
                     '<constant_folding>', 'exec'),
             self.env)

        super(ConstantFolding, self).prepare(node)
Beispiel #6
0
def to_graph(f, arg_value_hints=None):
    """Compile a Python function into equivalent TensorFlow code.

  Args:
    f: A Python function with arbitrary arguments and return values.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.

  Returns:
    A function with a signature identical to `f`, but which when executed it
  creates TF a graph that has the same functionality as the original function.
  """
    conversion_map = conversion.ConversionMap()
    _, name = conversion.object_to_graph(f, conversion_map, arg_value_hints)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.append(parser.parse_str(import_line))
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    compiled_node.__dict__.update(six.get_function_globals(f))

    compiled_fn = getattr(compiled_node, name)
    return compiled_fn
Beispiel #7
0
    def visit_Lambda(self, node):
        op = issimpleoperator(node)
        if op is not None:
            if mangle('operator') not in self.global_declarations:
                import_ = ast.Import(
                    [ast.alias('operator', mangle('operator'))])
                self.imports.append(import_)
                operator_module = MODULES['operator']
                self.global_declarations[mangle('operator')] = operator_module
            return ast.Attribute(
                ast.Name(mangle('operator'), ast.Load(), None, None), op,
                ast.Load())

        self.generic_visit(node)
        forged_name = "{0}_lambda{1}".format(self.prefix,
                                             len(self.lambda_functions))

        ii = self.gather(ImportedIds, node)
        ii.difference_update(self.lambda_functions)  # remove current lambdas

        binded_args = [
            ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii)
        ]
        node.args.args = (
            [ast.Name(iin, ast.Param(), None, None)
             for iin in sorted(ii)] + node.args.args)
        for patternname, pattern in self.patterns.items():
            if issamelambda(pattern, node):
                proxy_call = ast.Name(patternname, ast.Load(), None, None)
                break
        else:
            duc = ExtendedDefUseChains()
            nodepattern = deepcopy(node)
            duc.visit(ast.Module([ast.Expr(nodepattern)], []))
            self.patterns[forged_name] = nodepattern, duc

            forged_fdef = ast.FunctionDef(forged_name, copy(node.args),
                                          [ast.Return(node.body)], [], None,
                                          None)
            metadata.add(forged_fdef, metadata.Local())
            self.lambda_functions.append(forged_fdef)
            self.global_declarations[forged_name] = forged_fdef
            proxy_call = ast.Name(forged_name, ast.Load(), None, None)

        if binded_args:
            if MODULES['functools'] not in self.global_declarations.values():
                import_ = ast.Import(
                    [ast.alias('functools', mangle('functools'))])
                self.imports.append(import_)
                functools_module = MODULES['functools']
                self.global_declarations[mangle(
                    'functools')] = functools_module

            return ast.Call(
                ast.Attribute(
                    ast.Name(mangle('functools'), ast.Load(), None, None),
                    "partial", ast.Load()), [proxy_call] + binded_args, [])
        else:
            return proxy_call
Beispiel #8
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, do_not_convert, converted_call),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    _, name = conversion.entity_to_graph(e, conversion_map, arg_values,
                                         arg_types)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.extend(parser.parse_str(import_line).body)
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node, compiled_src = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    if tf_inspect.isfunction(e):
        for key, val in inspect_utils.getnamespace(e).items():
            # Avoid overwriting entities that have been transformed.
            if key not in compiled_node.__dict__:
                compiled_node.__dict__[key] = val
    compiled_fn = getattr(compiled_node, name)

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled_fn
Beispiel #9
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recursively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
                                                  arg_types)

  module = gast.Module([])
  for dep in reversed(program_ctx.dependency_cache.values()):
    module.body.append(dep)
  compiled_node, compiled_src = compiler.ast_to_object(
      module, source_prefix=program_ctx.required_imports)

  # The compiled code should see everything the entry entity saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  for key, val in namespace.items():
    # Avoid overwriting entities that have been transformed.
    if key not in compiled_node.__dict__:
      compiled_node.__dict__[key] = val
  compiled_fn = getattr(compiled_node, name)

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
Beispiel #10
0
def autodiff_tree(func, wrt, motion, mode, preserve_result, check_dims,
                  verbose):
  """Perform AD on all functions in a call tree.

  This function walks the call tree and differentiates each function in it. It
  also ensures that the global namespaces that each function in the call tree
  was in are merged.

  The `tangent` and `numpy` packages are added to the namespace here, so that
  the gradient templates can assume that they are present.

  Args:
    See `grad`.

  Returns:
    final: A single module which contains the primals and adjoints of all the
        functions in the call tree.
    namespace: A merged dictionary with all the variables in the global
        namespaces of each function. The primals and adjoints need access to
        these in order to execute.
  """
  # Imported here to avoid circular imports
  import tangent
  namespace = {'tangent': tangent, 'numpy': numpy}

  done = set()
  final = gast.Module(body=[])
  namespace.update(six.get_function_globals(func))

  node, required = autodiff_ast(func, wrt, motion, mode, preserve_result,
                                check_dims, verbose)
  final.body.extend(node.body)

  to_do = set(required)
  if motion == 'split' and mode == 'reverse':
    done.add((func, wrt))
    to_do -= done

  while to_do:
    func, wrt = to_do.pop()
    namespace.update(six.get_function_globals(func))

    node, required = autodiff_ast(
        func=func,
        wrt=wrt,
        motion='split',
        mode=mode,
        preserve_result=True,
        check_dims=False,
        verbose=verbose)

    final.body.extend(node.body)
    done.add((func, wrt))
    to_do.update(required)
    to_do -= done

  return final, namespace
Beispiel #11
0
    def visit_If(self, node):
        """Intercepts if statements.

    Converts each `if` to up to two separate `with` statements,
    `ProgramBuilder.if_(condition_variable)` and `ProgramBuilder.else_()`.  If
    the incoming `if` had one arm, returns the transformed AST node; if it had
    two, returns two nodes in a list.

    Args:
      node: An `ast.AST` node representing the `if` statement to convert.

    Returns:
      then_node: A node representing the `with`-guarded consequent branch.
      else_node: A node representing the `with`-guarded alternate branch,
        if present.
    """
        # Transform a branch
        # NOTE: this is a little hackery to make sure that prepending works
        # properly. Wrapping a list of statements in a Module ensures
        # that the AST-visiting machinery won't choke on, e.g., a list.
        then = self.generic_visit(gast.Module(node.body)).body

        # Construct header (goes in the `with`s).
        then_header = templates.replace_as_expression(
            '_tfp_autobatching_context_.if_(cond)',
            cond=self._to_reference(node.test))

        # Construct `with` node.
        # TODO(axch): Test that this form actually works with multiline bodies.
        then_node = templates.replace('with header: body',
                                      header=then_header,
                                      body=then)[0]

        if node.orelse:
            orelse = self.generic_visit(gast.Module(node.orelse)).body
            orelse_header = templates.replace_as_expression(
                '_tfp_autobatching_context_.else_()')
            orelse_node = templates.replace('with header: body',
                                            header=orelse_header,
                                            body=orelse)[0]
            # Return both
            return [then_node, orelse_node]
        else:
            return then_node
 def build_example(size):
     tree = gast.Module(
         body=[gast.Constant(value=i, kind=None) for i in range(size)],
         type_ignores=[])
     py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree))
     edges = []
     for i in range(1, size, 2):
         edges.append((ast_to_node_id[id(tree.body[i])],
                       ast_to_node_id[id(tree.body[i - 1])], 1))
     return graph_bundle.convert_graph_with_edges(
         py_graph, edges, py_ast_graphs.BUILDER)
Beispiel #13
0
 def root_build(body):
     """Given a list of statements, puts them into a function in a module."""
     return gast.Module(body=[
         gast.FunctionDef(name="random_function",
                          args=_make_arguments(
                              python_numbers_control_flow.make_name("a"),
                              python_numbers_control_flow.make_name("b")),
                          body=body,
                          decorator_list=[],
                          returns=None,
                          type_comment=None)
     ],
                        type_ignores=[])
Beispiel #14
0
def to_graph(e,
             recursive=True,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, graph_ready, convert_inline),
        partial_types=partial_types)
    _, name = conversion.entity_to_graph(e, conversion_map, arg_values,
                                         arg_types)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.append(parser.parse_str(import_line))
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    if tf_inspect.isfunction(e):
        compiled_node.__dict__.update(six.get_function_globals(e))

    compiled_fn = getattr(compiled_node, name)
    return compiled_fn
Beispiel #15
0
def joint(node):
  """Merge the bodies of primal and adjoint into a single function.

  Args:
    node: A module with the primal and adjoint function definitions as returned
        by `reverse_ad`.

  Returns:
    func: A `Module` node with a single function definition containing the
        combined primal and adjoint.
  """
  node, _, _ = _fix(node)
  body = node.body[0].body[:-1] + node.body[1].body
  func = gast.Module(body=[gast.FunctionDef(
      name=node.body[0].name, args=node.body[1].args, body=body,
      decorator_list=[], returns=None)])
  # Clean up
  anno.clearanno(func)
  return func
Beispiel #16
0
    def prepare(self, node, ctx):
        assert isinstance(node, ast.Module)
        self.env = {
            '__builtin__': __import__('__builtin__'),
        }

        for module_name in MODULES:
            # __dispatch__ is the only fake top-level module
            if module_name != '__dispatch__':
                import_name = module_name

                # handle module name conflicting with c++ keywords
                if (module_name.endswith("_")
                        and module_name[:-1] in cxx_keywords):
                    import_name = module_name[:-1]
                alias_module_name = mangle(module_name)
                self.env[alias_module_name] = __import__(import_name)

                # handle functions conflicting with c++ keywords
                for fun in MODULES[module_name]:
                    if fun in ("__theitemgetter__", "pythran"):
                        # these ones do not exist in Python
                        continue
                    # Set attributs pointing to another for C++ keyword
                    # case of __builtin__.int_ that point on __builtin__.int
                    if not hasattr(self.env[alias_module_name], fun):
                        setattr(
                            self.env[alias_module_name], fun,
                            getattr(self.env[alias_module_name],
                                    fun.strip("_")))

        # we need to parse the whole code to be able to apply user-defined pure
        # function but import are resolved before so we remove them to avoid
        # ImportError (for operator_ for example)
        dummy_module = ast.Module(
            [s for s in node.body if not isinstance(s, ast.Import)])
        eval(
            compile(ast.gast_to_ast(dummy_module), '<constant_folding>',
                    'exec'), self.env)

        super(ConstantFolding, self).prepare(node, ctx)
Beispiel #17
0
def forward_ad(node, wrt, preserve_result=False, check_dims=True):
    """Perform forward-mode AD on an AST.

  This function analyses the AST to determine which variables are active and
  proceeds by taking the naive derivative. Before returning the primal and
  adjoint it annotates push and pop statements as such.

  Args:
    node: A `FunctionDef` AST node.
    wrt: A tuple of argument indices with respect to which we take the
        derivative.
    preserve_result: A boolean indicating whether the original
        non-differentiated function value should be returned
    check_dims: A boolean indicating whether the provided derivatives should
        have the same shape as their corresponding arguments.

  Returns:
    mod: A `Module` node containing the naive primal and adjoint of the
        function which can be fed to the `split` and `joint` functions.
    required: A list of tuples of functions and argument indices. These
        functions were called by the function but did not have an adjoint.
  """
    if not isinstance(node, gast.FunctionDef):
        raise TypeError

    # Activity analysis
    cfg_obj = cfg.CFG.build_cfg(node)
    cfg.Active(range(len(node.args.args))).visit(cfg_obj.entry)

    # Build forward mode function
    fad = ForwardAD(wrt, preserve_result, check_dims)
    node = fad.visit(node)

    # Annotate stacks
    node = annotate.find_stacks(node)

    # Clean up naive forward-mode fcode
    node = gast.Module([node])
    anno.clearanno(node)

    return node, fad.required
 def test_usub(self):
     orig_ast = gast.ast_to_gast(ast.parse("-3"))
     target_ast = gast.Module(body=[gast.Expr(value=gast.Num(n=-3))])
     assert compare_ast(self.canonicalizer.visit(orig_ast), target_ast)
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        node, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            owner_type=c)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = node
    namer = program_ctx.new_namer(class_namespace)
    class_name = namer.compiled_class_name(c.__name__, c)

    # TODO(mdan): This needs to be explained more thoroughly.
    # Process any base classes: if the sueprclass if of a whitelisted type, an
    # absolute import line is generated. Otherwise, it is marked for conversion
    # (as a side effect of the call to namer.compiled_class_name() followed by
    # program_ctx.update_name_map(namer)).
    output_nodes = []
    renames = {}
    bases = []
    for base in c.__bases__:
        if isinstance(object, base):
            bases.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            # This will trigger a conversion into a class with this name.
            alias = namer.compiled_class_name(base.__name__, base)
        bases.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
    program_ctx.update_name_map(namer)

    # Generate the definition of the converted class.
    output_nodes.append(
        gast.ClassDef(class_name,
                      bases=bases,
                      keywords=[],
                      body=list(converted_members.values()),
                      decorator_list=[]))
    node = gast.Module(output_nodes)

    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    node = qual_names.resolve(node)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    node = ast_util.rename_symbols(node, renames)

    return node, class_name, class_namespace
Beispiel #20
0
 def visit_Module(self, node):
     new_node = gast.Module(
         self._visit(node.body),
         []  # type_ignores
     )
     return new_node
 def visit_Module(self, node):
     """ Add itertools import for imap, izip or ifilter iterator. """
     self.generic_visit(node)
     importIt = ast.Import(names=[ast.alias(name='itertools', asname=None)])
     return ast.Module(body=([importIt] + node.body))
 def test_usub(self):
     orig_ast = gast.ast_to_gast(ast.parse("-3"))
     target_ast = gast.Module(
         body=[gast.Expr(value=gast.Constant(value=-3, kind=None))],
         type_ignores=[])
     assert compare_ast(self.canonicalizer.visit(orig_ast), target_ast)