Example #1
0
def check_module(module: ast.Module,
                 f_import: Optional[ImportFn]) -> type_info.TypeInfo:
    """Validates type annotations on all functions within "module".

  Args:
    module: The module to type check functions for.
    f_import: Callback to import a module (a la a import statement). This may be
      None e.g. in unit testing situations where it's guaranteed there will be
      no import statements.
  Returns:
    Mapping from AST node to its deduced/checked type.
  Raises:
    XlsTypeError: If any of the function in f have typecheck errors.
  """
    ti = type_info.TypeInfo(module)
    interpreter_callback = functools.partial(interpret_expr, f_import=f_import)
    ctx = deduce.DeduceCtx(ti, module, interpreter_callback,
                           check_top_node_in_module)

    # First populate type_info with constants, enums, and resolved imports.
    ctx.fn_stack.append(
        ('top', dict()))  # No sym bindings in the global scope.
    for member in ctx.module.top:
        if isinstance(member, ast.Import):
            assert isinstance(member.name, tuple), member.name
            imported_module, imported_type_info = f_import(member.name)
            ctx.type_info.add_import(member,
                                     (imported_module, imported_type_info))
        elif isinstance(member, (ast.Constant, ast.Enum)):
            deduce.deduce(member, ctx)
        else:
            assert isinstance(member, (ast.Function, ast.Test, ast.Struct,
                                       ast.QuickCheck, ast.TypeDef)), member
    ctx.fn_stack.pop()

    quickcheck_map = {
        qc.f.name.identifier: qc
        for qc in ctx.module.get_quickchecks()
    }
    for qc in quickcheck_map.values():
        assert isinstance(qc, ast.QuickCheck), qc

        f = qc.f
        assert isinstance(f, ast.Function), f
        if f.is_parametric():
            # TODO(cdleary): 2020-08-09 See https://github.com/google/xls/issues/81
            raise PositionalError(
                'Quickchecking parametric '
                'functions is unsupported.', f.span)

        logging.vlog(2, 'Typechecking function: %s', f)
        ctx.fn_stack.append(
            (f.name.identifier, dict()))  # No symbolic bindings.
        check_top_node_in_module(f, ctx)

        quickcheck_f_body_type = ctx.type_info[f.body]
        if quickcheck_f_body_type != ConcreteType.U1:
            raise XlsTypeError(
                f.span,
                quickcheck_f_body_type,
                ConcreteType.U1,
                suffix='QuickCheck functions must return a bool.')

        logging.vlog(2, 'Finished typechecking function: %s', f)

    # We typecheck struct definitions using check_top_node_in_module() so that
    # we can typecheck function calls in parametric bindings, if any.
    struct_map = {s.name.identifier: s for s in ctx.module.get_structs()}
    for s in struct_map.values():
        assert isinstance(s, ast.Struct), s
        logging.vlog(2, 'Typechecking struct %s', s)
        ctx.fn_stack.append(('top', dict()))  # No symbolic bindings.
        check_top_node_in_module(s, ctx)
        logging.vlog(2, 'Finished typechecking struct: %s', s)

    typedef_map = {
        t.name.identifier: t
        for t in ctx.module.top if isinstance(t, ast.TypeDef)
    }
    for t in typedef_map.values():
        assert isinstance(t, ast.TypeDef), t
        logging.vlog(2, 'Typechecking typedef %s', t)
        ctx.fn_stack.append(('top', dict()))  # No symbolic bindings.
        check_top_node_in_module(t, ctx)
        logging.vlog(2, 'Finished typechecking typedef: %s', t)

    function_map = {f.name.identifier: f for f in ctx.module.get_functions()}
    for f in function_map.values():
        assert isinstance(f, ast.Function), f
        if f.is_parametric():
            # Let's typecheck parametric functions per invocation.
            continue

        logging.vlog(2, 'Typechecking function: %s', f)
        ctx.fn_stack.append(
            (f.name.identifier, dict()))  # No symbolic bindings.
        check_top_node_in_module(f, ctx)
        logging.vlog(2, 'Finished typechecking function: %s', f)

    test_map = {t.name.identifier: t for t in ctx.module.get_tests()}
    for t in test_map.values():
        assert isinstance(t, ast.Test), t

        if isinstance(t, ast.TestFunction):
            # New-style test constructs are specified using a function.
            # This function shouldn't be parametric and shouldn't take any arguments.
            if t.fn.params:
                raise PositionalError(
                    "Test functions shouldn't take arguments.", t.fn.span)

            if t.fn.is_parametric():
                raise PositionalError(
                    "Test functions shouldn't be parametric.", t.fn.span)

        # No symbolic bindings inside of a test.
        ctx.fn_stack.append(('{}_test'.format(t.name.identifier), dict()))
        logging.vlog(2, 'Typechecking test: %s', t)
        if isinstance(t, ast.TestFunction):
            # New-style tests are wrapped in a function.
            check_top_node_in_module(t.fn, ctx)
        else:
            # Old-style tests are specified in a construct with a body
            # (see check_test()).
            check_top_node_in_module(t, ctx)
        logging.vlog(2, 'Finished typechecking test: %s', t)

    return ctx.type_info
