def infer_type(self, env: Env, sigma: InferenceType):
     subst = Subst.empty()
     for stmt in self.statements:
         subst = subst.compose(stmt.infer_type(env,
                                               sigma.substitute(subst)))
         env.substitute(subst)
     return subst
 def infer_type(self, env: Env, sigma: InferenceType):
     a1, a2 = env.fresh_type_var(), env.fresh_type_var()
     star1 = self.fst.infer_type(env, a1)
     env.substitute(star1)
     star2 = self.snd.infer_type(env, a2).compose(star1)
     return sigma.substitute(star2).unify_or_type_error(InferenceTuple(a1, a2).substitute(star2), self.code_range) \
         .compose(star2)
 def infer_type(self, env: Env, sigma: InferenceType):
     t = env.get_var(self.id_number)
     if self.var_type:
         star = self.var_type.infer_type(env, t)
     else:
         star = Subst.empty()
     env.substitute(star)
     return self.expression.infer_type(env, t.substitute(star)).compose(star)
 def infer_type(self, env: Env, sigma: InferenceType):
     star1 = self.expression.infer_type(env, InferenceBool())
     env.substitute(star1)
     star2 = self.then_block.infer_type(
         env, sigma.substitute(star1)).compose(star1)
     if self.else_block is not None:
         env.substitute(star2)
         return self.else_block.infer_type(
             env, sigma.substitute(star2)).compose(star2)
     return star2
    def infer_type(self, env: Env, sigma: InferenceType):
        Logger.debug(f'* Start typing function {self.name.value}')
        assert isinstance(self.arg_ids, list), 'Binding analysis must be done before type inference'
        if self.name.value == 'main':
            assert len(self.arg_ids) == 0, \
                f"Function 'main' cannot take arguments, but is defined with {len(self.arg_ids)}"
        env.add_fun(self.name.value, self.arg_ids)
        name = self.name.value
        f = env.functions.get(name)
        env.update_fun_quants(name, [])

        star = Subst.empty()
        if self.fun_type is not None:
            args_len = len(f.usage.arg_types)
            types_len = len(self.fun_type.args.args)
            if args_len != types_len:
                raise FunArgsTypesMismatch(self.code_range, name, args_len, types_len)

            for arg_tv, arg_type_def in zip(f.usage.arg_types, self.fun_type.args.args):
                star = arg_type_def.infer_type(env, arg_tv).compose(star)
                env.substitute(star)
            star = self.fun_type.return_type.infer_type(env, f.usage.return_type).compose(star)
            env.substitute(star)

        star = self.block.infer_type(env, f.usage.return_type.substitute(star)).compose(star)
        f = env.functions.get(name)
        # Update the quantifiers
        type_vars = []
        for arg_tv in f.usage.arg_types:
            type_vars = arg_tv.substitute(star).collect_type_vars(type_vars)
        Logger.debug(f'TVs in resulting function type before removing free TVs= {f.usage}: {type_vars}')

        # Remove free variables in env from the TVs we're going to quantify over
        free_env_type_vars = env.free_type_vars(
            lambda fun_name: fun_name != name
        )
        Logger.debug(f'Free TVs in env: {free_env_type_vars}')
        type_vars = [x for x in type_vars if x not in free_env_type_vars]
        Logger.debug(f'TVs after removing free TVs: {type_vars}')
        env.update_fun_quants(name, type_vars)

        f = env.functions.get(name)
        postponed = env.postponed_functions.pop(name, None)
        if postponed is not None:
            for (instance_type, inst_code_range) in postponed:
                ft = f.instantiate(env)
                args_len = len(instance_type.arg_types)
                tv_len = len(ft.usage.arg_types)
                if tv_len != args_len:
                    raise FunCallArgsMismatch(self.code_range, self.name.value, args_len, tv_len)
                for actual, instance in zip(ft.usage.arg_types, instance_type.arg_types):
                    Logger.debug(f'Postponed function signature check: {actual} <-> {instance}')
                    star = actual.unify_or_type_error(instance.substitute(star), inst_code_range).compose(star)
                    env.substitute(star)
                star = instance_type.return_type.substitute(star)\
                    .unify_or_type_error(ft.usage.return_type.substitute(star), inst_code_range)
                env.substitute(star)
        Logger.debug(f'- Finished typing function {name}\n')
        return star
Beispiel #6
0
    def infer_type(self, env: Env, sigma: InferenceType):
        env.global_var_ids = [
            x.id_number for x in self.declarations if isinstance(x, VarDecl)
        ]
        subst = Subst.empty()
        for d in self.declarations:
            subst = subst.compose(d.infer_type(env, sigma))
            env.substitute(subst)

        tv_globals = env.get_globals_with_tv()
        for (g, tv) in tv_globals:
            for d in self.declarations:
                if isinstance(d, VarDecl) and d.id_number == g:
                    raise UnknownVarTypeError(d.code_range, tv, d.name.value)

        return subst
