Example #1
0
    def add_name_definition(self,
                            name: ScopedName,
                            identifier_definition: IdentifierDefinition,
                            location,
                            require_future_definition=True):
        """
        Adds a definition of an identifier named 'name' at 'location'.
        The identifier must already be found as a FutureIdentifierDefinition in 'self.identifiers'
        and be of a compatible type, unless it's a temporary variable.
        """

        future_definition = self.identifiers.get_by_full_name(name)
        if future_definition is None:
            if require_future_definition:
                self.handle_missing_future_definition(name=name,
                                                      location=location)
        else:
            if not isinstance(future_definition, FutureIdentifierDefinition):
                raise PreprocessorError(f"Redefinition of '{name}'.",
                                        location=location)
            if not isinstance(identifier_definition,
                              future_definition.identifier_type):
                raise PreprocessorError(
                    f"Identifier '{name}' expected to be of type "
                    f"'{future_definition.identifier_type.__name__}', not "
                    f"'{type(identifier_definition).__name__}'.",
                    location=location)

        self.identifiers.add_identifier(name, identifier_definition)
        self.identifier_locations[name] = location
    def handle_struct_definition(
            self, struct_name: ScopedName, code_block: CodeBlock, location):
        members_list: List[MemberInfo] = []
        for commented_code_element in code_block.code_elements:
            elm = commented_code_element.code_elm

            if isinstance(elm, CodeElementEmptyLine):
                continue

            if not isinstance(elm, CodeElementMember):
                raise PreprocessorError(
                    'Unexpected statement inside a struct definition.',
                    location=getattr(elm, 'location', location))

            assert_no_modifier(elm.typed_identifier)

            if elm.typed_identifier.expr_type is None:
                raise PreprocessorError(
                    'Struct members must be explicitly typed (e.g., member x : felt).',
                    location=elm.typed_identifier.location)

            identifier = elm.typed_identifier.identifier

            members_list.append(MemberInfo(
                name=identifier.name,
                cairo_type=elm.typed_identifier.get_type(),
                location=identifier.location))

        self.add_struct_definition(
            members_list=members_list, struct_name=struct_name, location=location)
Example #3
0
 def get_size(self, cairo_type: CairoType):
     """
     Returns the size of the given type.
     """
     if isinstance(cairo_type, (TypeFelt, TypePointer)):
         return 1
     elif isinstance(cairo_type, TypeStruct):
         if cairo_type.is_fully_resolved:
             try:
                 return get_struct_definition(
                     struct_name=cairo_type.scope,
                     identifier_manager=self.identifiers).size
             except DefinitionError as exc:
                 raise PreprocessorError(str(exc),
                                         location=cairo_type.location)
         else:
             return self.get_struct_size(struct_name=cairo_type.scope,
                                         location=cairo_type.location)
     elif isinstance(cairo_type, TypeTuple):
         return sum(
             self.get_size(member_type)
             for member_type in cairo_type.members)
     else:
         raise NotImplementedError(
             f'Type {type(cairo_type).__name__} is not supported.')
Example #4
0
    def get_canonical_struct_name(self, scoped_name: ScopedName,
                                  location: Optional[Location]):
        """
        Returns the canonical name for the struct given by scoped_name in the current
        accessible_scopes.
        This function also works for structs that do not have a StructDefinition yet.

        For example when parsing:
            struct S:
                member a : S*
            end
        We have to lookup S before S is defined in the identifier manager.

        location is used if there is an error.
        """
        result = self.identifiers.search(self.accessible_scopes, scoped_name)
        canonical_name = result.get_canonical_name()
        identifier_def = result.identifier_definition

        identifier_type = identifier_def.TYPE
        if isinstance(identifier_def, FutureIdentifierDefinition):
            identifier_type = identifier_def.identifier_type.TYPE  # type: ignore

        if identifier_type != StructDefinition.TYPE:
            raise PreprocessorError(f"""\
Expected '{scoped_name}' to be a {StructDefinition.TYPE}. Found: '{identifier_type}'.""",
                                    location=location)

        return canonical_name
