def test_division_by_zero(prime):
    simplifier = ExpressionSimplifier(prime)
    with pytest.raises(SimplifierError, match='Division by zero'):
        simplifier.visit(parse_expr('fp / 0'))
    with pytest.raises(SimplifierError, match='Division by zero'):
        simplifier.visit(parse_expr('5 / 0'))
    if prime is not None:
        with pytest.raises(SimplifierError, match='Division by zero'):
            simplifier.visit(parse_expr(f'fp / {prime}'))
def test_modulo():
    PRIME = 19
    simplifier = ExpressionSimplifier(PRIME)
    # Check that the range is (-PRIME/2, PRIME/2).
    assert simplifier.visit(parse_expr('-9')).format() == '-9'
    assert simplifier.visit(parse_expr('-10')).format() == '9'
    assert simplifier.visit(parse_expr('9')).format() == '9'
    assert simplifier.visit(parse_expr('10')).format() == '-9'

    # Check value which is bigger than PRIME.
    assert simplifier.visit(parse_expr('20')).format() == '1'

    # Check operators.
    assert simplifier.visit(parse_expr('10 + 10')).format() == '1'
    assert simplifier.visit(parse_expr('10 - 30')).format() == '-1'
    assert simplifier.visit(parse_expr('10 * 10')).format() == '5'
    assert simplifier.visit(parse_expr('2 / 3')).format() == '7'
def test_rotation(prime):
    simplifier = ExpressionSimplifier(prime)
    assert simplifier.visit(parse_expr('(fp + 10) + 1')).format() == 'fp + 11'
    assert simplifier.visit(parse_expr('(fp + 10) - 1')).format() == 'fp + 9'
    assert simplifier.visit(
        parse_expr('(fp - 10) + 1')).format() == 'fp + (-9)'
    assert simplifier.visit(
        parse_expr('(fp - 10) - 1')).format() == 'fp + (-11)'

    assert simplifier.visit(parse_expr('(10 + fp) - 1')).format() == 'fp + 9'
    assert simplifier.visit(parse_expr('10 + (fp - 1)')).format() == 'fp + 9'
    assert simplifier.visit(parse_expr('10 + (1 + fp)')).format() == 'fp + 11'
    assert simplifier.visit(
        parse_expr('10 + (1 + fp) + 100')).format() == 'fp + 111'
    assert simplifier.visit(
        parse_expr('10 + (1 + (fp + 100))')).format() == 'fp + 111'
Esempio n. 4
0
def test_operator_precedence():
    code = '(5 + 2) - (3 - 9) * (7 + (-8)) - 10 * (-2) * 5 + (((7)))'
    expr = parse_expr(code)
    # Test formatting.
    assert expr.format() == code

    # Compute the value of expr from the tree and compare it with the correct value.
    PRIME = 3 * 2**30 + 1
    simplified_expr = ExpressionSimplifier(PRIME).visit(expr)
    assert isinstance(simplified_expr, ExprConst)
    assert simplified_expr.val == eval(code)
Esempio n. 5
0
def test_div_expr():
    assert parse_expr(
        '[ap]/[fp]/3/[ap+1]').format() == '[ap] / [fp] / 3 / [ap + 1]'

    code = '120 / 2 / 3 / 4'
    expr = parse_expr(code)
    # Compute the value of expr from the tree and compare it with the correct value.
    PRIME = 3 * 2**30 + 1
    simplified_expr = ExpressionSimplifier(PRIME).visit(expr)
    assert isinstance(simplified_expr, ExprConst)
    assert simplified_expr.val == 5
Esempio n. 6
0
def test_simplifier(prime):
    assignments = {'x': 10, 'y': 3, 'z': -2, 'w': -60}
    simplifier = ExpressionSimplifier(prime)
    simplify = lambda expr: simplifier.visit(
        substitute_identifiers(expr, lambda var: assignments[var.name]))
    assert simplify(parse_expr('fp + x * (y + -1)')).format() == 'fp + 20'
    assert simplify(parse_expr('[fp + x] + [ap - (-z)]')).format() == \
        '[fp + 10] + [ap + (-2)]'
    assert simplify(parse_expr('fp + x - y')).format() == 'fp + 7'
    assert simplify(parse_expr('[1 + fp + 5]')).format() == '[fp + 6]'
    assert simplify(parse_expr('[fp] - 3')).format() == '[fp] + (-3)'
    if prime is not None:
        assert simplify(parse_expr('fp * (x - 1) / y')).format() == 'fp * 3'
        assert simplify(parse_expr('fp * w / x / y / z')).format() == 'fp'
    else:
        assert simplify(
            parse_expr('fp * (x - 1) / y')).format() == 'fp * 9 / 3'
        assert simplify(parse_expr('fp * w / x / y / z')).format() == \
            'fp * (-60) / 10 / 3 / (-2)'
    assert simplify(parse_expr('fp * 1')).format() == 'fp'
    assert simplify(parse_expr('1 * fp')).format() == 'fp'
