def visit_ExprDot(self, expr: ExprDot) -> Tuple[ExprDeref, CairoType]: self.verify_identifier_manager_initialized(location=expr.location) inner_expr, inner_type = self.visit(expr.expr) if isinstance(inner_type, TypePointer): if not isinstance(inner_type.pointee, TypeStruct): raise CairoTypeError( f'Cannot apply dot-operator to pointer-to-non-struct type ' f"'{inner_type.format()}'.", location=expr.location) # Allow for . as ->, once. inner_type = inner_type.pointee elif isinstance(inner_type, TypeStruct): if isinstance(inner_expr, ExprTuple): raise CairoTypeError( 'Accessing struct members for r-value structs is not supported yet.', location=expr.location) # Get the address, to evaluate . as ->. inner_expr = get_expr_addr(inner_expr) else: raise CairoTypeError( f"Cannot apply dot-operator to non-struct type '{inner_type.format()}'.", location=expr.location) try: struct_def = get_struct_definition( struct_name=inner_type.resolved_scope, identifier_manager=self.identifiers) except Exception as exc: raise CairoTypeError(str(exc), location=expr.location) if expr.member.name not in struct_def.members: raise CairoTypeError( f"Member '{expr.member.name}' does not appear in definition of struct " f"'{inner_type.format()}'.", location=expr.location) member_definition = struct_def.members[expr.member.name] member_type = member_definition.cairo_type member_offset = member_definition.offset if member_offset == 0: simplified_expr = ExprDeref(addr=inner_expr, location=expr.location) else: mem_offset_expr = ExprConst(val=member_offset, location=expr.location) simplified_expr = ExprDeref(addr=ExprOperator( a=inner_expr, op='+', b=mem_offset_expr, location=expr.location), location=expr.location) return simplified_expr, member_type
def process_retdata( self, ret_struct_ptr: Expression, ret_struct_type: CairoType, struct_def: StructDefinition, location: Optional[Location]) -> Tuple[Expression, Expression]: """ Processes the return values and return retdata_size and retdata_ptr. """ # Verify all of the return types are felts. for _, member_def in struct_def.members.items(): cairo_type = member_def.cairo_type if not isinstance(cairo_type, TypeFelt): raise PreprocessorError( f'Unsupported argument type {cairo_type.format()}.', location=cairo_type.location) self.add_reference( name=self.current_scope + 'retdata_ptr', value=ExprDeref( addr=ExprReg(reg=Register.AP), location=location, ), cairo_type=TypePointer(TypeFelt()), require_future_definition=False, location=location) self.visit(CodeElementHint( hint=ExprHint( hint_code='memory[ap] = segments.add()', n_prefix_newlines=0, location=location, ), location=location, )) # Skip check of hint whitelist as it fails before the workaround below. super().visit_CodeElementInstruction(CodeElementInstruction(InstructionAst( body=AddApInstruction(ExprConst(1)), inc_ap=False, location=location))) # Remove the references from the last instruction's flow tracking as they are # not needed by the hint and they cause the hint whitelist to fail. assert len(self.instructions[-1].hints) == 1 hint, hint_flow_tracking_data = self.instructions[-1].hints[0] self.instructions[-1].hints[0] = hint, dataclasses.replace( hint_flow_tracking_data, reference_ids={}) self.visit(CodeElementCompoundAssertEq( ExprDeref( ExprCast(ExprIdentifier('retdata_ptr'), TypePointer(ret_struct_type))), ret_struct_ptr)) return (ExprConst(struct_def.size), ExprIdentifier('retdata_ptr'))
def test_deref_expr(): expr = parse_expr('[[fp - 7] + 3]') assert expr == \ ExprDeref( addr=ExprOperator( a=ExprDeref( addr=ExprOperator( a=ExprReg(reg=Register.FP), op='-', b=ExprConst(val=7))), op='+', b=ExprConst(val=3))) assert expr.format() == '[[fp - 7] + 3]'
def test_add_expr(): expr = parse_expr('[fp + 1] + [ap - x]') assert expr == \ ExprOperator( a=ExprDeref( addr=ExprOperator( a=ExprReg(reg=Register.FP), op='+', b=ExprConst(val=1))), op='+', b=ExprDeref( addr=ExprOperator( a=ExprReg(reg=Register.AP), op='-', b=ExprIdentifier(name='x')))) assert expr.format() == '[fp + 1] + [ap - x]' assert parse_expr('[ap-7]+37').format() == '[ap - 7] + 37'
def rewrite_ExprDeref(self, expr: ExprDeref, sim: SimplicityLevel): if is_simple_deref(expr): # This is already a simple expression, just return it. return expr expr = ExprDeref(addr=self.rewrite(expr.addr, SimplicityLevel.DEREF_OFFSET), location=expr.location) return expr if sim is SimplicityLevel.OPERATION else self.wrap(expr)
def test_subscript_expr(): assert parse_expr('x[y]').format() == 'x[y]' assert parse_expr('[x][y][z][w]').format() == '[x][y][z][w]' assert parse_expr(' x [ [ y[z[w]] ] ]').format() == 'x[[y[z[w]]]]' assert parse_expr(' (x+y)[z+w] ').format() == '(x + y)[z + w]' assert parse_expr( '(&x)[3][(a-b)*2][&c]').format() == '(&x)[3][(a - b) * 2][&c]' assert parse_expr('x[i+n*j]').format() == 'x[i + n * j]' assert parse_expr('x+[y][z]').format() == 'x + [y][z]' assert parse_expr('[x][y][[z]]') == \ ExprSubscript( expr=ExprSubscript( expr=ExprDeref(addr=ExprIdentifier(name='x')), offset=ExprIdentifier(name='y') ), offset=ExprDeref(addr=ExprIdentifier(name='z'))) with pytest.raises(ParserError): parse_expr('x[)]') with pytest.raises(ParserError): parse_expr('x[]')
def create_simple_ref_expr(reg: Register, offset: int, cairo_type: CairoType, location: Optional[Location]) -> Expression: """ Creates an expression of the form '[cast(reg + offset, cairo_type*)]'. """ return ExprDeref(addr=ExprCast(expr=ExprOperator( a=ExprReg(reg=reg, location=location), op='+', b=ExprConst(val=offset, location=location), location=location), dest_type=TypePointer(pointee=cairo_type, location=location), location=location), location=location)
def eval( self, reference_manager: ReferenceManager, flow_tracking_data: FlowTrackingData) -> \ Expression: reference = flow_tracking_data.resolve_reference( reference_manager=reference_manager, name=self.parent.full_name) assert isinstance(flow_tracking_data, FlowTrackingDataActual), \ 'Resolved references can only come from FlowTrackingDataActual.' expr, expr_type = simplify_type_system(reference.eval(flow_tracking_data.ap_tracking)) for member_name in self.member_path.path: if isinstance(expr_type, TypeStruct): expr_type = expr_type.get_pointer_type() # In this case, take the address of the reference value. to_addr = lambda expr: ExprAddressOf(expr=expr) else: to_addr = lambda expr: expr if not isinstance(expr_type, TypePointer) or \ not isinstance(expr_type.pointee, TypeStruct): raise DefinitionError('Member access requires a type of the form Struct*.') qualified_member = expr_type.pointee.resolved_scope + member_name if qualified_member not in self.identifier_values: raise DefinitionError(f"Member '{qualified_member}' was not found.") member_definition = self.identifier_values[qualified_member] if not isinstance(member_definition, MemberDefinition): raise DefinitionError( f"Expected reference offset '{qualified_member}' to be a member, " f'found {member_definition.TYPE}.') offset_value = member_definition.offset expr_type = member_definition.cairo_type expr = ExprDeref(addr=ExprOperator(a=to_addr(expr), op='+', b=ExprConst(offset_value))) return ExprCast( expr=expr, dest_type=expr_type, )
def remove_parentheses(expr): """ Removes the parentheses (ExprParentheses) from an arithmetic expression. """ if isinstance(expr, ExprParentheses): return remove_parentheses(expr.val) if isinstance(expr, ExprOperator): return ExprOperator(a=remove_parentheses(expr.a), op=expr.op, b=remove_parentheses(expr.b)) if isinstance(expr, ExprAddressOf): return ExprAddressOf(expr=remove_parentheses(expr.expr)) if isinstance(expr, ExprNeg): return ExprNeg(val=remove_parentheses(expr.val)) if isinstance(expr, ExprDeref): return ExprDeref(addr=remove_parentheses(expr.addr)) if isinstance(expr, ExprDot): return ExprDot(expr=remove_parentheses(expr.expr), member=expr.member) if isinstance(expr, ExprSubscript): return ExprSubscript(expr=remove_parentheses(expr.expr), offset=remove_parentheses(expr.offset)) return expr
def atom_deref(self, value, meta): return ExprDeref(addr=value[1], notes=value[0], location=self.meta2loc(meta))
def visit_ExprDeref(self, expr: ExprDeref): return ExprDeref(addr=self.visit(expr.addr), location=self.location_modifier(expr.location))
def test_instruction(): # AssertEq. expr = parse_instruction('[ap] = [fp]; ap++') assert expr == \ InstructionAst( body=AssertEqInstruction( a=ExprDeref( addr=ExprReg(reg=Register.AP)), b=ExprDeref( addr=ExprReg(reg=Register.FP))), inc_ap=True) assert expr.format() == '[ap] = [fp]; ap++' assert parse_instruction( '[ap+5] = [fp]+[ap] - 5').format() == '[ap + 5] = [fp] + [ap] - 5' assert parse_instruction('[ap+5]+3= [fp]*7;ap ++ ').format() == \ '[ap + 5] + 3 = [fp] * 7; ap++' # Jump. expr = parse_instruction('jmp rel [ap] + x; ap++') assert expr == \ InstructionAst( body=JumpInstruction( val=ExprOperator( a=ExprDeref(addr=ExprReg(reg=Register.AP)), op='+', b=ExprIdentifier(name='x')), relative=True), inc_ap=True) assert expr.format() == 'jmp rel [ap] + x; ap++' assert parse_instruction(' jmp abs[ap]+x').format() == 'jmp abs [ap] + x' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('jmp abs') with pytest.raises(ParserError): parse_instruction('jmpabs[ap]') # JumpToLabel. expr = parse_instruction('jmp label') assert expr == \ InstructionAst( body=JumpToLabelInstruction( label=ExprIdentifier(name='label'), condition=None), inc_ap=False) assert expr.format() == 'jmp label' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('jmp [fp]') with pytest.raises(ParserError): parse_instruction('jmp 7') # Jnz. expr = parse_instruction('jmp rel [ap] + x if [fp + 3] != 0') assert expr == \ InstructionAst( body=JnzInstruction( jump_offset=ExprOperator( a=ExprDeref(addr=ExprReg(reg=Register.AP)), op='+', b=ExprIdentifier(name='x')), condition=ExprDeref( addr=ExprOperator( a=ExprReg(reg=Register.FP), op='+', b=ExprConst(val=3)))), inc_ap=False) assert expr.format() == 'jmp rel [ap] + x if [fp + 3] != 0' assert parse_instruction(' jmp rel 17 if[fp]!=0;ap++').format() == \ 'jmp rel 17 if [fp] != 0; ap++' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('jmprel 17 if x != 0') with pytest.raises(ParserError): parse_instruction('jmp 17 if x') with pytest.raises(ParserError, match='!= 0'): parse_instruction('jmp rel 17 if x != 2') with pytest.raises(ParserError): parse_instruction('jmp rel [fp] ifx != 0') # Jnz to label. expr = parse_instruction('jmp label if [fp] != 0') assert expr == \ InstructionAst( body=JumpToLabelInstruction( label=ExprIdentifier('label'), condition=ExprDeref(addr=ExprReg(reg=Register.FP))), inc_ap=False) assert expr.format() == 'jmp label if [fp] != 0' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('jmp [fp] if [fp] != 0') with pytest.raises(ParserError): parse_instruction('jmp 7 if [fp] != 0') # Call abs. expr = parse_instruction('call abs [fp] + x') assert expr == \ InstructionAst( body=CallInstruction( val=ExprOperator( a=ExprDeref(addr=ExprReg(reg=Register.FP)), op='+', b=ExprIdentifier(name='x')), relative=False), inc_ap=False) assert expr.format() == 'call abs [fp] + x' assert parse_instruction( 'call abs 17;ap++').format() == 'call abs 17; ap++' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('call abs') with pytest.raises(ParserError): parse_instruction('callabs 7') # Call rel. expr = parse_instruction('call rel [ap] + x') assert expr == \ InstructionAst( body=CallInstruction( val=ExprOperator( a=ExprDeref(addr=ExprReg(reg=Register.AP)), op='+', b=ExprIdentifier(name='x')), relative=True), inc_ap=False) assert expr.format() == 'call rel [ap] + x' assert parse_instruction( 'call rel 17;ap++').format() == 'call rel 17; ap++' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('call rel') with pytest.raises(ParserError): parse_instruction('callrel 7') # Call label. expr = parse_instruction('call label') assert expr == \ InstructionAst( body=CallLabelInstruction( label=ExprIdentifier(name='label')), inc_ap=False) assert expr.format() == 'call label' assert parse_instruction( 'call label ;ap++').format() == 'call label; ap++' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('call [fp]') with pytest.raises(ParserError): parse_instruction('call 7') # Ret. expr = parse_instruction('ret') assert expr == \ InstructionAst( body=RetInstruction(), inc_ap=False) assert expr.format() == 'ret' # AddAp. expr = parse_instruction('ap += [fp] + 2') assert expr == \ InstructionAst( body=AddApInstruction( expr=ExprOperator( a=ExprDeref( addr=ExprReg(reg=Register.FP)), op='+', b=ExprConst(val=2))), inc_ap=False) assert expr.format() == 'ap += [fp] + 2' assert parse_instruction('ap +=[ fp]+ 2').format() == 'ap += [fp] + 2' assert parse_instruction( 'ap +=[ fp]+ 2;ap ++').format() == 'ap += [fp] + 2; ap++'
def visit_ExprSubscript( self, expr: ExprSubscript) -> Tuple[Expression, CairoType]: inner_expr, inner_type = self.visit(expr.expr) offset_expr, offset_type = self.visit(expr.offset) if isinstance(inner_type, TypeTuple): self.verify_offset_is_felt(offset_type, offset_expr.location) offset_expr = ExpressionSimplifier().visit(offset_expr) if not isinstance(offset_expr, ExprConst): raise CairoTypeError( 'Subscript-operator for tuples supports only constant offsets, found ' f"'{type(offset_expr).__name__}'.", location=offset_expr.location) offset_value = offset_expr.val tuple_len = len(inner_type.members) if not 0 <= offset_value < tuple_len: raise CairoTypeError( f'Tuple index {offset_value} is out of range [0, {tuple_len}).', location=expr.location) item_type = inner_type.members[offset_value] if isinstance(inner_expr, ExprTuple): assert len(inner_expr.members.args) == tuple_len return ( # Take the inner item, but keep the original expression's location. dataclasses.replace( inner_expr.members.args[offset_value].expr, location=expr.location), item_type) elif isinstance(inner_expr, ExprDeref): # Handles pointers cast as tuples*, e.g. `[cast(ap, (felt, felt)*][0]`. addr = inner_expr.addr offset_in_felts = ExprConst(val=sum( map(self.get_size, inner_type.members[:offset_value])), location=offset_expr.location) addr_with_offset = ExprOperator(a=addr, op='+', b=offset_in_felts, location=expr.location) return ExprDeref(addr=addr_with_offset, location=expr.location), item_type else: raise CairoTypeError( 'Unexpected expression typed as TypeTuple. Expected ExprTuple or ExprDeref, ' f"found '{type(inner_expr).__name__}'.", location=expr.location) elif isinstance(inner_type, TypePointer): self.verify_offset_is_felt(offset_type, offset_expr.location) try: # If pointed type is struct, get_size could throw IdentifierErrors. We catch and # convert them to CairoTypeErrors. element_size = self.get_size(inner_type.pointee) except Exception as exc: raise CairoTypeError(str(exc), location=expr.location) element_size_expr = ExprConst(val=element_size, location=expr.location) modified_offset_expr = ExprOperator(a=offset_expr, op='*', b=element_size_expr, location=expr.location) simplified_expr = ExprDeref(addr=ExprOperator( a=inner_expr, op='+', b=modified_offset_expr, location=expr.location), location=expr.location) return simplified_expr, inner_type.pointee else: raise CairoTypeError( 'Cannot apply subscript-operator to non-pointer, non-tuple type ' f"'{inner_type.format()}'.", location=expr.location)
def create_func_wrapper(self, elm: CodeElementFunction, func_alias_name: str): """ Generates a wrapper that converts between the StarkNet contract ABI and the Cairo calling convention. Arguments: elm - the CodeElementFunction of the wrapped function. func_alias_name - an alias for the FunctionDefention in the current scope. """ os_context = self.get_os_context() func_location = elm.identifier.location assert func_location is not None # We expect the call stack to look as follows: # pointer to os_context struct. # calldata size. # pointer to the call data array. # ret_fp. # ret_pc. os_context_ptr = ExprDeref( addr=ExprOperator( ExprReg(reg=Register.FP, location=func_location), '+', ExprConst(-5, location=func_location), location=func_location), location=func_location) calldata_size = ExprDeref( addr=ExprOperator( ExprReg(reg=Register.FP, location=func_location), '+', ExprConst(-4, location=func_location), location=func_location), location=func_location) calldata_ptr = ExprDeref( addr=ExprOperator( ExprReg(reg=Register.FP, location=func_location), '+', ExprConst(-3, location=func_location), location=func_location), location=func_location) implicit_arguments = None implicit_arguments_identifiers: Dict[str, TypedIdentifier] = {} if elm.implicit_arguments is not None: args = [] for typed_identifier in elm.implicit_arguments.identifiers: ptr_name = typed_identifier.name if ptr_name not in os_context: raise PreprocessorError( f"Unexpected implicit argument '{ptr_name}' in an external function.", location=typed_identifier.identifier.location) implicit_arguments_identifiers[ptr_name] = typed_identifier # Add the assignment expression 'ptr_name = ptr_name' to the implicit arg list. args.append(ExprAssignment( identifier=typed_identifier.identifier, expr=typed_identifier.identifier, location=typed_identifier.location, )) implicit_arguments = ArgList( args=args, notes=[], has_trailing_comma=True, location=elm.implicit_arguments.location) return_args_exprs: List[Expression] = [] # Create references. for ptr_name, index in os_context.items(): ref_name = self.current_scope + ptr_name arg_identifier = implicit_arguments_identifiers.get(ptr_name) if arg_identifier is None: location: Optional[Location] = func_location cairo_type: CairoType = TypeFelt(location=location) else: location = arg_identifier.location cairo_type = self.resolve_type(arg_identifier.get_type()) # Add a reference of the form # 'let ref_name = [cast(os_context_ptr + index, cairo_type*)]'. self.add_reference( name=ref_name, value=ExprDeref( addr=ExprCast( ExprOperator( os_context_ptr, '+', ExprConst(index, location=location), location=location), dest_type=TypePointer(pointee=cairo_type, location=cairo_type.location), location=cairo_type.location), location=location), cairo_type=cairo_type, location=location, require_future_definition=False) assert index == len(return_args_exprs), 'Unexpected index.' return_args_exprs.append(ExprIdentifier(name=ptr_name, location=func_location)) arg_struct_def = self.get_struct_definition( name=ScopedName.from_string(func_alias_name) + CodeElementFunction.ARGUMENT_SCOPE, location=func_location) code_elements, call_args = process_calldata( calldata_ptr=calldata_ptr, calldata_size=calldata_size, identifiers=self.identifiers, struct_def=arg_struct_def, has_range_check_builtin='range_check_ptr' in os_context, location=func_location, ) for code_element in code_elements: self.visit(code_element) self.visit(CodeElementFuncCall( func_call=RvalueFuncCall( func_ident=ExprIdentifier(name=func_alias_name, location=func_location), arguments=call_args, implicit_arguments=implicit_arguments, location=func_location))) ret_struct_name = ScopedName.from_string(func_alias_name) + CodeElementFunction.RETURN_SCOPE ret_struct_type = self.resolve_type(TypeStruct(ret_struct_name, False)) ret_struct_def = self.get_struct_definition( name=ret_struct_name, location=func_location) ret_struct_expr = create_simple_ref_expr( reg=Register.AP, offset=-ret_struct_def.size, cairo_type=ret_struct_type, location=func_location) self.add_reference( name=self.current_scope + 'ret_struct', value=ret_struct_expr, cairo_type=ret_struct_type, require_future_definition=False, location=func_location) # Add function return values. retdata_size, retdata_ptr = self.process_retdata( ret_struct_ptr=ExprIdentifier(name='ret_struct'), ret_struct_type=ret_struct_type, struct_def=ret_struct_def, location=func_location, ) return_args_exprs += [retdata_size, retdata_ptr] # Push the return values. self.push_compound_expressions( compound_expressions=[self.simplify_expr_as_felt(expr) for expr in return_args_exprs], location=func_location, ) # Add a ret instruction. self.visit(CodeElementInstruction( instruction=InstructionAst( body=RetInstruction(), inc_ap=False, location=func_location))) # Add an entry to the ABI. external_decorator = self.get_external_decorator(elm) assert external_decorator is not None is_view = external_decorator.name == 'view' if external_decorator.name == L1_HANDLER_DECORATOR: entry_type = 'l1_handler' elif external_decorator.name in [EXTERNAL_DECORATOR, VIEW_DECORATOR]: entry_type = 'function' else: raise NotImplementedError(f'Unsupported decorator {external_decorator.name}') entry_type = ( 'function' if external_decorator.name != L1_HANDLER_DECORATOR else L1_HANDLER_DECORATOR) self.add_abi_entry( name=elm.name, arg_struct_def=arg_struct_def, ret_struct_def=ret_struct_def, is_view=is_view, entry_type=entry_type)