Example #1
0
def _check_function(f: Function, ctx: deduce.DeduceCtx) -> None:
    """Validates type annotations on parameters/return type of f are consistent.

  Args:
    f: The function to type check.
    ctx: Wraps a node_to_type, a mapping of AST node to its deduced type;
      (free-variable) references are resolved via this dictionary.

  Raises:
    XlsTypeError: When the return type deduced is inconsistent with the return
      type annotation on "f".
  """
    fn_name, _ = ctx.fn_stack[-1]
    # First, get the types of the function's parametrics, args, and return type
    if f.is_parametric() and f.name.identifier == fn_name:
        # Parametric functions are evaluated per invocation. If we're currently
        # inside of this function, it must mean that we already have the type
        # signature and now we just need to evaluate the body.
        assert f in ctx.node_to_type, f
        annotated_return_type = ctx.node_to_type[f].return_type  # pytype: disable=attribute-error
        param_types = list(ctx.node_to_type[f].params)  # pytype: disable=attribute-error
    else:
        logging.vlog(1, 'Type-checking sig for function: %s', f)
        param_types = _check_function_params(f, ctx)
        if f.is_parametric():
            # We just needed the type signature so that we can instantiate this
            # invocation. Let's return this for now and typecheck the body once we
            # have symbolic bindings.
            annotated_return_type = (deduce.deduce(f.return_type, ctx)
                                     if f.return_type else ConcreteType.NIL)
            ctx.node_to_type[f.name] = ctx.node_to_type[f] = FunctionType(
                tuple(param_types), annotated_return_type)
            return

    logging.vlog(1, 'Type-checking body for function: %s', f)

    # Second, typecheck the return type of the function.
    # NOTE: if there is no annotated return type, we assume NIL.
    annotated_return_type = (deduce.deduce(f.return_type, ctx)
                             if f.return_type else ConcreteType.NIL)
    resolved_return_type = deduce.resolve(annotated_return_type, ctx)

    # Third, typecheck the body of the function
    body_return_type = deduce.deduce(f.body, ctx)
    resolved_body_type = deduce.resolve(body_return_type, ctx)

    # Finally, assert type consistency between body and annotated return type.
    if resolved_return_type != resolved_body_type:
        raise XlsTypeError(
            f.body.span,
            resolved_body_type,
            resolved_return_type,
            suffix='Return type of function body for "{}" did not match the '
            'annotated return type.'.format(f.name.identifier))

    ctx.node_to_type[f.name] = ctx.node_to_type[f] = FunctionType(
        tuple(param_types), body_return_type)
Example #2
0
def _instantiate(builtin_name: ast.BuiltinNameDef, invocation: ast.Invocation,
                 ctx: deduce.DeduceCtx) -> Optional[ast.NameDef]:
    """Instantiates a builtin parametric invocation; e.g. 'update'."""
    arg_types = tuple(
        deduce.resolve(ctx.type_info[arg], ctx) for arg in invocation.args)

    higher_order_parametric_bindings = None
    map_fn_name = None
    if builtin_name.identifier == 'map':
        map_fn_ref = invocation.args[1]
        if isinstance(map_fn_ref, ast.ModRef):
            imported_module, imported_type_info = ctx.type_info.get_imported(
                map_fn_ref.mod)
            map_fn_name = map_fn_ref.value
            map_fn = imported_module.get_function(map_fn_name)
            higher_order_parametric_bindings = map_fn.parametric_bindings
        else:
            assert isinstance(map_fn_ref, ast.NameRef), map_fn_ref
            map_fn_name = map_fn_ref.identifier
            if map_fn_ref.identifier not in dslx_builtins.PARAMETRIC_BUILTIN_NAMES:
                map_fn = ctx.module.get_function(map_fn_name)
                higher_order_parametric_bindings = map_fn.parametric_bindings

    fsignature = dslx_builtins.get_fsignature(builtin_name.identifier)
    fn_type, symbolic_bindings = fsignature(arg_types, builtin_name.identifier,
                                            invocation.span, ctx,
                                            higher_order_parametric_bindings)

    _, fn_symbolic_bindings = ctx.fn_stack[-1]
    ctx.type_info.add_invocation_symbolic_bindings(
        invocation, tuple(fn_symbolic_bindings.items()), symbolic_bindings)
    ctx.type_info[invocation.callee] = fn_type
    ctx.type_info[invocation] = fn_type.return_type  # pytype: disable=attribute-error

    if builtin_name.identifier == 'map':
        assert isinstance(map_fn_name, str), map_fn_name
        if (map_fn_name in dslx_builtins.PARAMETRIC_BUILTIN_NAMES
                or not map_fn.is_parametric()):
            # A builtin higher-order parametric fn would've been typechecked when we
            # were going through the arguments of this invocation.
            # If the function wasn't parametric, then we're good to go.
            return None

        # If the higher order function is parametric, we need to typecheck its body
        # with the symbolic bindings we just computed.
        if isinstance(map_fn_ref, ast.ModRef):
            if ctx.type_info.has_instantiation(invocation, symbolic_bindings):
                # We've already typechecked this imported parametric function using
                # these bindings.
                return None
            invocation_imported_type_info = type_info.TypeInfo(
                imported_module, parent=imported_type_info)
            imported_ctx = deduce.DeduceCtx(invocation_imported_type_info,
                                            imported_module,
                                            ctx.interpret_expr,
                                            ctx.check_function_in_module)
            imported_ctx.fn_stack.append(
                (map_fn_name, dict(symbolic_bindings)))
            # We need to typecheck this imported function with respect to its module
            ctx.check_function_in_module(map_fn, imported_ctx)
            ctx.type_info.add_instantiation(invocation, symbolic_bindings,
                                            invocation_imported_type_info)
        else:
            # If the higher-order parametric fn is in this module, let's try to push
            # it onto the typechecking stack.
            if ctx.type_info.has_instantiation(invocation, symbolic_bindings):
                # We've already typecheck this parametric function using these
                # bindings.
                return None

            ctx.fn_stack.append((map_fn_name, dict(symbolic_bindings)))
            invocation_type_info = type_info.TypeInfo(ctx.module,
                                                      parent=ctx.type_info)
            ctx.type_info.add_instantiation(invocation, symbolic_bindings,
                                            invocation_type_info)
            ctx.type_info = invocation_type_info
            return map_fn_ref.name_def

    return None