Esempio n. 7
0
    def converge(
            self, reference_manager: ReferenceManager, other: 'FlowTrackingData',
            group_alloc: Callable):
        if not isinstance(other, FlowTrackingDataActual):
            return other.converge(reference_manager, self, group_alloc)

        new_ap_tracking = self.ap_tracking.converge(other.ap_tracking, group_alloc)
        simplifier = ExpressionSimplifier()

        # Allow different references from different branches to unite if they have the same name
        # and the same expression at the converged ap_tracking.
        reference_ids = {}
        for name, ref_id in self.reference_ids.items():
            other_ref_id = other.reference_ids.get(name)
            if other_ref_id is None:
                continue
            reference = reference_manager.get_ref(ref_id)
            other_ref = reference_manager.get_ref(other_ref_id)
            try:
                ref_expr = reference.eval(self.ap_tracking)
                if simplifier.visit(ref_expr) == \
                        simplifier.visit(other_ref.eval(other.ap_tracking)):
                    # Same expression.
                    if self.ap_tracking != new_ap_tracking:
                        # Create a new reference on the new ap tracking.
                        new_reference = Reference(
                            pc=reference.pc,
                            value=ref_expr,
                            ap_tracking_data=new_ap_tracking,
                        )
                        ref_id = reference_manager.get_id(new_reference)
                    reference_ids[name] = ref_id
            except FlowTrackingError:
                pass

        return FlowTrackingDataActual(
            ap_tracking=new_ap_tracking,
            reference_ids=reference_ids,
        )
Esempio n. 8
0
def test_pow(prime):
    simplifier = ExpressionSimplifier(prime)
    assert simplifier.visit(parse_expr('4 ** 3 ** 2')).format() == '262144'
    if prime is not None:
        # Make sure the exponent is not computed modulo prime (if it were,
        # the result would have been 1).
        assert simplifier.visit(
            parse_expr('(3 * 2**30 + 4) ** (3 * 2**30 + 1)')).format() == '3'

    with pytest.raises(
            SimplifierError,
            match='Power is not supported with a negative exponent'):
        simplifier.visit(parse_expr('2 ** (-1)'))
    def visit_ExprSubscript(
            self, expr: ExprSubscript) -> Tuple[Expression, CairoType]:
        inner_expr, inner_type = self.visit(expr.expr)
        offset_expr, offset_type = self.visit(expr.offset)

        if isinstance(inner_type, TypeTuple):
            self.verify_offset_is_felt(offset_type, offset_expr.location)
            offset_expr = ExpressionSimplifier().visit(offset_expr)
            if not isinstance(offset_expr, ExprConst):
                raise CairoTypeError(
                    'Subscript-operator for tuples supports only constant offsets, found '
                    f"'{type(offset_expr).__name__}'.",
                    location=offset_expr.location)
            offset_value = offset_expr.val

            tuple_len = len(inner_type.members)
            if not 0 <= offset_value < tuple_len:
                raise CairoTypeError(
                    f'Tuple index {offset_value} is out of range [0, {tuple_len}).',
                    location=expr.location)

            item_type = inner_type.members[offset_value]

            if isinstance(inner_expr, ExprTuple):
                assert len(inner_expr.members.args) == tuple_len
                return (
                    # Take the inner item, but keep the original expression's location.
                    dataclasses.replace(
                        inner_expr.members.args[offset_value].expr,
                        location=expr.location),
                    item_type)
            elif isinstance(inner_expr, ExprDeref):
                # Handles pointers cast as tuples*, e.g. `[cast(ap, (felt, felt)*][0]`.
                addr = inner_expr.addr
                offset_in_felts = ExprConst(val=sum(
                    map(self.get_size, inner_type.members[:offset_value])),
                                            location=offset_expr.location)
                addr_with_offset = ExprOperator(a=addr,
                                                op='+',
                                                b=offset_in_felts,
                                                location=expr.location)
                return ExprDeref(addr=addr_with_offset,
                                 location=expr.location), item_type
            else:
                raise CairoTypeError(
                    'Unexpected expression typed as TypeTuple. Expected ExprTuple or ExprDeref, '
                    f"found '{type(inner_expr).__name__}'.",
                    location=expr.location)
        elif isinstance(inner_type, TypePointer):
            self.verify_offset_is_felt(offset_type, offset_expr.location)
            try:
                # If pointed type is struct, get_size could throw IdentifierErrors. We catch and
                # convert them to CairoTypeErrors.
                element_size = self.get_size(inner_type.pointee)
            except Exception as exc:
                raise CairoTypeError(str(exc), location=expr.location)

            element_size_expr = ExprConst(val=element_size,
                                          location=expr.location)
            modified_offset_expr = ExprOperator(a=offset_expr,
                                                op='*',
                                                b=element_size_expr,
                                                location=expr.location)
            simplified_expr = ExprDeref(addr=ExprOperator(
                a=inner_expr,
                op='+',
                b=modified_offset_expr,
                location=expr.location),
                                        location=expr.location)

            return simplified_expr, inner_type.pointee
        else:
            raise CairoTypeError(
                'Cannot apply subscript-operator to non-pointer, non-tuple type '
                f"'{inner_type.format()}'.",
                location=expr.location)