Ejemplo n.º 1
0
def test_type_visitor():
    t = TypeStruct(scope=scope('T'), is_fully_resolved=False)
    t_star = TypePointer(pointee=t)
    t_star2 = TypePointer(pointee=t_star)
    assert simplify_type_system(
        parse_expr('fp + 3 + [ap]')) == (parse_expr('fp + 3 + [ap]'),
                                         TypeFelt())
    assert simplify_type_system(
        parse_expr('cast(fp + 3 + [ap], T*)')) == (parse_expr('fp + 3 + [ap]'),
                                                   t_star)
    # Two casts.
    assert simplify_type_system(
        parse_expr('cast(cast(fp, T*), felt)')) == (parse_expr('fp'),
                                                    TypeFelt())
    # Cast from T to T.
    assert simplify_type_system(
        parse_expr('cast([cast(fp, T*)], T)')) == (parse_expr('[fp]'), t)
    # Dereference.
    assert simplify_type_system(
        parse_expr('[cast(fp, T**)]')) == (parse_expr('[fp]'), t_star)
    assert simplify_type_system(
        parse_expr('[[cast(fp, T**)]]')) == (parse_expr('[[fp]]'), t)
    # Address of.
    assert simplify_type_system(
        parse_expr('&([[cast(fp, T**)]])')) == (parse_expr('[fp]'), t_star)
    assert simplify_type_system(
        parse_expr('&&[[cast(fp, T**)]]')) == (parse_expr('fp'), t_star2)
def test_struct_collector():
    modules = {'module': """
struct S:
    member x : S*
    member y : S*
end
""", '__main__': """
from module import S

func foo{z}(a : S, b) -> (c : S):
    struct T:
        member x : S*
    end
    const X = 5
    return (c=a + X)
end
const Y = 1 + 1
"""}

    scope = ScopedName.from_string

    struct_defs = _collect_struct_definitions(modules)

    expected_def = {
        'module.S': StructDefinition(
            full_name=scope('module.S'),
            members={
                'x': MemberDefinition(offset=0, cairo_type=TypePointer(pointee=TypeStruct(
                    scope=scope('module.S'), is_fully_resolved=True))),
                'y': MemberDefinition(offset=1, cairo_type=TypePointer(pointee=TypeStruct(
                    scope=scope('module.S'), is_fully_resolved=True))),
            }, size=2),
        '__main__.S': AliasDefinition(destination=scope('module.S')),
        '__main__.foo.Args': StructDefinition(
            full_name=scope('__main__.foo.Args'),
            members={
                'a': MemberDefinition(offset=0, cairo_type=TypeStruct(
                    scope=scope('module.S'), is_fully_resolved=True)),
                'b': MemberDefinition(offset=2, cairo_type=TypeFelt()),
            }, size=3),
        '__main__.foo.ImplicitArgs': StructDefinition(
            full_name=scope('__main__.foo.ImplicitArgs'),
            members={'z': MemberDefinition(offset=0, cairo_type=TypeFelt())}, size=1),
        '__main__.foo.Return': StructDefinition(
            full_name=scope('__main__.foo.Return'),
            members={
                'c': MemberDefinition(offset=0, cairo_type=TypeStruct(
                    scope=scope('module.S'), is_fully_resolved=True))
            }, size=2),
        '__main__.foo.T': StructDefinition(
            full_name=scope('__main__.foo.T'),
            members={
                'x': MemberDefinition(offset=0, cairo_type=TypePointer(pointee=TypeStruct(
                    scope=scope('module.S'), is_fully_resolved=True))),
            }, size=1)
    }

    assert struct_defs == expected_def
    def process_retdata(
            self, ret_struct_ptr: Expression, ret_struct_type: CairoType,
            struct_def: StructDefinition,
            location: Optional[Location]) -> Tuple[Expression, Expression]:
        """
        Processes the return values and return retdata_size and retdata_ptr.
        """

        # Verify all of the return types are felts.
        for _, member_def in struct_def.members.items():
            cairo_type = member_def.cairo_type
            if not isinstance(cairo_type, TypeFelt):
                raise PreprocessorError(
                    f'Unsupported argument type {cairo_type.format()}.',
                    location=cairo_type.location)

        self.add_reference(
            name=self.current_scope + 'retdata_ptr',
            value=ExprDeref(
                addr=ExprReg(reg=Register.AP),
                location=location,
            ),
            cairo_type=TypePointer(TypeFelt()),
            require_future_definition=False,
            location=location)

        self.visit(CodeElementHint(
            hint=ExprHint(
                hint_code='memory[ap] = segments.add()',
                n_prefix_newlines=0,
                location=location,
            ),
            location=location,
        ))

        # Skip check of hint whitelist as it fails before the workaround below.
        super().visit_CodeElementInstruction(CodeElementInstruction(InstructionAst(
            body=AddApInstruction(ExprConst(1)),
            inc_ap=False,
            location=location)))

        # Remove the references from the last instruction's flow tracking as they are
        # not needed by the hint and they cause the hint whitelist to fail.
        assert len(self.instructions[-1].hints) == 1
        hint, hint_flow_tracking_data = self.instructions[-1].hints[0]
        self.instructions[-1].hints[0] = hint, dataclasses.replace(
            hint_flow_tracking_data, reference_ids={})
        self.visit(CodeElementCompoundAssertEq(
            ExprDeref(
                ExprCast(ExprIdentifier('retdata_ptr'), TypePointer(ret_struct_type))),
            ret_struct_ptr))

        return (ExprConst(struct_def.size), ExprIdentifier('retdata_ptr'))
