def test_division_by_zero(prime): simplifier = ExpressionSimplifier(prime) with pytest.raises(SimplifierError, match='Division by zero'): simplifier.visit(parse_expr('fp / 0')) with pytest.raises(SimplifierError, match='Division by zero'): simplifier.visit(parse_expr('5 / 0')) if prime is not None: with pytest.raises(SimplifierError, match='Division by zero'): simplifier.visit(parse_expr(f'fp / {prime}'))
def test_modulo(): PRIME = 19 simplifier = ExpressionSimplifier(PRIME) # Check that the range is (-PRIME/2, PRIME/2). assert simplifier.visit(parse_expr('-9')).format() == '-9' assert simplifier.visit(parse_expr('-10')).format() == '9' assert simplifier.visit(parse_expr('9')).format() == '9' assert simplifier.visit(parse_expr('10')).format() == '-9' # Check value which is bigger than PRIME. assert simplifier.visit(parse_expr('20')).format() == '1' # Check operators. assert simplifier.visit(parse_expr('10 + 10')).format() == '1' assert simplifier.visit(parse_expr('10 - 30')).format() == '-1' assert simplifier.visit(parse_expr('10 * 10')).format() == '5' assert simplifier.visit(parse_expr('2 / 3')).format() == '7'
def test_rotation(prime): simplifier = ExpressionSimplifier(prime) assert simplifier.visit(parse_expr('(fp + 10) + 1')).format() == 'fp + 11' assert simplifier.visit(parse_expr('(fp + 10) - 1')).format() == 'fp + 9' assert simplifier.visit( parse_expr('(fp - 10) + 1')).format() == 'fp + (-9)' assert simplifier.visit( parse_expr('(fp - 10) - 1')).format() == 'fp + (-11)' assert simplifier.visit(parse_expr('(10 + fp) - 1')).format() == 'fp + 9' assert simplifier.visit(parse_expr('10 + (fp - 1)')).format() == 'fp + 9' assert simplifier.visit(parse_expr('10 + (1 + fp)')).format() == 'fp + 11' assert simplifier.visit( parse_expr('10 + (1 + fp) + 100')).format() == 'fp + 111' assert simplifier.visit( parse_expr('10 + (1 + (fp + 100))')).format() == 'fp + 111'
def test_operator_precedence(): code = '(5 + 2) - (3 - 9) * (7 + (-8)) - 10 * (-2) * 5 + (((7)))' expr = parse_expr(code) # Test formatting. assert expr.format() == code # Compute the value of expr from the tree and compare it with the correct value. PRIME = 3 * 2**30 + 1 simplified_expr = ExpressionSimplifier(PRIME).visit(expr) assert isinstance(simplified_expr, ExprConst) assert simplified_expr.val == eval(code)
def test_div_expr(): assert parse_expr( '[ap]/[fp]/3/[ap+1]').format() == '[ap] / [fp] / 3 / [ap + 1]' code = '120 / 2 / 3 / 4' expr = parse_expr(code) # Compute the value of expr from the tree and compare it with the correct value. PRIME = 3 * 2**30 + 1 simplified_expr = ExpressionSimplifier(PRIME).visit(expr) assert isinstance(simplified_expr, ExprConst) assert simplified_expr.val == 5
def test_simplifier(prime): assignments = {'x': 10, 'y': 3, 'z': -2, 'w': -60} simplifier = ExpressionSimplifier(prime) simplify = lambda expr: simplifier.visit( substitute_identifiers(expr, lambda var: assignments[var.name])) assert simplify(parse_expr('fp + x * (y + -1)')).format() == 'fp + 20' assert simplify(parse_expr('[fp + x] + [ap - (-z)]')).format() == \ '[fp + 10] + [ap + (-2)]' assert simplify(parse_expr('fp + x - y')).format() == 'fp + 7' assert simplify(parse_expr('[1 + fp + 5]')).format() == '[fp + 6]' assert simplify(parse_expr('[fp] - 3')).format() == '[fp] + (-3)' if prime is not None: assert simplify(parse_expr('fp * (x - 1) / y')).format() == 'fp * 3' assert simplify(parse_expr('fp * w / x / y / z')).format() == 'fp' else: assert simplify( parse_expr('fp * (x - 1) / y')).format() == 'fp * 9 / 3' assert simplify(parse_expr('fp * w / x / y / z')).format() == \ 'fp * (-60) / 10 / 3 / (-2)' assert simplify(parse_expr('fp * 1')).format() == 'fp' assert simplify(parse_expr('1 * fp')).format() == 'fp'
def converge( self, reference_manager: ReferenceManager, other: 'FlowTrackingData', group_alloc: Callable): if not isinstance(other, FlowTrackingDataActual): return other.converge(reference_manager, self, group_alloc) new_ap_tracking = self.ap_tracking.converge(other.ap_tracking, group_alloc) simplifier = ExpressionSimplifier() # Allow different references from different branches to unite if they have the same name # and the same expression at the converged ap_tracking. reference_ids = {} for name, ref_id in self.reference_ids.items(): other_ref_id = other.reference_ids.get(name) if other_ref_id is None: continue reference = reference_manager.get_ref(ref_id) other_ref = reference_manager.get_ref(other_ref_id) try: ref_expr = reference.eval(self.ap_tracking) if simplifier.visit(ref_expr) == \ simplifier.visit(other_ref.eval(other.ap_tracking)): # Same expression. if self.ap_tracking != new_ap_tracking: # Create a new reference on the new ap tracking. new_reference = Reference( pc=reference.pc, value=ref_expr, ap_tracking_data=new_ap_tracking, ) ref_id = reference_manager.get_id(new_reference) reference_ids[name] = ref_id except FlowTrackingError: pass return FlowTrackingDataActual( ap_tracking=new_ap_tracking, reference_ids=reference_ids, )
def test_pow(prime): simplifier = ExpressionSimplifier(prime) assert simplifier.visit(parse_expr('4 ** 3 ** 2')).format() == '262144' if prime is not None: # Make sure the exponent is not computed modulo prime (if it were, # the result would have been 1). assert simplifier.visit( parse_expr('(3 * 2**30 + 4) ** (3 * 2**30 + 1)')).format() == '3' with pytest.raises( SimplifierError, match='Power is not supported with a negative exponent'): simplifier.visit(parse_expr('2 ** (-1)'))
def visit_ExprSubscript( self, expr: ExprSubscript) -> Tuple[Expression, CairoType]: inner_expr, inner_type = self.visit(expr.expr) offset_expr, offset_type = self.visit(expr.offset) if isinstance(inner_type, TypeTuple): self.verify_offset_is_felt(offset_type, offset_expr.location) offset_expr = ExpressionSimplifier().visit(offset_expr) if not isinstance(offset_expr, ExprConst): raise CairoTypeError( 'Subscript-operator for tuples supports only constant offsets, found ' f"'{type(offset_expr).__name__}'.", location=offset_expr.location) offset_value = offset_expr.val tuple_len = len(inner_type.members) if not 0 <= offset_value < tuple_len: raise CairoTypeError( f'Tuple index {offset_value} is out of range [0, {tuple_len}).', location=expr.location) item_type = inner_type.members[offset_value] if isinstance(inner_expr, ExprTuple): assert len(inner_expr.members.args) == tuple_len return ( # Take the inner item, but keep the original expression's location. dataclasses.replace( inner_expr.members.args[offset_value].expr, location=expr.location), item_type) elif isinstance(inner_expr, ExprDeref): # Handles pointers cast as tuples*, e.g. `[cast(ap, (felt, felt)*][0]`. addr = inner_expr.addr offset_in_felts = ExprConst(val=sum( map(self.get_size, inner_type.members[:offset_value])), location=offset_expr.location) addr_with_offset = ExprOperator(a=addr, op='+', b=offset_in_felts, location=expr.location) return ExprDeref(addr=addr_with_offset, location=expr.location), item_type else: raise CairoTypeError( 'Unexpected expression typed as TypeTuple. Expected ExprTuple or ExprDeref, ' f"found '{type(inner_expr).__name__}'.", location=expr.location) elif isinstance(inner_type, TypePointer): self.verify_offset_is_felt(offset_type, offset_expr.location) try: # If pointed type is struct, get_size could throw IdentifierErrors. We catch and # convert them to CairoTypeErrors. element_size = self.get_size(inner_type.pointee) except Exception as exc: raise CairoTypeError(str(exc), location=expr.location) element_size_expr = ExprConst(val=element_size, location=expr.location) modified_offset_expr = ExprOperator(a=offset_expr, op='*', b=element_size_expr, location=expr.location) simplified_expr = ExprDeref(addr=ExprOperator( a=inner_expr, op='+', b=modified_offset_expr, location=expr.location), location=expr.location) return simplified_expr, inner_type.pointee else: raise CairoTypeError( 'Cannot apply subscript-operator to non-pointer, non-tuple type ' f"'{inner_type.format()}'.", location=expr.location)