def visit_AugAssign(self, node): node.value = self.visit(node.value) if util.contains_symbol(node.target, self.loop_var): if not util.contains_symbol(node.target.right, self.loop_var): target = self.visit(deepcopy(node.target)) curr_node = node.target idx = 1 while curr_node.left.right.name != self.loop_var: curr_node = curr_node.left idx += 1 curr_node.left = curr_node.left.left node.target = C.ArrayRef(node.target, C.SymbolRef(self.loop_var)) while not isinstance(curr_node, C.SymbolRef): curr_node = curr_node.left if curr_node.name in self.transposed_buffers and self.transposed_buffers[ curr_node.name] != idx: raise NotImplementedError() self.transposed_buffers[curr_node.name] = idx curr_node.name += "_transposed" if isinstance(node.target.right, C.Constant) and node.target.value == 0.0: return store_ps(node.target.left, C.BinaryOp(target, node.op, node.value)) else: return store_ps(C.Ref(node.target), C.BinaryOp(target, node.op, node.value)) else: if isinstance(node.target.right, C.Constant) and node.target.value == 0.0: return store_ps( node.target.left, C.BinaryOp(self.visit(node.target), node.op, node.value)) else: return store_ps( C.Ref(node.target), C.BinaryOp(self.visit(node.target), node.op, node.value)) elif isinstance(node.op, C.Op.Add) and isinstance( node.value, C.FunctionCall): # TODO: Verfiy it's a vector intrinsic return C.Assign( node.target, C.FunctionCall(C.SymbolRef("_mm256_add_ps"), [node.value, node.target])) elif isinstance(node.target, C.BinaryOp) and isinstance( node.target.op, C.Op.ArrayRef): raise NotImplementedError(node) node.target = self.visit(node.target) return node
def visit_BinaryOp(self, node): if isinstance(node.op, C.Op.ArrayRef): if util.contains_symbol(node, self.loop_var): idx = 0 curr_node = node while not isinstance(curr_node.right, C.SymbolRef) or \ curr_node.right.name != self.loop_var: idx += 1 curr_node = curr_node.left while not isinstance(curr_node, C.SymbolRef): curr_node = curr_node.left self.vectorized_buffers[curr_node.name] = idx if self.vectorize: return simd_macros.mm256_load_ps(node) else: return C.ArrayRef(node, C.SymbolRef("_neuron_index_1_inner")) else: if self.vectorize: return simd_macros.mm256_set1_ps(node) else: return node node.left = self.visit(node.left) node.right = self.visit(node.right) return node
def visit_AugAssign(self, node): node.value = self.visit(node.value) if not self.vectorize: node.target = self.visit(node.target) return node if util.contains_symbol(node.target, self.loop_var): return simd_macros.mm256_store_ps( node.target, C.BinaryOp(self.visit(node.target), node.op, node.value)) elif isinstance(node.op, C.Op.Add) and isinstance(node.value, C.BinaryOp) and \ isinstance(node.value.op, C.Op.Mul): # if not isinstance(node.target, C.SymbolRef): # node.value = C.FunctionCall(C.SymbolRef("vsum"), [node.value]) # return node # else: return C.Assign( node.target, C.FunctionCall( C.SymbolRef("_mm256_fmadd_ps"), [node.value.left, node.value.right, node.target])) elif isinstance(node.op, C.Op.Add) and isinstance( node.value, C.FunctionCall): # TODO: Verfiy it's a vector intrinsic return C.Assign( node.target, C.FunctionCall(C.SymbolRef("_mm256_add_ps"), [node.value, node.target])) elif isinstance(node.target, C.BinaryOp) and isinstance( node.target.op, C.Op.ArrayRef): raise NotImplementedError() node.target = self.visit(node.target) return node
def visit_If(self, node): check = [ util.contains_symbol(node, var) for var in list(self.unrolled_vars) + [self.target_var] ] if any(check): body = [] for i in range(self.factor): stmt = deepcopy(node) for var in self.unrolled_vars: stmt = util.replace_symbol(var, C.SymbolRef(var + "_" + str(i)), stmt) if self.unroll_type == 0: body.append( util.replace_symbol( self.target_var, C.Add(C.SymbolRef(self.target_var), C.Constant(i)), stmt)) elif self.unroll_type == 1: body.append( util.replace_symbol( self.target_var, C.Add( C.Mul(C.Constant(self.factor), C.SymbolRef(self.target_var)), C.Constant(i)), stmt)) else: assert (false) return body return node
def visit_BinaryOp(self, node): if isinstance(node.op, C.Op.Assign): check = [ util.contains_symbol(node.right, var) for var in list(self.unrolled_vars) + [self.target_var] ] if any(check): body = [] if hasattr(node.left, 'type') and node.left.type is not None: self.unrolled_vars.add(node.left.name) for i in range(self.factor): stmt = deepcopy(node) for var in self.unrolled_vars: stmt = util.replace_symbol( var, C.SymbolRef(var + "_" + str(i)), stmt) if self.unroll_type == 0: body.append( util.replace_symbol( self.target_var, C.Add(C.SymbolRef(self.target_var), C.Constant(i)), stmt)) elif self.unroll_type == 1: body.append( util.replace_symbol( self.target_var, C.Add( C.Mul(C.Constant(self.factor), C.SymbolRef(self.target_var)), C.Constant(i)), stmt)) else: assert (false) return body return node
def visit_BinaryOp(self, node): node.left = self.visit(node.left) node.right = self.visit(node.right) if isinstance(node.op, C.Op.ArrayRef) and isinstance( node.left, C.BinaryOp): for curr_index, var in enumerate(self.loop_vars): if util.contains_symbol(node.right, var): break for left_index, var in enumerate(self.loop_vars): if util.contains_symbol(node.left.right, var): break if curr_index < left_index: node.left.right, node.right = node.right, node.left.right node.left = self.visit(node.left) return node return node
def visit(self, node): node = super().visit(node) if hasattr(node, 'body'): new_body = [] for stmt in reversed(node.body): if isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.right, C.FunctionCall) and stmt.right.func.name in ["_mm256_broadcast_ss"]: value = stmt.left.name for i in range(len(new_body)): if util.contains_symbol(new_body[i], value): new_body.insert(i, stmt) break else: new_body.insert(0, stmt) node.body = new_body return node
def visit_For(self, node): node.body = [self.visit(s) for s in node.body] pre_stmts = [] loads = [] rest = [] for stmt in node.body: if not hasattr(stmt, 'body'): if util.contains_symbol(stmt, "_mm256_load_ps"): loads.append(stmt) elif isinstance(stmt, C.BinaryOp) and isinstance( stmt.op, C.Op.Assign) and isinstance( stmt.left, C.SymbolRef) and stmt.left.type is not None: pre_stmts.append(stmt) else: rest.append(stmt) else: rest.append(stmt) node.body = pre_stmts + loads + rest return node
def visit_For(self, node): node.body = util.flatten([self.visit(s) for s in node.body]) if node.init.left.name == "_neuron_index_0": #Don't lift out of outer most loop return node pre_stmts = [] new_body = [] post_stmts = [] loop_var = node.init.left.name deps = set() for stmt in node.body: # print(astor.dump_tree(stmt)) if isinstance(stmt, C.FunctionCall) and "_mm" in stmt.func.name and \ "_store" in stmt.func.name and \ not util.contains_symbol(stmt, loop_var) and \ not any(util.contains_symbol(stmt, dep) for dep in deps): post_stmts.append(stmt) elif isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.right, C.FunctionCall) and "_load" in stmt.right.func.name and \ not util.contains_symbol(stmt, loop_var) and \ not any(util.contains_symbol(stmt, dep) for dep in deps): pre_stmts.append(stmt) elif isinstance(stmt, C.BinaryOp) and \ isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.left, C.SymbolRef) and \ stmt.left.type is not None and \ not util.contains_symbol(stmt, loop_var) and \ not any(util.contains_symbol(stmt, dep) for dep in deps): pre_stmts.append(stmt) else: new_body.append(stmt) if isinstance(stmt, C.BinaryOp) and \ isinstance(stmt.op, C.Op.Assign) and \ isinstance(stmt.left, C.SymbolRef) and \ stmt.left.type is not None: deps.add(stmt.left.name) node.body = new_body return pre_stmts + [node] + post_stmts
def visit_BinaryOp(self, node): if isinstance(node.op, C.Op.ArrayRef): if util.contains_symbol(node, self.loop_var): if not util.contains_symbol(node.right, self.loop_var): curr_node = node idx = 1 while curr_node.left.right.name != self.loop_var: curr_node = curr_node.left idx += 1 curr_node.left = curr_node.left.left node = C.ArrayRef(node, C.SymbolRef(self.loop_var)) while not isinstance(curr_node, C.SymbolRef): curr_node = curr_node.left if curr_node.name in self.transposed_buffers and self.transposed_buffers[ curr_node.name] != idx: raise NotImplementedError() self.transposed_buffers[curr_node.name] = idx curr_node.name += "_transposed" if isinstance(node.right, C.Constant) and node.target.value == 0.0: return load_ps(node.left) else: return load_ps(C.Ref(node)) else: return broadcast_ss(C.Ref(node)) elif isinstance(node.op, C.Op.Assign): node.right = self.visit(node.right) if isinstance(node.right, C.FunctionCall) and \ ("load_ps" in node.right.func.name or "broadcast_ss" in node.right.func.name) and \ isinstance(node.left, C.SymbolRef) and node.left.type is not None: node.left.type = get_simd_type()() self.symbol_table[node.left.name] = node.left.type return node elif isinstance(node.left, C.BinaryOp) and util.contains_symbol( node.left, self.loop_var): if node.left.right.name != self.loop_var: curr_node = node idx = 1 while curr_node.left.right.name != self.loop_var: curr_node = curr_node.left idx += 1 curr_node.left = curr_node.left.left node = C.ArrayRef(node, C.SymbolRef(self.loop_var)) while not isinstance(curr_node, C.SymbolRef): curr_node = curr_node.left if curr_node.name in self.transposed_buffers and self.transposed_buffers[ curr_node.name] != idx: raise NotImplementedError() self.transposed_buffers[curr_node.name] = idx curr_node.name += "_transposed" is_float = self.get_type(node.left) if isinstance(is_float, ctypes.c_float): if isinstance(node.left.right, C.Constant) and node.target.value == 0.0: return store_ps(node.left.left, node.right) else: return store_ps(C.Ref(node.left), node.right) elif isinstance(is_float, ctypes.c_int): if isinstance(node.left.right, C.Constant) and node.target.value == 0.0: return store_epi32(node.left.left, node.right) else: return store_epi32(C.Ref(node.left), node.right) else: if isinstance(node.left.right, C.Constant) and node.target.value == 0.0: return store_ps(node.left.left, node.right) else: return store_ps(C.Ref(node.left), node.right) node.left = self.visit(node.left) return node node.left = self.visit(node.left) node.right = self.visit(node.right) return node
def visit_AugAssign(self, node): check = [ util.contains_symbol(node.value, var) for var in list(self.unrolled_vars) + [self.target_var] ] if any(check): body = [] if isinstance(node.target, C.SymbolRef): self.unrolled_vars.add(self._get_name(node.target.name)) for i in range(self.factor): stmt = deepcopy(node) for var in self.unrolled_vars: stmt = util.replace_symbol( var, C.SymbolRef(var + "_" + str(i)), stmt) #body.append(util.replace_symbol(self.target_var, C.Add(C.SymbolRef(self.target_var), C.Constant(i)), stmt)) if self.unroll_type == 0: body.append( util.replace_symbol( self.target_var, C.Add(C.SymbolRef(self.target_var), C.Constant(i)), stmt)) elif self.unroll_type == 1: body.append( util.replace_symbol( self.target_var, C.Add( C.Mul(C.Constant(self.factor), C.SymbolRef(self.target_var)), C.Constant(i)), stmt)) else: assert (false) return body elif isinstance(node.target, C.BinaryOp) and isinstance( node.target.op, C.Op.ArrayRef): assert False for i in range(self.factor): stmt = deepcopy(node) for var in self.unrolled_vars: stmt = util.replace_symbol( var, C.SymbolRef(var + "_" + str(i)), stmt) #body.append(util.replace_symbol(self.target_var, C.Add(C.SymbolRef(self.target_var), C.Constant(i)), stmt)) if self.unroll_type == 0: body.append( util.replace_symbol( self.target_var, C.Add(C.SymbolRef(self.target_var), C.Constant(i)), stmt)) elif self.unroll_type == 1: body.append( util.replace_symbol( self.target_var, C.Add( C.Mul(C.Constant(self.factor), C.SymbolRef(self.target_var)), C.Constant(i)), stmt)) else: assert (false) return body else: raise NotImplementedError() return node