def test_func_args(): code = """\ struct T: member s = 0 member t = 1 const SIZE = 2 end func f(x, y : T, z : T*): x = 1; ap++ y.s = 2; ap++ z.t = y.t; ap++ ret end """ program = preprocess_str(code=code, prime=PRIME) reference_x = program.instructions[ -1].flow_tracking_data.resolve_reference( reference_manager=program.reference_manager, name=ScopedName.from_string('f.x')) assert reference_x.value.format() == 'cast([fp + (-6)], felt)' reference_y = program.instructions[ -1].flow_tracking_data.resolve_reference( reference_manager=program.reference_manager, name=ScopedName.from_string('f.y')) assert reference_y.value.format() == 'cast([fp + (-5)], T)' reference_z = program.instructions[ -1].flow_tracking_data.resolve_reference( reference_manager=program.reference_manager, name=ScopedName.from_string('f.z')) assert reference_z.value.format() == 'cast([fp + (-3)], T*)' assert program.format() == """\
def visit_CodeElementImport(self, code_elm: CodeElementImport): for import_item in code_elm.import_items: self.add_identifier( ScopedName.from_string(code_elm.path.name) + ScopedName.from_string(import_item.orig_identifier.name), is_resolved=True, location=code_elm.location)
def run(self, context: PassManagerContext): visited_modules = set() for additional_module in self.additional_modules: files = collect_imports(additional_module, read_file=self.read_module) for module_name, ast in files.items(): if module_name in visited_modules: continue visited_modules.add(module_name) scope = ScopedName.from_string(module_name) context.modules.append( CairoModule(cairo_file=ast, module_name=scope)) for code, filename in context.codes: # Function used to read files given module names. # The root module (filename) is handled separately, for this module code is returned. def read_file_fixed(name): return ( code, filename) if name == filename else self.read_module(name) files = collect_imports(filename, read_file=read_file_fixed) for module_name, ast in files.items(): # Check if the module is one of the files given in 'context.codes'. is_main_scope = module_name == filename if is_main_scope: scope = context.main_scope else: scope = ScopedName.from_string(module_name) if module_name in visited_modules: continue visited_modules.add(module_name) context.modules.append( CairoModule(cairo_file=ast, module_name=scope))
def test_program_start_property(): identifiers = IdentifierManager.from_dict({ ScopedName.from_string('some.main.__start__'): LabelDefinition(3), }) reference_manager = ReferenceManager() main_scope = ScopedName.from_string('some.main') # The label __start__ is in identifiers. program = Program(prime=0, data=[], hints={}, builtins=[], main_scope=main_scope, identifiers=identifiers, reference_manager=reference_manager) assert program.start == 3 # The label __start__ is not in identifiers. program = Program(prime=0, data=[], hints={}, builtins=[], main_scope=main_scope, identifiers=IdentifierManager(), reference_manager=reference_manager) assert program.start == 0
def check_main_args(program: Program): """ Makes sure that for every builtin included in the program an appropriate ptr was passed as an argument to main() and is subsequently returned. """ expected_builtin_ptrs = [ f'{builtin_name}_ptr' for builtin_name in program.builtins ] try: main_args = list( get_struct_members( struct_name=ScopedName.from_string('__main__.main.Args'), identifier_manager=program.identifiers)) except IdentifierError: pass else: assert main_args == expected_builtin_ptrs, \ 'Expected main to contain the following arguments (in this order): ' \ f'{expected_builtin_ptrs}. Found: {main_args}.' try: main_returns = list( get_struct_members( struct_name=ScopedName.from_string('__main__.main.Return'), identifier_manager=program.identifiers)) except IdentifierError: pass else: assert main_returns == expected_builtin_ptrs, \ 'Expected main to return the following values (in this order): ' \ f'{expected_builtin_ptrs}. Found: {main_returns}.'
def _get_full_name(self, name: ScopedName): full_name = self.resolved_identifiers.get(name) if full_name is not None: return full_name return self.identifiers.search(accessible_scopes=[ ScopedName.from_string('__main__'), ScopedName() ], name=name).get_canonical_name()
def test_startswith(): assert ScopedName.from_string('a.b').startswith( ScopedName.from_string('a')) assert not ScopedName.from_string('x.a.b').startswith( ScopedName.from_string('a')) assert not ScopedName.from_string('a.b').startswith( ScopedName.from_string('x')) assert not ScopedName.from_string('a.b').startswith('b') assert ScopedName.from_string('x.a.b').startswith('') assert ScopedName.from_string('x.a.b').startswith('x.a') assert not ScopedName.from_string('abc').startswith('a')
def test_flow_tracking_labels(changes): # Good case. flow_tracking = FlowTracking() flow_tracking.add_flow_to_label(ScopedName.from_string('a'), changes.label0) flow_tracking.add_ap(changes.body0) flow_tracking.add_flow_to_label(ScopedName.from_string('a'), changes.label1) flow_tracking.add_ap(changes.body1) current_data = flow_tracking.get() flow_tracking.converge_with_label(ScopedName.from_string('a')) assert (flow_tracking.get() == current_data) is changes.valid
def test_imports(): collector = IdentifierCollector() collector.identifiers.add_identifier(ScopedName.from_string('foo.bar'), ConstDefinition(value=0)) ast = parse_file(""" from foo import bar as bar0 """) with collector.scoped(ScopedName(), parent=ast): collector.visit(ast.code_block) assert collector.identifiers.get_scope(ScopedName()).identifiers == { 'bar0': AliasDefinition(destination=ScopedName.from_string('foo.bar')), }
def run(self, func_name: str, *args, hint_locals: Optional[Dict[str, Any]] = None, static_locals: Optional[Dict[str, Any]] = None, verify_secure: Optional[bool] = None, trace_on_failure: bool = False, apply_modulo_to_args: Optional[bool] = None, use_full_name: bool = False, **kwargs): """ Runs func_name(*args). args are converted to Cairo-friendly ones using gen_arg. Additional params: verify_secure - Run verify_secure_runner to do extra verifications. trace_on_failure - Run the tracer in case of failure to help debugging. apply_modulo_to_args - Apply modulo operation on integer arguments. use_full_name - Treat func_name as a fully qualified identifer name, instance of a relative one. """ assert isinstance(self.program, Program) structs_factory = CairoStructFactory.from_program(program=self.program) full_args_struct = structs_factory.build_func_args( func=ScopedName.from_string(scope=func_name)) all_args = full_args_struct(*args, **kwargs) entrypoint: Union[str, int] if use_full_name: identifier = self.program.identifiers.get_by_full_name( name=ScopedName.from_string(scope=func_name)) assert isinstance(identifier, LabelDefinition) entrypoint = identifier.pc else: entrypoint = func_name try: self.run_from_entrypoint(entrypoint, *all_args, hint_locals=hint_locals, static_locals=static_locals, verify_secure=verify_secure, apply_modulo_to_args=apply_modulo_to_args) except (VmException, SecurityError, AssertionError) as ex: if trace_on_failure: print(f"""\ Got {type(ex).__name__} exception during the execution of {func_name}: {str(ex)} """) trace_runner(runner=self) raise
def test_scoped_name(): assert ScopedName(('some', 'thing')).path == ('some', 'thing') assert str(ScopedName(('some', 'thing'))) == 'some.thing' assert ScopedName.from_string('some.thing').path == ('some', 'thing') assert ScopedName(('some', 'thing')) + 'el.se' == ScopedName( ('some', 'thing', 'el', 'se')) assert ScopedName(('some', 'thing')) + 'el.se' != ScopedName( ('some', 'thing', 'else')) assert ScopedName.from_string( 'aa.bb.cc.dd')[1:3] == ScopedName.from_string('bb.cc') with pytest.raises(AssertionError): ScopedName(('some', 'thing.else'))
def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): """ If set_value is None, returns the value of the given attribute. Otherwise, sets it to set_value (setting to None will not work). """ try: # Handle attributes representing program scopes and constants. result = search_identifier_or_scope( identifiers=self._context.identifiers, accessible_scopes=self._accessible_scopes, name=ScopedName.from_string(name)) except MissingIdentifierError as exc: raise MissingIdentifierError(self._path + exc.fullname) from None value: Optional[IdentifierDefinition] if isinstance(result, IdentifierSearchResult): value = result.identifier_definition handler_name = f'handle_{type(value).__name__}' scope = result.get_canonical_name() identifier_type = value.TYPE elif isinstance(result, IdentifierScope): value = None handler_name = 'handle_scope' scope = result.fullname identifier_type = 'scope' else: raise NotImplementedError( f'Unexpected type {type(result).__name__}.') if handler_name not in dir(self): self.raise_unsupported_error(name=self._path + name, identifier_type=identifier_type) return getattr(self, handler_name)(name, value, scope, set_value)
def test_revoked_reference(): reference_manager = ReferenceManager() ref_id = reference_manager.alloc_id(reference=Reference( pc=0, value=parse_expr('[ap + 1]'), ap_tracking_data=RegTrackingData(group=0, offset=2), )) identifier_values = { scope('x'): ReferenceDefinition(full_name=scope('x'), cairo_type=TypeFelt(), references=[]), } prime = 2**64 + 13 ap = 100 fp = 200 memory = {} flow_tracking_data = FlowTrackingDataActual( ap_tracking=RegTrackingData(group=1, offset=4), reference_ids={scope('x'): ref_id}, ) context = VmConstsContext( identifiers=IdentifierManager.from_dict(identifier_values), evaluator=ExpressionEvaluator(prime, ap, fp, memory).eval, reference_manager=reference_manager, flow_tracking_data=flow_tracking_data, memory=memory, pc=0) consts = VmConsts(context=context, accessible_scopes=[ScopedName()]) with pytest.raises(FlowTrackingError, match='Failed to deduce ap.'): assert consts.x
def test_missing_attributes(): identifier_values = { scope('x.y'): ConstDefinition(1), scope('z'): AliasDefinition(scope('x')), scope('x.missing'): AliasDefinition(scope('nothing')), } context = VmConstsContext( identifiers=IdentifierManager.from_dict(identifier_values), evaluator=dummy_evaluator, reference_manager=ReferenceManager(), flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), memory={}, pc=0) consts = VmConsts(context=context, accessible_scopes=[ScopedName()]) # Identifier not exists anywhere. with pytest.raises(MissingIdentifierError, match="Unknown identifier 'xx'."): consts.xx # Identifier not exists in accessible scopes. with pytest.raises(MissingIdentifierError, match="Unknown identifier 'y'."): consts.y # Recursive search. with pytest.raises(MissingIdentifierError, match="Unknown identifier 'x.z'."): consts.x.z # Pass through alias. with pytest.raises(MissingIdentifierError, match="Unknown identifier 'z.x'."): consts.z.x # Pass through bad alias. with pytest.raises( IdentifierError, match="Alias resolution failed: x.missing -> nothing. Unknown identifier 'nothing'."): consts.x.missing.y
def test_get_struct_definition(): identifier_dict = { scope('T'): StructDefinition( full_name=scope('T'), members={ 'a': MemberDefinition(offset=0, cairo_type=TypeFelt()), 'b': MemberDefinition(offset=1, cairo_type=TypeFelt()), }, size=2, ), scope('MyConst'): ConstDefinition(value=5), } manager = IdentifierManager.from_dict(identifier_dict) struct_def = get_struct_definition(ScopedName.from_string('T'), manager) # Convert to a list, to check the order of the elements in the dict. assert list(struct_def.members.items()) == [ ('a', MemberDefinition(offset=0, cairo_type=TypeFelt())), ('b', MemberDefinition(offset=1, cairo_type=TypeFelt())), ] assert struct_def.size == 2 with pytest.raises( DefinitionError, match="Expected 'MyConst' to be a struct. Found: 'const'."): get_struct_definition(scope('MyConst'), manager) with pytest.raises(MissingIdentifierError, match=re.escape("Unknown identifier 'abc'.")): get_struct_definition(scope('abc'), manager)
def test_labels(): scope = ScopedName.from_string('my.cool.scope') program = preprocess_str(""" const x = 7 a0: [ap] = x; ap++ # Size: 2. [ap] = [fp] + 123 # Size: 2. a1: [ap] = [fp] # Size: 1. jmp rel [fp] # Size: 1. a2: jmp rel x # Size: 2. jmp a3 # Size: 2. jmp a3 if [ap] != 0 # Size: 2. call a3 # Size: 2. a3: """, prime=PRIME, main_scope=scope) program_labels = { name: identifier_definition.pc for name, identifier_definition in program.identifiers.get_scope( scope).identifiers.items() if isinstance(identifier_definition, LabelDefinition) } assert program_labels == {'a0': 0, 'a1': 4, 'a2': 6, 'a3': 14}
def _deserialize(self, value, attr, data, **kwargs) -> IdentifierManager: identifier_definition_schema = IdentifierDefinitionSchema() return IdentifierManager.from_dict({ ScopedName.from_string(name): identifier_definition_schema.load(serialized_identifier_definition) for name, serialized_identifier_definition in value.items() })
class CodeElementFunction(CodeElement): """ Represents either a 'func', 'namespace' or 'struct' statement. For example: func foo(x, y) -> (z, w): return (z=x, w=y) end """ # The type of the code element. Either 'func', 'namespace' or 'struct'. element_type: str identifier: ExprIdentifier arguments: IdentifierList returns: Optional[IdentifierList] code_block: CodeBlock ARGUMENT_SCOPE = ScopedName.from_string('Args') RETURN_SCOPE = ScopedName.from_string('Return') @property def name(self): return self.identifier.name def format(self, allowed_line_length): code = self.code_block.format(allowed_line_length=allowed_line_length - INDENTATION) code = indent(code, INDENTATION) if self.element_type in ['struct', 'namespace']: particles = [f'{self.element_type} {self.name}:'] elif self.returns is not None: particles = [ f'{self.element_type} {self.name}(', create_particle_sublist(self.arguments.get_particles(), ') -> ('), create_particle_sublist(self.returns.get_particles(), '):')] else: particles = [ f'{self.element_type} {self.name}(', create_particle_sublist(self.arguments.get_particles(), '):')] header = particles_in_lines( particles=particles, config=ParticleFormattingConfig( allowed_line_length=allowed_line_length, line_indent=INDENTATION * 2)) return f'{header}\n{code}end' def get_children(self) -> Sequence[Optional[AstNode]]: return [self.identifier, self.arguments, self.returns, self.code_block]
def get_reference_type(name): identifier_definition = program.identifiers.get_by_full_name( ScopedName.from_string(name)) assert isinstance(identifier_definition, ReferenceDefinition) assert len(identifier_definition.references) == 1 _, expr_type = simplify_type_system( identifier_definition.references[0].value) return expr_type
def test_flow_tracking_labels_diverge(changes): """ Tests a case of divergence. Diverge to a, b with different ap diffs, then converge at c. """ flow_tracking = FlowTracking() flow_tracking.add_flow_to_label(ScopedName.from_string('a'), changes.to_a) flow_tracking.add_flow_to_label(ScopedName.from_string('b'), changes.to_b) # Label a. flow_tracking.revoke() flow_tracking.converge_with_label(ScopedName.from_string('a')) flow_tracking.add_ap(changes.at_a) data_after_a = flow_tracking.get() flow_tracking.add_flow_to_label(ScopedName.from_string('c'), 0) # Label b. flow_tracking.revoke() flow_tracking.converge_with_label(ScopedName.from_string('b')) flow_tracking.add_ap(changes.at_b) data_after_b = flow_tracking.get() flow_tracking.add_flow_to_label(ScopedName.from_string('c'), 0) # Label c. flow_tracking.revoke() flow_tracking.converge_with_label(ScopedName.from_string('c')) data_at_c = flow_tracking.get() if changes.valid: assert data_after_a == data_after_b == data_at_c else: assert data_after_a != data_at_c and data_after_b != data_at_c
def test_process_file_scope(): # Verify the good scenario. valid_scope = ScopedName.from_string('some.valid.scope') program = preprocess_str('const x = 4', prime=PRIME, main_scope=valid_scope) module = CairoModule(cairo_file=program, module_name=valid_scope) assert program.identifiers.as_dict() == { valid_scope + 'x': ConstDefinition(4) }
def _extract_identifiers(code): """ Extracts the identifiers defined in the code block and returns them as strings. """ ast = parse_file(code) collector = IdentifierCollector() with collector.scoped(ScopedName(), parent=ast): collector.visit(ast.code_block) return [(str(name), identifier_definition.identifier_type) for name, identifier_definition in collector.identifiers.as_dict().items()]
def read(self, module_name: str) -> Tuple[str, str]: """ Given a module name, translates it to a file path to read the module from, and returns the module code and filename. """ filename = self.module_to_file_path(module_name) self.source_files.add(filename) self.source_files_with_scopes.add( (filename, ScopedName.from_string(module_name))) with open(filename, 'r') as f: return f.read(), filename
def get_vm_consts(identifier_values, reference_manager, flow_tracking_data, memory={}): """ Creates a simple VmConsts object. """ identifiers = IdentifierManager.from_dict(identifier_values) context = VmConstsContext( identifiers=identifiers, evaluator=ExpressionEvaluator(2**64 + 13, 0, 0, memory, identifiers).eval, reference_manager=reference_manager, flow_tracking_data=flow_tracking_data, memory=memory, pc=9) return VmConsts(context=context, accessible_scopes=[ScopedName()])
def search_identifier( self, name: str, location: Optional[Location]) -> Optional[IdentifierDefinition]: """ Searches for the given identifier in self.identifiers and returns the corresponding IdentifierDefinition. """ try: result = self.identifiers.search(self.accessible_scopes, ScopedName.from_string(name)) return resolve_search_result(result, identifiers=self.identifiers) except IdentifierError as exc: raise PreprocessorError(str(exc), location=location)
def test_get_struct_size(): identifiers = { scope('T'): ScopeDefinition(), scope('T.SIZE'): ConstDefinition(value=2), scope('S'): ScopeDefinition(), scope('S.SIZE'): ScopeDefinition(), } assert get_struct_size(ScopedName.from_string('T'), identifiers) == 2 with pytest.raises( DefinitionError, match=re.escape("The identifier 'abc.SIZE' was not found.")): get_struct_size(ScopedName.from_string('abc'), identifiers) with pytest.raises( DefinitionError, match=re.escape( f"Expected 'S.SIZE' to be a const, but it is a scope.")): get_struct_size(ScopedName.from_string('S'), identifiers)
def test_scope_order(): identifier_values = { scope('x.y'): ConstDefinition(1), scope('y'): ConstDefinition(2), } context = VmConstsContext( identifiers=IdentifierManager.from_dict(identifier_values), evaluator=dummy_evaluator, reference_manager=ReferenceManager(), flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), memory={}, pc=0) consts = VmConsts(context=context, accessible_scopes=[ScopedName(), scope('x')]) assert consts.y == 1 assert consts.x.y == 1
def get_identifier( self, name: Union[str, ScopedName], expected_type: Type[IdentifierDefinition]): scoped_name = name if isinstance(name, ScopedName) else ScopedName.from_string(name) result = self.identifiers.search( accessible_scopes=[self.main_scope], name=scoped_name) result.assert_fully_parsed() identifier_definition = result.identifier_definition assert isinstance(identifier_definition, expected_type), ( f"'{scoped_name}' is expected to be {expected_type.TYPE}, " + # type: ignore f'found {identifier_definition.TYPE}.') # type: ignore return identifier_definition
def test_tail_call(): code = """\ func f(a) -> (a): return f(a) end func g(a, b) -> (a): return f(a) end """ program = preprocess_str( code=code, prime=PRIME, main_scope=ScopedName.from_string('test_scope')) assert program.format() == """\
def test_main_scope(): identifiers = IdentifierManager.from_dict({ ScopedName.from_string('a.b'): ConstDefinition(value=1), ScopedName.from_string('x.y.z'): ConstDefinition(value=2), }) reference_manager = ReferenceManager() program = Program(prime=0, data=[], hints={}, builtins=[], main_scope=ScopedName.from_string('a'), identifiers=identifiers, reference_manager=reference_manager) # Check accessible identifiers. assert program.get_identifier('b', ConstDefinition).value == 1 # Ensure inaccessible identifiers. with pytest.raises(MissingIdentifierError, match="Unknown identifier 'a'."): program.get_identifier('a.b', ConstDefinition) with pytest.raises(MissingIdentifierError, match="Unknown identifier 'x'."): program.get_identifier('x.y', ConstDefinition) with pytest.raises(MissingIdentifierError, match="Unknown identifier 'y'."): program.get_identifier('y', ConstDefinition) # Full name lookup. assert program.get_identifier('a.b', ConstDefinition, full_name_lookup=True).value == 1 assert program.get_identifier('x.y.z', ConstDefinition, full_name_lookup=True).value == 2