示例#1
0
def test_import():
    # Test module names without periods.
    res = parse_code_element('from   a    import  b')
    assert res == CodeElementImport(
        path=ExprIdentifier(name='a'),
        import_items=[
            AliasedIdentifier(orig_identifier=ExprIdentifier(name='b'),
                              local_name=None)
        ])
    assert res.format(allowed_line_length=100) == 'from a import b'

    # Test module names without periods, with aliasing.
    res = parse_code_element('from   a    import  b   as   c')
    assert res == CodeElementImport(
        path=ExprIdentifier(name='a'),
        import_items=[
            AliasedIdentifier(orig_identifier=ExprIdentifier(name='b'),
                              local_name=ExprIdentifier(name='c'))
        ])
    assert res.format(allowed_line_length=100) == 'from a import b as c'

    # Test module names with periods.
    res = parse_code_element('from   a.b12.c4    import  lib345')
    assert res == CodeElementImport(
        path=ExprIdentifier(name='a.b12.c4'),
        import_items=[
            AliasedIdentifier(orig_identifier=ExprIdentifier(name='lib345'),
                              local_name=None)
        ])
    assert res.format(allowed_line_length=100) == 'from a.b12.c4 import lib345'

    # Test multiple imports.
    res = parse_code_element('from  lib    import  a,b as    b2,   c')

    assert res == CodeElementImport(
        path=ExprIdentifier(name='lib'),
        import_items=[
            AliasedIdentifier(orig_identifier=ExprIdentifier(name='a'),
                              local_name=None),
            AliasedIdentifier(orig_identifier=ExprIdentifier(name='b'),
                              local_name=ExprIdentifier(name='b2')),
            AliasedIdentifier(orig_identifier=ExprIdentifier(name='c'),
                              local_name=None),
        ])
    assert res.format(
        allowed_line_length=100) == 'from lib import a, b as b2, c'
    assert res.format(
        allowed_line_length=20) == 'from lib import (\n    a, b as b2, c)'

    assert res == parse_code_element('from lib import (\n    a, b as b2, c)')

    # Test module with bad identifier (with periods).
    with pytest.raises(ParserError):
        parse_expr('from a.b import c.d')

    # Test module with bad local name (with periods).
    with pytest.raises(ParserError):
        parse_expr('from a.b import c as d.d')