Example #5
0
    def visit_CodeElementFunction(self, elm: CodeElementFunction):
        storage_var, storage_var_location = is_storage_var(elm)
        if storage_var:
            if self.file_lang != STARKNET_LANG_DIRECTIVE:
                raise PreprocessorError(
                    '@storage_var can only be used in source files that contain the '
                    '"%lang starknet" directive.',
                    location=storage_var_location)
            # Add dummy references and calls that will be visited by the identifier collector
            # and the dependency graph.
            # Those statements will later be replaced by the real implementation.
            addr_func_body = """
let res = 0
call hash2
call normalize_address
"""
            read_func_body = """
let storage_addr = 0
call addr
call storage_read
"""
            write_func_body = """
let storage_addr = 0
call addr
call storage_write
"""
            return generate_storage_var_functions(
                elm,
                addr_func_body=addr_func_body,
                read_func_body=read_func_body,
                write_func_body=write_func_body,
                is_impl=False)

        return elm
    def add_struct_definition(
            self, members_list: List[MemberInfo], struct_name: ScopedName,
            location: Optional[Location]):

        offset = 0
        members: Dict[str, MemberDefinition] = {}
        for member_info in members_list:
            cairo_type = self.resolve_type(member_info.cairo_type)

            name = member_info.name
            if name in members:
                raise PreprocessorError(
                    f"Redefinition of '{struct_name + name}'.",
                    location=member_info.location)

            members[name] = MemberDefinition(
                offset=offset, cairo_type=cairo_type, location=member_info.location)
            offset += self.get_size(cairo_type)

        self.add_name_definition(
            struct_name,
            StructDefinition(
                full_name=struct_name,
                members=members,
                size=offset,
                location=location,
            ),
            location=location)
Example #7
0
 def resolve_type(self, cairo_type: CairoType) -> CairoType:
     """
     Resolves a CairoType instance to fully qualified name.
     """
     if isinstance(cairo_type, TypeFelt):
         return cairo_type
     elif isinstance(cairo_type, TypePointer):
         return dataclasses.replace(cairo_type,
                                    pointee=self.resolve_type(
                                        cairo_type.pointee))
     elif isinstance(cairo_type, TypeStruct):
         if cairo_type.is_fully_resolved:
             return cairo_type
         try:
             return dataclasses.replace(
                 cairo_type,
                 scope=self.get_canonical_struct_name(
                     scoped_name=cairo_type.scope,
                     location=cairo_type.location),
                 is_fully_resolved=True)
         except IdentifierError as exc:
             raise PreprocessorError(str(exc), location=cairo_type.location)
     elif isinstance(cairo_type, TypeTuple):
         return dataclasses.replace(cairo_type,
                                    members=[
                                        self.resolve_type(subtype)
                                        for subtype in cairo_type.members
                                    ])
     else:
         raise NotImplementedError(
             f'Type {type(cairo_type).__name__} is not supported.')
Example #8
0
def assert_no_modifier(typed_identifier: TypedIdentifier):
    """
    Throws a PreprocessorError if typed_identifier has a modifier.
    """
    if typed_identifier.modifier is not None:
        raise PreprocessorError(
            f"Unexpected modifier '{typed_identifier.modifier.format()}'.",
            location=typed_identifier.modifier.location)
    def visit_BuiltinsDirective(self, directive: BuiltinsDirective):
        super().visit_BuiltinsDirective(directive)
        assert self.builtins is not None

        if not is_subsequence(self.builtins, SUPPORTED_BUILTINS):
            raise PreprocessorError(
                f'{self.builtins} is not a subsequence of {SUPPORTED_BUILTINS}.',
                location=directive.location)
 def rewrite_ExprReg(self, expr: ExprReg, sim: SimplicityLevel):
     if expr.reg is Register.AP:
         raise PreprocessorError(
             'ap may only be used in an expression of the form [ap + <const>].',
             location=expr.location)
     elif expr.reg is Register.FP:
         return self.rewrite(expr=self.context.get_fp_val(expr.location),
                             sim=sim)
     else:
         raise NotImplementedError(f'Unknown register {expr.reg}.')
