def check_module(module: ast.Module, f_import: Optional[ImportFn]) -> type_info.TypeInfo: """Validates type annotations on all functions within "module". Args: module: The module to type check functions for. f_import: Callback to import a module (a la a import statement). This may be None e.g. in unit testing situations where it's guaranteed there will be no import statements. Returns: Mapping from AST node to its deduced/checked type. Raises: XlsTypeError: If any of the function in f have typecheck errors. """ ti = type_info.TypeInfo(module) interpreter_callback = functools.partial(interpret_expr, f_import=f_import) ctx = deduce.DeduceCtx(ti, module, interpreter_callback, check_top_node_in_module) # First populate type_info with constants, enums, and resolved imports. ctx.fn_stack.append( ('top', dict())) # No sym bindings in the global scope. for member in ctx.module.top: if isinstance(member, ast.Import): assert isinstance(member.name, tuple), member.name imported_module, imported_type_info = f_import(member.name) ctx.type_info.add_import(member, (imported_module, imported_type_info)) elif isinstance(member, (ast.Constant, ast.Enum)): deduce.deduce(member, ctx) else: assert isinstance(member, (ast.Function, ast.Test, ast.Struct, ast.QuickCheck, ast.TypeDef)), member ctx.fn_stack.pop() quickcheck_map = { qc.f.name.identifier: qc for qc in ctx.module.get_quickchecks() } for qc in quickcheck_map.values(): assert isinstance(qc, ast.QuickCheck), qc f = qc.f assert isinstance(f, ast.Function), f if f.is_parametric(): # TODO(cdleary): 2020-08-09 See https://github.com/google/xls/issues/81 raise PositionalError( 'Quickchecking parametric ' 'functions is unsupported.', f.span) logging.vlog(2, 'Typechecking function: %s', f) ctx.fn_stack.append( (f.name.identifier, dict())) # No symbolic bindings. check_top_node_in_module(f, ctx) quickcheck_f_body_type = ctx.type_info[f.body] if quickcheck_f_body_type != ConcreteType.U1: raise XlsTypeError( f.span, quickcheck_f_body_type, ConcreteType.U1, suffix='QuickCheck functions must return a bool.') logging.vlog(2, 'Finished typechecking function: %s', f) # We typecheck struct definitions using check_top_node_in_module() so that # we can typecheck function calls in parametric bindings, if any. struct_map = {s.name.identifier: s for s in ctx.module.get_structs()} for s in struct_map.values(): assert isinstance(s, ast.Struct), s logging.vlog(2, 'Typechecking struct %s', s) ctx.fn_stack.append(('top', dict())) # No symbolic bindings. check_top_node_in_module(s, ctx) logging.vlog(2, 'Finished typechecking struct: %s', s) typedef_map = { t.name.identifier: t for t in ctx.module.top if isinstance(t, ast.TypeDef) } for t in typedef_map.values(): assert isinstance(t, ast.TypeDef), t logging.vlog(2, 'Typechecking typedef %s', t) ctx.fn_stack.append(('top', dict())) # No symbolic bindings. check_top_node_in_module(t, ctx) logging.vlog(2, 'Finished typechecking typedef: %s', t) function_map = {f.name.identifier: f for f in ctx.module.get_functions()} for f in function_map.values(): assert isinstance(f, ast.Function), f if f.is_parametric(): # Let's typecheck parametric functions per invocation. continue logging.vlog(2, 'Typechecking function: %s', f) ctx.fn_stack.append( (f.name.identifier, dict())) # No symbolic bindings. check_top_node_in_module(f, ctx) logging.vlog(2, 'Finished typechecking function: %s', f) test_map = {t.name.identifier: t for t in ctx.module.get_tests()} for t in test_map.values(): assert isinstance(t, ast.Test), t if isinstance(t, ast.TestFunction): # New-style test constructs are specified using a function. # This function shouldn't be parametric and shouldn't take any arguments. if t.fn.params: raise PositionalError( "Test functions shouldn't take arguments.", t.fn.span) if t.fn.is_parametric(): raise PositionalError( "Test functions shouldn't be parametric.", t.fn.span) # No symbolic bindings inside of a test. ctx.fn_stack.append(('{}_test'.format(t.name.identifier), dict())) logging.vlog(2, 'Typechecking test: %s', t) if isinstance(t, ast.TestFunction): # New-style tests are wrapped in a function. check_top_node_in_module(t.fn, ctx) else: # Old-style tests are specified in a construct with a body # (see check_test()). check_top_node_in_module(t, ctx) logging.vlog(2, 'Finished typechecking test: %s', t) return ctx.type_info
def _instantiate(builtin_name: ast.BuiltinNameDef, invocation: ast.Invocation, ctx: deduce.DeduceCtx) -> Optional[ast.NameDef]: """Instantiates a builtin parametric invocation; e.g. 'update'.""" arg_types = tuple( deduce.resolve(ctx.type_info[arg], ctx) for arg in invocation.args) higher_order_parametric_bindings = None map_fn_name = None if builtin_name.identifier == 'map': map_fn_ref = invocation.args[1] if isinstance(map_fn_ref, ast.ModRef): imported_module, imported_type_info = ctx.type_info.get_imported( map_fn_ref.mod) map_fn_name = map_fn_ref.value map_fn = imported_module.get_function(map_fn_name) higher_order_parametric_bindings = map_fn.parametric_bindings else: assert isinstance(map_fn_ref, ast.NameRef), map_fn_ref map_fn_name = map_fn_ref.identifier if map_fn_ref.identifier not in dslx_builtins.PARAMETRIC_BUILTIN_NAMES: map_fn = ctx.module.get_function(map_fn_name) higher_order_parametric_bindings = map_fn.parametric_bindings fsignature = dslx_builtins.get_fsignature(builtin_name.identifier) fn_type, symbolic_bindings = fsignature(arg_types, builtin_name.identifier, invocation.span, ctx, higher_order_parametric_bindings) _, fn_symbolic_bindings = ctx.fn_stack[-1] ctx.type_info.add_invocation_symbolic_bindings( invocation, tuple(fn_symbolic_bindings.items()), symbolic_bindings) ctx.type_info[invocation.callee] = fn_type ctx.type_info[invocation] = fn_type.return_type # pytype: disable=attribute-error if builtin_name.identifier == 'map': assert isinstance(map_fn_name, str), map_fn_name if (map_fn_name in dslx_builtins.PARAMETRIC_BUILTIN_NAMES or not map_fn.is_parametric()): # A builtin higher-order parametric fn would've been typechecked when we # were going through the arguments of this invocation. # If the function wasn't parametric, then we're good to go. return None # If the higher order function is parametric, we need to typecheck its body # with the symbolic bindings we just computed. if isinstance(map_fn_ref, ast.ModRef): if ctx.type_info.has_instantiation(invocation, symbolic_bindings): # We've already typechecked this imported parametric function using # these bindings. return None invocation_imported_type_info = type_info.TypeInfo( imported_module, parent=imported_type_info) imported_ctx = deduce.DeduceCtx(invocation_imported_type_info, imported_module, ctx.interpret_expr, ctx.check_function_in_module) imported_ctx.fn_stack.append( (map_fn_name, dict(symbolic_bindings))) # We need to typecheck this imported function with respect to its module ctx.check_function_in_module(map_fn, imported_ctx) ctx.type_info.add_instantiation(invocation, symbolic_bindings, invocation_imported_type_info) else: # If the higher-order parametric fn is in this module, let's try to push # it onto the typechecking stack. if ctx.type_info.has_instantiation(invocation, symbolic_bindings): # We've already typecheck this parametric function using these # bindings. return None ctx.fn_stack.append((map_fn_name, dict(symbolic_bindings))) invocation_type_info = type_info.TypeInfo(ctx.module, parent=ctx.type_info) ctx.type_info.add_instantiation(invocation, symbolic_bindings, invocation_type_info) ctx.type_info = invocation_type_info return map_fn_ref.name_def return None