Beispiel #7
0
 def infer_type(self, env: Env, sigma: InferenceType):
     if self.field_type == FieldType.Fst:
         tup = InferenceTuple(sigma, env.fresh_type_var())
         return self.field.infer_type(env, tup)
     elif self.field_type == FieldType.Snd:
         tup = InferenceTuple(env.fresh_type_var(), sigma)
         return self.field.infer_type(env, tup)
     elif self.field_type == FieldType.Hd:
         lst = InferenceList(sigma)
         return self.field.infer_type(env, lst)
     elif self.field_type == FieldType.Tl:
         lst = InferenceList(env.fresh_type_var())
         star = sigma.unify_or_type_error(lst, self.code_range)
         env.substitute(star)
         return self.field.infer_type(env, sigma.substitute(star)).compose(star)
     else:
         raise Exception('Unknown field accessor')
    def infer_type(self, env: Env, sigma: InferenceType):
        e1_type, e2_type, result_type = None, None, None
        if self.op_type in [BinaryOpType.Sub, BinaryOpType.Mul, BinaryOpType.Div, BinaryOpType.Mod]:
            e1_type, e2_type, result_type = InferenceInt(), InferenceInt(), InferenceInt()
        elif self.op_type in [BinaryOpType.Eq, BinaryOpType.Neq, BinaryOpType.Geq, BinaryOpType.Leq, BinaryOpType.Lt,
                              BinaryOpType.Gt]:
            tv = env.fresh_type_var()
            e1_type, e2_type, result_type = tv, tv, InferenceBool()
        elif self.op_type in [BinaryOpType.Add]:  # Add is overloaded for int, char and list
            tv = env.fresh_type_var()
            e1_type, e2_type, result_type = tv, tv, tv
        elif self.op_type in [BinaryOpType.And, BinaryOpType.Or]:
            e1_type, e2_type, result_type = InferenceBool(), InferenceBool(), InferenceBool()
        elif self.op_type == BinaryOpType.Cons:
            tv = env.fresh_type_var()
            e1_type, e2_type, result_type = tv, InferenceList(tv), InferenceList(tv)

        star1 = self.expr1.infer_type(env, e1_type)
        env.substitute(star1)
        star2 = self.expr2.infer_type(env, e2_type.substitute(star1)).compose(star1)
        return sigma.substitute(star2).unify_or_type_error(result_type.substitute(star2), self.code_range).compose(
            star2)
    def infer_type(self, env: Env, sigma: InferenceType):
        f = env.functions.get(self.function_name.value)
        if f is not None:
            f = f.instantiate(env)
            args_len = len(self.expressions)
            tv_len = len(f.usage.arg_types)
            if args_len != tv_len:
                raise FunCallArgsMismatch(self.code_range, self.function_name.value, args_len, tv_len)

            subst = Subst.empty()
            for exp, tv in zip(self.expressions, f.usage.arg_types):
                subst = exp.infer_type(env, tv.substitute(subst)).compose(subst)
                env.substitute(subst)
            return sigma.unify_or_type_error(f.usage.return_type.substitute(subst), self.code_range).compose(subst)
        else:  # Function was not yet declared
            type_vars = []
            subst = Subst.empty()
            for arg in self.expressions:
                tv = env.fresh_type_var()
                type_vars.append(tv)
                subst = arg.infer_type(env, tv).compose(subst)
                env.substitute(subst)
            env.add_fun_usage(self.function_name.value, type_vars, sigma, self.code_range)
            return subst
 def infer_type(self, env: Env, sigma: InferenceType):
     star = self.expression.infer_type(env, InferenceBool())
     env.substitute(star)
     return self.body.infer_type(env, sigma.substitute(star)).compose(star)
Beispiel #11
0
    def analysis(self, ast: SPLFile):
        Logger.info(
            '-------------------------------------------------------------')
        Logger.info(
            '------------------ Starting analysis phase ------------------')
        Logger.info(
            '-------------------------------------------------------------')

        Logger.info('* Starting return value checking')
        rvc = ReturnValueChecker()
        return_warnings, return_errors = rvc.check_spl_file(ast)
        Logger.info('- Return value checking DONE')
        if len(return_warnings) > 0:
            for w in return_warnings:
                Logger.warning(w)
        if len(return_errors) > 0:
            for e in return_errors:
                Logger.error(e)
            sys.exit(1)
        context = Context()
        for b in self.builtins:
            b.add_to_context(context)
        Logger.info(
            f'- Added {len(self.builtins)} builtin functions to binding context: {self.get_builtin_str()}'
        )

        binding_feedback = {'errors': [], 'warnings': []}
        Logger.info('* Starting binding analysis')
        ast.binding_analysis(context, binding_feedback)
        Logger.info('- Binding analysis DONE')

        Logger.info(
            '*** Pretty printing AST with identifier IDs after binding analysis: ***'
        )
        Logger.info('\n' + ast.indented_print())

        if len(binding_feedback['warnings']) > 0:
            for w in binding_feedback['warnings']:
                Logger.warning(w)
        if len(binding_feedback['errors']) > 0:
            for e in binding_feedback['errors']:
                Logger.error(e)
            sys.exit(1)

        env = Env()
        for b in self.builtins:
            b.add_to_env(env)
        Logger.info(
            f'- Added {len(self.builtins)} builtin functions to type environment: {self.get_builtin_str()}'
        )
        subst = Subst.empty()
        Logger.info('* Starting type inference')
        try:
            subst = ast.infer_type(env, InferenceVoid())
        except Exception as e:
            Logger.error(str(e))
            # raise e
            sys.exit(1)
        env.substitute(subst)

        Logger.debug('* Inferred function types after inference:')
        for name, f in env.functions.items():
            Logger.debug(
                f'- {name} :: args: [{", ".join(str(a) for a in f.usage.arg_types)}], ret: {str(f.usage.return_type)}'
            )
        Logger.debug('* Inferred variable types after inference:')
        for num, v in env.variables.items():
            Logger.debug(f'- {num} :: {str(v)}')

        Logger.info('- Typing DONE')
        return env
Beispiel #12
0
 def infer_type(self, env: Env, sigma: InferenceType):
     a = env.fresh_type_var()
     star = self.list_type.infer_type(env, a)
     env.substitute(star)
     return sigma.unify_or_type_error(InferenceList(a).substitute(star), self.code_range).compose(star)
Beispiel #13
0
 def infer_type(self, env: Env, sigma: InferenceType):
     type_var = env.get_var(self.id_number)
     result = sigma.unify_or_type_error(type_var, self.code_range)
     env.substitute(result)
     return result