Ejemplo n.º 4
0
def test_reference_to_structs():
    t = TypeStruct(scope=scope('T'), is_fully_resolved=True)
    t_star = TypePointer(pointee=t)
    identifier_values = {
        scope('ref'): ReferenceDefinition(full_name=scope('ref'),
                                          references=[]),
        scope('T.x'): MemberDefinition(offset=3, cairo_type=t_star),
    }
    reference_manager = ReferenceManager()
    flow_tracking_data = FlowTrackingDataActual(ap_tracking=RegTrackingData())
    flow_tracking_data = flow_tracking_data.add_reference(
        reference_manager=reference_manager,
        name=scope('ref'),
        ref=Reference(
            pc=0,
            value=mark_types_in_expr_resolved(parse_expr('cast([100], T)')),
            ap_tracking_data=RegTrackingData(group=0, offset=2),
        ))
    consts = get_vm_consts(identifier_values,
                           reference_manager,
                           flow_tracking_data,
                           memory={103: 200})

    assert consts.ref.address_ == 100
    assert consts.ref.x == 200
Ejemplo n.º 5
0
def test_offset_reference_definition_typed_members():
    t = TypeStruct(scope=scope('T'), is_fully_resolved=True)
    t_star = TypePointer(pointee=t)
    reference_manager = ReferenceManager()

    main_reference = ReferenceDefinition(full_name=scope('a'),
                                         cairo_type=t_star,
                                         references=[])
    references = {
        scope('a'):
        reference_manager.alloc_id(
            Reference(
                pc=0,
                value=mark_types_in_expr_resolved(parse_expr('cast(ap, T*)')),
                ap_tracking_data=RegTrackingData(group=0, offset=0),
            )),
    }

    flow_tracking_data = FlowTrackingDataActual(
        ap_tracking=RegTrackingData(group=0, offset=1),
        reference_ids=references,
    )

    # Create OffsetReferenceDefinition instance for an expression of the form "a.<member_path>",
    # in this case a.x.y.z, and check the result of evaluation of this expression.
    definition = OffsetReferenceDefinition(parent=main_reference,
                                           member_path=scope('x.y.z'))
    assert definition.eval(reference_manager=reference_manager,
                           flow_tracking_data=flow_tracking_data).format(
                           ) == 'cast(ap - 1, T*).x.y.z'
def test_type_visitor_pointer_arithmetic():
    t = TypeStruct(scope=scope('T'), is_fully_resolved=True)
    t_star = TypePointer(pointee=t)

    simplify_type_system_test('cast(fp, T*) + 3', 'fp + 3', t_star)
    simplify_type_system_test('cast(fp, T*) - 3', 'fp - 3', t_star)
    simplify_type_system_test('cast(fp, T*) - cast(3, T*)', 'fp - 3',
                              TypeFelt())
Ejemplo n.º 7
0
def test_type_visitor_pointer_arithmetic():
    t = TypeStruct(scope=scope('T'), is_fully_resolved=False)
    t_star = TypePointer(pointee=t)
    assert simplify_type_system(parse_expr('cast(fp, T*) + 3')) == (
        parse_expr('fp + 3'), t_star)
    assert simplify_type_system(parse_expr('cast(fp, T*) - 3')) == (
        parse_expr('fp - 3'), t_star)
    assert simplify_type_system(parse_expr('cast(fp, T*) - cast(3, T*)')) == (
        parse_expr('fp - 3'), TypeFelt())
def test_type_visitor():
    t = TypeStruct(scope=scope('T'), is_fully_resolved=True)
    t_star = TypePointer(pointee=t)
    t_star2 = TypePointer(pointee=t_star)

    simplify_type_system_test('fp + 3 + [ap]', 'fp + 3 + [ap]', TypeFelt())
    simplify_type_system_test('cast(fp + 3 + [ap], T*)', 'fp + 3 + [ap]',
                              t_star)
    # Two casts.
    simplify_type_system_test('cast(cast(fp, T*), felt)', 'fp', TypeFelt())
    # Cast from T to T.
    simplify_type_system_test('cast([cast(fp, T*)], T)', '[fp]', t)
    # Dereference.
    simplify_type_system_test('[cast(fp, T**)]', '[fp]', t_star)
    simplify_type_system_test('[[cast(fp, T**)]]', '[[fp]]', t)
    # Address of.
    simplify_type_system_test('&([[cast(fp, T**)]])', '[fp]', t_star)
    simplify_type_system_test('&&[[cast(fp, T**)]]', 'fp', t_star2)
def test_offset_reference_definition_typed_members():
    t = TypeStruct(scope=scope('T'), is_fully_resolved=True)
    s_star = TypePointer(
        pointee=TypeStruct(scope=scope('S'), is_fully_resolved=True))
    reference_manager = ReferenceManager()
    identifiers = {
        scope('T'): ScopeDefinition(),
        scope('T.x'): MemberDefinition(offset=3, cairo_type=s_star),
        scope('T.flt'): MemberDefinition(offset=4, cairo_type=TypeFelt()),
        scope('S'): ScopeDefinition(),
        scope('S.x'): MemberDefinition(offset=10, cairo_type=t),
    }
    main_reference = ReferenceDefinition(full_name=scope('a'), references=[])
    references = {
        scope('a'):
        reference_manager.get_id(
            Reference(
                pc=0,
                value=mark_types_in_expr_resolved(parse_expr('cast(ap, T*)')),
                ap_tracking_data=RegTrackingData(group=0, offset=0),
            )),
    }

    flow_tracking_data = FlowTrackingDataActual(
        ap_tracking=RegTrackingData(group=0, offset=1),
        reference_ids=references,
    )

    # Create OffsetReferenceDefinition instances for expressions of the form "a.<member_path>",
    # such as a.x and a.x.x, and check the result of evaluation those expressions.
    for member_path, expected_result in [
        ('x', 'cast([ap - 1 + 3], S*)'),
        ('x.x', 'cast([[ap - 1 + 3] + 10], T)'),
        ('x.x.x', 'cast([&[[ap - 1 + 3] + 10] + 3], S*)'),
        ('x.x.flt', 'cast([&[[ap - 1 + 3] + 10] + 4], felt)')
    ]:
        definition = OffsetReferenceDefinition(parent=main_reference,
                                               identifier_values=identifiers,
                                               member_path=scope(member_path))
        assert definition.eval(
            reference_manager=reference_manager,
            flow_tracking_data=flow_tracking_data).format() == expected_result

    definition = OffsetReferenceDefinition(parent=main_reference,
                                           identifier_values=identifiers,
                                           member_path=scope('x.x.flt.x'))
    with pytest.raises(
            DefinitionError,
            match='Member access requires a type of the form Struct*.'):
        assert definition.eval(
            reference_manager=reference_manager,
            flow_tracking_data=flow_tracking_data).format() == expected_result
