예제 #1
0
 def infer_type(self, env: Env, sigma: InferenceType):
     subst = Subst.empty()
     for stmt in self.statements:
         subst = subst.compose(stmt.infer_type(env,
                                               sigma.substitute(subst)))
         env.substitute(subst)
     return subst
예제 #2
0
    def infer_type(self, env: Env, sigma: InferenceType):
        Logger.debug(f'* Start typing function {self.name.value}')
        assert isinstance(self.arg_ids, list), 'Binding analysis must be done before type inference'
        if self.name.value == 'main':
            assert len(self.arg_ids) == 0, \
                f"Function 'main' cannot take arguments, but is defined with {len(self.arg_ids)}"
        env.add_fun(self.name.value, self.arg_ids)
        name = self.name.value
        f = env.functions.get(name)
        env.update_fun_quants(name, [])

        star = Subst.empty()
        if self.fun_type is not None:
            args_len = len(f.usage.arg_types)
            types_len = len(self.fun_type.args.args)
            if args_len != types_len:
                raise FunArgsTypesMismatch(self.code_range, name, args_len, types_len)

            for arg_tv, arg_type_def in zip(f.usage.arg_types, self.fun_type.args.args):
                star = arg_type_def.infer_type(env, arg_tv).compose(star)
                env.substitute(star)
            star = self.fun_type.return_type.infer_type(env, f.usage.return_type).compose(star)
            env.substitute(star)

        star = self.block.infer_type(env, f.usage.return_type.substitute(star)).compose(star)
        f = env.functions.get(name)
        # Update the quantifiers
        type_vars = []
        for arg_tv in f.usage.arg_types:
            type_vars = arg_tv.substitute(star).collect_type_vars(type_vars)
        Logger.debug(f'TVs in resulting function type before removing free TVs= {f.usage}: {type_vars}')

        # Remove free variables in env from the TVs we're going to quantify over
        free_env_type_vars = env.free_type_vars(
            lambda fun_name: fun_name != name
        )
        Logger.debug(f'Free TVs in env: {free_env_type_vars}')
        type_vars = [x for x in type_vars if x not in free_env_type_vars]
        Logger.debug(f'TVs after removing free TVs: {type_vars}')
        env.update_fun_quants(name, type_vars)

        f = env.functions.get(name)
        postponed = env.postponed_functions.pop(name, None)
        if postponed is not None:
            for (instance_type, inst_code_range) in postponed:
                ft = f.instantiate(env)
                args_len = len(instance_type.arg_types)
                tv_len = len(ft.usage.arg_types)
                if tv_len != args_len:
                    raise FunCallArgsMismatch(self.code_range, self.name.value, args_len, tv_len)
                for actual, instance in zip(ft.usage.arg_types, instance_type.arg_types):
                    Logger.debug(f'Postponed function signature check: {actual} <-> {instance}')
                    star = actual.unify_or_type_error(instance.substitute(star), inst_code_range).compose(star)
                    env.substitute(star)
                star = instance_type.return_type.substitute(star)\
                    .unify_or_type_error(ft.usage.return_type.substitute(star), inst_code_range)
                env.substitute(star)
        Logger.debug(f'- Finished typing function {name}\n')
        return star
예제 #3
0
 def infer_type(self, env: Env, sigma: InferenceType):
     t = env.get_var(self.id_number)
     if self.var_type:
         star = self.var_type.infer_type(env, t)
     else:
         star = Subst.empty()
     env.substitute(star)
     return self.expression.infer_type(env, t.substitute(star)).compose(star)
예제 #4
0
    def infer_type(self, env: Env, sigma: InferenceType):
        f = env.functions.get(self.function_name.value)
        if f is not None:
            f = f.instantiate(env)
            args_len = len(self.expressions)
            tv_len = len(f.usage.arg_types)
            if args_len != tv_len:
                raise FunCallArgsMismatch(self.code_range, self.function_name.value, args_len, tv_len)

            subst = Subst.empty()
            for exp, tv in zip(self.expressions, f.usage.arg_types):
                subst = exp.infer_type(env, tv.substitute(subst)).compose(subst)
                env.substitute(subst)
            return sigma.unify_or_type_error(f.usage.return_type.substitute(subst), self.code_range).compose(subst)
        else:  # Function was not yet declared
            type_vars = []
            subst = Subst.empty()
            for arg in self.expressions:
                tv = env.fresh_type_var()
                type_vars.append(tv)
                subst = arg.infer_type(env, tv).compose(subst)
                env.substitute(subst)
            env.add_fun_usage(self.function_name.value, type_vars, sigma, self.code_range)
            return subst