Example #11
0
def get_return_type(elm: CodeElementFunction) -> CairoType:
    returns_single_value = elm.returns is not None and len(
        elm.returns.identifiers) == 1
    if not returns_single_value:
        raise PreprocessorError(
            'Storage variables must return exactly one value.',
            location=elm.returns.location
            if elm.returns is not None else elm.identifier.location)
    assert elm.returns is not None
    return elm.returns.identifiers[0].get_type()
 def search_identifier(
         self, name: str, location: Optional[Location]) -> Optional[IdentifierDefinition]:
     """
     Searches for the given identifier in self.identifiers and returns the corresponding
     IdentifierDefinition.
     """
     try:
         result = self.identifiers.search(self.accessible_scopes, ScopedName.from_string(name))
         return resolve_search_result(result, identifiers=self.identifiers)
     except IdentifierError as exc:
         raise PreprocessorError(str(exc), location=location)
    def add_identifier(self, name: ScopedName,
                       identifier_definition: IdentifierDefinition,
                       location: Optional[Location]):
        """
        Adds an identifier with name 'name' and the given identifier definition at location
        'location'.
        """
        existing_definition = self.identifiers.get_by_full_name(name)
        if existing_definition is not None:
            if not isinstance(existing_definition, FutureIdentifierDefinition) or \
                    not isinstance(identifier_definition, FutureIdentifierDefinition):
                raise PreprocessorError(f"Redefinition of '{name}'.",
                                        location=location)
            if (existing_definition.identifier_type,
                    identifier_definition.identifier_type) != (
                        ReferenceDefinition, ReferenceDefinition):
                # Redefinition is only allowed in reference rebinding.
                raise PreprocessorError(f"Redefinition of '{name}'.",
                                        location=location)

        self.identifiers.add_identifier(name, identifier_definition)
    def process_retdata(
            self, ret_struct_ptr: Expression, ret_struct_type: CairoType,
            struct_def: StructDefinition,
            location: Optional[Location]) -> Tuple[Expression, Expression]:
        """
        Processes the return values and return retdata_size and retdata_ptr.
        """

        # Verify all of the return types are felts.
        for _, member_def in struct_def.members.items():
            cairo_type = member_def.cairo_type
            if not isinstance(cairo_type, TypeFelt):
                raise PreprocessorError(
                    f'Unsupported argument type {cairo_type.format()}.',
                    location=cairo_type.location)

        self.add_reference(
            name=self.current_scope + 'retdata_ptr',
            value=ExprDeref(
                addr=ExprReg(reg=Register.AP),
                location=location,
            ),
            cairo_type=TypePointer(TypeFelt()),
            require_future_definition=False,
            location=location)

        self.visit(CodeElementHint(
            hint=ExprHint(
                hint_code='memory[ap] = segments.add()',
                n_prefix_newlines=0,
                location=location,
            ),
            location=location,
        ))

        # Skip check of hint whitelist as it fails before the workaround below.
        super().visit_CodeElementInstruction(CodeElementInstruction(InstructionAst(
            body=AddApInstruction(ExprConst(1)),
            inc_ap=False,
            location=location)))

        # Remove the references from the last instruction's flow tracking as they are
        # not needed by the hint and they cause the hint whitelist to fail.
        assert len(self.instructions[-1].hints) == 1
        hint, hint_flow_tracking_data = self.instructions[-1].hints[0]
        self.instructions[-1].hints[0] = hint, dataclasses.replace(
            hint_flow_tracking_data, reference_ids={})
        self.visit(CodeElementCompoundAssertEq(
            ExprDeref(
                ExprCast(ExprIdentifier('retdata_ptr'), TypePointer(ret_struct_type))),
            ret_struct_ptr))

        return (ExprConst(struct_def.size), ExprIdentifier('retdata_ptr'))
    def visit_CodeElementFunction(self, elm: CodeElementFunction):
        super().visit_CodeElementFunction(elm)

        external_decorator = self.get_external_decorator(elm)
        if external_decorator is None:
            return

        if self.file_lang != STARKNET_LANG_DIRECTIVE:
            raise PreprocessorError(
                'External decorators can only be used in source files that contain the '
                '"%lang starknet" directive.',
                location=external_decorator.location)

        location = elm.identifier.location

        # Retrieve the canonical name of the function before switching scopes.
        _, func_canonical_name = self.get_label(elm.name, location=location)
        assert func_canonical_name is not None

        scope = WRAPPER_SCOPE

        if external_decorator.name == L1_HANDLER_DECORATOR:
            self.validate_l1_handler_signature(elm)

        self.flow_tracking.revoke()
        with self.scoped(scope, parent=elm), self.set_reference_states({}):
            current_wrapper_scope = self.current_scope + elm.name

            self.add_name_definition(
                current_wrapper_scope,
                FunctionDefinition(  # type: ignore
                    pc=self.current_pc,
                    decorators=[identifier.name for identifier in elm.decorators],
                ),
                location=elm.identifier.location,
                require_future_definition=False)

            with self.scoped(current_wrapper_scope, parent=elm):
                # Generate an alias that will allow us to call the original function.
                func_alias_name = f'__wrapped_func'
                alias_canonical_name = current_wrapper_scope + func_alias_name
                self.add_future_definition(
                    name=alias_canonical_name,
                    future_definition=FutureIdentifierDefinition(
                        identifier_type=AliasDefinition),
                )

                self.add_name_definition(
                    name=alias_canonical_name,
                    identifier_definition=AliasDefinition(destination=func_canonical_name),
                    location=location)

                self.create_func_wrapper(elm=elm, func_alias_name=func_alias_name)
