def make_il(self, il_code, symbol_table, c): """ Make code for this node """ lval = self.expr.lvalue(il_code, symbol_table, c) if not lval or not lval.modable(): err = f"operand of {self.descr} operator not a modifiable lvalue" raise CompilerError(err, self.expr.r) val = self.expr.make_il(il_code, symbol_table, c) one = ILValue(val.ctype) if val.ctype.is_arith(): il_code.register_literal_var(one, 1) elif val.ctype.is_pointer() and val.ctype.arg.is_complete(): il_code.register_literal_var(one, val.ctype.arg.size) elif val.ctype.is_pointer(): err = "invalid arithmetic on pointer to incomplete type" raise CompilerError(err, self.expr.r) else: err = f"invalid type for {self.descr} operator" raise CompilerError(err, self.expr.r) new_val = ILValue(val.ctype) if self.return_new: il_code.add(self.cmd(new_val, val, one)) lval.set_to(new_val, il_code, self.expr.r) return new_val else: old_val = ILValue(val.ctype) il_code.add(value_cmds.Set(old_val, val)) il_code.add(self.cmd(new_val, val, one)) lval.set_to(new_val, il_code, self.expr.r) return old_val
def nonarith(self, left, right, il_code): """ Check equality of non-arithmetic expressions """ # If either operand is a null pointer constant, cast it to the other's pointer type. if left.ctype.is_pointer() and right.null_ptr_const: right = set_type(right, left.ctype, il_code) elif right.ctype.is_pointer() and left.null_ptr_const: left = set_type(left, right.ctype, il_code) # If both operands are not pointer types, quit now if not left.ctype.is_pointer() or not right.ctype.is_pointer(): with report_err(): err = "comparison between incomparable types" raise CompilerError(err, self.op.r) # If one side is pointer to void, cast the other to same. elif left.ctype.arg.is_void(): check_cast(right, left.ctype, self.op.r) right = set_type(right, left.ctype, il_code) elif right.ctype.arg.is_void(): check_cast(left, right.ctype, self.op.r) left = set_type(left, right.ctype, il_code) # If both types are still incompatible, warn! elif not left.ctype.compatible(right.ctype): with report_err(): err = "comparison between distinct pointer types" raise CompilerError(err, self.op.r) # Now, we can do comparison out = ILValue(ctypes.integer) il_code.add(self.eq_il_cmd(out, left, right)) return out
def make_il(self, il_code, symbol_table, c): # ILValue for storing the output of this boolean operation out = ILValue(ctypes.integer) # ILValue for initial value of output variable. init = ILValue(ctypes.integer) il_code.register_literal_var(init, self.initial_value) # ILValue for other value of output variable. other = ILValue(ctypes.integer) il_code.register_literal_var(other, 1 - self.initial_value) # Label which immediately precedes the line which sets out to 0 or 1. set_out = il_code.get_label() # Label which skips the line which sets out to 0 or 1. end = il_code.get_label() err = f"'{str(self.op)}' operator requires scalar operands" left = self.left.make_il(il_code, symbol_table, c) if not left.ctype.is_scalar(): raise CompilerError(err, self.left.r) il_code.add(value_cmds.Set(out, init)) il_code.add(self.jump_cmd(left, set_out)) right = self.right.make_il(il_code, symbol_table, c) if not right.ctype.is_scalar(): raise CompilerError(err, self.right.r) il_code.add(self.jump_cmd(right, set_out)) il_code.add(control_cmds.Jump(end)) il_code.add(control_cmds.Label(set_out)) il_code.add(value_cmds.Set(out, other)) il_code.add(control_cmds.Label(end)) return out
def visit_For(self, node: ast.For): if node.orelse: raise CompilerError( "or else clause not supported for for statements") iter_node = self.visit(node.iter) target_node = self.visit(node.target) assert iter_node is not None assert target_node is not None pos = extract_positional_info(node) targets = set() iterables = set() # Do initial checks for weird issues that may arise here. # We don't lower it fully at this point, because it injects # additional arithmetic and not all variable types may be fully known # at this point. try: for target, iterable in unpack_iterated( target_node, iter_node, include_enumerate_indices=True): targets.add(target) iterables.add(iterable) except ValueError: # Generator will throw an error on bad unpacking msg = f"Cannot safely unpack for loop expression, line: {pos.line_begin}" raise CompilerError(msg) conflicts = targets.intersection(iterables) if conflicts: conflict_names = ", ".join(c for c in conflicts) msg = f"{conflict_names} appear in both the target an iterable sequences of a for loop, " \ f"line {pos.line_begin}. This is not supported." raise CompilerError(msg) with self.loop_region(node): for stmt in node.body: self.visit(stmt) loop = ir.ForLoop(target_node, iter_node, self.body, pos) self.body.append(loop)
def array_arg_from_spec(ndims, dtype, fixed_dims=(), evol=None): """ Parameterized array type suitable for use as an argument. evol can be None, sliding window, and iterated (just advance iterator by one each time), with any subscript applied to a sliding window being folded into the variable's evolution. dims should be a dense map, tuple of key, value pairs """ dtype = scalar_type_from_numpy_type(dtype) # should be a tuple of pairs seen = set() for index, value in fixed_dims: if index in seen: msg = f"index {index} is duplicated." raise CompilerError(msg) seen.add(index) if not isinstance(index, numbers.Integral): msg = f"dims can only be used to specify fixed dimensions, received: {dim}." raise CompilerError(msg) elif 0 > dim: msg = f"Negative dim {dim} specified" raise CompilerError(msg) elif dim >= ndims: msg = f"dim {dim} specified for array with {ndims} dimensions." raise CompilerError(msg) dims = tuple(d for d in fixed_dims) return ir.ArrayArg(ndims, dtype, dims, evol)
def parse_struct_spec(self, node, redec): """Parse struct ctype from the given decl_nodes.Struct node. node (decl_nodes.Struct) - the Struct node to parse redec (bool) - Whether this declaration is alone like so: struct S; or declares variables/has storage specifiers: struct S *p; extern struct S; If it's the first, then this is always a forward declaration for a new `struct S` but if it's the second and a `struct S` already exists in higher scope, it's just using the higher scope struct. """ has_members = node.members is not None ctype_req = StructCType if node.tag: tag = str(node.tag) ctype = self.symbol_table.lookup_struct(tag) if ctype and not isinstance(ctype, ctype_req): err = f"defined as wrong kind of tag '{node.kind} {tag}'" raise CompilerError(err, node.r) if not ctype or has_members or redec: ctype = self.symbol_table.add_struct(tag, ctype_req(tag)) if has_members and ctype.is_complete(): err = f"redefinition of '{node.kind} {tag}'" raise CompilerError(err, node.r) else: ctype = ctype_req(None) if not has_members: return ctype # Struct does have members members = [] members_set = set() for member in node.members: decl_infos = [] # needed in case get_decl_infos below fails with report_err(): decl_infos = self.get_decl_infos(member) for decl_info in decl_infos: with report_err(): self.check_struct_member_decl_info(decl_info, node.kind, members_set) name = decl_info.identifier.content members_set.add(name) members.append((name, decl_info.ctype)) ctype.set_members(members) return ctype
def replace_len(node: ir.Call): if node.keywords: msg = "'len' does not allow keyword arguments." raise CompilerError(msg) nargs = len(node.args) if nargs != 1: msg = f"'len' accepts a single argument. {nargs} provided" raise CompilerError(msg) iterable, = node.args return ir.Length(iterable)
def check_cast(il_value, ctype, span): """Emit warnings/errors of casting il_value to given ctype. This method does not actually cast the values. If values cannot be cast, an error is raised by this method. il_value - ILValue to convert. ctype - CType to convert to. span - Range for error reporting. """ # Cast between compatible types is always okay if il_value.ctype.weak_compat(ctype): return # Cast between arithmetic types is always okay if ctype.is_arith() and il_value.ctype.is_arith(): return # Cast between weak compatible structs is okay if ctype.is_struct() and il_value.ctype.is_struct( ) and il_value.ctype.weak_compat(ctype): return elif ctype.is_pointer() and il_value.ctype.is_pointer(): # both operands are pointers to qualified or unqualified versions of compatible types, and the type pointed to # by the left has all the qualifiers of the type pointed to by the right if ctype.arg.weak_compat( il_value.ctype.arg) and (not il_value.ctype.arg.const or ctype.arg.const): return # Cast between void pointer and pointer to object type okay elif (ctype.arg.is_void() and il_value.ctype.arg.is_object() and (not il_value.ctype.arg.const or ctype.arg.const)): return elif (ctype.arg.is_object() and il_value.ctype.arg.is_void() and (not il_value.ctype.arg.const or ctype.arg.const)): return # error on any other kind of pointer cast else: with report_err(): err = "conversion from incompatible pointer type" raise CompilerError(err, span) return # Cast from null pointer constant to pointer okay elif ctype.is_pointer() and il_value.null_ptr_const: return # Cast from pointer to boolean okay elif ctype.is_bool() and il_value.ctype.is_pointer(): return else: err = "invalid conversion between types" raise CompilerError(err, span)
def process_typedef(self, symbol_table): """ Process type declarations """ if self.init: err = "typedef cannot have initializer" raise CompilerError(err, self.span) if self.body: err = "function definition cannot be a typedef" raise CompilerError(err, self.span) symbol_table.add_typedef(self.identifier, self.ctype)
def nonarith(self, left, right, il_code): """ Compare non-arithmetic expressions """ if not left.ctype.is_pointer() or not right.ctype.is_pointer(): err = "comparison between incomparable types" raise CompilerError(err, self.op.r) elif not left.ctype.compatible(right.ctype): err = "comparison between distinct pointer types" raise CompilerError(err, self.op.r) out = ILValue(ctypes.integer) il_code.add(self.comp_cmd(out, left, right)) return out
def get_ctype_tag(type_): # Todo: add something to distinguish uniform, by dim, sliding window if isinstance(type_, (ir.ArrayType, ir.ArrayArg)): ndims = type_.ndims dtype = scalar_type_mangling.get(type_.dtype) if dtype is None: msg = f"Array has unrecognized element type: {type_.dtype}" raise CompilerError(msg) tag = f"a{ndims}{dtype}" else: tag = scalar_type_mangling.get(type_) if tag is None: msg = f"Scalar has unrecognized type: {type_}" raise CompilerError(msg) return tag
def replace_enumerate(node: ir.Call): nargs = len(node.args) + len(node.keywords) if not (1 <= nargs <= 2): msg = f"'enumerate' accepts either 1 or 2 arguments, {nargs} provided." raise CompilerError(msg) if node.keywords: if node.args: key, value = node.keywords if key != "start": if key == "iterable": msg = "keyword 'iterable' in call to enumerate shadows a positional argument." else: msg = f"Unrecognized keyword {key} in call to enumerate" raise CompilerError(msg) start = value iterable, = node.args else: if len(node.keywords) == 1: key, iterable = node.keywords if key != "iterable": msg = "Missing 'iterable' argument in call to enumerate" raise CompilerError(msg) start = ir.Zero else: # dictionaries are not hashable, thus we have a tuple with explicit pairings # we should factor out a duplicate key check earlier on kwargs = node.keywords iterable = None start = None for key, value in kwargs: if key == "iterable": iterable = value elif key == "start": start = value else: msg = f"Unrecognized keyword {key} in call to enumerate." raise CompilerError(msg) if iterable is None or start is None: msg = f"bad keyword combination {kwargs[0][0]} {kwargs[0][1]} in call to enumerate, expected" \ f"'iterable' and 'start'" raise CompilerError(msg) else: if nargs == 2: iterable, start = node.args else: iterable, = node.args start = ir.Zero return ir.Enumerate(iterable, start)
def compile_module(file_path, types, verbose=False, print_result=True, out=None): # pipeline = build_function_pipeline() if verbose: if file_path: print(f"Compiling: {file_name}:") modname = file_path.name modname, _ = os.path.splitext(modname) if not modname: msg = "No module specified" raise CompilerError(msg) mod_ir, symbols = build_module_ir_and_symbols(file_path, types) funcs = [] norm_paths = NormalizePaths() # rc = ReachingCheck() for func in mod_ir.functions: s = symbols.get(func.name) ll = loop_lowering(s) func = norm_paths(func) func = ll(func) funcs.append(func) if print_result: from pretty_printing import pretty_printer pp = pretty_printer() pp(func, s) if out is None: # try in same folder out = Path.cwd() codegen(out, funcs, symbols, modname)
def do_body(self, il_code, symbol_table, c): """ Create code for function body. Caller must check that this function has a body """ is_main = self.identifier.content == "main" for param in self.param_names: if not param: err = "function definition missing parameter name" raise CompilerError(err, self.span) if is_main: self.check_main_type() c = c.set_return(self.ctype.ret) il_code.start_func(self.identifier.content) symbol_table.new_scope() num_params = len(self.ctype.args) iterations = zip(self.ctype.args, self.param_names, range(num_params)) for ctype, param, i in iterations: arg = symbol_table.add_variable(param, ctype, symbol_table.DEFINED, None, symbol_table.AUTOMATIC) il_code.add(value_cmds.LoadArg(arg, i)) self.body.make_il(il_code, symbol_table, c, no_scope=True) if not il_code.always_returns() and is_main: zero = ILValue(ctypes.integer) il_code.register_literal_var(zero, 0) il_code.add(control_cmds.Return(zero)) elif not il_code.always_returns(): il_code.add(control_cmds.Return(None)) symbol_table.end_scope()
def do_init(self, var, storage, il_code, symbol_table, c): """ Create code for initializing given variable. Caller must check that this object has an initializer """ init = self.init.make_il(il_code, symbol_table, c) if storage == symbol_table.STATIC and not init.literal: err = ("non-constant initializer for variable with static " "storage duration") raise CompilerError(err, self.init.r) elif storage == symbol_table.STATIC: il_code.static_initialize(var, getattr(init.literal, "val", None)) elif var.ctype.is_arith() or var.ctype.is_pointer(): lval = DirectLValue(var) lval.set_to(init, il_code, self.identifier.r) else: err = "declared variable is not of assignable type" raise CompilerError(err, self.span)
def make_il(self, il_code, symbol_table, c): """ Make code for this node """ expr = self.expr.make_il(il_code, symbol_table, c) if not expr.ctype.is_scalar(): err = "'!' operator requires scalar operand" raise CompilerError(err, self.r) # ILValue for storing the output out = ILValue(ctypes.integer) # ILValue for zero. zero = ILValue(ctypes.integer) il_code.register_literal_var(zero, "0") # ILValue for one. one = ILValue(ctypes.integer) il_code.register_literal_var(one, "1") # Label which skips the line which sets out to 0. end = il_code.get_label() il_code.add(value_cmds.Set(out, one)) il_code.add(control_cmds.JumpZero(expr, end)) il_code.add(value_cmds.Set(out, zero)) il_code.add(control_cmds.Label(end)) return out
def make_il(self, il_code, symbol_table, c): """ Make IL code for returning this value """ if self.return_value and not c.return_type.is_void(): il_value = self.return_value.make_il(il_code, symbol_table, c) check_cast(il_value, c.return_type, self.return_value.r) ret = set_type(il_value, c.return_type, il_code) il_code.add(control_cmds.Return(ret)) elif self.return_value and c.return_type.is_void(): err = "function with void return type cannot return value" raise CompilerError(err, self.r) elif not self.return_value and not c.return_type.is_void(): err = "function with non-void return type must return value" raise CompilerError(err, self.r) else: il_code.add(control_cmds.Return())
def visit_AnnAssign(self, node: ast.AnnAssign): target = self.visit(node.target) value = self.visit(node.value) pos = extract_positional_info(node) annotation = self.visit(node.annotation) if isinstance(annotation, ir.NameRef): # Check if type is recognized by name type_ = self.symbols.type_by_name.get(annotation.name) if type_ is None: msg = f"Ignoring unrecognized annotation: {annotation}, line: {pos.line_begin}" warnings.warn(msg) else: ir_type = self.symbols.get_ir_type(type_) if isinstance(target, ir.NameRef): sym = self.symbols.lookup(target) existing_type = sym.type_ # This is an error, since it's an actual conflict. if existing_type != ir_type: msg = f"IR type from type hint conflicts with existing " \ f"(possibly inferred) type {existing_type}, line: {pos.line_begin}" raise CompilerError(msg) if node.value is not None: # CPython will turn the syntax "var: annotation" into an AnnAssign node # with node.value = None. If executed, this won't bind or update the value of var. assign = ir.Assign(target, value, pos) self.body.append(assign)
def rewrite_pow(expr): coeff = expr.right base = expr.left if coeff == ir.Zero: return ir.One elif base == ir.Zero: # checking for weird errors more than anything if coeff.constant: if operator.lt(coeff.value, 0): # this isn't intended to catch all edge cases, just an obvious # one that may come up after folding msg = f"raises 0 to a negative power {expr}." raise CompilerError(msg) else: return ir.Zero elif coeff == ir.One: return expr.left elif coeff == ir.IntConst(-1): op = "/=" if expr.in_place else "/" return ir.BinOp(ir.One, expr.left, op) elif coeff == ir.IntConst(-2): op = "/=" if expr.in_place else "/" return ir.BinOp(ir.One, ir.BinOp(expr.left, expr.left, "*"), op) elif coeff == ir.IntConst(2): op = "*=" if expr.in_place else "*" return ir.BinOp(expr.left, expr.left, op) elif coeff == ir.FloatConst(0.5): return ir.Call("sqrt", (expr.left,), ()) else: return expr
def add_block(block, tokens): """Convert block into a token if possible and add to tokens. If block is non-empty but cannot be made into a token, this function records a compiler error. We don't need to check for symbol kind tokens here because they are converted before they are shifted into the block. block - block to convert into a token, as list of Tagged characters. tokens (List[Token]) - List of the tokens so fat parsed. """ if block: range_ = Range(block[0].p, block[-1].p) keyword_kind = match_keyword_kind(block) if keyword_kind: tokens.append(Token(keyword_kind, r=range_)) return number_string = match_number_string(block) if number_string: tokens.append(Token(token_kinds.number, number_string, r=range_)) return identifier_name = match_identifier_name(block) if identifier_name: tokens.append(Token( token_kinds.identifier, identifier_name, r=range_)) return descr = f"unrecognized token at '{block_to_str(block)}'" raise CompilerError(descr, range_)
def read_file(arguments): """ Read the file(s) in arguments and return the file contents """ try: with open(arguments.filename) as c_file: return c_file.read(), arguments.filename except IOError: descr = "could not read file: '{}'" error_collector.add(CompilerError(descr.format(arguments.filename)))
def _lvalue(self, il_code, symbol_table, c): addr = self.expr.make_il(il_code, symbol_table, c) if not addr.ctype.is_pointer(): err = "operand of unary '*' must have pointer type" raise CompilerError(err, self.expr.r) return IndirectLValue(addr)
def get_offset_info(self, struct_ctype): """Given a struct ctype, return the member offset and ctype. If the given ctype is None, emits the error for requesting a member in something not a structure. """ if not struct_ctype or not struct_ctype.is_struct(): err = "request for member in something not a structure" raise CompilerError(err, self.r) offset, ctype = struct_ctype.get_offset(self.member.content) if offset is None: err = f"structure has no member '{self.member.content}'" raise CompilerError(err, self.r) if struct_ctype.is_const(): ctype = ctype.make_const() return offset, ctype
def visit_Attribute(self, node: ast.Attribute) -> ir.NameRef: value = self.visit(node.value) if isinstance(value, ir.NameRef) and value.id == "shape": value = ir.ShapeRef(node.base) else: msg = f"The only currently supported attribute is shape for arrays." raise CompilerError(msg) return value
def make_il(self, il_code, symbol_table, c): """ Make IL for this cast operation """ self.set_self_vars(il_code, symbol_table, c) base_type, _ = self.make_specs_ctype(self.node.specs, False) ctype, _ = self.make_ctype(self.node.decls[0], base_type) if not ctype.is_void() and not ctype.is_scalar(): err = "can only cast to scalar or void type" raise CompilerError(err, self.node.decls[0].r) il_value = self.expr.make_il(il_code, symbol_table, c) if not il_value.ctype.is_scalar(): err = "can only cast from scalar type" raise CompilerError(err, self.r) return set_type(il_value, ctype, il_code)
def make_il(self, il_code, symbol_table, c): """ Make code for this node """ lvalue = self.expr.lvalue(il_code, symbol_table, c) if lvalue: return lvalue.addr(il_code) else: err = "operand of unary '&' must be lvalue" raise CompilerError(err, self.expr.r)
def _(self, node: ir.BinOp): left = node.left right = node.right # if a constant expression shows up here, treat it as an error since # it's weirder to handle than it seems assert not (left.constant and right.constant) two = ir.IntConst(2) negative_one = ir.IntConst(-1) if is_pow(node): if right == ir.Zero: return ir.One elif right == ir.One: return left elif right == two: return ir.BinOp(left, left, "*=" if node.in_place else "*") # square roots shouldn't come up here, given the associative qualifier elif is_addition(node): if right == ir.Zero: return left elif equals_unary_negate(right): return ir.BinOp(left, right.operand, "-=" if node.in_place else "-") elif equals_unary_negate(left): assert not node.in_place return ir.BinOp(right, left.operand, "-") elif is_subtraction(node): if left == ir.Zero: return ir.UnaryOp(right, "-") elif right == ir.Zero: return left elif equals_unary_negate(right): # Todo: this is not entirely correct... as it may not be a unary node # need something like extract unary operand.. return ir.BinOp(left, right.operand, "+=" if node.in_place else "+") elif equals_unary_negate(left): assert not node.in_place return ir.BinOp(right, left.operand, "+") elif is_division(node): if right == ir.Zero: msg = f"Divide by zero error in expression {node}." raise CompilerError(msg) elif node.op in ("//", "//="): # only safe to fold floor divide, ignore left == right since these might # be zero. Constant cases should be handled by the const folder. if left == ir.Zero or right == ir.One: return left elif is_multiplication(node): if left == ir.Zero: return ir.Zero elif left == ir.One: return right elif left == negative_one: if equals_unary_negate(right): # -(-something)) is safe in Python but possibly unsafe in a fixed width # destination. Folding it should be considered safe. return right.operand else: return ir.UnaryOp(right, "-")
def _(self, node: ir.BinOp): # no caching since they may change left = self.visit(node.left) right = self.visit(node.right) expr_type = resolve_binop_type(left, right, node.op) if expr_type is None: msg = f"No signature match for operator {node.op} with candidate signatures: ({left}, {right})." raise CompilerError(msg) return expr_type
def optimize(node: lll_ast.LLLNode, version: str = LATEST_VERSION, num_iterations: int = 200) -> lll_ast.LLLNode: if version not in evm_opcodes: raise CompilerError(f"Fork '{version}' is not supported") vm = evm_opcodes[version] for _ in range(num_iterations): try: node = optimizer.update(node, vm) except ExplorationError as e: raise CompilerError("Optimization Failed") from e if node is None: raise CompilerError("Optimization Failed") return node
def read_string(line, start, delim, null): """Return a lexed string list in input characters. Also returns the index of the string end quote. line[start] should be the first character after the opening quote of the string to be lexed. This function continues reading characters until an unescaped closing quote is reached. The length returned is the number of input character that were read, not the length of the string. The latter is the length of the lexed string list. The lexed string is a list of integers, where each integer is the ASCII value (between 0 and 128) of the corresponding character in the string. The returned lexed string includes a null-terminator. line - List of Tagged objects for each character in the line. start - Index at which to start reading the string. delim - Delimiter with which the string ends, like `"` or `'` null - Whether to add a null-terminator to the returned character list """ i = start chars = [] escapes = {"'": 39, '"': 34, "?": 63, "\\": 92, "a": 7, "b": 8, "f": 12, "n": 10, "r": 13, "t": 9, "v": 11} octdigits = "01234567" hexdigits = "0123456789abcdefABCDEF" while True: if i >= len(line): descr = "missing terminating quote" raise CompilerError(descr, line[start - 1].r) elif line[i].c == delim: if null: chars.append(0) return chars, i elif i + 1 < len(line) and line[i].c == "\\" and line[i + 1].c in escapes: chars.append(escapes[line[i + 1].c]) i += 2 elif i + 1 < len(line) and line[i].c == "\\" and line[i + 1].c in octdigits: octal = line[i + 1].c i += 2 while i < len(line) and len(octal) < 3 and line[i].c in octdigits: octal += line[i].c i += 1 chars.append(int(octal, 8)) elif i + 2 < len(line) and line[i].c == "\\" and line[i + 1].c == "x" and line[i + 2].c in hexdigits: hexa = line[i + 2].c i += 3 while i < len(line) and line[i].c in hexdigits: hexa += line[i].c i += 1 chars.append(int(hexa, 16)) else: chars.append(ord(line[i].c)) i += 1