예제 #5
0
    def infer_type(self, env: Env, sigma: InferenceType):
        env.global_var_ids = [
            x.id_number for x in self.declarations if isinstance(x, VarDecl)
        ]
        subst = Subst.empty()
        for d in self.declarations:
            subst = subst.compose(d.infer_type(env, sigma))
            env.substitute(subst)

        tv_globals = env.get_globals_with_tv()
        for (g, tv) in tv_globals:
            for d in self.declarations:
                if isinstance(d, VarDecl) and d.id_number == g:
                    raise UnknownVarTypeError(d.code_range, tv, d.name.value)

        return subst
예제 #6
0
 def generate_function_impls(self):
     while len(self.context.needed_fun_instances) > 0:
         fun_key, fun_inst = self.context.needed_fun_instances.popitem()
         if fun_key in self.processed_instances:
             continue  # Already generated code
         code_builder = OpCodeBuilder(self.context, copy.deepcopy(
             self.env))  # Deep copy env for polymorphic funcs
         if (fun_decl := self.function_asts.get(fun_inst.name,
                                                None)) is not None:
             # Check argument types
             subst = Subst.empty()
             for arg_id, arg_type in zip(fun_decl.arg_ids,
                                         fun_inst.arg_types):
                 current = code_builder.env.get_var(arg_id)
                 subst = current.substitute(subst).unify(arg_type).compose(
                     subst)
             code_builder.env.substitute(subst)
             # Initialize args as local vars
             num_args = len(fun_decl.arg_ids)
             for i, arg_id in enumerate(fun_decl.arg_ids):
                 code_builder.add_local(arg_id, -num_args + i)
             fun_decl.block.generate_code(code_builder)
         else:
             # Check builtins
             builtin = None
             for b in self.builtins:
                 if b.name == fun_inst.name:
                     builtin = b
             if builtin is not None:
                 builtin.generate_code(fun_inst.arg_types, code_builder)
             else:
                 raise Exception(
                     f'Unknown function \'{fun_inst.name}\' encountered while generating code'
                 )
         if not code_builder.ends_with_return():
             code_builder.add(codes.RetNoValue(
             ))  # Add return if function doesn't end with return stmt
         self.functions.append(
             (fun_inst, gen_utils.FunctionImpl(code_builder.ops)))
         self.processed_instances.add(fun_key)
예제 #7
0
 def infer_type(self, env: Env, sigma: InferenceType):
     return Subst.empty()
예제 #8
0
    def analysis(self, ast: SPLFile):
        Logger.info(
            '-------------------------------------------------------------')
        Logger.info(
            '------------------ Starting analysis phase ------------------')
        Logger.info(
            '-------------------------------------------------------------')

        Logger.info('* Starting return value checking')
        rvc = ReturnValueChecker()
        return_warnings, return_errors = rvc.check_spl_file(ast)
        Logger.info('- Return value checking DONE')
        if len(return_warnings) > 0:
            for w in return_warnings:
                Logger.warning(w)
        if len(return_errors) > 0:
            for e in return_errors:
                Logger.error(e)
            sys.exit(1)
        context = Context()
        for b in self.builtins:
            b.add_to_context(context)
        Logger.info(
            f'- Added {len(self.builtins)} builtin functions to binding context: {self.get_builtin_str()}'
        )

        binding_feedback = {'errors': [], 'warnings': []}
        Logger.info('* Starting binding analysis')
        ast.binding_analysis(context, binding_feedback)
        Logger.info('- Binding analysis DONE')

        Logger.info(
            '*** Pretty printing AST with identifier IDs after binding analysis: ***'
        )
        Logger.info('\n' + ast.indented_print())

        if len(binding_feedback['warnings']) > 0:
            for w in binding_feedback['warnings']:
                Logger.warning(w)
        if len(binding_feedback['errors']) > 0:
            for e in binding_feedback['errors']:
                Logger.error(e)
            sys.exit(1)

        env = Env()
        for b in self.builtins:
            b.add_to_env(env)
        Logger.info(
            f'- Added {len(self.builtins)} builtin functions to type environment: {self.get_builtin_str()}'
        )
        subst = Subst.empty()
        Logger.info('* Starting type inference')
        try:
            subst = ast.infer_type(env, InferenceVoid())
        except Exception as e:
            Logger.error(str(e))
            # raise e
            sys.exit(1)
        env.substitute(subst)

        Logger.debug('* Inferred function types after inference:')
        for name, f in env.functions.items():
            Logger.debug(
                f'- {name} :: args: [{", ".join(str(a) for a in f.usage.arg_types)}], ret: {str(f.usage.return_type)}'
            )
        Logger.debug('* Inferred variable types after inference:')
        for num, v in env.variables.items():
            Logger.debug(f'- {num} :: {str(v)}')

        Logger.info('- Typing DONE')
        return env
예제 #9
0
 def instantiate(self, env):
     subst = Subst({
         k: env.fresh_type_var()
         for i, k in enumerate(self.quantified_type_vars)
     })
     return self.substitute(subst)