Example #2
0
def _instantiate(builtin_name: ast.BuiltinNameDef, invocation: ast.Invocation,
                 ctx: deduce.DeduceCtx) -> Optional[ast.NameDef]:
    """Instantiates a builtin parametric invocation; e.g. 'update'."""
    arg_types = tuple(
        deduce.resolve(ctx.type_info[arg], ctx) for arg in invocation.args)

    higher_order_parametric_bindings = None
    map_fn_name = None
    if builtin_name.identifier == 'map':
        map_fn_ref = invocation.args[1]
        if isinstance(map_fn_ref, ast.ModRef):
            imported_module, imported_type_info = ctx.type_info.get_imported(
                map_fn_ref.mod)
            map_fn_name = map_fn_ref.value
            map_fn = imported_module.get_function(map_fn_name)
            higher_order_parametric_bindings = map_fn.parametric_bindings
        else:
            assert isinstance(map_fn_ref, ast.NameRef), map_fn_ref
            map_fn_name = map_fn_ref.identifier
            if map_fn_ref.identifier not in dslx_builtins.PARAMETRIC_BUILTIN_NAMES:
                map_fn = ctx.module.get_function(map_fn_name)
                higher_order_parametric_bindings = map_fn.parametric_bindings

    fsignature = dslx_builtins.get_fsignature(builtin_name.identifier)
    fn_type, symbolic_bindings = fsignature(arg_types, builtin_name.identifier,
                                            invocation.span, ctx,
                                            higher_order_parametric_bindings)

    _, fn_symbolic_bindings = ctx.fn_stack[-1]
    ctx.type_info.add_invocation_symbolic_bindings(
        invocation, tuple(fn_symbolic_bindings.items()), symbolic_bindings)
    ctx.type_info[invocation.callee] = fn_type
    ctx.type_info[invocation] = fn_type.return_type  # pytype: disable=attribute-error

    if builtin_name.identifier == 'map':
        assert isinstance(map_fn_name, str), map_fn_name
        if (map_fn_name in dslx_builtins.PARAMETRIC_BUILTIN_NAMES
                or not map_fn.is_parametric()):
            # A builtin higher-order parametric fn would've been typechecked when we
            # were going through the arguments of this invocation.
            # If the function wasn't parametric, then we're good to go.
            return None

        # If the higher order function is parametric, we need to typecheck its body
        # with the symbolic bindings we just computed.
        if isinstance(map_fn_ref, ast.ModRef):
            if ctx.type_info.has_instantiation(invocation, symbolic_bindings):
                # We've already typechecked this imported parametric function using
                # these bindings.
                return None
            invocation_imported_type_info = type_info.TypeInfo(
                imported_module, parent=imported_type_info)
            imported_ctx = deduce.DeduceCtx(invocation_imported_type_info,
                                            imported_module,
                                            ctx.interpret_expr,
                                            ctx.check_function_in_module)
            imported_ctx.fn_stack.append(
                (map_fn_name, dict(symbolic_bindings)))
            # We need to typecheck this imported function with respect to its module
            ctx.check_function_in_module(map_fn, imported_ctx)
            ctx.type_info.add_instantiation(invocation, symbolic_bindings,
                                            invocation_imported_type_info)
        else:
            # If the higher-order parametric fn is in this module, let's try to push
            # it onto the typechecking stack.
            if ctx.type_info.has_instantiation(invocation, symbolic_bindings):
                # We've already typecheck this parametric function using these
                # bindings.
                return None

            ctx.fn_stack.append((map_fn_name, dict(symbolic_bindings)))
            invocation_type_info = type_info.TypeInfo(ctx.module,
                                                      parent=ctx.type_info)
            ctx.type_info.add_instantiation(invocation, symbolic_bindings,
                                            invocation_type_info)
            ctx.type_info = invocation_type_info
            return map_fn_ref.name_def

    return None