Ejemplo n.º 10
0
    def handle_ReferenceDefinition(self, name: str,
                                   identifier: ReferenceDefinition,
                                   scope: ScopedName,
                                   set_value: Optional[MaybeRelocatable]):
        # In set mode, take the address of the given reference instead.
        reference = self._context.flow_tracking_data.resolve_reference(
            reference_manager=self._context.reference_manager,
            name=identifier.full_name)

        if set_value is None:
            expr = reference.eval(self._context.flow_tracking_data.ap_tracking)
            expr, expr_type = simplify_type_system(expr)
            if isinstance(expr_type, TypeStruct):
                # If the reference is of type T, take its address and treat it as T*.
                assert isinstance(expr, ExprDeref), \
                    f"Expected expression of type '{expr_type.format()}' to have an address."
                expr = expr.addr
                expr_type = TypePointer(pointee=expr_type)
            val = self._context.evaluator(expr)

            # Check if the type is felt* or any_type**.
            is_pointer_to_felt_or_pointer = (
                isinstance(expr_type, TypePointer)
                and isinstance(expr_type.pointee, (TypePointer, TypeFelt)))
            if isinstance(expr_type,
                          TypeFelt) or is_pointer_to_felt_or_pointer:
                return val
            else:
                # Typed reference, return VmConstsReference which allows accessing members.
                assert isinstance(expr_type, TypePointer) and \
                    isinstance(expr_type.pointee, TypeStruct), \
                    'Type must be of the form T*.'
                return VmConstsReference(
                    context=self._context,
                    accessible_scopes=[expr_type.pointee.scope],
                    reference_value=val,
                    add_addr_var=True)
        else:
            assert str(scope[-1:]) == name, 'Expecting scope to end with name.'
            value, value_type = simplify_type_system(reference.value)
            assert isinstance(value, ExprDeref), f"""\
{scope} (= {value.format()}) does not reference memory and cannot be assigned."""

            value_ref = Reference(
                pc=reference.pc,
                value=ExprCast(expr=value.addr, dest_type=value_type),
                ap_tracking_data=reference.ap_tracking_data,
            )

            addr = self._context.evaluator(
                value_ref.eval(self._context.flow_tracking_data.ap_tracking))
            self._context.memory[addr] = set_value
Ejemplo n.º 11
0
def create_simple_ref_expr(reg: Register, offset: int, cairo_type: CairoType,
                           location: Optional[Location]) -> Expression:
    """
    Creates an expression of the form '[cast(reg + offset, cairo_type*)]'.
    """
    return ExprDeref(addr=ExprCast(expr=ExprOperator(
        a=ExprReg(reg=reg, location=location),
        op='+',
        b=ExprConst(val=offset, location=location),
        location=location),
                                   dest_type=TypePointer(pointee=cairo_type,
                                                         location=location),
                                   location=location),
                     location=location)
def test_type_tuples():
    t = TypeStruct(scope=scope('T'), is_fully_resolved=True)
    t_star = TypePointer(pointee=t)

    # Simple tuple.
    simplify_type_system_test('(fp, [cast(fp, T*)], cast(fp,T*))',
                              '(fp, [fp], fp)',
                              TypeTuple(members=[TypeFelt(), t, t_star], ))

    # Nested.
    simplify_type_system_test(
        '(fp, (), ([cast(fp, T*)],))', '(fp, (), ([fp],))',
        TypeTuple(members=[
            TypeFelt(),
            TypeTuple(members=[]),
            TypeTuple(members=[t])
        ], ))
Ejemplo n.º 13
0
def test_type_tuples():
    t = TypeStruct(scope=scope('T'), is_fully_resolved=False)
    t_star = TypePointer(pointee=t)

    # Simple tuple.
    assert simplify_type_system(parse_expr('(fp, [cast(fp, T*)], cast(fp,T*))')) == (
        parse_expr('(fp, [fp], fp)'), TypeTuple(members=[TypeFelt(), t, t_star],)
    )

    # Nested.
    assert simplify_type_system(parse_expr('(fp, (), ([cast(fp, T*)],))')) == (
        parse_expr('(fp, (), ([fp],))'), TypeTuple(
            members=[
                TypeFelt(),
                TypeTuple(members=[]),
                TypeTuple(members=[t])],
        )
    )
