def visit_While(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) if not state: # TODO(mdan): Implement this properly. # To complete this statement, we need to check whether any variable # created inside the body scope is used before being modified outside the # scope. This should be done during activity analysis, and in general # should cover the case where variables may not be initialized. raise ValueError('cannot convert while loop: no outputs') state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = py2tf_utils.run_while(test_name, body_name, [state]) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.context.namer.new_symbol('loop_test', body_scope.referenced), test=test, body_name=self.context.namer.new_symbol('loop_body', body_scope.referenced), body=node_body) return node
def visit_While(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) if not state: # TODO(mdan): Implement this properly. # To complete this statement, we need to check whether any variable # created inside the body scope is used before being modified outside the # scope. This should be done during activity analysis, and in general # should cover the case where variables may not be initialized. raise ValueError('cannot convert while loop: no outputs') state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = py2tf_utils.run_while(test_name, body_name, [state]) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.context.namer.new_symbol('loop_test', body_scope.referenced), test=test, body_name=self.context.namer.new_symbol('loop_body', body_scope.referenced), body=node_body) return node
def visit_While(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = tf.while_loop(test_name, body_name, [state]) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.context.namer.new_symbol('loop_test', body_scope.referenced), test=test, body_name=self.context.namer.new_symbol('loop_body', body_scope.referenced), body=node_body) return node
def visit_While(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf for name, ssf in zip(state, state_ssf) if str(name) != ssf } if len(state) == 1: state = state[0] state_ssf = state_ssf[0] state_ast_tuple = state else: state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) template = """ def test_name(state_ssf): return test def body_name(state_ssf): body return state_ssf, state_ast_tuple = py2tf_utils.run_while(test_name, body_name, [state]) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, test_name=self.context.namer.new_symbol('loop_test', body_scope.referenced), test=test, body_name=self.context.namer.new_symbol('loop_body', body_scope.referenced), body=node_body) return node
def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes alias_map = {} reindent_requested = False for n in nodes: n = self.visit(n) # NOTE: the order in which these statements execute is important; in # particular, watch out for ending up with cycles in the AST. if alias_map: n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): reindent_requested = True new_dest, new_alias_map = anno.getanno( n, anno.Basic.INDENT_BLOCK_REMAINDER) anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) new_alias_map.update(alias_map) alias_map = new_alias_map current_dest = new_dest if reindent_requested and not current_dest: # TODO(mdan): There may still be something that could be done. raise ValueError('Unable to insert statement into the computation flow: ' 'it is not followed by any computation which ' 'the statement could gate.') return new_nodes
def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes alias_map = {} reindent_requested = False for n in nodes: n = self.visit(n) # NOTE: the order in which these statements execute is important; in # particular, watch out for ending up with cycles in the AST. if alias_map: n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): reindent_requested = True new_dest, new_alias_map = anno.getanno( n, anno.Basic.INDENT_BLOCK_REMAINDER) anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) new_alias_map.update(alias_map) alias_map = new_alias_map current_dest = new_dest if reindent_requested and not current_dest: # TODO(mdan): There may still be something that could be done. raise ValueError( 'Unable to insert statement into the computation flow: ' 'it is not followed by any computation which ' 'the statement could gate.') return new_nodes
def test_rename_symbols(self): node = ast.Tuple([ ast.Name('a', ast.Load()), ast.Name('b', ast.Load()), ast.Attribute(ast.Name('b', None), 'c', ast.Store()), ast.Attribute(ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None) ], None) node = qual_names.resolve(node) node = ast_util.rename_symbols( node, { qual_names.QN('a'): qual_names.QN('renamed_a'), qual_names.QN('b.c'): qual_names.QN('renamed_b_c'), }) self.assertEqual(node.elts[0].id, 'renamed_a') self.assertTrue(isinstance(node.elts[0].ctx, ast.Load)) self.assertEqual(node.elts[1].id, 'b') self.assertEqual(node.elts[2].id, 'renamed_b_c') self.assertTrue(isinstance(node.elts[2].ctx, ast.Store)) self.assertEqual(node.elts[3].value.id, 'renamed_b_c') self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
def test_rename_symbols(self): node = ast.Tuple([ ast.Name('a', ast.Load()), ast.Name('b', ast.Load()), ast.Attribute(ast.Name('b', None), 'c', ast.Store()), ast.Attribute( ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None) ], None) node = qual_names.resolve(node) node = ast_util.rename_symbols( node, { qual_names.QN('a'): qual_names.QN('renamed_a'), qual_names.QN('b.c'): qual_names.QN('renamed_b_c'), }) self.assertEqual(node.elts[0].id, 'renamed_a') self.assertTrue(isinstance(node.elts[0].ctx, ast.Load)) self.assertEqual(node.elts[1].id, 'b') self.assertEqual(node.elts[2].id, 'renamed_b_c') self.assertTrue(isinstance(node.elts[2].ctx, ast.Store)) self.assertEqual(node.elts[3].value.id, 'renamed_b_c') self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load))
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) if body_scope.created - orelse_scope.created: raise ValueError( 'The if branch creates new symbols that the else branch does not.' ) if orelse_scope.created - body_scope.created: raise ValueError( 'The else branch creates new symbols that the if branch does not.' ) modified = tuple(body_scope.modified | orelse_scope.modified) all_referenced = body_scope.referenced | orelse_scope.referenced # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. need_alias = ((body_scope.modified | orelse_scope.modified) - (body_scope.created | orelse_scope.created)) aliased_orig_names = tuple(need_alias) aliased_new_names = tuple( self.context.namer.new_symbol(s.ssf(), all_referenced) for s in aliased_orig_names) alias_map = dict(zip(aliased_orig_names, aliased_new_names)) node_body = ast_util.rename_symbols(node.body, alias_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_map) if not modified: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. results = None elif len(modified) == 1: results = modified[0] else: results = gast.Tuple([s.ast() for s in modified], None) body_name = self.context.namer.new_symbol('if_true', all_referenced) orelse_name = self.context.namer.new_symbol('if_false', all_referenced) if modified: body_returns = tuple(alias_map[s] if s in aliased_orig_names else s for s in modified) else: body_returns = templates.replace('tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), body=node_orelse, returns=body_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) if body_scope.created - orelse_scope.created: raise ValueError( 'The if branch creates new symbols that the else branch does not.') if orelse_scope.created - body_scope.created: raise ValueError( 'The else branch creates new symbols that the if branch does not.') modified = tuple(body_scope.modified | orelse_scope.modified) all_referenced = body_scope.referenced | orelse_scope.referenced # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. need_alias = ( (body_scope.modified | orelse_scope.modified) - (body_scope.created | orelse_scope.created)) aliased_orig_names = tuple(need_alias) aliased_new_names = tuple( self.context.namer.new_symbol(s.ssf(), all_referenced) for s in aliased_orig_names) alias_map = dict(zip(aliased_orig_names, aliased_new_names)) node_body = ast_util.rename_symbols(node.body, alias_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_map) if not modified: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. results = None elif len(modified) == 1: results = modified[0] else: results = gast.Tuple([s.ast() for s in modified], None) body_name = self.context.namer.new_symbol('if_true', all_referenced) orelse_name = self.context.namer.new_symbol('if_false', all_referenced) if modified: body_returns = tuple( alias_map[s] if s in aliased_orig_names else s for s in modified) else: body_returns = templates.replace('tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), body=node_orelse, returns=body_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) if body_scope.created - orelse_scope.created: raise ValueError( 'The if branch creates new symbols that the else branch does not.') if orelse_scope.created - body_scope.created: raise ValueError( 'The else branch creates new symbols that the if branch does not.') all_modified = tuple(body_scope.modified | orelse_scope.modified) all_referenced = body_scope.referenced | orelse_scope.referenced # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. need_alias = ( (body_scope.modified | orelse_scope.modified) - (body_scope.created | orelse_scope.created)) aliased_orig_names = tuple(need_alias) aliased_new_names = tuple( self.context.namer.new_symbol(s.ssf(), all_referenced) for s in aliased_orig_names) alias_map = dict(zip(aliased_orig_names, aliased_new_names)) node_body = ast_util.rename_symbols(node.body, alias_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_map) if len(all_modified) == 1: results = all_modified[0] else: results = gast.Tuple([s.ast() for s in all_modified], None) if aliased_orig_names: template = """ def body_name(): aliased_new_names, = aliased_orig_names, body return (all_results,) def orelse_name(): aliased_new_names, = aliased_orig_names, orelse return (all_results,) results = tf.cond(test, body_name, orelse_name) """ body_name = self.context.namer.new_symbol('if_true', all_referenced) return templates.replace( template, test=node.test, body_name=body_name, body=node_body, orelse_name=self.context.namer.new_symbol('if_false', all_referenced), orelse=node_orelse, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), all_results=tuple(alias_map[s] if s in aliased_orig_names else s for s in all_modified), results=results) else: template = """ def body_name(): body return (all_results,) def orelse_name(): orelse return (all_results,) results = tf.cond(test, body_name, orelse_name) """ body_name = self.context.namer.new_symbol('if_true', all_referenced) return templates.replace( template, test=node.test, body_name=body_name, body=node_body, orelse_name=self.context.namer.new_symbol('if_false', all_referenced), orelse=node_orelse, all_results=tuple(s for s in all_modified), results=results)
def visit_If(self, node): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) if body_scope.created - orelse_scope.created: raise ValueError( 'The if branch creates new symbols that the else branch does not.' ) if orelse_scope.created - body_scope.created: raise ValueError( 'The else branch creates new symbols that the if branch does not.' ) all_modified = tuple(body_scope.modified | orelse_scope.modified) all_referenced = body_scope.referenced | orelse_scope.referenced # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. need_alias = ((body_scope.modified | orelse_scope.modified) - (body_scope.created | orelse_scope.created)) aliased_orig_names = tuple(need_alias) aliased_new_names = tuple( self.context.namer.new_symbol(s.ssf(), all_referenced) for s in aliased_orig_names) alias_map = dict(zip(aliased_orig_names, aliased_new_names)) node_body = ast_util.rename_symbols(node.body, alias_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_map) if len(all_modified) == 1: results = all_modified[0] else: results = gast.Tuple([s.ast() for s in all_modified], None) if aliased_orig_names: template = """ def body_name(): aliased_new_names, = aliased_orig_names, body return (all_results,) def orelse_name(): aliased_new_names, = aliased_orig_names, orelse return (all_results,) results = tf.cond(test, body_name, orelse_name) """ body_name = self.context.namer.new_symbol('if_true', all_referenced) return templates.replace( template, test=node.test, body_name=body_name, body=node_body, orelse_name=self.context.namer.new_symbol( 'if_false', all_referenced), orelse=node_orelse, aliased_orig_names=tuple(aliased_orig_names), aliased_new_names=tuple(aliased_new_names), all_results=tuple( alias_map[s] if s in aliased_orig_names else s for s in all_modified), results=results) else: template = """ def body_name(): body return (all_results,) def orelse_name(): orelse return (all_results,) results = tf.cond(test, body_name, orelse_name) """ body_name = self.context.namer.new_symbol('if_true', all_referenced) return templates.replace(template, test=node.test, body_name=body_name, body=node_body, orelse_name=self.context.namer.new_symbol( 'if_false', all_referenced), orelse=node_orelse, all_results=tuple(s for s in all_modified), results=results)