def _check_function(f: Function, ctx: deduce.DeduceCtx) -> None: """Validates type annotations on parameters/return type of f are consistent. Args: f: The function to type check. ctx: Wraps a node_to_type, a mapping of AST node to its deduced type; (free-variable) references are resolved via this dictionary. Raises: XlsTypeError: When the return type deduced is inconsistent with the return type annotation on "f". """ fn_name, _ = ctx.fn_stack[-1] # First, get the types of the function's parametrics, args, and return type if f.is_parametric() and f.name.identifier == fn_name: # Parametric functions are evaluated per invocation. If we're currently # inside of this function, it must mean that we already have the type # signature and now we just need to evaluate the body. assert f in ctx.node_to_type, f annotated_return_type = ctx.node_to_type[f].return_type # pytype: disable=attribute-error param_types = list(ctx.node_to_type[f].params) # pytype: disable=attribute-error else: logging.vlog(1, 'Type-checking sig for function: %s', f) param_types = _check_function_params(f, ctx) if f.is_parametric(): # We just needed the type signature so that we can instantiate this # invocation. Let's return this for now and typecheck the body once we # have symbolic bindings. annotated_return_type = (deduce.deduce(f.return_type, ctx) if f.return_type else ConcreteType.NIL) ctx.node_to_type[f.name] = ctx.node_to_type[f] = FunctionType( tuple(param_types), annotated_return_type) return logging.vlog(1, 'Type-checking body for function: %s', f) # Second, typecheck the return type of the function. # NOTE: if there is no annotated return type, we assume NIL. annotated_return_type = (deduce.deduce(f.return_type, ctx) if f.return_type else ConcreteType.NIL) resolved_return_type = deduce.resolve(annotated_return_type, ctx) # Third, typecheck the body of the function body_return_type = deduce.deduce(f.body, ctx) resolved_body_type = deduce.resolve(body_return_type, ctx) # Finally, assert type consistency between body and annotated return type. if resolved_return_type != resolved_body_type: raise XlsTypeError( f.body.span, resolved_body_type, resolved_return_type, suffix='Return type of function body for "{}" did not match the ' 'annotated return type.'.format(f.name.identifier)) ctx.node_to_type[f.name] = ctx.node_to_type[f] = FunctionType( tuple(param_types), body_return_type)
def _instantiate(builtin_name: ast.BuiltinNameDef, invocation: ast.Invocation, ctx: deduce.DeduceCtx) -> Optional[ast.NameDef]: """Instantiates a builtin parametric invocation; e.g. 'update'.""" arg_types = tuple( deduce.resolve(ctx.type_info[arg], ctx) for arg in invocation.args) higher_order_parametric_bindings = None map_fn_name = None if builtin_name.identifier == 'map': map_fn_ref = invocation.args[1] if isinstance(map_fn_ref, ast.ModRef): imported_module, imported_type_info = ctx.type_info.get_imported( map_fn_ref.mod) map_fn_name = map_fn_ref.value map_fn = imported_module.get_function(map_fn_name) higher_order_parametric_bindings = map_fn.parametric_bindings else: assert isinstance(map_fn_ref, ast.NameRef), map_fn_ref map_fn_name = map_fn_ref.identifier if map_fn_ref.identifier not in dslx_builtins.PARAMETRIC_BUILTIN_NAMES: map_fn = ctx.module.get_function(map_fn_name) higher_order_parametric_bindings = map_fn.parametric_bindings fsignature = dslx_builtins.get_fsignature(builtin_name.identifier) fn_type, symbolic_bindings = fsignature(arg_types, builtin_name.identifier, invocation.span, ctx, higher_order_parametric_bindings) _, fn_symbolic_bindings = ctx.fn_stack[-1] ctx.type_info.add_invocation_symbolic_bindings( invocation, tuple(fn_symbolic_bindings.items()), symbolic_bindings) ctx.type_info[invocation.callee] = fn_type ctx.type_info[invocation] = fn_type.return_type # pytype: disable=attribute-error if builtin_name.identifier == 'map': assert isinstance(map_fn_name, str), map_fn_name if (map_fn_name in dslx_builtins.PARAMETRIC_BUILTIN_NAMES or not map_fn.is_parametric()): # A builtin higher-order parametric fn would've been typechecked when we # were going through the arguments of this invocation. # If the function wasn't parametric, then we're good to go. return None # If the higher order function is parametric, we need to typecheck its body # with the symbolic bindings we just computed. if isinstance(map_fn_ref, ast.ModRef): if ctx.type_info.has_instantiation(invocation, symbolic_bindings): # We've already typechecked this imported parametric function using # these bindings. return None invocation_imported_type_info = type_info.TypeInfo( imported_module, parent=imported_type_info) imported_ctx = deduce.DeduceCtx(invocation_imported_type_info, imported_module, ctx.interpret_expr, ctx.check_function_in_module) imported_ctx.fn_stack.append( (map_fn_name, dict(symbolic_bindings))) # We need to typecheck this imported function with respect to its module ctx.check_function_in_module(map_fn, imported_ctx) ctx.type_info.add_instantiation(invocation, symbolic_bindings, invocation_imported_type_info) else: # If the higher-order parametric fn is in this module, let's try to push # it onto the typechecking stack. if ctx.type_info.has_instantiation(invocation, symbolic_bindings): # We've already typecheck this parametric function using these # bindings. return None ctx.fn_stack.append((map_fn_name, dict(symbolic_bindings))) invocation_type_info = type_info.TypeInfo(ctx.module, parent=ctx.type_info) ctx.type_info.add_instantiation(invocation, symbolic_bindings, invocation_type_info) ctx.type_info = invocation_type_info return map_fn_ref.name_def return None