def visit_Assign(self, node): if len(node.targets) > 1: raise NotImplementedError("cannot process multiple assignment") if not isinstance(node.targets[0], gast.Name): raise NotImplementedError("cannot process indexed assignment") # $lhs = $lhs.update_($rhs, matchbox.EXECUTION_MASK) if (lhs in vars() # or lhs in globals()) and isinstance($lhs, (matchbox.MaskedBatch, # matchbox.TENSOR_TYPE)) else $rhs node.value = gast.IfExp( gast.BoolOp( gast.And(), [ gast.BoolOp(gast.Or(), [ gast.Compare(gast.Str( node.targets[0].id), [gast.In()], [ gast.Call(gast.Name('vars', gast.Load, None), [], []) ]), gast.Compare(gast.Str( node.targets[0].id), [gast.In()], [ gast.Call( gast.Name('globals', gast.Load, None), [], []) ]) ]), # gast.Compare( # gast.Attribute( # gast.Name('matchbox', gast.Load(), None), # gast.Name('EXECUTION_MASK', gast.Load(), None), # gast.Load()), # [gast.IsNot()], # [gast.NameConstant(None)]), gast.Call(gast.Name('isinstance', gast.Load(), None), [ node.targets[0], gast.Tuple([ gast.Attribute( gast.Name('matchbox', gast.Load(), None), gast.Name('MaskedBatch', gast.Load(), None), gast.Load()), gast.Attribute( gast.Name('matchbox', gast.Load(), None), gast.Name('TENSOR_TYPE', gast.Load(), None), gast.Load()) ], gast.Load()) ], []) ]), gast.Call( gast.Attribute( gast.Name(node.targets[0].id, gast.Load(), None), gast.Name('_update', gast.Load(), None), gast.Load()), [ node.value, gast.Attribute( gast.Name('matchbox', gast.Load(), None), gast.Name('EXECUTION_MASK', gast.Load(), None), gast.Load()) ], []), node.value) return node
def visit_For(self, node): modified_node = self.generic_visit(node) continued_id = len(self.for_continued_stack) continued_flags = self.for_continued_stack.pop() if continued_flags: node.body.insert( 0, gast.Assign(targets=[ gast.Name(id=self.continued_flag + str(continued_id), ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Constant(value=False, kind=None))) breaked_id = len(self.for_breaked_stack) breaked_flags = self.for_breaked_stack.pop() bool_values = [] if breaked_flags: node.body.insert( 0, gast.Assign(targets=[ gast.Name(id=self.breaked_flag + str(breaked_id), ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Constant(value=False, kind=None))) bool_values.append( gast.Name(id=self.breaked_flag + str(breaked_id), ctx=gast.Load(), annotation=None, type_comment=None)) if len(self.func_returned_stack) > 0: returned_id = len(self.func_returned_stack) returned_flags = self.func_returned_stack[-1] if returned_flags: bool_values.append( gast.Name(id=self.returned_flag + str(returned_id), ctx=gast.Load(), annotation=None, type_comment=None)) if len(bool_values) > 0: if len(bool_values) == 1: cond = bool_values[0] elif len(bool_values) > 1: cond = gast.BoolOp(op=gast.Or(), values=bool_values) node.body.append( gast.Assign(targets=[ gast.Name(id=self.keepgoing_flag, ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.UnaryOp(op=gast.Not(), operand=cond))) node.body.append(gast.If(test=cond, body=[gast.Break()], orelse=[])) return modified_node
def generic_visit(self, node): if isinstance(node, gast.stmt): if self.stack_has_flags( self.for_continued_stack) or self.stack_has_flags( self.for_breaked_stack) or self.stack_has_flags( self.func_returned_stack): bool_values = [] if self.stack_has_flags(self.for_continued_stack): continued_id = len(self.for_continued_stack) bool_values.append( gast.UnaryOp(op=gast.Not(), operand=gast.Name(id=self.continued_flag + str(continued_id), ctx=gast.Load(), annotation=None, type_comment=None))) if self.stack_has_flags(self.for_breaked_stack): breaked_id = len(self.for_breaked_stack) bool_values.append( gast.UnaryOp(op=gast.Not(), operand=gast.Name(id=self.breaked_flag + str(breaked_id), ctx=gast.Load(), annotation=None, type_comment=None))) if self.stack_has_flags(self.func_returned_stack): returned_id = len(self.func_returned_stack) bool_values.append( gast.UnaryOp(op=gast.Not(), operand=gast.Name(id=self.returned_flag + str(returned_id), ctx=gast.Load(), annotation=None, type_comment=None))) if isinstance(node, gast.For): self.for_continued_stack.append(False) self.for_breaked_stack.append(False) elif isinstance(node, gast.FunctionDef): self.func_returned_stack.append(False) modified_node = super().generic_visit(node) if len(bool_values) == 1: cond = bool_values[0] else: cond = gast.BoolOp(op=gast.And(), values=bool_values) replacement = gast.If(test=cond, body=[modified_node], orelse=[]) ret = gast.copy_location(replacement, node) else: if isinstance(node, gast.For): self.for_continued_stack.append(False) self.for_breaked_stack.append(False) elif isinstance(node, gast.FunctionDef): self.func_returned_stack.append(False) ret = super().generic_visit(node) else: ret = super().generic_visit(node) return ret
def visit_Return(self, node): cur_func_node = self.function_def[-1] return_name = unique_name.generate(RETURN_PREFIX) self.return_name[cur_func_node].append(return_name) max_return_length = self.pre_analysis.get_func_max_return_length( cur_func_node) parent_node_of_return = self.ancestor_nodes[-2] for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)): ancestor = self.ancestor_nodes[ancestor_index] cur_node = self.ancestor_nodes[ancestor_index + 1] if hasattr(ancestor, "body") and index_in_list(ancestor.body, cur_node) != -1: if cur_node == node: self._replace_return_in_stmt_list( ancestor.body, cur_node, return_name, max_return_length, parent_node_of_return) self._replace_after_node_to_if_in_stmt_list( ancestor.body, cur_node, return_name, parent_node_of_return) elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse, cur_node) != -1: if cur_node == node: self._replace_return_in_stmt_list( ancestor.orelse, cur_node, return_name, max_return_length, parent_node_of_return) self._replace_after_node_to_if_in_stmt_list( ancestor.orelse, cur_node, return_name, parent_node_of_return) # If return node in while loop, add `not return_name` in gast.While.test if isinstance(ancestor, gast.While): cond_var_node = gast.UnaryOp( op=gast.Not(), operand=gast.Name( id=return_name, ctx=gast.Load(), annotation=None, type_comment=None)) ancestor.test = gast.BoolOp( op=gast.And(), values=[ancestor.test, cond_var_node]) continue # If return node in for loop, add `not return_name` in gast.While.test if isinstance(ancestor, gast.For): cond_var_node = gast.UnaryOp( op=gast.Not(), operand=gast.Name( id=return_name, ctx=gast.Load(), annotation=None, type_comment=None)) parent_node = self.ancestor_nodes[ancestor_index - 1] for_to_while = ForToWhileTransformer(parent_node, ancestor, cond_var_node) new_stmts = for_to_while.transform() while_node = new_stmts[-1] self.ancestor_nodes[ancestor_index] = while_node if ancestor == cur_func_node: break
def visit_BoolOp(self, node): self.generic_visit(node) flipped_op = BOOL_INVERSIONS[type(node.op)] return gast.UnaryOp(op=gast.Not(), operand=gast.BoolOp(op=flipped_op(), values=[ gast.UnaryOp(op=gast.Not(), operand=val) for val in node.values ]))
def visit_For(self, node): modified_node = self.generic_visit(node) continue_flags = self.for_continue_stack.pop() for flag in continue_flags: node.body.insert( 0, gast.Assign(targets=[ gast.Name(id=flag, ctx=gast.Store(), annotation=None) ], value=gast.NameConstant(value=False))) breaked_flags = self.for_breaked_stack.pop() bool_values = [] for flag in breaked_flags: node.body.insert( 0, gast.Assign(targets=[ gast.Name(id=flag, ctx=gast.Store(), annotation=None) ], value=gast.NameConstant(value=False))) bool_values.append( gast.Name(id=flag, ctx=gast.Load(), annotation=None)) if len(bool_values) > 0: if len(bool_values) == 1: cond = bool_values[0] elif len(bool_values) > 1: cond = gast.BoolOp(op=gast.Or(), values=bool_values) if isinstance(modified_node, gast.For): modified_node.body.append( gast.Assign(targets=[ gast.Name(id=self.keepgoing_flag, ctx=gast.Store(), annotation=None) ], value=gast.UnaryOp(op=gast.Not(), operand=cond))) modified_node.body.append( gast.If(test=cond, body=[gast.Break()], orelse=[])) elif isinstance(modified_node, gast.If): if isinstance(modified_node.body[0], gast.For): modified_node.body[0].body.append( gast.Assign(targets=[ gast.Name(id=self.keepgoing_flag, ctx=gast.Store(), annotation=None) ], value=gast.UnaryOp(op=gast.Not(), operand=cond))) modified_node.body[0].body.append( gast.If(test=cond, body=[gast.Break()], orelse=[])) return modified_node
def make_Iterator(self, gen): if gen.ifs: ldFilter = ast.Lambda( ast.arguments([ast.Name(gen.target.id, ast.Param(), None)], None, [], [], None, []), ast.BoolOp(ast.And(), gen.ifs) if len(gen.ifs) > 1 else gen.ifs[0]) ifilterName = ast.Attribute(value=ast.Name(id=MODULE, ctx=ast.Load(), annotation=None), attr=IFILTER, ctx=ast.Load()) return ast.Call(ifilterName, [ldFilter, gen.iter], []) else: return gen.iter
def get_for_stmt_nodes(self, node): assert isinstance( node, gast.For), "Input node is NOT gast.For in get_for_stmt_nodes" # 1. parse current gast.For node current_for_node_parser = ForNodeVisitor(node) stmts_tuple = current_for_node_parser.parse() if stmts_tuple is None: return [node] init_stmts, cond_stmt, body_stmts = stmts_tuple # 2. append break statement new_cond_stmt = gast.BoolOp( op=gast.And(), values=[cond_stmt, self.condition_node]) # 3. construct gast.While node while_node = gast.While( test=new_cond_stmt, body=body_stmts, orelse=node.orelse) init_stmts.append(while_node) return init_stmts
def get_for_args_stmts(self, iter_name, args_list): ''' Returns 3 gast stmt nodes for argument. 1. Initailize of iterate variable 2. Condition for the loop 3. Statement for changing of iterate variable during the loop ''' len_range_args = len(args_list) assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments" if len_range_args == 1: init_stmt = get_constant_variable_node(iter_name, 0) else: init_stmt = gast.Assign(targets=[ gast.Name(id=iter_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=args_list[0]) range_max_node = args_list[0] if len_range_args == 1 else args_list[1] step_node = args_list[2] if len_range_args == 3 else gast.Constant( value=1, kind=None) old_cond_stmt = gast.Compare(left=gast.BinOp(left=gast.Name( id=iter_name, ctx=gast.Load(), annotation=None, type_comment=None), op=gast.Add(), right=step_node), ops=[gast.LtE()], comparators=[range_max_node]) cond_stmt = gast.BoolOp(op=gast.And(), values=[old_cond_stmt, self.condition_node]) change_stmt = gast.AugAssign(target=gast.Name(id=iter_name, ctx=gast.Store(), annotation=None, type_comment=None), op=gast.Add(), value=step_node) return init_stmt, cond_stmt, change_stmt
def generic_visit(self, node): if isinstance(node, gast.stmt): if (len(self.for_continue_stack) > 0 and len(self.for_continue_stack[-1]) > 0) or ( len(self.for_breaked_stack) > 0 and len(self.for_breaked_stack[-1]) > 0): bool_values = [] if (len(self.for_continue_stack) > 0 and len(self.for_continue_stack[-1]) > 0): for flag in self.for_continue_stack[-1]: bool_values.append( gast.UnaryOp(op=gast.Not(), operand=gast.Name(id=flag, ctx=gast.Load(), annotation=None))) if (len(self.for_breaked_stack) > 0 and len(self.for_breaked_stack[-1]) > 0): for flag in self.for_breaked_stack[-1]: bool_values.append( gast.UnaryOp(op=gast.Not(), operand=gast.Name(id=flag, ctx=gast.Load(), annotation=None))) if isinstance(node, gast.For): self.for_continue_stack.append([]) self.for_breaked_stack.append([]) node = super().generic_visit(node) if len(bool_values) == 1: cond = bool_values[0] else: cond = gast.BoolOp(op=gast.And(), values=bool_values) replacement = gast.If(test=cond, body=[node], orelse=[]) ret = gast.copy_location(replacement, node) else: if isinstance(node, gast.For): self.for_continue_stack.append([]) self.for_breaked_stack.append([]) ret = super().generic_visit(node) else: ret = super().generic_visit(node) return ret
def visit_While(self, node): self.generic_visit(node.test) scope = anno.getanno(node, 'body_scope') break_var = self.namer.new_symbol('break_requested', scope.referenced) self.break_uses.append([False, break_var]) node.body = self._manual_visit_list(node.body) if self.break_uses[-1][0]: node.test = gast.BoolOp(gast.And(), [ node.test, gast.UnaryOp(gast.Not(), gast.Name(break_var, gast.Load(), None)) ]) final_nodes = [self._create_break_init(), node] else: final_nodes = node self.break_uses.pop() for n in node.orelse: self.generic_visit(n) return final_nodes
def visit_Break(self, node): loop_node_index = self._find_ancestor_loop_index(node) assert loop_node_index != -1, "SyntaxError: 'break' outside loop" loop_node = self.ancestor_nodes[loop_node_index] # 1. Map the 'break/continue' stmt with an unique boolean variable V. variable_name = unique_name.generate(BREAK_NAME_PREFIX) # 2. Find the first ancestor block containing this 'break/continue', a # block can be a node containing stmt list. We should remove all stmts # after the 'break/continue' and set the V to True here. first_block_index = self._remove_stmts_after_break_continue( node, variable_name, loop_node_index) # 3. Add 'if V' for stmts in ancestor blocks between the first one # (exclusive) and the ancestor loop (inclusive) self._replace_if_stmt(loop_node_index, first_block_index, variable_name) # 4. For 'break' add break into condition of the loop. assign_false_node = create_fill_constant_node(variable_name, False) self._add_stmt_before_cur_node(loop_node_index, assign_false_node) cond_var_node = gast.UnaryOp( op=gast.Not(), operand=gast.Name( id=variable_name, ctx=gast.Load(), annotation=None, type_comment=None)) if isinstance(loop_node, gast.While): loop_node.test = gast.BoolOp( op=gast.And(), values=[loop_node.test, cond_var_node]) elif isinstance(loop_node, gast.For): parent_node = self.ancestor_nodes[loop_node_index - 1] for_to_while = ForToWhileTransformer(parent_node, loop_node, cond_var_node) for_to_while.transform()
def visit_Compare(self, node): self.generic_visit(node) all_comparators = [node.left] + node.comparators if len(all_comparators) == 2: # Replace `a < b` with `not a >= b` inverted_op = OP_INVERSIONS[type(node.ops[0])] return gast.UnaryOp(op=gast.Not(), operand=gast.Compare( left=node.left, ops=[inverted_op()], comparators=node.comparators)) else: # Replace `a < b < c` with `not (a >= b or b >= c)` or_clauses = [] for left, op, right in zip(all_comparators[:-1], node.ops, all_comparators[1:]): inverted_op = OP_INVERSIONS[type(op)] or_clauses.append( gast.Compare(left=left, ops=[inverted_op()], comparators=[right])) return gast.UnaryOp(op=gast.Not(), operand=gast.BoolOp(op=gast.Or(), values=or_clauses))
def negate(node): if isinstance(node, ast.Name): # Not type info, could be anything :( raise UnsupportedExpression() if isinstance(node, ast.UnaryOp): # !~x <> ~x == 0 <> x == ~0 <> x == -1 if isinstance(node.op, ast.Invert): return ast.Compare(node.operand, [ast.Eq()], [ast.Constant(-1, None)]) # !!x <> x if isinstance(node.op, ast.Not): return node.operand # !+x <> +x == 0 <> x == 0 <> !x if isinstance(node.op, ast.UAdd): return node.operand # !-x <> -x == 0 <> x == 0 <> !x if isinstance(node.op, ast.USub): return node.operand if isinstance(node, ast.BoolOp): new_values = [ast.UnaryOp(ast.Not(), v) for v in node.values] # !(x or y) <> !x and !y if isinstance(node.op, ast.Or): return ast.BoolOp(ast.And(), new_values) # !(x and y) <> !x or !y if isinstance(node.op, ast.And): return ast.BoolOp(ast.Or(), new_values) if isinstance(node, ast.Compare): cmps = [ast.Compare(x, [negate(o)], [y]) for x, o, y in zip([node.left] + node.comparators[:-1], node.ops, node.comparators)] if len(cmps) == 1: return cmps[0] return ast.BoolOp(ast.Or(), cmps) if isinstance(node, ast.Eq): return ast.NotEq() if isinstance(node, ast.NotEq): return ast.Eq() if isinstance(node, ast.Gt): return ast.LtE() if isinstance(node, ast.GtE): return ast.Lt() if isinstance(node, ast.Lt): return ast.GtE() if isinstance(node, ast.LtE): return ast.Gt() if isinstance(node, ast.In): return ast.NotIn() if isinstance(node, ast.NotIn): return ast.In() if isinstance(node, ast.Attribute): if node.attr == 'False': return ast.Constant(True, None) if node.attr == 'True': return ast.Constant(False, None) raise UnsupportedExpression()
def build(left, right): return gast.BoolOp(op=op, values=[left, right])