def test_import(): # Test module names without periods. res = parse_code_element('from a import b') assert res == CodeElementImport(path=ExprIdentifier(name='a'), orig_identifier=ExprIdentifier(name='b'), local_name=None) assert res.format(allowed_line_length=100) == 'from a import b' # Test module names without periods, with aliasing. res = parse_code_element('from a import b as c') assert res == CodeElementImport(path=ExprIdentifier(name='a'), orig_identifier=ExprIdentifier(name='b'), local_name=ExprIdentifier(name='c')) assert res.format(allowed_line_length=100) == 'from a import b as c' # Test module names with periods. res = parse_code_element('from a.b12.c4 import lib345') assert res == CodeElementImport( path=ExprIdentifier(name='a.b12.c4'), orig_identifier=ExprIdentifier(name='lib345')) assert res.format(allowed_line_length=100) == 'from a.b12.c4 import lib345' # Test module with bad identifier (with periods). with pytest.raises(ParserError): parse_expr('from a.b import c.d') # Test module with bad local name (with periods). with pytest.raises(ParserError): parse_expr('from a.b import c as d.d')
def simplify_type_system_test(orig_expr: str, simplified_expr: str, simplified_type: CairoType, identifiers: Optional[IdentifierManager] = None): parsed_expr = mark_types_in_expr_resolved(parse_expr(orig_expr)) assert simplify_type_system( parsed_expr, identifiers=identifiers) == (parse_expr(simplified_expr), simplified_type)
def test_import(): # Test module names without periods. res = parse_code_element('from a import b') assert res == CodeElementImport( path=ExprIdentifier(name='a'), import_items=[ AliasedIdentifier(orig_identifier=ExprIdentifier(name='b'), local_name=None) ]) assert res.format(allowed_line_length=100) == 'from a import b' # Test module names without periods, with aliasing. res = parse_code_element('from a import b as c') assert res == CodeElementImport( path=ExprIdentifier(name='a'), import_items=[ AliasedIdentifier(orig_identifier=ExprIdentifier(name='b'), local_name=ExprIdentifier(name='c')) ]) assert res.format(allowed_line_length=100) == 'from a import b as c' # Test module names with periods. res = parse_code_element('from a.b12.c4 import lib345') assert res == CodeElementImport( path=ExprIdentifier(name='a.b12.c4'), import_items=[ AliasedIdentifier(orig_identifier=ExprIdentifier(name='lib345'), local_name=None) ]) assert res.format(allowed_line_length=100) == 'from a.b12.c4 import lib345' # Test multiple imports. res = parse_code_element('from lib import a,b as b2, c') assert res == CodeElementImport( path=ExprIdentifier(name='lib'), import_items=[ AliasedIdentifier(orig_identifier=ExprIdentifier(name='a'), local_name=None), AliasedIdentifier(orig_identifier=ExprIdentifier(name='b'), local_name=ExprIdentifier(name='b2')), AliasedIdentifier(orig_identifier=ExprIdentifier(name='c'), local_name=None), ]) assert res.format( allowed_line_length=100) == 'from lib import a, b as b2, c' assert res.format( allowed_line_length=20) == 'from lib import (\n a, b as b2, c)' assert res == parse_code_element('from lib import (\n a, b as b2, c)') # Test module with bad identifier (with periods). with pytest.raises(ParserError): parse_expr('from a.b import c.d') # Test module with bad local name (with periods). with pytest.raises(ParserError): parse_expr('from a.b import c as d.d')
def test_division_by_zero(prime): simplifier = ExpressionSimplifier(prime) with pytest.raises(SimplifierError, match='Division by zero'): simplifier.visit(parse_expr('fp / 0')) with pytest.raises(SimplifierError, match='Division by zero'): simplifier.visit(parse_expr('5 / 0')) if prime is not None: with pytest.raises(SimplifierError, match='Division by zero'): simplifier.visit(parse_expr(f'fp / {prime}'))
def test_eval_registers_and_memory(): ap = 5 fp = 10 prime = 13 memory = {(2 * ap + 3 * fp - 5) % prime: 7, 7: 5, 6: 0} evaluator = ExpressionEvaluator(prime=prime, ap=ap, fp=fp, memory=memory) assert evaluator.eval(parse_expr('[2 * ap + 3 * fp - 5]')) == 7 assert evaluator.eval(parse_expr('[[2 * ap + 3 * fp - 5]] + 3 * ap')) == \ (memory[7] + 3 * ap) % prime assert evaluator.eval(parse_expr('[[[2 * ap + 3 * fp - 5]]+1]')) == 0
def test_tuple_expr(): assert parse_expr('( )').format() == '()' assert parse_expr('( 2)').format() == '(2)' # Not a tuple. assert parse_expr('(a= 2)').format() == '(a=2)' # Tuple. assert parse_expr('( 2,)').format() == '(2,)' assert parse_expr('( ...,1)').format() == '(..., 1)' assert parse_expr('( ...,)').format() == '(...,)' assert parse_expr('( ... )').format() == '(...)' assert parse_expr('( 1 , ap)').format() == '(1, ap)' assert parse_expr('( 1 , ap, )').format() == '(1, ap,)' assert parse_expr('( 1 , a=2, b=(c=()))').format() == '(1, a=2, b=(c=()))'
def test_div_expr(): assert parse_expr( '[ap]/[fp]/3/[ap+1]').format() == '[ap] / [fp] / 3 / [ap + 1]' code = '120 / 2 / 3 / 4' expr = parse_expr(code) # Compute the value of expr from the tree and compare it with the correct value. PRIME = 3 * 2**30 + 1 simplified_expr = ExpressionSimplifier(PRIME).visit(expr) assert isinstance(simplified_expr, ExprConst) assert simplified_expr.val == 5
def test_hex_int(): expr = parse_expr(' 0x1234 ') assert expr == ExprConst(val=0x1234) assert expr.format_str == '0x1234' assert expr.format() == '0x1234' expr = parse_expr('-0x01234') assert expr == ExprNeg(val=ExprConst(val=0x1234)) assert expr.val.format_str == '0x01234' assert expr.format() == '-0x01234' assert parse_expr('-0x1234') == parse_expr('- 0x1234')
def test_pow(prime): simplifier = ExpressionSimplifier(prime) assert simplifier.visit(parse_expr('4 ** 3 ** 2')).format() == '262144' if prime is not None: # Make sure the exponent is not computed modulo prime (if it were, # the result would have been 1). assert simplifier.visit( parse_expr('(3 * 2**30 + 4) ** (3 * 2**30 + 1)')).format() == '3' with pytest.raises( SimplifierError, match='Power is not supported with a negative exponent'): simplifier.visit(parse_expr('2 ** (-1)'))
def test_tuple_expr_with_notes(): assert parse_expr("""\ ( 1 , # a. ( # c. ) #b. , (fp,[3]))""").format() == """\ (1, # a. ( # c. ), # b. (fp, [3]))""" assert parse_expr("""\ ( 1 # b. , # a. )""").format() == """\
def test_revoked_reference(): reference_manager = ReferenceManager() ref_id = reference_manager.alloc_id(reference=Reference( pc=0, value=parse_expr('[ap + 1]'), ap_tracking_data=RegTrackingData(group=0, offset=2), )) identifier_values = { scope('x'): ReferenceDefinition(full_name=scope('x'), cairo_type=TypeFelt(), references=[]), } prime = 2**64 + 13 ap = 100 fp = 200 memory = {} flow_tracking_data = FlowTrackingDataActual( ap_tracking=RegTrackingData(group=1, offset=4), reference_ids={scope('x'): ref_id}, ) context = VmConstsContext( identifiers=IdentifierManager.from_dict(identifier_values), evaluator=ExpressionEvaluator(prime, ap, fp, memory).eval, reference_manager=reference_manager, flow_tracking_data=flow_tracking_data, memory=memory, pc=0) consts = VmConsts(context=context, accessible_scopes=[ScopedName()]) with pytest.raises(FlowTrackingError, match='Failed to deduce ap.'): assert consts.x
def test_reference_rebinding(): identifier_values = { scope('ref'): ReferenceDefinition( full_name=scope('ref'), cairo_type=TypeFelt(), references=[], ) } reference_manager = ReferenceManager() flow_tracking_data = FlowTrackingDataActual(ap_tracking=RegTrackingData()) consts = get_vm_consts(identifier_values, reference_manager, flow_tracking_data) with pytest.raises(FlowTrackingError, match='Reference ref revoked'): consts.ref flow_tracking_data = flow_tracking_data.add_reference( reference_manager=reference_manager, name=scope('ref'), ref=Reference( pc=10, value=parse_expr('10'), ap_tracking_data=RegTrackingData(group=0, offset=2), ), ) consts = get_vm_consts(identifier_values, reference_manager, flow_tracking_data) assert consts.ref == 10
def test_reference_to_structs(): t = TypeStruct(scope=scope('T'), is_fully_resolved=True) t_star = TypePointer(pointee=t) identifier_values = { scope('ref'): ReferenceDefinition(full_name=scope('ref'), references=[]), scope('T.x'): MemberDefinition(offset=3, cairo_type=t_star), } reference_manager = ReferenceManager() flow_tracking_data = FlowTrackingDataActual(ap_tracking=RegTrackingData()) flow_tracking_data = flow_tracking_data.add_reference( reference_manager=reference_manager, name=scope('ref'), ref=Reference( pc=0, value=mark_types_in_expr_resolved(parse_expr('cast([100], T)')), ap_tracking_data=RegTrackingData(group=0, offset=2), )) consts = get_vm_consts(identifier_values, reference_manager, flow_tracking_data, memory={103: 200}) assert consts.ref.address_ == 100 assert consts.ref.x == 200
def test_offset_reference_definition_typed_members(): t = TypeStruct(scope=scope('T'), is_fully_resolved=True) t_star = TypePointer(pointee=t) reference_manager = ReferenceManager() main_reference = ReferenceDefinition(full_name=scope('a'), cairo_type=t_star, references=[]) references = { scope('a'): reference_manager.alloc_id( Reference( pc=0, value=mark_types_in_expr_resolved(parse_expr('cast(ap, T*)')), ap_tracking_data=RegTrackingData(group=0, offset=0), )), } flow_tracking_data = FlowTrackingDataActual( ap_tracking=RegTrackingData(group=0, offset=1), reference_ids=references, ) # Create OffsetReferenceDefinition instance for an expression of the form "a.<member_path>", # in this case a.x.y.z, and check the result of evaluation of this expression. definition = OffsetReferenceDefinition(parent=main_reference, member_path=scope('x.y.z')) assert definition.eval(reference_manager=reference_manager, flow_tracking_data=flow_tracking_data).format( ) == 'cast(ap - 1, T*).x.y.z'
def test_eval_with_types(): ap = 5 fp = 10 prime = 13 evaluator = ExpressionEvaluator(prime=prime, ap=ap, fp=fp, memory={}) assert evaluator.eval(parse_expr('cast(ap, T*)')) == ap
def test_pow_expr(): assert parse_expr('2 ** 3').format() == '2 ** 3' verify_exception( 'let x = 2 * * 3', """ file:?:?: Unexpected operator. Did you mean "**"? let x = 2 * * 3 ^*^ """)
def test_eval_registers(): ap = 5 fp = 10 prime = 13 evaluator = ExpressionEvaluator(prime=prime, ap=ap, fp=fp, memory={}) assert evaluator.eval( parse_expr('2 * ap + 3 * fp - 5')) == (2 * ap + 3 * fp - 5) % prime
def test_add_expr(): expr = parse_expr('[fp + 1] + [ap - x]') assert expr == \ ExprOperator( a=ExprDeref( addr=ExprOperator( a=ExprReg(reg=Register.FP), op='+', b=ExprConst(val=1))), op='+', b=ExprDeref( addr=ExprOperator( a=ExprReg(reg=Register.AP), op='-', b=ExprIdentifier(name='x')))) assert expr.format() == '[fp + 1] + [ap - x]' assert parse_expr('[ap-7]+37').format() == '[ap - 7] + 37'
def verify_exception(expr_str: str, error: str): """ Verifies that calling simplify_type_system() on the code results in the given error. """ with pytest.raises(CairoTypeError) as e: simplify_type_system(parse_expr(expr_str)) # Remove line and column information from the error using a regular expression. assert re.sub(':[0-9]+:[0-9]+: ', 'file:?:?: ', str(e.value)) == error.strip()
def test_eval_reference(): x = Reference(pc=0, value=parse_expr('2 * ap + 3 * fp - 5'), ap_tracking_data=RegTrackingData(group=1, offset=5)) with pytest.raises(FlowTrackingError): x.eval(RegTrackingData(group=2, offset=5)) assert x.eval(RegTrackingData( group=1, offset=8)).format() == '2 * (ap - 3) + 3 * fp - 5'
def test_type_tuples(): t = TypeStruct(scope=scope('T'), is_fully_resolved=False) t_star = TypePointer(pointee=t) # Simple tuple. assert simplify_type_system(parse_expr('(fp, [cast(fp, T*)], cast(fp,T*))')) == ( parse_expr('(fp, [fp], fp)'), TypeTuple(members=[TypeFelt(), t, t_star],) ) # Nested. assert simplify_type_system(parse_expr('(fp, (), ([cast(fp, T*)],))')) == ( parse_expr('(fp, (), ([fp],))'), TypeTuple( members=[ TypeFelt(), TypeTuple(members=[]), TypeTuple(members=[t])], ) )
def test_parent_location(): parent_location = (parse_expr('1 + 2').location, 'An error ocurred while processing:') location = parse_code_element( 'let x = 3 + 4', parser_context=ParserContext( parent_location=parent_location)).expr.location location_err = LocationError(message='Error', location=location) assert str(location_err) == """\
def test_rotation(prime): simplifier = ExpressionSimplifier(prime) assert simplifier.visit(parse_expr('(fp + 10) + 1')).format() == 'fp + 11' assert simplifier.visit(parse_expr('(fp + 10) - 1')).format() == 'fp + 9' assert simplifier.visit( parse_expr('(fp - 10) + 1')).format() == 'fp + (-9)' assert simplifier.visit( parse_expr('(fp - 10) - 1')).format() == 'fp + (-11)' assert simplifier.visit(parse_expr('(10 + fp) - 1')).format() == 'fp + 9' assert simplifier.visit(parse_expr('10 + (fp - 1)')).format() == 'fp + 9' assert simplifier.visit(parse_expr('10 + (1 + fp)')).format() == 'fp + 11' assert simplifier.visit( parse_expr('10 + (1 + fp) + 100')).format() == 'fp + 111' assert simplifier.visit( parse_expr('10 + (1 + (fp + 100))')).format() == 'fp + 111'
def eval(self, expr): if expr == 'null': return '' expr, expr_type = simplify_type_system( substitute_identifiers(parse_expr(expr), self.get_variable)) if isinstance(expr_type, TypeStruct): raise NotImplementedError('Structs are not supported.') res = self.visit(expr) if isinstance(res, ExprConst): return field_element_repr(res.val, self.tracer_data.program.prime) return res.format()
def test_operator_precedence(): code = '(5 + 2) - (3 - 9) * (7 + (-8)) - 10 * (-2) * 5 + (((7)))' expr = parse_expr(code) # Test formatting. assert expr.format() == code # Compute the value of expr from the tree and compare it with the correct value. PRIME = 3 * 2**30 + 1 simplified_expr = ExpressionSimplifier(PRIME).visit(expr) assert isinstance(simplified_expr, ExprConst) assert simplified_expr.val == eval(code)
def test_modulo(): PRIME = 19 simplifier = ExpressionSimplifier(PRIME) # Check that the range is (-PRIME/2, PRIME/2). assert simplifier.visit(parse_expr('-9')).format() == '-9' assert simplifier.visit(parse_expr('-10')).format() == '9' assert simplifier.visit(parse_expr('9')).format() == '9' assert simplifier.visit(parse_expr('10')).format() == '-9' # Check value which is bigger than PRIME. assert simplifier.visit(parse_expr('20')).format() == '1' # Check operators. assert simplifier.visit(parse_expr('10 + 10')).format() == '1' assert simplifier.visit(parse_expr('10 - 30')).format() == '-1' assert simplifier.visit(parse_expr('10 * 10')).format() == '5' assert simplifier.visit(parse_expr('2 / 3')).format() == '7'
def test_format_parentheses_notes(): before = """\ ( # Comment. a + b)""" after = """\ ( # Comment. a + b)""" assert parse_expr(before).format() == after before = """\ ( a + b)""" after = """\ ( a + b)""" assert parse_expr(before).format() == after before = """\ ( # Comment. a + b)""" after = """\ ( # Comment. a + b)""" assert parse_expr(before).format() == after before = """\ (# Comment. # # x. # y. a + b)""" after = """\ ( # Comment. # # x. # y. a + b)""" assert parse_expr(before).format() == after
def test_offset_reference_definition_typed_members(): t = TypeStruct(scope=scope('T'), is_fully_resolved=True) s_star = TypePointer( pointee=TypeStruct(scope=scope('S'), is_fully_resolved=True)) reference_manager = ReferenceManager() identifiers = { scope('T'): ScopeDefinition(), scope('T.x'): MemberDefinition(offset=3, cairo_type=s_star), scope('T.flt'): MemberDefinition(offset=4, cairo_type=TypeFelt()), scope('S'): ScopeDefinition(), scope('S.x'): MemberDefinition(offset=10, cairo_type=t), } main_reference = ReferenceDefinition(full_name=scope('a'), references=[]) references = { scope('a'): reference_manager.get_id( Reference( pc=0, value=mark_types_in_expr_resolved(parse_expr('cast(ap, T*)')), ap_tracking_data=RegTrackingData(group=0, offset=0), )), } flow_tracking_data = FlowTrackingDataActual( ap_tracking=RegTrackingData(group=0, offset=1), reference_ids=references, ) # Create OffsetReferenceDefinition instances for expressions of the form "a.<member_path>", # such as a.x and a.x.x, and check the result of evaluation those expressions. for member_path, expected_result in [ ('x', 'cast([ap - 1 + 3], S*)'), ('x.x', 'cast([[ap - 1 + 3] + 10], T)'), ('x.x.x', 'cast([&[[ap - 1 + 3] + 10] + 3], S*)'), ('x.x.flt', 'cast([&[[ap - 1 + 3] + 10] + 4], felt)') ]: definition = OffsetReferenceDefinition(parent=main_reference, identifier_values=identifiers, member_path=scope(member_path)) assert definition.eval( reference_manager=reference_manager, flow_tracking_data=flow_tracking_data).format() == expected_result definition = OffsetReferenceDefinition(parent=main_reference, identifier_values=identifiers, member_path=scope('x.x.flt.x')) with pytest.raises( DefinitionError, match='Member access requires a type of the form Struct*.'): assert definition.eval( reference_manager=reference_manager, flow_tracking_data=flow_tracking_data).format() == expected_result
def test_deref_expr(): expr = parse_expr('[[fp - 7] + 3]') assert expr == \ ExprDeref( addr=ExprOperator( a=ExprDeref( addr=ExprOperator( a=ExprReg(reg=Register.FP), op='-', b=ExprConst(val=7))), op='+', b=ExprConst(val=3))) assert expr.format() == '[[fp - 7] + 3]'
def test_type_visitor_pointer_arithmetic(): t = TypeStruct(scope=scope('T'), is_fully_resolved=False) t_star = TypePointer(pointee=t) assert simplify_type_system(parse_expr('cast(fp, T*) + 3')) == ( parse_expr('fp + 3'), t_star) assert simplify_type_system(parse_expr('cast(fp, T*) - 3')) == ( parse_expr('fp - 3'), t_star) assert simplify_type_system(parse_expr('cast(fp, T*) - cast(3, T*)')) == ( parse_expr('fp - 3'), TypeFelt())