Ejemplo n.º 14
0
def test_reference_to_structs():
    t = TypeStruct(scope=scope('T'), is_fully_resolved=True)
    t_star = TypePointer(pointee=t)
    identifier_values = {
        scope('ref'): ReferenceDefinition(
            full_name=scope('ref'), cairo_type=t, references=[]
        ),
        scope('T'): StructDefinition(
            full_name=scope('T'),
            members={
                'x': MemberDefinition(offset=3, cairo_type=t_star),
            },
            size=4,
        ),
    }
    reference_manager = ReferenceManager()
    flow_tracking_data = FlowTrackingDataActual(ap_tracking=RegTrackingData())
    flow_tracking_data = flow_tracking_data.add_reference(
        reference_manager=reference_manager,
        name=scope('ref'),
        ref=Reference(
            pc=0,
            value=mark_types_in_expr_resolved(parse_expr('[cast(100, T*)]')),
            ap_tracking_data=RegTrackingData(group=0, offset=2),
        ),
    )
    memory = {103: 200}
    consts = get_vm_consts(
        identifier_values, reference_manager, flow_tracking_data, memory=memory)

    assert consts.ref.address_ == 100
    assert consts.ref.x.address_ == 200
    # Set the pointer ref.x.x to 300.
    consts.ref.x.x = 300
    assert memory[203] == 300
    assert consts.ref.x.x.address_ == 300

    assert consts.ref.type_ == consts.T
Ejemplo n.º 15
0
import pytest

from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer
from starkware.cairo.lang.compiler.ast.expr import ExprIdentifier
from starkware.cairo.lang.compiler.error_handling import InputFile, Location
from starkware.cairo.lang.compiler.identifier_definition import (
    IdentifierDefinition, MemberDefinition, StructDefinition)
from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager
from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition
from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError
from starkware.cairo.lang.compiler.scoped_name import ScopedName
from starkware.cairo.lang.compiler.type_casts import FELT_STAR
from starkware.starknet.compiler.calldata_parser import process_calldata

scope = ScopedName.from_string
FELT_STAR_STAR = TypePointer(pointee=FELT_STAR)


def dummy_location():
    return Location(start_line=1,
                    start_col=2,
                    end_line=3,
                    end_col=4,
                    input_file=InputFile(filename=None, content=''))