def process_calldata(calldata_ptr: Expression, calldata_size: Expression,
                     identifiers: IdentifierManager,
                     struct_def: StructDefinition,
                     has_range_check_builtin: bool,
                     location: Location) -> Tuple[List[CodeElement], ArgList]:
    """
    Processes the calldata.

    Returns the expected size of the calldata and an ArgList that corresponds to 'struct_def'.

    Currently only the trivial case where struct consists only of felts is supported.
    """
    def parse_code_element(code: str, parent_location: ParentLocation):
        filename = f'autogen/starknet/arg_parser/{hashlib.sha256(code.encode()).hexdigest()}.cairo'
        return parse(
            filename=filename,
            code=code,
            code_type='code_element',
            expected_type=CodeElement,
            parser_context=ParserContext(parent_location=parent_location))

    struct_parent_location = (location, 'While handling calldata of')
    code_elements = [
        parse_code_element(
            f'let __calldata_ptr : felt* = cast({calldata_ptr.format()}, felt*)',
            parent_location=struct_parent_location)
    ]
    args = []
    prev_member: Optional[Tuple[str, MemberDefinition]] = None
    for member_name, member_def in struct_def.members.items():
        member_location = member_def.location
        assert member_location is not None
        member_parent_location = (
            member_location,
            f"While handling calldata argument '{member_name}'")
        cairo_type = member_def.cairo_type
        if isinstance(cairo_type, TypePointer) and isinstance(
                cairo_type.pointee, TypeFelt):
            has_len = prev_member is not None and prev_member[0] == f'{member_name}_len' and \
                isinstance(prev_member[1].cairo_type, TypeFelt)
            if not has_len:
                raise PreprocessorError(
                    f'Array argument "{member_name}" must be preceeded by a length argument '
                    f'named "{member_name}_len" of type felt.',
                    location=member_location)
            if not has_range_check_builtin:
                raise PreprocessorError(
                    "The 'range_check' builtin must be declared in the '%builtins' directive "
                    'when using array arguments in external functions.',
                    location=member_location)

            code_element_strs = [
                # Check that the length is positive.
                f'assert [range_check_ptr] = __calldata_arg_{member_name}_len',
                f'let range_check_ptr = range_check_ptr + 1',
                # Create the reference.
                f'let __calldata_arg_{member_name} : felt* = __calldata_ptr',
                # Use 'tempvar' instead of 'let' to avoid repeating this computation for the
                # following arguments.
                f'tempvar __calldata_ptr = __calldata_ptr + __calldata_arg_{member_name}_len',
            ]
            for code_element_str in code_element_strs:
                code_elements.append(
                    parse_code_element(code_element_str,
                                       parent_location=member_parent_location))
        elif isinstance(cairo_type, TypeFelt):
            code_elements.append(
                parse_code_element(
                    f'let __calldata_arg_{member_name} = [__calldata_ptr]',
                    parent_location=member_parent_location))
            code_elements.append(
                parse_code_element(f'let __calldata_ptr = __calldata_ptr + 1',
                                   parent_location=member_parent_location))
        else:
            raise PreprocessorError(
                f'Unsupported argument type {cairo_type.format()}.',
                location=cairo_type.location)

        args.append(
            ExprAssignment(identifier=ExprIdentifier(name=member_name,
                                                     location=member_location),
                           expr=ExprIdentifier(
                               name=f'__calldata_arg_{member_name}',
                               location=member_location),
                           location=member_location))

        prev_member = member_name, member_def

    code_elements.append(
        parse_code_element(
            f'let __calldata_actual_size =  __calldata_ptr - cast({calldata_ptr.format()}, felt*)',
            parent_location=struct_parent_location))
    code_elements.append(
        parse_code_element(
            f'assert {calldata_size.format()} = __calldata_actual_size',
            parent_location=struct_parent_location))

    return code_elements, ArgList(args=args,
                                  notes=[Notes()] * (len(args) + 1),
                                  has_trailing_comma=True,
                                  location=location)
