예제 #1
0
def test_get_struct_definition():
    identifier_dict = {
        scope('T'):
        StructDefinition(
            full_name=scope('T'),
            members={
                'a': MemberDefinition(offset=0, cairo_type=TypeFelt()),
                'b': MemberDefinition(offset=1, cairo_type=TypeFelt()),
            },
            size=2,
        ),
        scope('MyConst'):
        ConstDefinition(value=5),
    }

    manager = IdentifierManager.from_dict(identifier_dict)

    struct_def = get_struct_definition(ScopedName.from_string('T'), manager)

    # Convert to a list, to check the order of the elements in the dict.
    assert list(struct_def.members.items()) == [
        ('a', MemberDefinition(offset=0, cairo_type=TypeFelt())),
        ('b', MemberDefinition(offset=1, cairo_type=TypeFelt())),
    ]

    assert struct_def.size == 2

    with pytest.raises(
            DefinitionError,
            match="Expected 'MyConst' to be a struct. Found: 'const'."):
        get_struct_definition(scope('MyConst'), manager)

    with pytest.raises(MissingIdentifierError,
                       match=re.escape("Unknown identifier 'abc'.")):
        get_struct_definition(scope('abc'), manager)
예제 #2
0
def check_felts_only_type(cairo_type: CairoType,
                          identifier_manager: IdentifierManager) -> bool:
    """
    A felts-only type defined to be either felt or a struct whose members are all felts-only types.
    Return True if the given type is felts-only.
    """

    if isinstance(cairo_type, TypeFelt):
        return True
    elif isinstance(cairo_type, TypeStruct):
        struct_definition = get_struct_definition(
            cairo_type.resolved_scope, identifier_manager=identifier_manager)
        for member_def in struct_definition.members.values():
            res = check_felts_only_type(member_def.cairo_type,
                                        identifier_manager=identifier_manager)
            if not res:
                return False
        return True
    elif isinstance(cairo_type, TypeTuple):
        for item_type in cairo_type.members:
            res = check_felts_only_type(item_type,
                                        identifier_manager=identifier_manager)
            if not res:
                return False
        return True
    else:
        return False
예제 #3
0
 def get_size(self, cairo_type: CairoType):
     """
     Returns the size of the given type.
     """
     if isinstance(cairo_type, (TypeFelt, TypePointer)):
         return 1
     elif isinstance(cairo_type, TypeStruct):
         if cairo_type.is_fully_resolved:
             try:
                 return get_struct_definition(
                     struct_name=cairo_type.scope,
                     identifier_manager=self.identifiers).size
             except DefinitionError as exc:
                 raise PreprocessorError(str(exc),
                                         location=cairo_type.location)
         else:
             return self.get_struct_size(struct_name=cairo_type.scope,
                                         location=cairo_type.location)
     elif isinstance(cairo_type, TypeTuple):
         return sum(
             self.get_size(member_type)
             for member_type in cairo_type.members)
     else:
         raise NotImplementedError(
             f'Type {type(cairo_type).__name__} is not supported.')
예제 #4
0
 def build_struct(self, name: ScopedName):
     """
     Builds and returns namedtuple from a Cairo struct.
     """
     full_name = self._get_full_name(name)
     members = get_struct_definition(full_name, self.identifiers).members
     return namedtuple(full_name.path[-1], list(members.keys()))