Example #16
0
    def get_struct_definition(
            self, name: ScopedName,
            location: Optional[Location]) -> StructDefinition:
        """
        Returns the struct definition that corresponds to the given identifier.
        location is used if there is an error.
        """

        try:
            res = self.identifiers.search(
                accessible_scopes=self.accessible_scopes, name=name)
            res.assert_fully_parsed()
        except IdentifierError as exc:
            raise PreprocessorError(str(exc), location=location)

        struct_def = res.identifier_definition
        if not isinstance(struct_def, StructDefinition):
            raise PreprocessorError(f"""\
Expected '{res.canonical_name}' to be a {StructDefinition.TYPE}. Found: '{struct_def.TYPE}'.""",
                                    location=location)

        return struct_def
    def validate_l1_handler_signature(self, elm: CodeElementFunction):
        """
        Validates the signature of an l1_handler.
        """

        args = elm.arguments.identifiers
        if len(args) == 0 or args[0].name != 'from_address':
            # An empty argument list has no location so we point to the identifier.
            location = elm.identifier.location if len(args) == 0 else args[0].location
            raise PreprocessorError(
                "The first argument of an L1 handler must be named 'from_address'.",
                location=location)

        from_address_type = args[0].get_type()
        if not isinstance(from_address_type, TypeFelt):
            raise PreprocessorError(
                "The type of 'from_address' must be felt.",
                location=from_address_type.location)

        if elm.returns is not None:
            raise PreprocessorError(
                'An L1 handler can not have a return value.',
                location=elm.returns.location)
    def add_identifier(
            self, name: ScopedName, location: Optional[Location], is_resolved: bool = False):
        if name.path[-1] == '_':
            return
        if is_resolved:
            canonical_name = name
        else:
            try:
                canonical_name = self.identifiers.search(
                    accessible_scopes=self.accessible_scopes, name=name).canonical_name
            except MissingIdentifierError as e:
                raise PreprocessorError(str(e), location=location)

        if self.current_function is not None:
            self.visited_identifiers.setdefault(self.current_function, []).append(
                canonical_name)
Example #19
0
    def visit_CodeElementImport(self, elm: CodeElementImport):
        alias_dst = ScopedName.from_string(
            elm.path.name) + elm.orig_identifier.name
        local_identifier = elm.identifier

        # Ensure destination is a valid identifier.
        if self.identifiers.get_by_full_name(alias_dst) is None:
            raise PreprocessorError(
                f"Scope '{elm.path.name}' does not include identifier "
                f"'{elm.orig_identifier.name}'.",
                location=elm.orig_identifier.location)

        # Add alias to identifiers.
        self.add_identifier(
            name=self.current_scope + local_identifier.name,
            identifier_definition=AliasDefinition(destination=alias_dst),
            location=elm.identifier.location)
        def handle_function_arguments(
                identifier_list: Optional[IdentifierList],
                struct_name: ScopedName):
            handle_struct_def(identifier_list=identifier_list,
                              struct_name=struct_name)
            if identifier_list is None:
                return

            for arg_id in identifier_list.identifiers:
                if arg_id.name == N_LOCALS_CONSTANT:
                    raise PreprocessorError(
                        f"The name '{N_LOCALS_CONSTANT}' is reserved and cannot be used as an "
                        'argument name.',
                        location=arg_id.location)
                # Within a function, arguments are also accessible directly.
                self.add_future_identifier(function_scope + arg_id.name,
                                           ReferenceDefinition,
                                           arg_id.location)
    def visit_CodeElementInstruction(self, elm: CodeElementInstruction):
        if self.hint_whitelist is not None:
            for hint, flow_tracking_data in self.next_instruction_hints:
                try:
                    self.hint_whitelist.verify_hint_secure(
                        hint=CairoHint(
                            code=hint.hint_code,
                            accessible_scopes=self.accessible_scopes,
                            flow_tracking_data=flow_tracking_data,
                        ),
                        reference_manager=self.flow_tracking.reference_manager)
                except InsecureHintError:
                    raise PreprocessorError(
                        """\
Hint is not whitelisted.
This may indicate that this library function cannot be used in StarkNet contracts.""",
                        location=hint.location)

        super().visit_CodeElementInstruction(elm)
Example #22
0
    def visit_CodeElementImport(self, elm: CodeElementImport):
        for import_item in elm.import_items:
            alias_dst = ScopedName.from_string(
                elm.path.name) + import_item.orig_identifier.name
            local_identifier = import_item.identifier

            # Ensure destination is a valid identifier.
            if self.identifiers.get_by_full_name(alias_dst) is None:
                try:
                    self.identifiers.get_scope(alias_dst)
                except IdentifierError:
                    raise PreprocessorError(
                        f"Cannot import '{import_item.orig_identifier.name}' "
                        f"from '{elm.path.name}'.",
                        location=import_item.orig_identifier.location)

            # Add alias to identifiers.
            self.add_identifier(
                name=self.current_scope + local_identifier.name,
                identifier_definition=AliasDefinition(destination=alias_dst),
                location=import_item.identifier.location)
