def test_import(): # Test module names without periods. res = parse_code_element('from a import b') assert res == CodeElementImport( path=ExprIdentifier(name='a'), import_items=[ AliasedIdentifier(orig_identifier=ExprIdentifier(name='b'), local_name=None) ]) assert res.format(allowed_line_length=100) == 'from a import b' # Test module names without periods, with aliasing. res = parse_code_element('from a import b as c') assert res == CodeElementImport( path=ExprIdentifier(name='a'), import_items=[ AliasedIdentifier(orig_identifier=ExprIdentifier(name='b'), local_name=ExprIdentifier(name='c')) ]) assert res.format(allowed_line_length=100) == 'from a import b as c' # Test module names with periods. res = parse_code_element('from a.b12.c4 import lib345') assert res == CodeElementImport( path=ExprIdentifier(name='a.b12.c4'), import_items=[ AliasedIdentifier(orig_identifier=ExprIdentifier(name='lib345'), local_name=None) ]) assert res.format(allowed_line_length=100) == 'from a.b12.c4 import lib345' # Test multiple imports. res = parse_code_element('from lib import a,b as b2, c') assert res == CodeElementImport( path=ExprIdentifier(name='lib'), import_items=[ AliasedIdentifier(orig_identifier=ExprIdentifier(name='a'), local_name=None), AliasedIdentifier(orig_identifier=ExprIdentifier(name='b'), local_name=ExprIdentifier(name='b2')), AliasedIdentifier(orig_identifier=ExprIdentifier(name='c'), local_name=None), ]) assert res.format( allowed_line_length=100) == 'from lib import a, b as b2, c' assert res.format( allowed_line_length=20) == 'from lib import (\n a, b as b2, c)' assert res == parse_code_element('from lib import (\n a, b as b2, c)') # Test module with bad identifier (with periods). with pytest.raises(ParserError): parse_expr('from a.b import c.d') # Test module with bad local name (with periods). with pytest.raises(ParserError): parse_expr('from a.b import c as d.d')
def process_calldata(calldata_ptr: Expression, calldata_size: Expression, identifiers: IdentifierManager, struct_def: StructDefinition, has_range_check_builtin: bool, location: Location) -> Tuple[List[CodeElement], ArgList]: """ Processes the calldata. Returns the expected size of the calldata and an ArgList that corresponds to 'struct_def'. Currently only the trivial case where struct consists only of felts is supported. """ def parse_code_element(code: str, parent_location: ParentLocation): filename = f'autogen/starknet/arg_parser/{hashlib.sha256(code.encode()).hexdigest()}.cairo' return parse( filename=filename, code=code, code_type='code_element', expected_type=CodeElement, parser_context=ParserContext(parent_location=parent_location)) struct_parent_location = (location, 'While handling calldata of') code_elements = [ parse_code_element( f'let __calldata_ptr : felt* = cast({calldata_ptr.format()}, felt*)', parent_location=struct_parent_location) ] args = [] prev_member: Optional[Tuple[str, MemberDefinition]] = None for member_name, member_def in struct_def.members.items(): member_location = member_def.location assert member_location is not None member_parent_location = ( member_location, f"While handling calldata argument '{member_name}'") cairo_type = member_def.cairo_type if isinstance(cairo_type, TypePointer) and isinstance( cairo_type.pointee, TypeFelt): has_len = prev_member is not None and prev_member[0] == f'{member_name}_len' and \ isinstance(prev_member[1].cairo_type, TypeFelt) if not has_len: raise PreprocessorError( f'Array argument "{member_name}" must be preceeded by a length argument ' f'named "{member_name}_len" of type felt.', location=member_location) if not has_range_check_builtin: raise PreprocessorError( "The 'range_check' builtin must be declared in the '%builtins' directive " 'when using array arguments in external functions.', location=member_location) code_element_strs = [ # Check that the length is positive. f'assert [range_check_ptr] = __calldata_arg_{member_name}_len', f'let range_check_ptr = range_check_ptr + 1', # Create the reference. f'let __calldata_arg_{member_name} : felt* = __calldata_ptr', # Use 'tempvar' instead of 'let' to avoid repeating this computation for the # following arguments. f'tempvar __calldata_ptr = __calldata_ptr + __calldata_arg_{member_name}_len', ] for code_element_str in code_element_strs: code_elements.append( parse_code_element(code_element_str, parent_location=member_parent_location)) elif isinstance(cairo_type, TypeFelt): code_elements.append( parse_code_element( f'let __calldata_arg_{member_name} = [__calldata_ptr]', parent_location=member_parent_location)) code_elements.append( parse_code_element(f'let __calldata_ptr = __calldata_ptr + 1', parent_location=member_parent_location)) else: raise PreprocessorError( f'Unsupported argument type {cairo_type.format()}.', location=cairo_type.location) args.append( ExprAssignment(identifier=ExprIdentifier(name=member_name, location=member_location), expr=ExprIdentifier( name=f'__calldata_arg_{member_name}', location=member_location), location=member_location)) prev_member = member_name, member_def code_elements.append( parse_code_element( f'let __calldata_actual_size = __calldata_ptr - cast({calldata_ptr.format()}, felt*)', parent_location=struct_parent_location)) code_elements.append( parse_code_element( f'assert {calldata_size.format()} = __calldata_actual_size', parent_location=struct_parent_location)) return code_elements, ArgList(args=args, notes=[Notes()] * (len(args) + 1), has_trailing_comma=True, location=location)
def test_instruction(): # AssertEq. expr = parse_instruction('[ap] = [fp]; ap++') assert expr == \ InstructionAst( body=AssertEqInstruction( a=ExprDeref( addr=ExprReg(reg=Register.AP)), b=ExprDeref( addr=ExprReg(reg=Register.FP))), inc_ap=True) assert expr.format() == '[ap] = [fp]; ap++' assert parse_instruction( '[ap+5] = [fp]+[ap] - 5').format() == '[ap + 5] = [fp] + [ap] - 5' assert parse_instruction('[ap+5]+3= [fp]*7;ap ++ ').format() == \ '[ap + 5] + 3 = [fp] * 7; ap++' # Jump. expr = parse_instruction('jmp rel [ap] + x; ap++') assert expr == \ InstructionAst( body=JumpInstruction( val=ExprOperator( a=ExprDeref(addr=ExprReg(reg=Register.AP)), op='+', b=ExprIdentifier(name='x')), relative=True), inc_ap=True) assert expr.format() == 'jmp rel [ap] + x; ap++' assert parse_instruction(' jmp abs[ap]+x').format() == 'jmp abs [ap] + x' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('jmp abs') with pytest.raises(ParserError): parse_instruction('jmpabs[ap]') # JumpToLabel. expr = parse_instruction('jmp label') assert expr == \ InstructionAst( body=JumpToLabelInstruction( label=ExprIdentifier(name='label'), condition=None), inc_ap=False) assert expr.format() == 'jmp label' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('jmp [fp]') with pytest.raises(ParserError): parse_instruction('jmp 7') # Jnz. expr = parse_instruction('jmp rel [ap] + x if [fp + 3] != 0') assert expr == \ InstructionAst( body=JnzInstruction( jump_offset=ExprOperator( a=ExprDeref(addr=ExprReg(reg=Register.AP)), op='+', b=ExprIdentifier(name='x')), condition=ExprDeref( addr=ExprOperator( a=ExprReg(reg=Register.FP), op='+', b=ExprConst(val=3)))), inc_ap=False) assert expr.format() == 'jmp rel [ap] + x if [fp + 3] != 0' assert parse_instruction(' jmp rel 17 if[fp]!=0;ap++').format() == \ 'jmp rel 17 if [fp] != 0; ap++' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('jmprel 17 if x != 0') with pytest.raises(ParserError): parse_instruction('jmp 17 if x') with pytest.raises(ParserError, match='!= 0'): parse_instruction('jmp rel 17 if x != 2') with pytest.raises(ParserError): parse_instruction('jmp rel [fp] ifx != 0') # Jnz to label. expr = parse_instruction('jmp label if [fp] != 0') assert expr == \ InstructionAst( body=JumpToLabelInstruction( label=ExprIdentifier('label'), condition=ExprDeref(addr=ExprReg(reg=Register.FP))), inc_ap=False) assert expr.format() == 'jmp label if [fp] != 0' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('jmp [fp] if [fp] != 0') with pytest.raises(ParserError): parse_instruction('jmp 7 if [fp] != 0') # Call abs. expr = parse_instruction('call abs [fp] + x') assert expr == \ InstructionAst( body=CallInstruction( val=ExprOperator( a=ExprDeref(addr=ExprReg(reg=Register.FP)), op='+', b=ExprIdentifier(name='x')), relative=False), inc_ap=False) assert expr.format() == 'call abs [fp] + x' assert parse_instruction( 'call abs 17;ap++').format() == 'call abs 17; ap++' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('call abs') with pytest.raises(ParserError): parse_instruction('callabs 7') # Call rel. expr = parse_instruction('call rel [ap] + x') assert expr == \ InstructionAst( body=CallInstruction( val=ExprOperator( a=ExprDeref(addr=ExprReg(reg=Register.AP)), op='+', b=ExprIdentifier(name='x')), relative=True), inc_ap=False) assert expr.format() == 'call rel [ap] + x' assert parse_instruction( 'call rel 17;ap++').format() == 'call rel 17; ap++' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('call rel') with pytest.raises(ParserError): parse_instruction('callrel 7') # Call label. expr = parse_instruction('call label') assert expr == \ InstructionAst( body=CallLabelInstruction( label=ExprIdentifier(name='label')), inc_ap=False) assert expr.format() == 'call label' assert parse_instruction( 'call label ;ap++').format() == 'call label; ap++' # Make sure the following are not OK. with pytest.raises(ParserError): parse_instruction('call [fp]') with pytest.raises(ParserError): parse_instruction('call 7') # Ret. expr = parse_instruction('ret') assert expr == \ InstructionAst( body=RetInstruction(), inc_ap=False) assert expr.format() == 'ret' # AddAp. expr = parse_instruction('ap += [fp] + 2') assert expr == \ InstructionAst( body=AddApInstruction( expr=ExprOperator( a=ExprDeref( addr=ExprReg(reg=Register.FP)), op='+', b=ExprConst(val=2))), inc_ap=False) assert expr.format() == 'ap += [fp] + 2' assert parse_instruction('ap +=[ fp]+ 2').format() == 'ap += [fp] + 2' assert parse_instruction( 'ap +=[ fp]+ 2;ap ++').format() == 'ap += [fp] + 2; ap++'
def create_func_wrapper(self, elm: CodeElementFunction, func_alias_name: str): """ Generates a wrapper that converts between the StarkNet contract ABI and the Cairo calling convention. Arguments: elm - the CodeElementFunction of the wrapped function. func_alias_name - an alias for the FunctionDefention in the current scope. """ os_context = self.get_os_context() func_location = elm.identifier.location assert func_location is not None # We expect the call stack to look as follows: # pointer to os_context struct. # calldata size. # pointer to the call data array. # ret_fp. # ret_pc. os_context_ptr = ExprDeref( addr=ExprOperator( ExprReg(reg=Register.FP, location=func_location), '+', ExprConst(-5, location=func_location), location=func_location), location=func_location) calldata_size = ExprDeref( addr=ExprOperator( ExprReg(reg=Register.FP, location=func_location), '+', ExprConst(-4, location=func_location), location=func_location), location=func_location) calldata_ptr = ExprDeref( addr=ExprOperator( ExprReg(reg=Register.FP, location=func_location), '+', ExprConst(-3, location=func_location), location=func_location), location=func_location) implicit_arguments = None implicit_arguments_identifiers: Dict[str, TypedIdentifier] = {} if elm.implicit_arguments is not None: args = [] for typed_identifier in elm.implicit_arguments.identifiers: ptr_name = typed_identifier.name if ptr_name not in os_context: raise PreprocessorError( f"Unexpected implicit argument '{ptr_name}' in an external function.", location=typed_identifier.identifier.location) implicit_arguments_identifiers[ptr_name] = typed_identifier # Add the assignment expression 'ptr_name = ptr_name' to the implicit arg list. args.append(ExprAssignment( identifier=typed_identifier.identifier, expr=typed_identifier.identifier, location=typed_identifier.location, )) implicit_arguments = ArgList( args=args, notes=[], has_trailing_comma=True, location=elm.implicit_arguments.location) return_args_exprs: List[Expression] = [] # Create references. for ptr_name, index in os_context.items(): ref_name = self.current_scope + ptr_name arg_identifier = implicit_arguments_identifiers.get(ptr_name) if arg_identifier is None: location: Optional[Location] = func_location cairo_type: CairoType = TypeFelt(location=location) else: location = arg_identifier.location cairo_type = self.resolve_type(arg_identifier.get_type()) # Add a reference of the form # 'let ref_name = [cast(os_context_ptr + index, cairo_type*)]'. self.add_reference( name=ref_name, value=ExprDeref( addr=ExprCast( ExprOperator( os_context_ptr, '+', ExprConst(index, location=location), location=location), dest_type=TypePointer(pointee=cairo_type, location=cairo_type.location), location=cairo_type.location), location=location), cairo_type=cairo_type, location=location, require_future_definition=False) assert index == len(return_args_exprs), 'Unexpected index.' return_args_exprs.append(ExprIdentifier(name=ptr_name, location=func_location)) arg_struct_def = self.get_struct_definition( name=ScopedName.from_string(func_alias_name) + CodeElementFunction.ARGUMENT_SCOPE, location=func_location) code_elements, call_args = process_calldata( calldata_ptr=calldata_ptr, calldata_size=calldata_size, identifiers=self.identifiers, struct_def=arg_struct_def, has_range_check_builtin='range_check_ptr' in os_context, location=func_location, ) for code_element in code_elements: self.visit(code_element) self.visit(CodeElementFuncCall( func_call=RvalueFuncCall( func_ident=ExprIdentifier(name=func_alias_name, location=func_location), arguments=call_args, implicit_arguments=implicit_arguments, location=func_location))) ret_struct_name = ScopedName.from_string(func_alias_name) + CodeElementFunction.RETURN_SCOPE ret_struct_type = self.resolve_type(TypeStruct(ret_struct_name, False)) ret_struct_def = self.get_struct_definition( name=ret_struct_name, location=func_location) ret_struct_expr = create_simple_ref_expr( reg=Register.AP, offset=-ret_struct_def.size, cairo_type=ret_struct_type, location=func_location) self.add_reference( name=self.current_scope + 'ret_struct', value=ret_struct_expr, cairo_type=ret_struct_type, require_future_definition=False, location=func_location) # Add function return values. retdata_size, retdata_ptr = self.process_retdata( ret_struct_ptr=ExprIdentifier(name='ret_struct'), ret_struct_type=ret_struct_type, struct_def=ret_struct_def, location=func_location, ) return_args_exprs += [retdata_size, retdata_ptr] # Push the return values. self.push_compound_expressions( compound_expressions=[self.simplify_expr_as_felt(expr) for expr in return_args_exprs], location=func_location, ) # Add a ret instruction. self.visit(CodeElementInstruction( instruction=InstructionAst( body=RetInstruction(), inc_ap=False, location=func_location))) # Add an entry to the ABI. external_decorator = self.get_external_decorator(elm) assert external_decorator is not None is_view = external_decorator.name == 'view' if external_decorator.name == L1_HANDLER_DECORATOR: entry_type = 'l1_handler' elif external_decorator.name in [EXTERNAL_DECORATOR, VIEW_DECORATOR]: entry_type = 'function' else: raise NotImplementedError(f'Unsupported decorator {external_decorator.name}') entry_type = ( 'function' if external_decorator.name != L1_HANDLER_DECORATOR else L1_HANDLER_DECORATOR) self.add_abi_entry( name=elm.name, arg_struct_def=arg_struct_def, ret_struct_def=ret_struct_def, is_view=is_view, entry_type=entry_type)
def visit_ExprIdentifier(self, expr: ExprIdentifier): return ExprIdentifier(name=expr.name, location=self.location_modifier(expr.location))
def identifier_def(self, value, meta): return ExprIdentifier(name=value[0].value, location=self.meta2loc(meta))
def identifier(self, value, meta): return ExprIdentifier(name='.'.join(x.value for x in value), location=self.meta2loc(meta))