def process_retdata(
            self, ret_struct_ptr: Expression, ret_struct_type: CairoType,
            struct_def: StructDefinition,
            location: Optional[Location]) -> Tuple[Expression, Expression]:
        """
        Processes the return values and return retdata_size and retdata_ptr.
        """

        # Verify all of the return types are felts.
        for _, member_def in struct_def.members.items():
            cairo_type = member_def.cairo_type
            if not isinstance(cairo_type, TypeFelt):
                raise PreprocessorError(
                    f'Unsupported argument type {cairo_type.format()}.',
                    location=cairo_type.location)

        self.add_reference(
            name=self.current_scope + 'retdata_ptr',
            value=ExprDeref(
                addr=ExprReg(reg=Register.AP),
                location=location,
            ),
            cairo_type=TypePointer(TypeFelt()),
            require_future_definition=False,
            location=location)

        self.visit(CodeElementHint(
            hint=ExprHint(
                hint_code='memory[ap] = segments.add()',
                n_prefix_newlines=0,
                location=location,
            ),
            location=location,
        ))

        # Skip check of hint whitelist as it fails before the workaround below.
        super().visit_CodeElementInstruction(CodeElementInstruction(InstructionAst(
            body=AddApInstruction(ExprConst(1)),
            inc_ap=False,
            location=location)))

        # Remove the references from the last instruction's flow tracking as they are
        # not needed by the hint and they cause the hint whitelist to fail.
        assert len(self.instructions[-1].hints) == 1
        hint, hint_flow_tracking_data = self.instructions[-1].hints[0]
        self.instructions[-1].hints[0] = hint, dataclasses.replace(
            hint_flow_tracking_data, reference_ids={})
        self.visit(CodeElementCompoundAssertEq(
            ExprDeref(
                ExprCast(ExprIdentifier('retdata_ptr'), TypePointer(ret_struct_type))),
            ret_struct_ptr))

        return (ExprConst(struct_def.size), ExprIdentifier('retdata_ptr'))
Beispiel #2
0
    def handle_ReferenceDefinition(self, name: str,
                                   identifier: ReferenceDefinition,
                                   scope: ScopedName,
                                   set_value: Optional[MaybeRelocatable]):
        # In set mode, take the address of the given reference instead.
        reference = self._context.flow_tracking_data.resolve_reference(
            reference_manager=self._context.reference_manager,
            name=identifier.full_name)

        if set_value is None:
            expr = reference.eval(self._context.flow_tracking_data.ap_tracking)
            expr, expr_type = simplify_type_system(expr)
            if isinstance(expr_type, TypeStruct):
                # If the reference is of type T, take its address and treat it as T*.
                assert isinstance(expr, ExprDeref), \
                    f"Expected expression of type '{expr_type.format()}' to have an address."
                expr = expr.addr
                expr_type = TypePointer(pointee=expr_type)
            val = self._context.evaluator(expr)

            # Check if the type is felt* or any_type**.
            is_pointer_to_felt_or_pointer = (
                isinstance(expr_type, TypePointer)
                and isinstance(expr_type.pointee, (TypePointer, TypeFelt)))
            if isinstance(expr_type,
                          TypeFelt) or is_pointer_to_felt_or_pointer:
                return val
            else:
                # Typed reference, return VmConstsReference which allows accessing members.
                assert isinstance(expr_type, TypePointer) and \
                    isinstance(expr_type.pointee, TypeStruct), \
                    'Type must be of the form T*.'
                return VmConstsReference(
                    context=self._context,
                    accessible_scopes=[expr_type.pointee.scope],
                    reference_value=val,
                    add_addr_var=True)
        else:
            assert str(scope[-1:]) == name, 'Expecting scope to end with name.'
            value, value_type = simplify_type_system(reference.value)
            assert isinstance(value, ExprDeref), f"""\
{scope} (= {value.format()}) does not reference memory and cannot be assigned."""

            value_ref = Reference(
                pc=reference.pc,
                value=ExprCast(expr=value.addr, dest_type=value_type),
                ap_tracking_data=reference.ap_tracking_data,
            )

            addr = self._context.evaluator(
                value_ref.eval(self._context.flow_tracking_data.ap_tracking))
            self._context.memory[addr] = set_value