Example #23
0
    def visit_CodeElementFunction(self, elm: CodeElementFunction):
        """
        Registers the function's identifier, arguments and return values, and then recursively
        visits the code block contained in the function.
        """
        function_scope = self.current_scope + elm.name
        args_scope = function_scope + CodeElementFunction.ARGUMENT_SCOPE
        rets_scope = function_scope + CodeElementFunction.RETURN_SCOPE

        self.add_future_identifier(function_scope, LabelDefinition,
                                   elm.identifier.location)
        self.add_future_identifier(args_scope + SIZE_CONSTANT, ConstDefinition,
                                   elm.identifier.location)
        self.add_future_identifier(rets_scope + SIZE_CONSTANT, ConstDefinition,
                                   elm.identifier.location)

        for arg_id in elm.arguments.identifiers:
            if arg_id.name == N_LOCALS_CONSTANT:
                raise PreprocessorError(
                    f"The name '{N_LOCALS_CONSTANT}' is reserved and cannot be used as an "
                    'argument name.',
                    location=arg_id.location)
            self.add_future_identifier(args_scope + arg_id.name,
                                       MemberDefinition, arg_id.location)
            # Within a function, arguments are also accessible directly.
            self.add_future_identifier(function_scope + arg_id.name,
                                       ReferenceDefinition, arg_id.location)
        if elm.returns is not None:
            for ret_id in elm.returns.identifiers:
                self.add_future_identifier(rets_scope + ret_id.name,
                                           MemberDefinition, ret_id.location)

        # Add SIZEOF_LOCALS for current block at identifier definition location if available.
        self.add_future_identifier(function_scope + N_LOCALS_CONSTANT,
                                   ConstDefinition, elm.identifier.location)
        super().visit_CodeElementFunction(elm)
 def visit_LangDirective(self, directive: LangDirective):
     if directive.name != STARKNET_LANG_DIRECTIVE:
         raise PreprocessorError(
             f'Unsupported %lang directive. Are you using the correct compiler?',
             location=directive.location,
         )
Example #25
0
def process_storage_var(visitor: IdentifierAwareVisitor,
                        elm: CodeElementFunction):
    for commented_code_elm in elm.code_block.code_elements:
        code_elm = commented_code_elm.code_elm
        if not isinstance(code_elm, CodeElementEmptyLine):
            if hasattr(code_elm, 'location'):
                location = code_elm.location  # type: ignore
            else:
                location = elm.identifier.location
            raise PreprocessorError(
                'Storage variables must have an empty body.',
                location=location)

    if elm.implicit_arguments is not None:
        raise PreprocessorError(
            'Storage variables must have no implicit arguments.',
            location=elm.implicit_arguments.location)

    for decorator in elm.decorators:
        if decorator.name != STORAGE_VAR_DECORATOR:
            raise PreprocessorError(
                'Storage variables must have no decorators in addition to '
                f'@{STORAGE_VAR_DECORATOR}.',
                location=decorator.location)

    for arg in elm.arguments.identifiers:
        arg_type = arg.get_type()
        if not isinstance(arg_type, TypeFelt):
            raise PreprocessorError(
                'Only felt arguments are supported in storage variables.',
                location=arg_type.location)

    unresolved_return_type = get_return_type(elm=elm)
    return_type = visitor.resolve_type(unresolved_return_type)
    if not check_felts_only_type(cairo_type=return_type,
                                 identifier_manager=visitor.identifiers):
        raise PreprocessorError(
            'The return type of storage variables must consist of felts.',
            location=elm.returns.location
            if elm.returns is not None else elm.identifier.location)
    var_size = visitor.get_size(return_type)

    if var_size > MAX_STORAGE_ITEM_SIZE:
        raise PreprocessorError(
            f'The storage variable size ({var_size}) exceeds the maximum value '
            f'({MAX_STORAGE_ITEM_SIZE}).',
            location=elm.returns.location
            if elm.returns is not None else elm.identifier.location)

    var_name = elm.identifier.name
    addr = storage_var_name_to_base_addr(var_name)
    addr_func_body = f'let res = {addr}\n'
    for arg in elm.arguments.identifiers:
        addr_func_body += \
            f'let (res) = hash2{{hash_ptr=pedersen_ptr}}(res, {arg.identifier.name})\n'
    if len(elm.arguments.identifiers) > 0:
        addr_func_body += 'let (res) = normalize_address(addr=res)\n'
    addr_func_body += 'return (res=res)\n'

    args = ', '.join(arg.identifier.name for arg in elm.arguments.identifiers)

    read_func_body = f'let (storage_addr) = addr({args})\n'
    for i in range(var_size):
        read_func_body += \
            f'let (__storage_var_temp{i}) = storage_read(address=storage_addr + {i})\n'
    # Copy the return implicit args and the return values to a contiguous segment.
    read_func_body += """
tempvar storage_ptr = storage_ptr
tempvar range_check_ptr = range_check_ptr
tempvar pedersen_ptr = pedersen_ptr
"""
    for i in range(var_size):
        read_func_body += f'tempvar __storage_var_temp{i} : felt = __storage_var_temp{i}\n'
    unresolved_return_type_ptr = TypePointer(pointee=unresolved_return_type)
    read_func_body += \
        f'return ([cast(&__storage_var_temp0, {unresolved_return_type_ptr.format()})])'

    write_func_body = f'let (storage_addr) = addr({args})\n'
    for i in range(var_size):
        write_func_body += \
            f'storage_write(address=storage_addr + {i}, value=[cast(&value, felt) + {i}])\n'
    write_func_body += 'return ()\n'
    return generate_storage_var_functions(elm,
                                          addr_func_body=addr_func_body,
                                          read_func_body=read_func_body,
                                          write_func_body=write_func_body,
                                          is_impl=True)