예제 #5
0
    def build_func_args(self, func: ScopedName):
        """
        Builds a namedtuple that contains both the explicit and the implicit arguments of 'func'.
        """
        full_name = self._get_full_name(func)

        implict_args = get_struct_definition(
            full_name + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE,
            self.identifiers).members
        args = get_struct_definition(
            full_name + CodeElementFunction.ARGUMENT_SCOPE,
            self.identifiers).members
        return namedtuple(f'{func[-1:]}_full_args',
                          list({
                              **implict_args,
                              **args
                          }))
    def visit_ExprDot(self, expr: ExprDot) -> Tuple[ExprDeref, CairoType]:
        self.verify_identifier_manager_initialized(location=expr.location)

        inner_expr, inner_type = self.visit(expr.expr)
        if isinstance(inner_type, TypePointer):
            if not isinstance(inner_type.pointee, TypeStruct):
                raise CairoTypeError(
                    f'Cannot apply dot-operator to pointer-to-non-struct type '
                    f"'{inner_type.format()}'.",
                    location=expr.location)
            # Allow for . as ->, once.
            inner_type = inner_type.pointee
        elif isinstance(inner_type, TypeStruct):
            if isinstance(inner_expr, ExprTuple):
                raise CairoTypeError(
                    'Accessing struct members for r-value structs is not supported yet.',
                    location=expr.location)
            # Get the address, to evaluate . as ->.
            inner_expr = get_expr_addr(inner_expr)
        else:
            raise CairoTypeError(
                f"Cannot apply dot-operator to non-struct type '{inner_type.format()}'.",
                location=expr.location)

        try:
            struct_def = get_struct_definition(
                struct_name=inner_type.resolved_scope,
                identifier_manager=self.identifiers)
        except Exception as exc:
            raise CairoTypeError(str(exc), location=expr.location)

        if expr.member.name not in struct_def.members:
            raise CairoTypeError(
                f"Member '{expr.member.name}' does not appear in definition of struct "
                f"'{inner_type.format()}'.",
                location=expr.location)
        member_definition = struct_def.members[expr.member.name]
        member_type = member_definition.cairo_type
        member_offset = member_definition.offset

        if member_offset == 0:
            simplified_expr = ExprDeref(addr=inner_expr,
                                        location=expr.location)
        else:
            mem_offset_expr = ExprConst(val=member_offset,
                                        location=expr.location)
            simplified_expr = ExprDeref(addr=ExprOperator(
                a=inner_expr,
                op='+',
                b=mem_offset_expr,
                location=expr.location),
                                        location=expr.location)

        return simplified_expr, member_type
예제 #7
0
def check_main_args(program: Program):
    """
    Makes sure that for every builtin included in the program an appropriate ptr was passed as an
    argument to main() and is subsequently returned.
    """
    expected_builtin_ptrs = [
        f'{builtin_name}_ptr' for builtin_name in program.builtins
    ]

    try:
        implicit_args = list(
            get_struct_definition(
                struct_name=ScopedName.from_string(
                    '__main__.main.ImplicitArgs'),
                identifier_manager=program.identifiers).members)
    except IdentifierError:
        return

    try:
        main_args = implicit_args + list(
            get_struct_definition(
                struct_name=ScopedName.from_string('__main__.main.Args'),
                identifier_manager=program.identifiers).members)
    except IdentifierError:
        pass
    else:
        assert main_args == expected_builtin_ptrs, \
            'Expected main to contain the following arguments (in this order): ' \
            f'{expected_builtin_ptrs}. Found: {main_args}.'

    try:
        main_returns = implicit_args + list(
            get_struct_definition(
                struct_name=ScopedName.from_string('__main__.main.Return'),
                identifier_manager=program.identifiers).members)
    except IdentifierError:
        pass
    else:
        assert main_returns == expected_builtin_ptrs, \
            'Expected main to return the following values (in this order): ' \
            f'{expected_builtin_ptrs}. Found: {main_returns}.'
예제 #8
0
    def __init__(self,
                 identifiers: IdentifierManager,
                 additional_imports: Optional[List[str]] = None):
        """
        Creates a CairoStructFactory that converts Cairo structs to python namedtuples.

        identifiers - an identifier manager holding the structs.
        additional_imports - An optional list of fully qualified names of structs to preload.
          Useful for importing absolute paths, rather than relative.
        """
        self.identifiers = identifiers

        self.resolved_identifiers: MutableMapping[
            ScopedName, ScopedName] = WriteOnceDict()
        if additional_imports is not None:
            for identifier_path in additional_imports:
                scope_name = ScopedName.from_string(identifier_path)
                # Call get_struct_definition to make sure scope_name is a struct.
                get_struct_definition(struct_name=scope_name,
                                      identifier_manager=identifiers)
                self.resolved_identifiers[scope_name[-1:]] = scope_name
예제 #9
0
    def __init__(self, *, reference_value, struct_name: ScopedName, **kw):
        """
        Constructs a VmConstsReference which allows accessing a typed reference fields.
        """
        super().__init__(**kw)

        object.__setattr__(
            self, '_struct_definition',
            get_struct_definition(
                struct_name=struct_name,
                identifier_manager=self._context.identifiers))

        object.__setattr__(self, '_reference_value', reference_value)
        object.__setattr__(self, 'address_', reference_value)