def create_simple_ref_expr(reg: Register, offset: int, cairo_type: CairoType,
                           location: Optional[Location]) -> Expression:
    """
    Creates an expression of the form '[cast(reg + offset, cairo_type*)]'.
    """
    return ExprDeref(addr=ExprCast(expr=ExprOperator(
        a=ExprReg(reg=reg, location=location),
        op='+',
        b=ExprConst(val=offset, location=location),
        location=location),
                                   dest_type=TypePointer(pointee=cairo_type,
                                                         location=location),
                                   location=location),
                     location=location)
    def visit_ExprFuncCall(self, expr: ExprFuncCall):
        # Convert ExprFuncCall to ExprCast.
        rvalue = expr.rvalue
        if rvalue.implicit_arguments is not None:
            raise CairoTypeError(
                'Implicit arguments cannot be used with struct constructors.',
                location=rvalue.implicit_arguments.location)

        struct_type = self.resolve_type_callback(
            TypeStruct(scope=ScopedName.from_string(rvalue.func_ident.name),
                       is_fully_resolved=False,
                       location=expr.location))

        return self.visit(
            ExprCast(expr=ExprTuple(rvalue.arguments, location=expr.location),
                     dest_type=struct_type,
                     location=expr.location))
    def eval(
            self, reference_manager: ReferenceManager, flow_tracking_data: FlowTrackingData) -> \
            Expression:
        reference = flow_tracking_data.resolve_reference(
            reference_manager=reference_manager,
            name=self.parent.full_name)
        assert isinstance(flow_tracking_data, FlowTrackingDataActual), \
            'Resolved references can only come from FlowTrackingDataActual.'
        expr, expr_type = simplify_type_system(reference.eval(flow_tracking_data.ap_tracking))
        for member_name in self.member_path.path:
            if isinstance(expr_type, TypeStruct):
                expr_type = expr_type.get_pointer_type()
                # In this case, take the address of the reference value.
                to_addr = lambda expr: ExprAddressOf(expr=expr)
            else:
                to_addr = lambda expr: expr

            if not isinstance(expr_type, TypePointer) or \
                    not isinstance(expr_type.pointee, TypeStruct):
                raise DefinitionError('Member access requires a type of the form Struct*.')

            qualified_member = expr_type.pointee.resolved_scope + member_name
            if qualified_member not in self.identifier_values:
                raise DefinitionError(f"Member '{qualified_member}' was not found.")
            member_definition = self.identifier_values[qualified_member]

            if not isinstance(member_definition, MemberDefinition):
                raise DefinitionError(
                    f"Expected reference offset '{qualified_member}' to be a member, "
                    f'found {member_definition.TYPE}.')
            offset_value = member_definition.offset
            expr_type = member_definition.cairo_type

            expr = ExprDeref(addr=ExprOperator(a=to_addr(expr), op='+', b=ExprConst(offset_value)))

        return ExprCast(
            expr=expr,
            dest_type=expr_type,
        )
 def atom_cast(self, value, meta):
     return ExprCast(expr=value[1],
                     notes=value[0],
                     dest_type=value[2],
                     location=self.meta2loc(meta))
 def visit_ExprCast(self, expr: ExprCast):
     inner_expr = self.visit(expr.expr)
     return ExprCast(expr=inner_expr,
                     dest_type=expr.dest_type,
                     cast_type=expr.cast_type,
                     location=self.location_modifier(expr.location))
 def visit_ExprCast(self, expr: ExprCast):
     return ExprCast(expr=self.visit(expr.expr),
                     dest_type=self.resolve_type_callback(expr.dest_type),
                     cast_type=expr.cast_type,
                     notes=expr.notes,
                     location=expr.location)
    def create_func_wrapper(self, elm: CodeElementFunction, func_alias_name: str):
        """
        Generates a wrapper that converts between the StarkNet contract ABI and the
        Cairo calling convention.

        Arguments:
        elm - the CodeElementFunction of the wrapped function.
        func_alias_name - an alias for the FunctionDefention in the current scope.
        """

        os_context = self.get_os_context()

        func_location = elm.identifier.location
        assert func_location is not None

        # We expect the call stack to look as follows:
        # pointer to os_context struct.
        # calldata size.
        # pointer to the call data array.
        # ret_fp.
        # ret_pc.
        os_context_ptr = ExprDeref(
            addr=ExprOperator(
                ExprReg(reg=Register.FP, location=func_location),
                '+',
                ExprConst(-5, location=func_location),
                location=func_location),
            location=func_location)

        calldata_size = ExprDeref(
            addr=ExprOperator(
                ExprReg(reg=Register.FP, location=func_location),
                '+',
                ExprConst(-4, location=func_location),
                location=func_location),
            location=func_location)

        calldata_ptr = ExprDeref(
            addr=ExprOperator(
                ExprReg(reg=Register.FP, location=func_location),
                '+',
                ExprConst(-3, location=func_location),
                location=func_location),
            location=func_location)

        implicit_arguments = None

        implicit_arguments_identifiers: Dict[str, TypedIdentifier] = {}
        if elm.implicit_arguments is not None:
            args = []
            for typed_identifier in elm.implicit_arguments.identifiers:
                ptr_name = typed_identifier.name
                if ptr_name not in os_context:
                    raise PreprocessorError(
                        f"Unexpected implicit argument '{ptr_name}' in an external function.",
                        location=typed_identifier.identifier.location)

                implicit_arguments_identifiers[ptr_name] = typed_identifier

                # Add the assignment expression 'ptr_name = ptr_name' to the implicit arg list.
                args.append(ExprAssignment(
                    identifier=typed_identifier.identifier,
                    expr=typed_identifier.identifier,
                    location=typed_identifier.location,
                ))

            implicit_arguments = ArgList(
                args=args, notes=[], has_trailing_comma=True,
                location=elm.implicit_arguments.location)

        return_args_exprs: List[Expression] = []

        # Create references.
        for ptr_name, index in os_context.items():
            ref_name = self.current_scope + ptr_name

            arg_identifier = implicit_arguments_identifiers.get(ptr_name)
            if arg_identifier is None:
                location: Optional[Location] = func_location
                cairo_type: CairoType = TypeFelt(location=location)
            else:
                location = arg_identifier.location
                cairo_type = self.resolve_type(arg_identifier.get_type())

            # Add a reference of the form
            # 'let ref_name = [cast(os_context_ptr + index, cairo_type*)]'.
            self.add_reference(
                name=ref_name,
                value=ExprDeref(
                    addr=ExprCast(
                        ExprOperator(
                            os_context_ptr, '+', ExprConst(index, location=location),
                            location=location),
                        dest_type=TypePointer(pointee=cairo_type, location=cairo_type.location),
                        location=cairo_type.location),
                    location=location),
                cairo_type=cairo_type,
                location=location,
                require_future_definition=False)

            assert index == len(return_args_exprs), 'Unexpected index.'

            return_args_exprs.append(ExprIdentifier(name=ptr_name, location=func_location))

        arg_struct_def = self.get_struct_definition(
            name=ScopedName.from_string(func_alias_name) + CodeElementFunction.ARGUMENT_SCOPE,
            location=func_location)

        code_elements, call_args = process_calldata(
            calldata_ptr=calldata_ptr,
            calldata_size=calldata_size,
            identifiers=self.identifiers,
            struct_def=arg_struct_def,
            has_range_check_builtin='range_check_ptr' in os_context,
            location=func_location,
        )

        for code_element in code_elements:
            self.visit(code_element)

        self.visit(CodeElementFuncCall(
            func_call=RvalueFuncCall(
                func_ident=ExprIdentifier(name=func_alias_name, location=func_location),
                arguments=call_args,
                implicit_arguments=implicit_arguments,
                location=func_location)))

        ret_struct_name = ScopedName.from_string(func_alias_name) + CodeElementFunction.RETURN_SCOPE
        ret_struct_type = self.resolve_type(TypeStruct(ret_struct_name, False))
        ret_struct_def = self.get_struct_definition(
            name=ret_struct_name,
            location=func_location)
        ret_struct_expr = create_simple_ref_expr(
            reg=Register.AP, offset=-ret_struct_def.size, cairo_type=ret_struct_type,
            location=func_location)
        self.add_reference(
            name=self.current_scope + 'ret_struct',
            value=ret_struct_expr,
            cairo_type=ret_struct_type,
            require_future_definition=False,
            location=func_location)

        # Add function return values.
        retdata_size, retdata_ptr = self.process_retdata(
            ret_struct_ptr=ExprIdentifier(name='ret_struct'),
            ret_struct_type=ret_struct_type, struct_def=ret_struct_def,
            location=func_location,
        )
        return_args_exprs += [retdata_size, retdata_ptr]

        # Push the return values.
        self.push_compound_expressions(
            compound_expressions=[self.simplify_expr_as_felt(expr) for expr in return_args_exprs],
            location=func_location,
        )

        # Add a ret instruction.
        self.visit(CodeElementInstruction(
            instruction=InstructionAst(
                body=RetInstruction(),
                inc_ap=False,
                location=func_location)))

        # Add an entry to the ABI.
        external_decorator = self.get_external_decorator(elm)
        assert external_decorator is not None
        is_view = external_decorator.name == 'view'

        if external_decorator.name == L1_HANDLER_DECORATOR:
            entry_type = 'l1_handler'
        elif external_decorator.name in [EXTERNAL_DECORATOR, VIEW_DECORATOR]:
            entry_type = 'function'
        else:
            raise NotImplementedError(f'Unsupported decorator {external_decorator.name}')

        entry_type = (
            'function' if external_decorator.name != L1_HANDLER_DECORATOR else L1_HANDLER_DECORATOR)
        self.add_abi_entry(
            name=elm.name, arg_struct_def=arg_struct_def, ret_struct_def=ret_struct_def,
            is_view=is_view, entry_type=entry_type)