Example #26
0
 def handle_missing_future_definition(self, name: ScopedName, location):
     raise PreprocessorError(
         f"Identifier '{name}' not found by IdentifierCollector.",
         location=location)
    def visit_CodeElementFunction(self, elm: CodeElementFunction):
        """
        Registers the function's identifier, arguments and return values, and then recursively
        visits the code block contained in the function.
        """
        function_scope = self.current_scope + elm.name
        if elm.element_type == 'struct':
            self.add_future_identifier(function_scope, StructDefinition,
                                       elm.identifier.location)
            return

        args_scope = function_scope + CodeElementFunction.ARGUMENT_SCOPE
        implicit_args_scope = function_scope + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE
        rets_scope = function_scope + CodeElementFunction.RETURN_SCOPE

        def handle_struct_def(identifier_list: Optional[IdentifierList],
                              struct_name: ScopedName):
            location = elm.identifier.location
            if identifier_list is not None:
                location = identifier_list.location

            self.add_future_identifier(name=struct_name,
                                       identifier_type=StructDefinition,
                                       location=location)

        def handle_function_arguments(
                identifier_list: Optional[IdentifierList],
                struct_name: ScopedName):
            handle_struct_def(identifier_list=identifier_list,
                              struct_name=struct_name)
            if identifier_list is None:
                return

            for arg_id in identifier_list.identifiers:
                if arg_id.name == N_LOCALS_CONSTANT:
                    raise PreprocessorError(
                        f"The name '{N_LOCALS_CONSTANT}' is reserved and cannot be used as an "
                        'argument name.',
                        location=arg_id.location)
                # Within a function, arguments are also accessible directly.
                self.add_future_identifier(function_scope + arg_id.name,
                                           ReferenceDefinition,
                                           arg_id.location)

        handle_function_arguments(identifier_list=elm.arguments,
                                  struct_name=args_scope)
        handle_function_arguments(identifier_list=elm.implicit_arguments,
                                  struct_name=implicit_args_scope)

        handle_struct_def(identifier_list=elm.returns, struct_name=rets_scope)

        # Make sure there is no name collision.
        if elm.implicit_arguments is not None:
            implicit_arg_names = {
                arg_id.name
                for arg_id in elm.implicit_arguments.identifiers
            }
            arg_and_return_identifiers = list(elm.arguments.identifiers)
            if elm.returns is not None:
                arg_and_return_identifiers += elm.returns.identifiers

            for arg_id in arg_and_return_identifiers:
                if arg_id.name in implicit_arg_names:
                    raise PreprocessorError(
                        'Arguments and return values cannot have the same name of an implicit '
                        'argument.',
                        location=arg_id.location)

        self.add_future_identifier(function_scope, LabelDefinition,
                                   elm.identifier.location)

        # Add SIZEOF_LOCALS for current block at identifier definition location if available.
        self.add_future_identifier(function_scope + N_LOCALS_CONSTANT,
                                   ConstDefinition, elm.identifier.location)
        super().visit_CodeElementFunction(elm)
    def create_func_wrapper(self, elm: CodeElementFunction, func_alias_name: str):
        """
        Generates a wrapper that converts between the StarkNet contract ABI and the
        Cairo calling convention.

        Arguments:
        elm - the CodeElementFunction of the wrapped function.
        func_alias_name - an alias for the FunctionDefention in the current scope.
        """

        os_context = self.get_os_context()

        func_location = elm.identifier.location
        assert func_location is not None

        # We expect the call stack to look as follows:
        # pointer to os_context struct.
        # calldata size.
        # pointer to the call data array.
        # ret_fp.
        # ret_pc.
        os_context_ptr = ExprDeref(
            addr=ExprOperator(
                ExprReg(reg=Register.FP, location=func_location),
                '+',
                ExprConst(-5, location=func_location),
                location=func_location),
            location=func_location)

        calldata_size = ExprDeref(
            addr=ExprOperator(
                ExprReg(reg=Register.FP, location=func_location),
                '+',
                ExprConst(-4, location=func_location),
                location=func_location),
            location=func_location)

        calldata_ptr = ExprDeref(
            addr=ExprOperator(
                ExprReg(reg=Register.FP, location=func_location),
                '+',
                ExprConst(-3, location=func_location),
                location=func_location),
            location=func_location)

        implicit_arguments = None

        implicit_arguments_identifiers: Dict[str, TypedIdentifier] = {}
        if elm.implicit_arguments is not None:
            args = []
            for typed_identifier in elm.implicit_arguments.identifiers:
                ptr_name = typed_identifier.name
                if ptr_name not in os_context:
                    raise PreprocessorError(
                        f"Unexpected implicit argument '{ptr_name}' in an external function.",
                        location=typed_identifier.identifier.location)

                implicit_arguments_identifiers[ptr_name] = typed_identifier

                # Add the assignment expression 'ptr_name = ptr_name' to the implicit arg list.
                args.append(ExprAssignment(
                    identifier=typed_identifier.identifier,
                    expr=typed_identifier.identifier,
                    location=typed_identifier.location,
                ))

            implicit_arguments = ArgList(
                args=args, notes=[], has_trailing_comma=True,
                location=elm.implicit_arguments.location)

        return_args_exprs: List[Expression] = []

        # Create references.
        for ptr_name, index in os_context.items():
            ref_name = self.current_scope + ptr_name

            arg_identifier = implicit_arguments_identifiers.get(ptr_name)
            if arg_identifier is None:
                location: Optional[Location] = func_location
                cairo_type: CairoType = TypeFelt(location=location)
            else:
                location = arg_identifier.location
                cairo_type = self.resolve_type(arg_identifier.get_type())

            # Add a reference of the form
            # 'let ref_name = [cast(os_context_ptr + index, cairo_type*)]'.
            self.add_reference(
                name=ref_name,
                value=ExprDeref(
                    addr=ExprCast(
                        ExprOperator(
                            os_context_ptr, '+', ExprConst(index, location=location),
                            location=location),
                        dest_type=TypePointer(pointee=cairo_type, location=cairo_type.location),
                        location=cairo_type.location),
                    location=location),
                cairo_type=cairo_type,
                location=location,
                require_future_definition=False)

            assert index == len(return_args_exprs), 'Unexpected index.'

            return_args_exprs.append(ExprIdentifier(name=ptr_name, location=func_location))

        arg_struct_def = self.get_struct_definition(
            name=ScopedName.from_string(func_alias_name) + CodeElementFunction.ARGUMENT_SCOPE,
            location=func_location)

        code_elements, call_args = process_calldata(
            calldata_ptr=calldata_ptr,
            calldata_size=calldata_size,
            identifiers=self.identifiers,
            struct_def=arg_struct_def,
            has_range_check_builtin='range_check_ptr' in os_context,
            location=func_location,
        )

        for code_element in code_elements:
            self.visit(code_element)

        self.visit(CodeElementFuncCall(
            func_call=RvalueFuncCall(
                func_ident=ExprIdentifier(name=func_alias_name, location=func_location),
                arguments=call_args,
                implicit_arguments=implicit_arguments,
                location=func_location)))

        ret_struct_name = ScopedName.from_string(func_alias_name) + CodeElementFunction.RETURN_SCOPE
        ret_struct_type = self.resolve_type(TypeStruct(ret_struct_name, False))
        ret_struct_def = self.get_struct_definition(
            name=ret_struct_name,
            location=func_location)
        ret_struct_expr = create_simple_ref_expr(
            reg=Register.AP, offset=-ret_struct_def.size, cairo_type=ret_struct_type,
            location=func_location)
        self.add_reference(
            name=self.current_scope + 'ret_struct',
            value=ret_struct_expr,
            cairo_type=ret_struct_type,
            require_future_definition=False,
            location=func_location)

        # Add function return values.
        retdata_size, retdata_ptr = self.process_retdata(
            ret_struct_ptr=ExprIdentifier(name='ret_struct'),
            ret_struct_type=ret_struct_type, struct_def=ret_struct_def,
            location=func_location,
        )
        return_args_exprs += [retdata_size, retdata_ptr]

        # Push the return values.
        self.push_compound_expressions(
            compound_expressions=[self.simplify_expr_as_felt(expr) for expr in return_args_exprs],
            location=func_location,
        )

        # Add a ret instruction.
        self.visit(CodeElementInstruction(
            instruction=InstructionAst(
                body=RetInstruction(),
                inc_ap=False,
                location=func_location)))

        # Add an entry to the ABI.
        external_decorator = self.get_external_decorator(elm)
        assert external_decorator is not None
        is_view = external_decorator.name == 'view'

        if external_decorator.name == L1_HANDLER_DECORATOR:
            entry_type = 'l1_handler'
        elif external_decorator.name in [EXTERNAL_DECORATOR, VIEW_DECORATOR]:
            entry_type = 'function'
        else:
            raise NotImplementedError(f'Unsupported decorator {external_decorator.name}')

        entry_type = (
            'function' if external_decorator.name != L1_HANDLER_DECORATOR else L1_HANDLER_DECORATOR)
        self.add_abi_entry(
            name=elm.name, arg_struct_def=arg_struct_def, ret_struct_def=ret_struct_def,
            is_view=is_view, entry_type=entry_type)
 def rewrite_ExprPow(self, expr: ExprReg, sim: SimplicityLevel):
     raise PreprocessorError(
         "Operator '**' is only supported for constant values.",
         location=expr.location)