示例#3
0
def test_instruction():
    # AssertEq.
    expr = parse_instruction('[ap] = [fp]; ap++')
    assert expr == \
        InstructionAst(
            body=AssertEqInstruction(
                a=ExprDeref(
                    addr=ExprReg(reg=Register.AP)),
                b=ExprDeref(
                    addr=ExprReg(reg=Register.FP))),
            inc_ap=True)
    assert expr.format() == '[ap] = [fp]; ap++'
    assert parse_instruction(
        '[ap+5] = [fp]+[ap] - 5').format() == '[ap + 5] = [fp] + [ap] - 5'
    assert parse_instruction('[ap+5]+3= [fp]*7;ap  ++ ').format() == \
        '[ap + 5] + 3 = [fp] * 7; ap++'

    # Jump.
    expr = parse_instruction('jmp rel [ap] + x; ap++')
    assert expr == \
        InstructionAst(
            body=JumpInstruction(
                val=ExprOperator(
                    a=ExprDeref(addr=ExprReg(reg=Register.AP)),
                    op='+',
                    b=ExprIdentifier(name='x')),
                relative=True),
            inc_ap=True)
    assert expr.format() == 'jmp rel [ap] + x; ap++'
    assert parse_instruction(' jmp   abs[ap]+x').format() == 'jmp abs [ap] + x'
    # Make sure the following are not OK.
    with pytest.raises(ParserError):
        parse_instruction('jmp abs')
    with pytest.raises(ParserError):
        parse_instruction('jmpabs[ap]')

    # JumpToLabel.
    expr = parse_instruction('jmp label')
    assert expr == \
        InstructionAst(
            body=JumpToLabelInstruction(
                label=ExprIdentifier(name='label'),
                condition=None),
            inc_ap=False)
    assert expr.format() == 'jmp label'
    # Make sure the following are not OK.
    with pytest.raises(ParserError):
        parse_instruction('jmp [fp]')
    with pytest.raises(ParserError):
        parse_instruction('jmp 7')

    # Jnz.
    expr = parse_instruction('jmp rel [ap] + x if [fp + 3] != 0')
    assert expr == \
        InstructionAst(
            body=JnzInstruction(
                jump_offset=ExprOperator(
                    a=ExprDeref(addr=ExprReg(reg=Register.AP)),
                    op='+',
                    b=ExprIdentifier(name='x')),
                condition=ExprDeref(
                    addr=ExprOperator(
                        a=ExprReg(reg=Register.FP),
                        op='+',
                        b=ExprConst(val=3)))),
            inc_ap=False)
    assert expr.format() == 'jmp rel [ap] + x if [fp + 3] != 0'
    assert parse_instruction(' jmp   rel 17  if[fp]!=0;ap++').format() == \
        'jmp rel 17 if [fp] != 0; ap++'
    # Make sure the following are not OK.
    with pytest.raises(ParserError):
        parse_instruction('jmprel 17 if x != 0')
    with pytest.raises(ParserError):
        parse_instruction('jmp 17 if x')
    with pytest.raises(ParserError, match='!= 0'):
        parse_instruction('jmp rel 17 if x != 2')
    with pytest.raises(ParserError):
        parse_instruction('jmp rel [fp] ifx != 0')

    # Jnz to label.
    expr = parse_instruction('jmp label if [fp] != 0')
    assert expr == \
        InstructionAst(
            body=JumpToLabelInstruction(
                label=ExprIdentifier('label'),
                condition=ExprDeref(addr=ExprReg(reg=Register.FP))),
            inc_ap=False)
    assert expr.format() == 'jmp label if [fp] != 0'
    # Make sure the following are not OK.
    with pytest.raises(ParserError):
        parse_instruction('jmp [fp] if [fp] != 0')
    with pytest.raises(ParserError):
        parse_instruction('jmp 7 if [fp] != 0')

    # Call abs.
    expr = parse_instruction('call abs [fp] + x')
    assert expr == \
        InstructionAst(
            body=CallInstruction(
                val=ExprOperator(
                    a=ExprDeref(addr=ExprReg(reg=Register.FP)),
                    op='+',
                    b=ExprIdentifier(name='x')),
                relative=False),
            inc_ap=False)
    assert expr.format() == 'call abs [fp] + x'
    assert parse_instruction(
        'call   abs   17;ap++').format() == 'call abs 17; ap++'
    # Make sure the following are not OK.
    with pytest.raises(ParserError):
        parse_instruction('call abs')
    with pytest.raises(ParserError):
        parse_instruction('callabs 7')

    # Call rel.
    expr = parse_instruction('call rel [ap] + x')
    assert expr == \
        InstructionAst(
            body=CallInstruction(
                val=ExprOperator(
                    a=ExprDeref(addr=ExprReg(reg=Register.AP)),
                    op='+',
                    b=ExprIdentifier(name='x')),
                relative=True),
            inc_ap=False)
    assert expr.format() == 'call rel [ap] + x'
    assert parse_instruction(
        'call   rel   17;ap++').format() == 'call rel 17; ap++'
    # Make sure the following are not OK.
    with pytest.raises(ParserError):
        parse_instruction('call rel')
    with pytest.raises(ParserError):
        parse_instruction('callrel 7')

    # Call label.
    expr = parse_instruction('call label')
    assert expr == \
        InstructionAst(
            body=CallLabelInstruction(
                label=ExprIdentifier(name='label')),
            inc_ap=False)
    assert expr.format() == 'call label'
    assert parse_instruction(
        'call   label ;ap++').format() == 'call label; ap++'
    # Make sure the following are not OK.
    with pytest.raises(ParserError):
        parse_instruction('call [fp]')
    with pytest.raises(ParserError):
        parse_instruction('call 7')

    # Ret.
    expr = parse_instruction('ret')
    assert expr == \
        InstructionAst(
            body=RetInstruction(),
            inc_ap=False)
    assert expr.format() == 'ret'

    # AddAp.
    expr = parse_instruction('ap += [fp] + 2')
    assert expr == \
        InstructionAst(
            body=AddApInstruction(
                expr=ExprOperator(
                    a=ExprDeref(
                        addr=ExprReg(reg=Register.FP)),
                    op='+',
                    b=ExprConst(val=2))),
            inc_ap=False)
    assert expr.format() == 'ap += [fp] + 2'
    assert parse_instruction('ap  +=[ fp]+   2').format() == 'ap += [fp] + 2'
    assert parse_instruction(
        'ap  +=[ fp]+   2;ap ++').format() == 'ap += [fp] + 2; ap++'
    def create_func_wrapper(self, elm: CodeElementFunction, func_alias_name: str):
        """
        Generates a wrapper that converts between the StarkNet contract ABI and the
        Cairo calling convention.

        Arguments:
        elm - the CodeElementFunction of the wrapped function.
        func_alias_name - an alias for the FunctionDefention in the current scope.
        """

        os_context = self.get_os_context()

        func_location = elm.identifier.location
        assert func_location is not None

        # We expect the call stack to look as follows:
        # pointer to os_context struct.
        # calldata size.
        # pointer to the call data array.
        # ret_fp.
        # ret_pc.
        os_context_ptr = ExprDeref(
            addr=ExprOperator(
                ExprReg(reg=Register.FP, location=func_location),
                '+',
                ExprConst(-5, location=func_location),
                location=func_location),
            location=func_location)

        calldata_size = ExprDeref(
            addr=ExprOperator(
                ExprReg(reg=Register.FP, location=func_location),
                '+',
                ExprConst(-4, location=func_location),
                location=func_location),
            location=func_location)

        calldata_ptr = ExprDeref(
            addr=ExprOperator(
                ExprReg(reg=Register.FP, location=func_location),
                '+',
                ExprConst(-3, location=func_location),
                location=func_location),
            location=func_location)

        implicit_arguments = None

        implicit_arguments_identifiers: Dict[str, TypedIdentifier] = {}
        if elm.implicit_arguments is not None:
            args = []
            for typed_identifier in elm.implicit_arguments.identifiers:
                ptr_name = typed_identifier.name
                if ptr_name not in os_context:
                    raise PreprocessorError(
                        f"Unexpected implicit argument '{ptr_name}' in an external function.",
                        location=typed_identifier.identifier.location)

                implicit_arguments_identifiers[ptr_name] = typed_identifier

                # Add the assignment expression 'ptr_name = ptr_name' to the implicit arg list.
                args.append(ExprAssignment(
                    identifier=typed_identifier.identifier,
                    expr=typed_identifier.identifier,
                    location=typed_identifier.location,
                ))

            implicit_arguments = ArgList(
                args=args, notes=[], has_trailing_comma=True,
                location=elm.implicit_arguments.location)

        return_args_exprs: List[Expression] = []

        # Create references.
        for ptr_name, index in os_context.items():
            ref_name = self.current_scope + ptr_name

            arg_identifier = implicit_arguments_identifiers.get(ptr_name)
            if arg_identifier is None:
                location: Optional[Location] = func_location
                cairo_type: CairoType = TypeFelt(location=location)
            else:
                location = arg_identifier.location
                cairo_type = self.resolve_type(arg_identifier.get_type())

            # Add a reference of the form
            # 'let ref_name = [cast(os_context_ptr + index, cairo_type*)]'.
            self.add_reference(
                name=ref_name,
                value=ExprDeref(
                    addr=ExprCast(
                        ExprOperator(
                            os_context_ptr, '+', ExprConst(index, location=location),
                            location=location),
                        dest_type=TypePointer(pointee=cairo_type, location=cairo_type.location),
                        location=cairo_type.location),
                    location=location),
                cairo_type=cairo_type,
                location=location,
                require_future_definition=False)

            assert index == len(return_args_exprs), 'Unexpected index.'

            return_args_exprs.append(ExprIdentifier(name=ptr_name, location=func_location))

        arg_struct_def = self.get_struct_definition(
            name=ScopedName.from_string(func_alias_name) + CodeElementFunction.ARGUMENT_SCOPE,
            location=func_location)

        code_elements, call_args = process_calldata(
            calldata_ptr=calldata_ptr,
            calldata_size=calldata_size,
            identifiers=self.identifiers,
            struct_def=arg_struct_def,
            has_range_check_builtin='range_check_ptr' in os_context,
            location=func_location,
        )

        for code_element in code_elements:
            self.visit(code_element)

        self.visit(CodeElementFuncCall(
            func_call=RvalueFuncCall(
                func_ident=ExprIdentifier(name=func_alias_name, location=func_location),
                arguments=call_args,
                implicit_arguments=implicit_arguments,
                location=func_location)))

        ret_struct_name = ScopedName.from_string(func_alias_name) + CodeElementFunction.RETURN_SCOPE
        ret_struct_type = self.resolve_type(TypeStruct(ret_struct_name, False))
        ret_struct_def = self.get_struct_definition(
            name=ret_struct_name,
            location=func_location)
        ret_struct_expr = create_simple_ref_expr(
            reg=Register.AP, offset=-ret_struct_def.size, cairo_type=ret_struct_type,
            location=func_location)
        self.add_reference(
            name=self.current_scope + 'ret_struct',
            value=ret_struct_expr,
            cairo_type=ret_struct_type,
            require_future_definition=False,
            location=func_location)

        # Add function return values.
        retdata_size, retdata_ptr = self.process_retdata(
            ret_struct_ptr=ExprIdentifier(name='ret_struct'),
            ret_struct_type=ret_struct_type, struct_def=ret_struct_def,
            location=func_location,
        )
        return_args_exprs += [retdata_size, retdata_ptr]

        # Push the return values.
        self.push_compound_expressions(
            compound_expressions=[self.simplify_expr_as_felt(expr) for expr in return_args_exprs],
            location=func_location,
        )

        # Add a ret instruction.
        self.visit(CodeElementInstruction(
            instruction=InstructionAst(
                body=RetInstruction(),
                inc_ap=False,
                location=func_location)))

        # Add an entry to the ABI.
        external_decorator = self.get_external_decorator(elm)
        assert external_decorator is not None
        is_view = external_decorator.name == 'view'

        if external_decorator.name == L1_HANDLER_DECORATOR:
            entry_type = 'l1_handler'
        elif external_decorator.name in [EXTERNAL_DECORATOR, VIEW_DECORATOR]:
            entry_type = 'function'
        else:
            raise NotImplementedError(f'Unsupported decorator {external_decorator.name}')

        entry_type = (
            'function' if external_decorator.name != L1_HANDLER_DECORATOR else L1_HANDLER_DECORATOR)
        self.add_abi_entry(
            name=elm.name, arg_struct_def=arg_struct_def, ret_struct_def=ret_struct_def,
            is_view=is_view, entry_type=entry_type)
 def visit_ExprIdentifier(self, expr: ExprIdentifier):
     return ExprIdentifier(name=expr.name, location=self.location_modifier(expr.location))
 def identifier_def(self, value, meta):
     return ExprIdentifier(name=value[0].value,
                           location=self.meta2loc(meta))
 def identifier(self, value, meta):
     return ExprIdentifier(name='.'.join(x.value for x in value),
                           location=self.meta2loc(meta))