def _create_cond_expr(self, results, test, body_name, orelse_name, state_getter_name, state_setter_name): if results is not None: template = """ results = ag__.if_stmt(test, body_name, orelse_name, state_getter_name, state_setter_name) """ return templates.replace( template, test=test, results=results, body_name=body_name, orelse_name=orelse_name, state_getter_name=state_getter_name, state_setter_name=state_setter_name) else: template = """ ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name) """ return templates.replace( template, test=test, body_name=body_name, orelse_name=orelse_name, getter_name=state_getter_name, setter_name=state_setter_name)
def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if len(returns) == 1: template = """ return retval """ return_stmt = templates.replace(template, retval=returns[0]) else: template = """ return (retvals,) """ return_stmt = templates.replace(template, retvals=returns) if aliased_orig_names: template = """ def body_name(): aliased_new_names, = aliased_orig_names, body return_stmt """ return templates.replace( template, body_name=body_name, body=body, aliased_orig_names=aliased_orig_names, aliased_new_names=aliased_new_names, return_stmt=return_stmt) else: template = """ def body_name(): body return_stmt """ return templates.replace( template, body_name=body_name, body=body, return_stmt=return_stmt)
def _create_state_functions(self, composites, state_getter_name, state_setter_name): if composites: composite_tuple = tuple(composites) template = """ def state_getter_name(): return composite_tuple, def state_setter_name(vals): composite_tuple, = vals """ node = templates.replace( template, state_getter_name=state_getter_name, state_setter_name=state_setter_name, composite_tuple=composite_tuple) else: template = """ def state_getter_name(): return () def state_setter_name(_): pass """ node = templates.replace( template, state_getter_name=state_getter_name, state_setter_name=state_setter_name) return node
def visit_For(self, node): self.generic_visit(node) loop_state, reserved_symbols = self._get_loop_state(node) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') if loop_state: template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', reserved_symbols), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) else: template = """ def extra_test_name(): return extra_test_expr def body_name(loop_vars): # Workaround for PEP-3113 iterate = loop_vars body return () ag__.for_stmt(iter_, extra_test_name, body_name, ()) """ node = templates.replace( template, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', reserved_symbols), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) return node
def test_replace_name_mixed_attr_subscript(self, expression_source): template = 'foo = bar' replacement = _parse_with_unset_ctx(expression_source) target_node = templates.replace(template, foo=replacement)[0].targets[0] self.assertExpectedCtxSet(target_node, gast.Store) value_node = templates.replace(template, bar=replacement)[0].value self.assertExpectedCtxSet(value_node, gast.Load)
def visit_With(self, node): # Depth-first traversal of syntax node = self.generic_visit(node) # If the with statement returns, lift the return if isinstance(node.body[-1], gast.Return): node.body[-1] = templates.replace( 'a = b', a=self.common_return_name, b=node.body[-1].value)[0] return_node = templates.replace('return a', a=self.common_return_name)[0] node = self.generic_visit(node) self.changes_made = True return [node, return_node] else: return node
def test_replace_attribute(self): template = """ def test_fn(a): return a.foo """ node = templates.replace(template, foo='b')[0] result, _ = compiler.ast_to_object(node) mod = imp.new_module('test') mod.b = 3 self.assertEquals(3, result.test_fn(mod)) with self.assertRaises(ValueError): templates.replace(template, foo=1)
def _for_loop_with_extra_test(self, loop_state, state_ssf, state_ast_tuple, original_node, extra_test_name, extra_test, body_name, loop_body, ssf_map): target_nodes = ast_util.rename_symbols(original_node.target, ssf_map) template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 target = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ return templates.replace( template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=original_node.iter, target=target_nodes, extra_test_name=extra_test_name, extra_test_expr=extra_test, body_name=body_name, body=loop_body)
def visit_Continue(self, node): self.set_local(CONTINUE_USED, True) template = """ var_name = tf.constant(True) """ return templates.replace( template, var_name=self.get_local(CONTROL_VAR_NAME))
def visit_Assert(self, node): self.generic_visit(node) # Note: The lone tf.Assert call will be wrapped with control_dependencies # by side_effect_guards. template = """ tf.Assert(test, (msg,)) """ if node.msg is None: return templates.replace( template, test=node.test, msg=gast.Str('Assertion error')) elif isinstance(node.msg, gast.Str): return templates.replace(template, test=node.test, msg=node.msg) else: raise NotImplementedError('can only convert string messages for now.')
def visit_While(self, node): original_node = node scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ var_name = False while ag__.and_(lambda: test, lambda: ag__.not_(var_name)): body else: orelse """ node = templates.replace(template, var_name=break_var, test=node.test, body=node.body, orelse=guarded_orelse) new_while_node = node[1] anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) return node
def create_assignment(self, target, expression): template = """ target = expression """ return templates.replace(template, target=target, expression=expression)
def test_replace_code_block(self): template = """ def test_fn(a): block return a """ class ShouldBeReplaced(object): pass node = templates.replace( template, block=[ gast.Assign( [ gast.Name('a', ctx=ShouldBeReplaced, annotation=None, type_comment=None) ], gast.BinOp( gast.Name('a', ctx=ShouldBeReplaced, annotation=None, type_comment=None), gast.Add(), gast.Constant(1, kind=None)), ), ] * 2)[0] result, _, _ = loader.load_ast(node) self.assertEqual(3, result.test_fn(1))
def test_replace_call_keyword(self): template = """ def test_fn(): def f(a, d, f): return a + d + f return f(1, kws=None) """ source = parser.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] result, _, _ = loader.load_ast(node) self.assertEqual(9, result.test_fn()) with self.assertRaises(ValueError): templates.replace(template, kws=[]) templates.replace(template, kws=1)
def _insert_dynamic_conversion(self, node): """Inlines a dynamic conversion for a dynamic function.""" # TODO(mdan): Pass information on the statically compiled functions. # Having access to the statically compiled functions can help avoid # unnecessary compilation. # For example, this would lead to function `a` being compiled twice: # # def a(): # v = b # b() # def b(): # a() # # This is really a problem with recursive calls, which currently can # only be gated by a static condition, and should be rare. # TODO(mdan): It probably makes sense to use dynamic conversion every time. # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ ag__.converted_call( func, ag__.ConversionOptions.new(recursive=recursive_val), args) """ call_expr = templates.replace(template, func=node.func, recursive_val=parser.parse_expression( str(self.ctx.program.recursive)), args=node.args) new_call = call_expr[0].value # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords return new_call
def _insert_dynamic_conversion(self, node): """Inlines a dynamic conversion for a dynamic function.""" # TODO(mdan): Pass information on the statically compiled functions. # Having access to the statically compiled functions can help avoid # unnecessary compilation. # For example, this would lead to function `a` being compiled twice: # # def a(): # v = b # b() # def b(): # a() # # This is really a problem with recursive calls, which currently can # only be gated by a static condition, and should be rare. # TODO(mdan): It probably makes sense to use dynamic conversion every time. # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ ag__.converted_call(func, options, args) """ call_expr = templates.replace( template, func=node.func, options=self.ctx.program.options.to_ast(self.ctx.info.namespace), args=node.args) new_call = call_expr[0].value # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords return new_call
def to_ast(self): """Returns a representation of this object as an AST node. The AST node encodes a constructor that would create an object with the same contents. Returns: ast.Node """ if self == STANDARD_OPTIONS: return parser.parse_expression('ag__.STD') template = """ ag__.ConversionOptions( recursive=recursive_val, user_requested=user_requested_val, optional_features=optional_features_val, internal_convert_user_code=internal_convert_user_code_val) """ def list_of_features(values): return parser.parse_expression('({})'.format(', '.join( 'ag__.{}'.format(str(v)) for v in values))) expr_ast = templates.replace( template, recursive_val=parser.parse_expression(str(self.recursive)), user_requested_val=parser.parse_expression(str( self.user_requested)), internal_convert_user_code_val=parser.parse_expression( str(self.internal_convert_user_code)), optional_features_val=list_of_features(self.optional_features)) return expr_ast[0].value
def test_replace_call_keyword(self): template = """ def test_fn(): def f(a, d, f): return a + d + f return f(1, kws=None) """ source = parser.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(9, result.test_fn()) with self.assertRaises(ValueError): templates.replace(template, kws=[]) templates.replace(template, kws=1)
def visit_Delete(self, node): node = self.generic_visit(node) rewrite_targets = [] for tgt in node.targets: # Don't rewrite composites like `del a[0]`. if isinstance(tgt, gast.Name): rewrite_targets.append(tgt) if not rewrite_targets: return node results = [] for tgt in rewrite_targets: template = """ var_ = ag__.Undefined(var_name) """ results.extend( templates.replace(template, var_=tgt, var_name=gast.Constant(tgt.id, kind=None))) remaining_targets = [ n for n in node.targets if n not in rewrite_targets ] if remaining_targets: results.append(gast.Delete(targets=remaining_targets)) return results
def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): new_name = self.ctx.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: if anno.hasanno(node.func, 'parent_type'): owner_type = anno.getanno(node.func, 'parent_type') else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.ctx.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace('func_name', func_name=new_name)[0] return node
def visit_Continue(self, node): self.state[_Continue].used = True template = """ var_name = True """ return templates.replace( template, var_name=self.state[_Continue].control_var_name)
def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ var_name = tf.constant(False) while test and not var_name: body else: orelse """ node = templates.replace( template, var_name=break_var, test=node.test, body=node.body, orelse=guarded_orelse) return node
def to_ast(self, ctx, internal_convert_user_code=None): """Returns a representation of this object as an AST node. The AST node encodes a constructor that would create an object with the same contents. Args: ctx: EntityContext, the entity with which this AST needs to be consistent. internal_convert_user_code: Optional[bool], allows ovrriding the corresponding value. Returns: ast.Node """ template = """ constructor_name( recursive=recursive_val, verbose=verbose_val, strip_decorators=strip_decorators_val, force_conversion=force_conversion_val, optional_features=optional_features_val, internal_convert_user_code=internal_convert_user_code_val) """ def as_qualified_name(o): name = inspect_utils.getqualifiedname(ctx.info.namespace, o, max_depth=1) if not name: # TODO(mdan): This needs to account for the symbols defined locally. name = ctx.namer.new_symbol(o.__name__, ()) ctx.program.add_symbol(name, weakref.ref(o)) return name def list_of_names(values): return parser.parse_expression('({})'.format(', '.join( tuple(as_qualified_name(v) for v in values)))) def list_of_features(values): return parser.parse_expression('({})'.format(', '.join( 'ag__.Feature.{}'.format(v) for v in Feature.__members__ if v in values))) if internal_convert_user_code is not None: internal_convert_user_code = self.internal_convert_user_code expr_ast = templates.replace( template, constructor_name=parser.parse_expression( as_qualified_name(ConversionOptions)), recursive_val=parser.parse_expression(str(self.recursive)), verbose_val=parser.parse_expression(str(int(self.verbose))), strip_decorators_val=list_of_names(self._strip_decorators), force_conversion_val=parser.parse_expression( str(self.force_conversion)), internal_convert_user_code_val=parser.parse_expression( str(internal_convert_user_code)), optional_features_val=list_of_features(self.optional_features)) return expr_ast[0].value
def visit_If(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) cond_vars, undefined, nouts = self._get_block_vars( node, body_scope.bound | orelse_scope.bound) undefined_assigns = self._create_undefined_assigns(undefined) nonlocal_declarations = self._create_nonlocal_declarations(cond_vars) reserved = body_scope.referenced | orelse_scope.referenced state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) state_functions = self._create_state_functions(cond_vars, nonlocal_declarations, state_getter_name, state_setter_name) orelse_body = node.orelse if not orelse_body: orelse_body = [gast.Pass()] template = """ state_functions def body_name(): nonlocal_declarations body def orelse_name(): nonlocal_declarations orelse undefined_assigns ag__.if_stmt( test, body_name, orelse_name, state_getter_name, state_setter_name, (symbol_names,), nouts) """ new_nodes = templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('if_body', reserved), orelse=orelse_body, orelse_name=self.ctx.namer.new_symbol('else_body', reserved), nonlocal_declarations=nonlocal_declarations, nouts=gast.Constant(nouts, kind=None), state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, symbol_names=tuple( gast.Constant(str(s), kind=None) for s in cond_vars), test=node.test, undefined_assigns=undefined_assigns) origin_info.copy_origin(node, new_nodes[-1]) return new_nodes
def to_ast(self, ctx, internal_convert_user_code=None): """Returns a representation of this object as an AST node. The AST node encodes a constructor that would create an object with the same contents. Args: ctx: EntityContext, the entity with which this AST needs to be consistent. internal_convert_user_code: Optional[bool], allows ovrriding the corresponding value. Returns: ast.Node """ template = """ ag__.ConversionOptions( recursive=recursive_val, verbose=verbose_val, strip_decorators=strip_decorators_val, force_conversion=force_conversion_val, optional_features=optional_features_val, internal_convert_user_code=internal_convert_user_code_val) """ def as_qualified_name(o): name = inspect_utils.getqualifiedname(ctx.info.namespace, o, max_depth=1) if not name: if isinstance(o, weakref.ref): # `o` might already be a weak reference, if this object was # constructed from code generated by `to_ast` itself. # If so, unpack it. o = o() # TODO(mdan): This needs to account for the symbols defined locally. name = ctx.namer.new_symbol(o.__name__, ()) ctx.program.add_symbol(name, weakref.ref(o)) return name def list_of_names(values): return parser.parse_expression('({})'.format(', '.join( tuple(as_qualified_name(v) for v in values)))) def list_of_features(values): return parser.parse_expression('({})'.format(', '.join( 'ag__.{}'.format(str(v)) for v in values))) if internal_convert_user_code is None: internal_convert_user_code = self.internal_convert_user_code expr_ast = templates.replace( template, recursive_val=parser.parse_expression(str(self.recursive)), verbose_val=parser.parse_expression(str(int(self.verbose))), strip_decorators_val=list_of_names(self._strip_decorators), force_conversion_val=parser.parse_expression( str(self.force_conversion)), internal_convert_user_code_val=parser.parse_expression( str(internal_convert_user_code)), optional_features_val=list_of_features(self.optional_features)) return expr_ast[0].value
def _create_undefined_assigns(self, undefined_symbols): assignments = [] for s in undefined_symbols: template = ''' var = ag__.UNDEFINED ''' assignments += templates.replace(template, var=s) return assignments
def visit_Continue(self, node): self.state[_Continue].used = True self.state[_Block].reset_guard_state() template = """ var_name = tf.constant(True) """ return templates.replace(template, var_name=self.get_local(CONTROL_VAR_NAME))
def _create_cond_expr(self, results, test, body_name, orelse_name): if results is not None: template = """ results = ag__.utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, test=test, results=results, body_name=body_name, orelse_name=orelse_name) else: template = """ ag__.utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, test=test, body_name=body_name, orelse_name=orelse_name)
def visit_While(self, node): self.generic_visit(node) loop_state, reserved_symbols, possibly_undefs = self._get_loop_state( node, anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified) loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) if loop_state: template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = ag__.while_stmt(test_name, body_name, (state,)) """ node = templates.replace( template, state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) else: template = """ def test_name(): return test def body_name(): body return () ag__.while_stmt(test_name, body_name, ()) """ node = templates.replace( template, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) undefined_assigns = self._create_undefined_assigns(possibly_undefs) return undefined_assigns + node
def visit_Continue(self, node): self.state[_Continue].used = True self.state[_Block].reset_guard_state() template = """ var_name = True """ return templates.replace( template, var_name=self.state[_Continue].control_var_name)
def visit_Return(self, node): if node.value is None: return node node = self.generic_visit(node) return templates.replace( 'return function_context_name.mark_return_value(value)', function_context_name=self.state[_Function].context_name, value=node.value)
def to_ast(self, namespace, internal_convert_user_code=None): """Returns a representation of this object as an AST node. The AST node encodes a constructor that would create an object with the same contents. Args: namespace: Dict[str, Any], the namespace to use when serializing values to names. internal_convert_user_code: Optional[bool], allows ovrriding the corresponding value. Returns: ast.Node """ template = """ constructor_name( recursive=recursive_val, verbose=verbose_val, strip_decorators=strip_decorators_val, force_conversion=force_conversion_val, optional_features=optional_features_val, internal_convert_user_code=internal_convert_user_code_val) """ def as_qualified_name(o): name = inspect_utils.getqualifiedname(namespace, o) if not name: raise ValueError('Could not locate entity {} in {}'.format( o, namespace)) return name def list_of_names(values): return parser.parse_expression('({})'.format(', '.join( tuple(as_qualified_name(v) for v in values)))) def list_of_features(values): return parser.parse_expression('({})'.format(', '.join( 'ag__.Feature.{}'.format(v) for v in Feature.__members__ if v in values))) if internal_convert_user_code is not None: internal_convert_user_code = self.internal_convert_user_code expr_ast = templates.replace( template, constructor_name=parser.parse_expression( as_qualified_name(ConversionOptions)), recursive_val=parser.parse_expression(str(self.recursive)), verbose_val=parser.parse_expression(str(int(self.verbose))), strip_decorators_val=list_of_names(self.strip_decorators), force_conversion_val=parser.parse_expression( str(self.force_conversion)), internal_convert_user_code_val=parser.parse_expression( str(internal_convert_user_code)), optional_features_val=list_of_features(self.optional_features)) return expr_ast[0].value
def visit_For(self, node): self.generic_visit(node) self._validate_no_live_vars_created(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: extra_test = parser.parse_expression('True') template = """ def extra_test_name(state_ssf): return extra_test_expr def body_name(loop_vars, state_ssf): # Workaround for PEP-3113 iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node
def to_ast(self, namespace, internal_convert_user_code=None): """Returns a representation of this object as an AST node. The AST node encodes a constructor that would create an object with the same contents. Args: namespace: Dict[str, Any], the namespace to use when serializing values to names. internal_convert_user_code: Optional[bool], allows ovrriding the corresponding value. Returns: ast.Node """ template = """ constructor_name( recursive=recursive_val, verbose=verbose_val, strip_decorators=strip_decorators_val, force_conversion=force_conversion_val, optional_features=optional_features_val, internal_convert_user_code=internal_convert_user_code_val) """ def as_qualified_name(o): name = inspect_utils.getqualifiedname(namespace, o) if not name: raise ValueError('Could not locate entity {} in {}'.format( o, namespace)) return name def list_of_names(values): return parser.parse_expression('({})'.format(', '.join( tuple(as_qualified_name(v) for v in values)))) def list_of_features(values): return parser.parse_expression('({})'.format(', '.join( 'ag__.Feature.{}'.format(v) for v in Feature.__members__ if v in values))) if internal_convert_user_code is not None: internal_convert_user_code = self.internal_convert_user_code expr_ast = templates.replace( template, constructor_name=parser.parse_expression( as_qualified_name(ConversionOptions)), recursive_val=parser.parse_expression(str(self.recursive)), verbose_val=parser.parse_expression(str(int(self.verbose))), strip_decorators_val=list_of_names(self._strip_decorators), force_conversion_val=parser.parse_expression( str(self.force_conversion)), internal_convert_user_code_val=parser.parse_expression( str(internal_convert_user_code)), optional_features_val=list_of_features(self.optional_features)) return expr_ast[0].value
def test_replace_name_with_subscript(self): template = """ foo = bar """ replacement = qn.QN(qn.QN('dictionary'), subscript=qn.QN('key')) node = templates.replace(template, foo=replacement)[0].targets[0] self.assertIsInstance(node.ctx, gast.Store) self.assertIsInstance(node.value.ctx, gast.Load)
def test_lambda_in_function_call(self): template = """ a = foo(arg) """ source = parser.parse_expression('[lambda i: i]') node = templates.replace(template, arg=source) lambda_arg = node[0].value.args[0].elts[0] self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param) self.assertIsInstance(lambda_arg.body.ctx, gast.Load)
def test_star_comprehension_in_function_call(self): template = """ a = foo(func, args) """ source = parser.parse_expression('bar(*[i for i in range(j)])') node = templates.replace(template, func=source.func, args=source.args) arg_node = node[0].value.args[1].value self.assertIsInstance(arg_node.generators[0].target.ctx, gast.Store) self.assertIsInstance(arg_node.elt.ctx, gast.Load)
def replace_as_expression(self): template = """ foo(a) """ node = templates.replace(template, foo='bar', a='baz') self.assertTrue(node is gast.Call) self.assertEqual(node.func.id, 'bar') self.assertEqual(node.func.args[0].id, 'baz')
def visit_Assert(self, node): self.generic_visit(node) # Note: The lone tf.Assert call will be wrapped with control_dependencies # by side_effect_guards. template = """ tf.Assert(test, (msg,)) """ if node.msg is None: return templates.replace(template, test=node.test, msg=gast.Str('Assertion error')) elif isinstance(node.msg, gast.Str): return templates.replace(template, test=node.test, msg=node.msg) else: raise NotImplementedError( 'can only convert string messages for now.')
def _replace_append_call(self, node): assert len(node.args) == 1 assert isinstance(node.func, gast.Attribute) template = """ target = ag__.list_append(target, element) """ return templates.replace(template, target=node.func.value, element=node.args[0])
def visit_Break(self, node): self.state[_Break].used = True var_name = self.state[_Break].control_var_name # TODO(mdan): This will fail when expanded inside a top-level else block. template = """ var_name = tf.constant(True) continue """ return templates.replace(template, var_name=var_name)
def visit_If(self, node): """Intercepts if statements. Converts each `if` to up to two separate `with` statements, `ProgramBuilder.if_(condition_variable)` and `ProgramBuilder.else_()`. If the incoming `if` had one arm, returns the transformed AST node; if it had two, returns two nodes in a list. Args: node: An `ast.AST` node representing the `if` statement to convert. Returns: then_node: A node representing the `with`-guarded consequent branch. else_node: A node representing the `with`-guarded alternate branch, if present. """ # Transform a branch # NOTE: this is a little hackery to make sure that prepending works # properly. Wrapping a list of statements in a Module ensures # that the AST-visiting machinery won't choke on, e.g., a list. then = self.generic_visit(gast_util.Module(node.body)).body # Construct header (goes in the `with`s). then_header = templates.replace_as_expression( '_tfp_autobatching_context_.if_(cond)', cond=self._to_reference(node.test)) # Construct `with` node. # TODO(axch): Test that this form actually works with multiline bodies. then_node = templates.replace('with header: body', header=then_header, body=then)[0] if node.orelse: orelse = self.generic_visit(gast_util.Module(node.orelse)).body orelse_header = templates.replace_as_expression( '_tfp_autobatching_context_.else_()') orelse_node = templates.replace('with header: body', header=orelse_header, body=orelse)[0] # Return both return [then_node, orelse_node] else: return then_node
def test_replace_expression_context(self): template = """ def test_fn(): foo """ node = templates.replace( template, foo=parser.parse_expression('a + 2 * b / -c'))[0] self.assertIsInstance(node.body[0].left.ctx, gast.Load) self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)
def _create_undefined_assigns(self, undefined_symbols): assignments = [] for s in undefined_symbols: template = ''' var = ag__.Undefined(symbol_name) ''' assignments += templates.replace(template, var=s, symbol_name=gast.Str(s.ssf())) return assignments
def test_replace_name_with_dict(self): template = """ def test_fn(): return foo['bar'] """ source = parser.parse_expression('{\'bar\': 3}') node = templates.replace(template, foo=source)[0] result, _, _ = loader.load_ast(node) self.assertEqual(3, result.test_fn())
def test_replace_tuple(self): template = """ def test_fn(a, c): return b, """ node = templates.replace(template, b=('a', 'c'))[0] result, _ = compiler.ast_to_object(node) self.assertEquals((2, 3), result.test_fn(2, 3))
def test_replace_name_with_dict(self): template = """ def test_fn(): return foo['bar'] """ source = parser.parse_expression('{\'bar\': 3}') node = templates.replace(template, foo=source)[0] result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn())
def test_replace_tuple(self): template = """ def test_fn(a, c): return b, """ node = templates.replace(template, b=('a', 'c'))[0] result, _, _ = loader.load_ast(node) self.assertEqual((2, 3), result.test_fn(2, 3))