예제 #10
0
    def __init__(self, *, reference_value, struct_name: ScopedName,
                 add_addr_var: bool, **kw):
        """
        Constructs a VmConstsReference which allows accessing a typed reference fields.
        If add_addr_var, the value of the reference itself can be accessed using self.address_.
        """
        super().__init__(**kw)

        object.__setattr__(
            self, '_struct_definition',
            get_struct_definition(
                struct_name=struct_name,
                identifier_manager=self._context.identifiers))

        object.__setattr__(self, '_reference_value', reference_value)
        if add_addr_var:
            object.__setattr__(self, 'address_', reference_value)
예제 #11
0
def process_test_calldata(members: Dict[str, MemberDefinition],
                          has_range_check_builtin=True):
    identifier_values: Dict[ScopedName, IdentifierDefinition] = {
        scope('MyStruct'):
        StructDefinition(
            full_name=scope('MyStruct'),
            members=members,
            size=0,
        ),
    }
    identifiers = IdentifierManager.from_dict(identifier_values)
    calldata_ptr = ExprIdentifier('calldata_ptr')
    calldata_size = ExprIdentifier('calldata_size')
    return process_calldata(calldata_ptr=calldata_ptr,
                            calldata_size=calldata_size,
                            identifiers=identifiers,
                            struct_def=get_struct_definition(
                                struct_name=scope('MyStruct'),
                                identifier_manager=identifiers),
                            has_range_check_builtin=has_range_check_builtin,
                            location=dummy_location())
예제 #12
0
def check_cast(src_type: CairoType,
               dest_type: CairoType,
               identifier_manager: IdentifierManager,
               expr: Optional[Expression] = None,
               cast_type: CastType = CastType.EXPLICIT) -> bool:
    """
    Returns true if the given expression can be casted from src_type to dest_type
    according to the given 'cast_type'.
    In some cases of cast failure, an exception with more specific details is raised.

    'expr' must be specified (not None) when CastType.EXPLICIT (or above) is used.
    """

    # CastType.ASSIGN checks:

    if src_type == dest_type:
        return True

    # Allow implicit cast from pointers to felt*.
    if isinstance(src_type, TypePointer) and dest_type == FELT_STAR:
        return True

    if cast_type is CastType.ASSIGN:
        return False

    # CastType.UNPACKING checks:

    # Allow explicit cast between felts and pointers.
    if isinstance(src_type, (TypeFelt, TypePointer)) and \
            isinstance(dest_type, (TypeFelt, TypePointer)):
        return True

    if cast_type is CastType.UNPACKING:
        return False

    # CastType.EXPLICIT checks:
    assert expr is not None, f'CastType.EXPLICIT requires expr != None.'

    if isinstance(src_type, TypeTuple) and isinstance(dest_type, TypeStruct):
        struct_def = get_struct_definition(
            struct_name=dest_type.resolved_scope,
            identifier_manager=identifier_manager)

        n_src_members = len(src_type.members)
        n_dest_members = len(struct_def.members)
        if n_src_members != n_dest_members:
            raise CairoTypeError(f"""\
Cannot cast an expression of type '{src_type.format()}' to '{dest_type.format()}'.
The former has {n_src_members} members while the latter has {n_dest_members} members.""",
                                 location=expr.location)

        src_exprs = ([arg.expr for arg in expr.members.args] if isinstance(
            expr, ExprTuple) else itertools.repeat(expr))

        for (src_expr, src_member_type,
             dest_member) in zip(src_exprs, src_type.members,
                                 struct_def.members.values()):
            dest_member_type = dest_member.cairo_type
            if not check_cast(
                    src_type=src_member_type,
                    dest_type=dest_member_type,
                    identifier_manager=identifier_manager,
                    expr=src_expr,
                    cast_type=CastType.FORCED
                    if cast_type is CastType.FORCED else CastType.ASSIGN):

                raise CairoTypeError(
                    f"Cannot cast '{src_member_type.format()}' to '{dest_member_type.format()}'.",
                    location=src_expr.location)

        return True

    if cast_type is CastType.EXPLICIT:
        return False

    # CastType.FORCED checks:
    if isinstance(src_type, TypeFelt) and isinstance(
            dest_type, TypeStruct) and isinstance(expr, ExprDeref):
        return True

    assert cast_type is CastType.FORCED, f'Unsupported cast type: {cast_type}.'
    return False
예제 #13
0
 def get_struct_size(self, name: ScopedName) -> int:
     """
     Returns the size of the given struct.
     """
     full_name = self._get_full_name(name)
     return get_struct_definition(full_name, self.identifiers).size