def process_test_calldata(members: Dict[str, MemberDefinition],
                          has_range_check_builtin=True):
    identifier_values: Dict[ScopedName, IdentifierDefinition] = {
        scope('MyStruct'):
        StructDefinition(
Ejemplo n.º 16
0
def test_references():
    reference_manager = ReferenceManager()
    references = {
        scope('x.ref'):
        reference_manager.alloc_id(
            Reference(
                pc=0,
                value=parse_expr('[ap + 1]'),
                ap_tracking_data=RegTrackingData(group=0, offset=2),
            )),
        scope('x.ref2'):
        reference_manager.alloc_id(
            Reference(
                pc=0,
                value=parse_expr('[ap + 1] + 0'),
                ap_tracking_data=RegTrackingData(group=0, offset=2),
            )),
        scope('x.ref3'):
        reference_manager.alloc_id(
            Reference(
                pc=0,
                value=parse_expr('ap + 1'),
                ap_tracking_data=RegTrackingData(group=0, offset=2),
            )),
        scope('x.typeref'):
        reference_manager.alloc_id(
            Reference(
                pc=0,
                value=mark_types_in_expr_resolved(
                    parse_expr('cast(ap + 1, MyStruct*)')),
                ap_tracking_data=RegTrackingData(group=0, offset=3),
            )),
        scope('x.typeref2'):
        reference_manager.alloc_id(
            Reference(
                pc=0,
                value=mark_types_in_expr_resolved(
                    parse_expr('cast([ap + 1], MyStruct*)')),
                ap_tracking_data=RegTrackingData(group=0, offset=3),
            )),
    }
    my_struct_star = TypePointer(
        pointee=TypeStruct(scope=scope('MyStruct'), is_fully_resolved=True))
    identifier_values = {
        scope('x.ref'):
        ReferenceDefinition(full_name=scope('x.ref'),
                            cairo_type=TypeFelt(),
                            references=[]),
        scope('x.ref2'):
        ReferenceDefinition(full_name=scope('x.ref2'),
                            cairo_type=TypeFelt(),
                            references=[]),
        scope('x.ref3'):
        ReferenceDefinition(full_name=scope('x.ref3'),
                            cairo_type=TypeFelt(),
                            references=[]),
        scope('x.typeref'):
        ReferenceDefinition(full_name=scope('x.typeref'),
                            cairo_type=my_struct_star,
                            references=[]),
        scope('x.typeref2'):
        ReferenceDefinition(full_name=scope('x.typeref2'),
                            cairo_type=my_struct_star,
                            references=[]),
        scope('MyStruct'):
        StructDefinition(
            full_name=scope('MyStruct'),
            members={
                'member': MemberDefinition(offset=10, cairo_type=TypeFelt()),
            },
            size=11,
        ),
    }
    prime = 2**64 + 13
    ap = 100
    fp = 200
    memory = {
        (ap - 2) + 1: 1234,
        (ap - 1) + 1: 1000,
        (ap - 1) + 1 + 2: 13,
        (ap - 1) + 1 + 10: 17,
    }

    flow_tracking_data = FlowTrackingDataActual(
        ap_tracking=RegTrackingData(group=0, offset=4),
        reference_ids=references,
    )
    context = VmConstsContext(
        identifiers=IdentifierManager.from_dict(identifier_values),
        evaluator=ExpressionEvaluator(prime, ap, fp, memory).eval,
        reference_manager=reference_manager,
        flow_tracking_data=flow_tracking_data,
        memory=memory,
        pc=0)
    consts = VmConsts(context=context, accessible_scopes=[ScopedName()])

    assert consts.x.ref == memory[(ap - 2) + 1]
    assert consts.x.typeref.address_ == (ap - 1) + 1
    assert consts.x.typeref.member == memory[(ap - 1) + 1 + 10]
    with pytest.raises(IdentifierError,
                       match="'abc' is not a member of 'MyStruct'."):
        consts.x.typeref.abc

    with pytest.raises(IdentifierError,
                       match="'SIZE' is not a member of 'MyStruct'."):
        consts.x.typeref.SIZE

    with pytest.raises(
            AssertionError,
            match='Cannot change the value of a struct definition.'):
        consts.MyStruct = 13

    assert consts.MyStruct.member == 10
    with pytest.raises(AssertionError,
                       match='Cannot change the value of a constant.'):
        consts.MyStruct.member = 13

    assert consts.MyStruct.SIZE == 11
    with pytest.raises(AssertionError,
                       match='Cannot change the value of a constant.'):
        consts.MyStruct.SIZE = 13

    with pytest.raises(IdentifierError,
                       match="'abc' is not a member of 'MyStruct'."):
        consts.MyStruct.abc

    # Test that VmConsts can be used to assign values to references of the form '[...]'.
    memory.clear()

    consts.x.ref = 1234
    assert memory == {(ap - 2) + 1: 1234}

    memory.clear()
    consts.x.typeref.member = 1001
    assert memory == {(ap - 1) + 1 + 10: 1001}

    memory.clear()
    consts.x.typeref2 = 4321
    assert memory == {(ap - 1) + 1: 4321}

    consts.x.typeref2.member = 1
    assert memory == {
        (ap - 1) + 1: 4321,
        4321 + 10: 1,
    }

    with pytest.raises(AssertionError,
                       match='Cannot change the value of a scope definition'):
        consts.x = 1000
    with pytest.raises(
            AssertionError,
            match=
            r'x.ref2 \(= \[ap \+ 1\] \+ 0\) does not reference memory and cannot be assigned.',
    ):
        consts.x.ref2 = 1000
    with pytest.raises(
            AssertionError,
            match=
            r'x.typeref \(= ap \+ 1\) does not reference memory and cannot be assigned.',
    ):
        consts.x.typeref = 1000
 def visit_ExprAddressOf(
         self, expr: ExprAddressOf) -> Tuple[Expression, TypePointer]:
     inner_expr, inner_type = self.visit(expr.expr)
     return get_expr_addr(inner_expr), TypePointer(pointee=inner_type)
Ejemplo n.º 18
0
import itertools
from typing import Optional

from starkware.cairo.lang.compiler.ast.cairo_types import (CairoType, CastType,
                                                           TypeFelt,
                                                           TypePointer,
                                                           TypeStruct,
                                                           TypeTuple)
from starkware.cairo.lang.compiler.ast.expr import ExprDeref, Expression, ExprTuple
from starkware.cairo.lang.compiler.error_handling import LocationError
from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager
from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition

FELT_STAR = TypePointer(pointee=TypeFelt())


class CairoTypeError(LocationError):
    pass


def check_cast(src_type: CairoType,
               dest_type: CairoType,
               identifier_manager: IdentifierManager,
               expr: Optional[Expression] = None,
               cast_type: CastType = CastType.EXPLICIT) -> bool:
    """
    Returns true if the given expression can be casted from src_type to dest_type
    according to the given 'cast_type'.
    In some cases of cast failure, an exception with more specific details is raised.

    'expr' must be specified (not None) when CastType.EXPLICIT (or above) is used.
def test_type_subscript_op():
    felt_star_star = TypePointer(pointee=TypePointer(pointee=TypeFelt()))
    t = TypeStruct(scope=scope('T'), is_fully_resolved=True)
    t_star = TypePointer(pointee=t)

    identifier_dict = {
        scope('T'): StructDefinition(full_name=scope('T'), members={}, size=7)
    }
    identifiers = IdentifierManager.from_dict(identifier_dict)

    simplify_type_system_test('cast(fp, felt*)[3]', '[fp + 3 * 1]', TypeFelt())
    simplify_type_system_test('cast(fp, felt***)[0]', '[fp + 0 * 1]',
                              felt_star_star)
    simplify_type_system_test('[cast(fp, T****)][ap][ap]',
                              '[[[fp] + ap * 1] + ap * 1]', t_star)
    simplify_type_system_test('cast(fp, T**)[1][2]',
                              '[[fp + 1 * 1] + 2 * 7]',
                              t,
                              identifiers=identifiers)

    # Test that 'cast(fp, T*)[2 * ap + 3]' simplifies into '[fp + (2 * ap + 3) * 7]', but without
    # the parentheses.
    assert simplify_type_system(
        mark_types_in_expr_resolved(parse_expr('cast(fp, T*)[2 * ap + 3]')),
        identifiers=identifiers) == (remove_parentheses(
            parse_expr('[fp + (2 * ap + 3) * 7]')), t)

    # Test subscript operator for tuples.
    simplify_type_system_test('(cast(fp, felt**), fp, cast(fp, T*))[2]', 'fp',
                              t_star)
    simplify_type_system_test('(cast(fp, felt**), fp, cast(fp, T*))[0]', 'fp',
                              felt_star_star)
    simplify_type_system_test('(cast(fp, felt**), ap, cast(fp, T*))[3*4 - 11]',
                              'ap', TypeFelt())
    simplify_type_system_test('[cast(ap, (felt, felt)*)][0]', '[ap + 0]',
                              TypeFelt())
    simplify_type_system_test('[cast(ap, (T*, T, felt, T*, felt*)*)][3]',
                              '[ap + 9]',
                              t_star,
                              identifiers=identifiers)

    # Test failures.

    verify_exception(
        '(fp, fp, fp)[cast(ap, felt*)]', """
file:?:?: Cannot apply subscript-operator with offset of non-felt type 'felt*'.
(fp, fp, fp)[cast(ap, felt*)]
             ^*************^
""")

    verify_exception(
        '(fp, fp, fp)[[ap]]', """
file:?:?: Subscript-operator for tuples supports only constant offsets, found 'ExprDeref'.
(fp, fp, fp)[[ap]]
             ^**^
""")

    # The simplifier in TypeSystemVisitor cannot access PRIME, so PyConsts are unsimplified.
    verify_exception(
        '(fp, fp, fp)[%[1%]]', """
file:?:?: Subscript-operator for tuples supports only constant offsets, found 'ExprPyConst'.
(fp, fp, fp)[%[1%]]
             ^***^
""")

    verify_exception(
        '(fp, fp, fp)[3]', """
file:?:?: Tuple index 3 is out of range [0, 3).
(fp, fp, fp)[3]
^*************^
""")

    verify_exception(
        '[cast(fp, (T*, T, felt)*)][-1]', """
file:?:?: Tuple index -1 is out of range [0, 3).
[cast(fp, (T*, T, felt)*)][-1]
^****************************^
""")

    verify_exception(
        'cast(fp, felt)[0]', """
file:?:?: Cannot apply subscript-operator to non-pointer, non-tuple type 'felt'.
cast(fp, felt)[0]
^***************^
""")

    verify_exception(
        '[cast(fp, T*)][0]', """
file:?:?: Cannot apply subscript-operator to non-pointer, non-tuple type 'T'.
[cast(fp, T*)][0]
^***************^
""")

    verify_exception(
        'cast(fp, felt*)[[cast(ap, T*)]]', """
file:?:?: Cannot apply subscript-operator with offset of non-felt type 'T'.
cast(fp, felt*)[[cast(ap, T*)]]
                ^************^
""")

    verify_exception('cast(fp, Z*)[0]',
                     """
file:?:?: Unknown identifier 'Z'.
cast(fp, Z*)[0]
^*************^
""",
                     identifiers=identifiers)

    verify_exception('cast(fp, T*)[0]',
                     """
file:?:?: Unknown identifier 'T'.
cast(fp, T*)[0]
^*************^
""",
                     identifiers=None)
Ejemplo n.º 20
0
def process_storage_var(visitor: IdentifierAwareVisitor,
                        elm: CodeElementFunction):
    for commented_code_elm in elm.code_block.code_elements:
        code_elm = commented_code_elm.code_elm
        if not isinstance(code_elm, CodeElementEmptyLine):
            if hasattr(code_elm, 'location'):
                location = code_elm.location  # type: ignore
            else:
                location = elm.identifier.location
            raise PreprocessorError(
                'Storage variables must have an empty body.',
                location=location)

    if elm.implicit_arguments is not None:
        raise PreprocessorError(
            'Storage variables must have no implicit arguments.',
            location=elm.implicit_arguments.location)

    for decorator in elm.decorators:
        if decorator.name != STORAGE_VAR_DECORATOR:
            raise PreprocessorError(
                'Storage variables must have no decorators in addition to '
                f'@{STORAGE_VAR_DECORATOR}.',
                location=decorator.location)

    for arg in elm.arguments.identifiers:
        arg_type = arg.get_type()
        if not isinstance(arg_type, TypeFelt):
            raise PreprocessorError(
                'Only felt arguments are supported in storage variables.',
                location=arg_type.location)

    unresolved_return_type = get_return_type(elm=elm)
    return_type = visitor.resolve_type(unresolved_return_type)
    if not check_felts_only_type(cairo_type=return_type,
                                 identifier_manager=visitor.identifiers):
        raise PreprocessorError(
            'The return type of storage variables must consist of felts.',
            location=elm.returns.location
            if elm.returns is not None else elm.identifier.location)
    var_size = visitor.get_size(return_type)

    if var_size > MAX_STORAGE_ITEM_SIZE:
        raise PreprocessorError(
            f'The storage variable size ({var_size}) exceeds the maximum value '
            f'({MAX_STORAGE_ITEM_SIZE}).',
            location=elm.returns.location
            if elm.returns is not None else elm.identifier.location)

    var_name = elm.identifier.name
    addr = storage_var_name_to_base_addr(var_name)
    addr_func_body = f'let res = {addr}\n'
    for arg in elm.arguments.identifiers:
        addr_func_body += \
            f'let (res) = hash2{{hash_ptr=pedersen_ptr}}(res, {arg.identifier.name})\n'
    if len(elm.arguments.identifiers) > 0:
        addr_func_body += 'let (res) = normalize_address(addr=res)\n'
    addr_func_body += 'return (res=res)\n'

    args = ', '.join(arg.identifier.name for arg in elm.arguments.identifiers)

    read_func_body = f'let (storage_addr) = addr({args})\n'
    for i in range(var_size):
        read_func_body += \
            f'let (__storage_var_temp{i}) = storage_read(address=storage_addr + {i})\n'
    # Copy the return implicit args and the return values to a contiguous segment.
    read_func_body += """
tempvar storage_ptr = storage_ptr
tempvar range_check_ptr = range_check_ptr
tempvar pedersen_ptr = pedersen_ptr
"""
    for i in range(var_size):
        read_func_body += f'tempvar __storage_var_temp{i} : felt = __storage_var_temp{i}\n'
    unresolved_return_type_ptr = TypePointer(pointee=unresolved_return_type)
    read_func_body += \
        f'return ([cast(&__storage_var_temp0, {unresolved_return_type_ptr.format()})])'

    write_func_body = f'let (storage_addr) = addr({args})\n'
    for i in range(var_size):
        write_func_body += \
            f'storage_write(address=storage_addr + {i}, value=[cast(&value, felt) + {i}])\n'
    write_func_body += 'return ()\n'
    return generate_storage_var_functions(elm,
                                          addr_func_body=addr_func_body,
                                          read_func_body=read_func_body,
                                          write_func_body=write_func_body,
                                          is_impl=True)
    def create_func_wrapper(self, elm: CodeElementFunction, func_alias_name: str):
        """
        Generates a wrapper that converts between the StarkNet contract ABI and the
        Cairo calling convention.

        Arguments:
        elm - the CodeElementFunction of the wrapped function.
        func_alias_name - an alias for the FunctionDefention in the current scope.
        """

        os_context = self.get_os_context()

        func_location = elm.identifier.location
        assert func_location is not None

        # We expect the call stack to look as follows:
        # pointer to os_context struct.
        # calldata size.
        # pointer to the call data array.
        # ret_fp.
        # ret_pc.
        os_context_ptr = ExprDeref(
            addr=ExprOperator(
                ExprReg(reg=Register.FP, location=func_location),
                '+',
                ExprConst(-5, location=func_location),
                location=func_location),
            location=func_location)

        calldata_size = ExprDeref(
            addr=ExprOperator(
                ExprReg(reg=Register.FP, location=func_location),
                '+',
                ExprConst(-4, location=func_location),
                location=func_location),
            location=func_location)

        calldata_ptr = ExprDeref(
            addr=ExprOperator(
                ExprReg(reg=Register.FP, location=func_location),
                '+',
                ExprConst(-3, location=func_location),
                location=func_location),
            location=func_location)

        implicit_arguments = None

        implicit_arguments_identifiers: Dict[str, TypedIdentifier] = {}
        if elm.implicit_arguments is not None:
            args = []
            for typed_identifier in elm.implicit_arguments.identifiers:
                ptr_name = typed_identifier.name
                if ptr_name not in os_context:
                    raise PreprocessorError(
                        f"Unexpected implicit argument '{ptr_name}' in an external function.",
                        location=typed_identifier.identifier.location)

                implicit_arguments_identifiers[ptr_name] = typed_identifier

                # Add the assignment expression 'ptr_name = ptr_name' to the implicit arg list.
                args.append(ExprAssignment(
                    identifier=typed_identifier.identifier,
                    expr=typed_identifier.identifier,
                    location=typed_identifier.location,
                ))

            implicit_arguments = ArgList(
                args=args, notes=[], has_trailing_comma=True,
                location=elm.implicit_arguments.location)

        return_args_exprs: List[Expression] = []

        # Create references.
        for ptr_name, index in os_context.items():
            ref_name = self.current_scope + ptr_name

            arg_identifier = implicit_arguments_identifiers.get(ptr_name)
            if arg_identifier is None:
                location: Optional[Location] = func_location
                cairo_type: CairoType = TypeFelt(location=location)
            else:
                location = arg_identifier.location
                cairo_type = self.resolve_type(arg_identifier.get_type())

            # Add a reference of the form
            # 'let ref_name = [cast(os_context_ptr + index, cairo_type*)]'.
            self.add_reference(
                name=ref_name,
                value=ExprDeref(
                    addr=ExprCast(
                        ExprOperator(
                            os_context_ptr, '+', ExprConst(index, location=location),
                            location=location),
                        dest_type=TypePointer(pointee=cairo_type, location=cairo_type.location),
                        location=cairo_type.location),
                    location=location),
                cairo_type=cairo_type,
                location=location,
                require_future_definition=False)

            assert index == len(return_args_exprs), 'Unexpected index.'

            return_args_exprs.append(ExprIdentifier(name=ptr_name, location=func_location))

        arg_struct_def = self.get_struct_definition(
            name=ScopedName.from_string(func_alias_name) + CodeElementFunction.ARGUMENT_SCOPE,
            location=func_location)

        code_elements, call_args = process_calldata(
            calldata_ptr=calldata_ptr,
            calldata_size=calldata_size,
            identifiers=self.identifiers,
            struct_def=arg_struct_def,
            has_range_check_builtin='range_check_ptr' in os_context,
            location=func_location,
        )

        for code_element in code_elements:
            self.visit(code_element)

        self.visit(CodeElementFuncCall(
            func_call=RvalueFuncCall(
                func_ident=ExprIdentifier(name=func_alias_name, location=func_location),
                arguments=call_args,
                implicit_arguments=implicit_arguments,
                location=func_location)))

        ret_struct_name = ScopedName.from_string(func_alias_name) + CodeElementFunction.RETURN_SCOPE
        ret_struct_type = self.resolve_type(TypeStruct(ret_struct_name, False))
        ret_struct_def = self.get_struct_definition(
            name=ret_struct_name,
            location=func_location)
        ret_struct_expr = create_simple_ref_expr(
            reg=Register.AP, offset=-ret_struct_def.size, cairo_type=ret_struct_type,
            location=func_location)
        self.add_reference(
            name=self.current_scope + 'ret_struct',
            value=ret_struct_expr,
            cairo_type=ret_struct_type,
            require_future_definition=False,
            location=func_location)

        # Add function return values.
        retdata_size, retdata_ptr = self.process_retdata(
            ret_struct_ptr=ExprIdentifier(name='ret_struct'),
            ret_struct_type=ret_struct_type, struct_def=ret_struct_def,
            location=func_location,
        )
        return_args_exprs += [retdata_size, retdata_ptr]

        # Push the return values.
        self.push_compound_expressions(
            compound_expressions=[self.simplify_expr_as_felt(expr) for expr in return_args_exprs],
            location=func_location,
        )

        # Add a ret instruction.
        self.visit(CodeElementInstruction(
            instruction=InstructionAst(
                body=RetInstruction(),
                inc_ap=False,
                location=func_location)))

        # Add an entry to the ABI.
        external_decorator = self.get_external_decorator(elm)
        assert external_decorator is not None
        is_view = external_decorator.name == 'view'

        if external_decorator.name == L1_HANDLER_DECORATOR:
            entry_type = 'l1_handler'
        elif external_decorator.name in [EXTERNAL_DECORATOR, VIEW_DECORATOR]:
            entry_type = 'function'
        else:
            raise NotImplementedError(f'Unsupported decorator {external_decorator.name}')

        entry_type = (
            'function' if external_decorator.name != L1_HANDLER_DECORATOR else L1_HANDLER_DECORATOR)
        self.add_abi_entry(
            name=elm.name, arg_struct_def=arg_struct_def, ret_struct_def=ret_struct_def,
            is_view=is_view, entry_type=entry_type)
Ejemplo n.º 22
0
 def type_pointer(self, value, meta):
     return TypePointer(pointee=value[0], location=self.meta2loc(meta))
def test_type_dot_op():
    """
    Tests type_system_visitor for ExprDot-s, in the following struct architecture:

    struct S:
        member x : felt
        member y : felt
    end

    struct T:
        member t : felt
        member s : S
        member sp : S*
    end

    struct R:
        member r : R*
    end
    """
    t = TypeStruct(scope=scope('T'), is_fully_resolved=True)
    s = TypeStruct(scope=scope('S'), is_fully_resolved=True)
    s_star = TypePointer(pointee=s)
    r = TypeStruct(scope=scope('R'), is_fully_resolved=True)
    r_star = TypePointer(pointee=r)

    identifier_dict = {
        scope('T'):
        StructDefinition(
            full_name=scope('T'),
            members={
                't': MemberDefinition(offset=0, cairo_type=TypeFelt()),
                's': MemberDefinition(offset=1, cairo_type=s),
                'sp': MemberDefinition(offset=3, cairo_type=s_star),
            },
            size=4,
        ),
        scope('S'):
        StructDefinition(
            full_name=scope('S'),
            members={
                'x': MemberDefinition(offset=0, cairo_type=TypeFelt()),
                'y': MemberDefinition(offset=1, cairo_type=TypeFelt()),
            },
            size=2,
        ),
        scope('R'):
        StructDefinition(
            full_name=scope('R'),
            members={
                'r': MemberDefinition(offset=0, cairo_type=r_star),
            },
            size=1,
        ),
    }

    identifiers = IdentifierManager.from_dict(identifier_dict)

    for (orig_expr, simplified_expr, simplified_type) in [
        ('[cast(fp, T*)].t', '[fp]', TypeFelt()),
        ('[cast(fp, T*)].s', '[fp + 1]', s),
        ('[cast(fp, T*)].sp', '[fp + 3]', s_star),
        ('[cast(fp, T*)].s.x', '[fp + 1]', TypeFelt()),
        ('[cast(fp, T*)].s.y', '[fp + 1 + 1]', TypeFelt()),
        ('[[cast(fp, T*)].sp].x', '[[fp + 3]]', TypeFelt()),
        ('[cast(fp, R*)]', '[fp]', r),
        ('[cast(fp, R*)].r', '[fp]', r_star),
        ('[[[cast(fp, R*)].r].r].r', '[[[fp]]]', r_star),
            # Test . as ->
        ('cast(fp, T*).t', '[fp]', TypeFelt()),
        ('cast(fp, T*).sp.y', '[[fp + 3] + 1]', TypeFelt()),
        ('cast(fp, R*).r.r.r', '[[[fp]]]', r_star),
            # More tests.
        ('(cast(fp, T*).s)', '[fp + 1]', s),
        ('(cast(fp, T*).s).x', '[fp + 1]', TypeFelt()),
        ('(&(cast(fp, T*).s)).x', '[fp + 1]', TypeFelt())
    ]:
        simplify_type_system_test(orig_expr,
                                  simplified_expr,
                                  simplified_type,
                                  identifiers=identifiers)

    # Test failures.

    verify_exception('cast(fp, felt).x',
                     """
file:?:?: Cannot apply dot-operator to non-struct type 'felt'.
cast(fp, felt).x
^**************^
""",
                     identifiers=identifiers)

    verify_exception('cast(fp, felt*).x',
                     """
file:?:?: Cannot apply dot-operator to pointer-to-non-struct type 'felt*'.
cast(fp, felt*).x
^***************^
""",
                     identifiers=identifiers)

    verify_exception('cast(fp, T*).x',
                     """
file:?:?: Member 'x' does not appear in definition of struct 'T'.
cast(fp, T*).x
^************^
""",
                     identifiers=identifiers)

    verify_exception('cast(fp, Z*).x',
                     """
file:?:?: Unknown identifier 'Z'.
cast(fp, Z*).x
^************^
""",
                     identifiers=identifiers)

    verify_exception('cast(fp, T*).x',
                     """
file:?:?: Identifiers must be initialized for type-simplification of dot-operator expressions.
cast(fp, T*).x
^************^
""",
                     identifiers=None)

    verify_exception('cast(fp, Z*).x',
                     """
file:?:?: Type is expected to be fully resolved at this point.
cast(fp, Z*).x
^************^
""",
                     identifiers=identifiers,
                     resolve_types=False)