def parse_entity(entity, future_features): """Returns the AST and source code of given entity. Args: entity: Any, Python function/method/class future_features: Iterable[Text], future features to use (e.g. 'print_statement'). See https://docs.python.org/2/reference/simple_stmts.html#future Returns: gast.AST, Text: the parsed AST node; the source code that was parsed to generate the AST (including any prefixes that this function may have added). """ if inspect_utils.islambda(entity): return _parse_lambda(entity) try: original_source = inspect_utils.getimmediatesource(entity) except (IOError, OSError) as e: raise ValueError( 'Unable to locate the source code of {}. Note that functions defined' ' in certain environments, like the interactive Python shell, do not' ' expose their source code. If that is the case, you should define' ' them in a .py source file. If you are certain the code is' ' graph-compatible, wrap the call using' ' @tf.autograph.experimental.do_not_convert. Original error: {}'.format( entity, e)) source = dedent_block(original_source) future_statements = tuple( 'from __future__ import {}'.format(name) for name in future_features) source = '\n'.join(future_statements + (source,)) return parse(source, preamble_len=len(future_features)), source
def compiled_function_name(self, original_fqn, live_entity=None, owner_type=None): """See call_trees.FunctionNamer.compiled_function_name.""" if not self.recursive: return None, False if (live_entity is not None and inspect_utils.islambda(live_entity)): return None, False if owner_type is not None and owner_type not in self.partial_types: # Members are not renamed when part of an entire converted class. return None, False if live_entity is not None and live_entity in self.renamed_calls: return self.renamed_calls[live_entity], True canonical_name = self._as_symbol_name( original_fqn, style=_NamingStyle.SNAKE) new_name_root = 'tf__%s' % canonical_name new_name = new_name_root n = 0 while new_name in self.global_namespace: n += 1 new_name = '%s_%d' % (new_name_root, n) if live_entity is not None: self.renamed_calls[live_entity] = new_name self.generated_names.add(new_name) return new_name, True
def compiled_function_name(self, original_fqn, live_entity=None, owner_type=None): """See call_trees.FunctionNamer.compiled_function_name.""" if not self.recursive: return None, False if (live_entity is not None and inspect_utils.islambda(live_entity)): return None, False if owner_type is not None and owner_type not in self.partial_types: # Members are not renamed when part of an entire converted class. return None, False if live_entity is not None and live_entity in self.renamed_calls: return self.renamed_calls[live_entity], True canonical_name = self._as_symbol_name(original_fqn, style=_NamingStyle.SNAKE) new_name_root = 'tf__%s' % canonical_name new_name = new_name_root n = 0 while new_name in self.global_namespace: n += 1 new_name = '%s_%d' % (new_name_root, n) if live_entity is not None: self.renamed_calls[live_entity] = new_name self.generated_names.add(new_name) return new_name, True
def compiled_function_name(self, original_fqn, live_entity=None, owner_type=None): if inspect_utils.islambda(live_entity): return None, False if owner_type is not None: return None, False return ('renamed_%s' % '_'.join(original_fqn)), True
def _should_compile(self, node, fqn): """Determines whether an entity should be compiled in the context.""" # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. module_name = fqn[0] for mod in self.ctx.program.uncompiled_modules: if module_name.startswith(mod[0] + '.'): return False for i in range(1, len(fqn)): if fqn[:i] in self.ctx.program.uncompiled_modules: return False target_entity = self._try_resolve_target(node.func) if target_entity is not None: # Currently, lambdas are always converted. # TODO(mdan): Allow markers of the kind f = ag.do_not_convert(lambda: ...) if inspect_utils.islambda(target_entity): return True # This may be reached when "calling" a callable attribute of an object. # For example: # # self.fc = tf.keras.layers.Dense() # self.fc() # for mod in self.ctx.program.uncompiled_modules: if target_entity.__module__.startswith(mod[0] + '.'): return False # Inspect the target function decorators. If any include a @convert # or @do_not_convert annotation, then they must be called as they are. # TODO(mdan): This may be quite heavy. Perhaps always dynamically convert? # To parse and re-analyze each function for every call site could be quite # wasteful. Maybe we could cache the parsed AST? try: target_node, _ = parser.parse_entity(target_entity) target_node = target_node.body[0] except TypeError: # Functions whose source we cannot access are compilable (e.g. wrapped # to py_func). return True # This attribute is set when the decorator was applied before the # function was parsed. See api.py. if hasattr(target_entity, '__ag_compiled'): return False for dec in target_node.decorator_list: decorator_fn = self._resolve_decorator_name(dec) if (decorator_fn is not None and decorator_fn in self.ctx.program.options.strip_decorators): return False return True
def test_islambda(self): def test_fn(): pass self.assertTrue(inspect_utils.islambda(lambda x: x)) self.assertFalse(inspect_utils.islambda(test_fn))
def _errors_are_normally_possible(entity, error): if inspect_utils.islambda(entity) and isinstance(error, ValueError): return True return False
def test_islambda_renamed_lambda(self): l = lambda x: 1 l.__name__ = 'f' self.assertTrue(inspect_utils.islambda(l))