def process_calldata(calldata_ptr: Expression, calldata_size: Expression,
                     identifiers: IdentifierManager,
                     struct_def: StructDefinition,
                     has_range_check_builtin: bool,
                     location: Location) -> Tuple[List[CodeElement], ArgList]:
    """
    Processes the calldata.

    Returns the expected size of the calldata and an ArgList that corresponds to 'struct_def'.

    Currently only the trivial case where struct consists only of felts is supported.
    """
    def parse_code_element(code: str, parent_location: ParentLocation):
        filename = f'autogen/starknet/arg_parser/{hashlib.sha256(code.encode()).hexdigest()}.cairo'
        return parse(
            filename=filename,
            code=code,
            code_type='code_element',
            expected_type=CodeElement,
            parser_context=ParserContext(parent_location=parent_location))

    struct_parent_location = (location, 'While handling calldata of')
    code_elements = [
        parse_code_element(
            f'let __calldata_ptr : felt* = cast({calldata_ptr.format()}, felt*)',
            parent_location=struct_parent_location)
    ]
    args = []
    prev_member: Optional[Tuple[str, MemberDefinition]] = None
    for member_name, member_def in struct_def.members.items():
        member_location = member_def.location
        assert member_location is not None
        member_parent_location = (
            member_location,
            f"While handling calldata argument '{member_name}'")
        cairo_type = member_def.cairo_type
        if isinstance(cairo_type, TypePointer) and isinstance(
                cairo_type.pointee, TypeFelt):
            has_len = prev_member is not None and prev_member[0] == f'{member_name}_len' and \
                isinstance(prev_member[1].cairo_type, TypeFelt)
            if not has_len:
                raise PreprocessorError(
                    f'Array argument "{member_name}" must be preceeded by a length argument '
                    f'named "{member_name}_len" of type felt.',
                    location=member_location)
            if not has_range_check_builtin:
                raise PreprocessorError(
                    "The 'range_check' builtin must be declared in the '%builtins' directive "
                    'when using array arguments in external functions.',
                    location=member_location)

            code_element_strs = [
                # Check that the length is positive.
                f'assert [range_check_ptr] = __calldata_arg_{member_name}_len',
                f'let range_check_ptr = range_check_ptr + 1',
                # Create the reference.
                f'let __calldata_arg_{member_name} : felt* = __calldata_ptr',
                # Use 'tempvar' instead of 'let' to avoid repeating this computation for the
                # following arguments.
                f'tempvar __calldata_ptr = __calldata_ptr + __calldata_arg_{member_name}_len',
            ]
            for code_element_str in code_element_strs:
                code_elements.append(
                    parse_code_element(code_element_str,
                                       parent_location=member_parent_location))
        elif isinstance(cairo_type, TypeFelt):
            code_elements.append(
                parse_code_element(
                    f'let __calldata_arg_{member_name} = [__calldata_ptr]',
                    parent_location=member_parent_location))
            code_elements.append(
                parse_code_element(f'let __calldata_ptr = __calldata_ptr + 1',
                                   parent_location=member_parent_location))
        else:
            raise PreprocessorError(
                f'Unsupported argument type {cairo_type.format()}.',
                location=cairo_type.location)

        args.append(
            ExprAssignment(identifier=ExprIdentifier(name=member_name,
                                                     location=member_location),
                           expr=ExprIdentifier(
                               name=f'__calldata_arg_{member_name}',
                               location=member_location),
                           location=member_location))

        prev_member = member_name, member_def

    code_elements.append(
        parse_code_element(
            f'let __calldata_actual_size =  __calldata_ptr - cast({calldata_ptr.format()}, felt*)',
            parent_location=struct_parent_location))
    code_elements.append(
        parse_code_element(
            f'assert {calldata_size.format()} = __calldata_actual_size',
            parent_location=struct_parent_location))

    return code_elements, ArgList(args=args,
                                  notes=[Notes()] * (len(args) + 1),
                                  has_trailing_comma=True,
                                  location=location)