class TypeChecker:
    def __init__(self, root, cls_refs):
        self.logger = init_logger('TypeChecker')

        self.root = root
        self.cls_refs = cls_refs
        self._t = 0
        self._dfs(root, Environment())

        self.cur_env = None
        self.cur_cls = None

    def _dfs(self, u, env):
        self._t += 1
        u.td = self._t

        for name, ref in u.methods.items():
            old = env.get(name)

            if old and old.get_signature() != ref.get_signature():
                raise SemanticError(
                    ref.id.line, ref.id.col,
                    f'{ref} of {u} is not compatible to {old} for inheritance')

            env.define(name, ref)

            for formal in ref.formal_list:
                if formal.type.value == 'SELF_TYPE':
                    raise SemanticError(
                        formal.type.line, formal.type.col,
                        f'Tried to declare {formal} with {formal.type}')

                formal.set_static_type(
                    self._get_correct_type(formal, u.self_type)
                )  #precalculate static type of formals before doing visitor

        for v in u.children:
            v.parent = u
            v.level = u.level + 1

            self._dfs(v, Environment(env))

        u.tf = self._t

        self.logger.debug(
            f'{u}, td={u.td}, tf={u.tf}, level={u.level}, parent={u.parent}')

    def _conforms(self, u, v):  #Does u conforms with v?
        return v.td <= u.td <= v.tf

    def _lca(self, u, v):
        self.logger.debug(f'LCA query between {u} and {v}')

        while u.type.value != v.type.value:
            if u.level > v.level:
                u = u.parent

            else:
                v = v.parent

        self.logger.debug(f'LCA is: {u}')

        return u

    def _dispatch(self, u, name):
        while u and name not in u.methods:
            u = u.parent

        return u.methods[name] if u else None

    def _get_correct_type(self, node, default_type):
        if node.type.value == 'SELF_TYPE':
            return default_type

        if node.type.value not in self.cls_refs:
            raise TypeError(node.type.line, node.type.col,
                            f'{Class(node.type)} doesnt exists')

        return self.cls_refs[node.type.value]

    def visit(self, node):
        self.logger.debug(f'On {node}')

        fn = getattr(self, 'visit_' + node.__class__.__name__)
        res = fn(node)

        if hasattr(node, 'static_type'):
            self.logger.debug(f'{node}, static_type: {node.static_type}')

        return res

    def visit_Class(self, node):
        old_env = self.cur_env
        self.cur_env = Environment(old_env)

        old_cls = self.cur_cls
        self.cur_cls = node

        self.cur_env.define('self',
                            Attribute(Id('self'), Type('SELF_TYPE'), None))

        self.logger.info(f'{node} Created new environment with self')

        for feature in node.feature_list:
            if isinstance(feature, Attribute):
                if self.cur_env.get(feature.id.value):
                    raise SemanticError(
                        feature.id.line, feature.id.col,
                        f'Tried to redefine {feature} by inheritance')

                self.cur_env.define(feature.id.value, feature)
                self.logger.info(f'{node} defined {feature}')

        for feature in node.feature_list:
            self.visit(feature)

        for cls in node.children:
            self.visit(cls)

        self.cur_env = old_env
        self.cur_cls = old_cls

        self.logger.info(f'{node} Restoring previous environment')

    def visit_SELF_TYPE(self, node):
        pass  #For SELF_TYPE(C) classes, for consistency

    def visit_Formal(self, node):
        if node.id.value == 'self':
            raise SemanticError(node.id.line, node.id.col,
                                f'Tried to assign to {node.id}')

        if node.id.value in self.cur_env.map:  #check that is not defined on current env only!
            raise SemanticError(node.id.line, node.id.col,
                                f'Tried to redefine {node}')

        self.cur_env.define(node.id.value, node)

        self.visit(node.id)
        node.set_static_type(node.id.static_type)

    def visit_Method(self, node):
        old_env = self.cur_env
        self.cur_env = Environment(old_env)

        for formal in node.formal_list:
            self.visit(formal)

        if node.expr:  #if it is not a native method
            self.visit(node.expr)

            _static_type = self._get_correct_type(node, self.cur_cls.self_type)

            self.logger.debug(f'{node} static type: {_static_type}')

            if not self._conforms(node.expr.static_type, _static_type):
                raise TypeError(
                    node.expr.line, node.expr.col,
                    f'{node.expr} with {node.expr.static_type} doesnt conform to {node} with {_static_type}'
                )

        self.cur_env = old_env

    def visit_Attribute(self, node):
        self.visit(node.id)
        node.set_static_type(node.id.static_type)

        if node.opt_expr_init:
            self.logger.info(f'{node} has expr')

            expr = node.opt_expr_init
            self.visit(expr)

            if not self._conforms(expr.static_type, node.static_type):
                raise TypeError(
                    node.line, node.col,
                    f'{expr} with {expr.static_type} doesnt conform to {node} with {node.static_type}'
                )

    def visit_Dispatch(self, node):
        for expr in node.expr_list:
            self.visit(expr)

        self.visit(node.expr)
        cls = None

        if node.opt_type:  #static dispatch
            if node.opt_type.value == 'SELF_TYPE':
                raise SemanticError(
                    node.opt_type.line, node.opt_type.col,
                    f'Cant perform static dispatch on {node.opt_type}')

            if node.opt_type.value not in self.cls_refs:
                raise TypeError(node.opt_type.line, node.opt_type.col,
                                f'{Class(node.opt_type, None)} doesnt exists')

            cls = self.cls_refs[node.opt_type.value]

            if not self._conforms(node.expr.static_type, cls):
                raise TypeError(
                    node.line, node.col,
                    f'Dispatch failed, {node.expr} with {node.expr.static_type} doenst conform to {cls}'
                )

        else:
            cls = node.expr.static_type

            #Assert that static type of node.expr is one of the nodes of the tree and NEVER a declared SELF_TYPE
            #It can be SELF_TYPE(C) though
            assert node.expr.static_type.td > 0

            if isinstance(node.expr.static_type, SELF_TYPE):
                cls = self.cur_cls

        self.logger.debug(
            f'{node}: finding method {node.id} on class {cls} or some ancestor'
        )

        method = self._dispatch(cls, node.id.value)

        if not method:
            raise AttributeError(
                node.line, node.col,
                f'Dispatch failed: couldnt find a method with {node.id} in {cls} or any ancestor'
            )

        self.logger.debug(f'{node}, found {method}')

        formals = list(method.formal_list)

        if len(node.expr_list) != len(formals):
            raise SemanticError(node.line, node.col, (
                f'Dispatch failed, number of arguments of dispatch is {len(node.expr_list)}, '
                f'number of formals is {len(formals)}'))

        for expr, formal in zip(node.expr_list, formals):
            self.logger.debug(f'Checking conformance of {expr} and {formal}')

            if not self._conforms(expr.static_type, formal.static_type):
                raise TypeError(
                    expr.line, expr.col,
                    f'{expr} with {expr.static_type} doesnt conform to {formal} with {formal.static_type}'
                )

        node.set_static_type(
            self._get_correct_type(method, node.expr.static_type))

    def visit_Assignment(self, node):
        if node.id.value == 'self':
            raise SemanticError(node.id.line, node.id.col,
                                f'Tried to assign to {node.id}')

        self.visit(node.id)
        self.visit(node.expr)

        if not self._conforms(node.expr.static_type, node.id.static_type):
            raise TypeError(
                node.line, node.col,
                f'{node.expr} with {node.expr.static_type} doesnt conform to {node.id} with {node.id.static_type}'
            )

        node.set_static_type(node.expr.static_type)

    def visit_Block(self, node):
        for expr in node.expr_list:
            self.visit(expr)

        node.set_static_type(node.expr_list[-1].static_type)

    def visit_New(self, node):
        node.set_static_type(
            self._get_correct_type(node, self.cur_cls.self_type))

    def visit_If(self, node):
        self.visit(node.predicate)

        _static_type = node.predicate.static_type

        if _static_type.type.value != 'Bool':
            raise TypeError(
                node.predicate.line, node.predicate.col,
                f'{node} predicate must have {self.cls_refs["Bool"]}, not {_static_type}'
            )

        self.visit(node.if_branch)
        self.visit(node.else_branch)

        node.set_static_type(
            self._lca(node.if_branch.static_type,
                      node.else_branch.static_type))

    def visit_While(self, node):
        self.visit(node.predicate)

        _static_type = node.predicate.static_type

        if _static_type.type.value != 'Bool':
            raise TypeError(
                node.predicate.line, node.predicate.col,
                f'{node} predicate must have {self.cls_refs["Bool"]}, not {_static_type}'
            )

        self.visit(node.body)

        node.set_static_type(self.cls_refs['Object'])

    def visit_LetVar(self, node):
        if node.id.value == 'self':
            raise SemanticError(node.id.line, node.id.col,
                                f'Tried to assign to {node.id}')

        if node.opt_expr_init:
            self.logger.info(f'{node} has expr')

            expr = node.opt_expr_init
            self.visit(expr)

            self.cur_env.define(node.id.value, node)
            self.visit(node.id)
            node.set_static_type(node.id.static_type)

            if not self._conforms(expr.static_type, node.static_type):
                raise TypeError(
                    node.line, node.col,
                    f'{expr} with {expr.static_type} doesnt conform to {node} with {node.static_type}'
                )

        else:
            self.cur_env.define(node.id.value, node)
            self.visit(node.id)
            node.set_static_type(node.id.static_type)

    def visit_Let(self, node):
        old_env = self.cur_env
        self.cur_env = Environment(old_env)

        for let_var in node.let_list:
            self.visit(let_var)

        self.visit(node.body)
        node.set_static_type(node.body.static_type)

        self.cur_env = old_env

    def visit_CaseVar(self, node):
        if node.id.value == 'self':
            raise SemanticError(node.id.line, node.id.col,
                                f'Tried to assign to {node.id}')

        if node.type.value == 'SELF_TYPE':
            raise SemanticError(node.type.line, node.type.col,
                                f'Tried to declare {node} with {node.type}')

        self.cur_env.define(node.id.value, node)
        self.visit(node.id)
        node.set_static_type(node.id.static_type)

    def visit_Case(self, node):
        self.visit(node.expr)

        mp = {}
        lca = None

        for branch in node.case_list:
            if branch.case_var.type.value in mp:
                raise SemanticError(
                    branch.case_var.type.line, branch.case_var.type.col,
                    f'{branch.case_var.type} appears in other branch of {node}'
                )

            mp[branch.case_var.type.value] = True

            old_env = self.cur_env
            self.cur_env = Environment(old_env)

            self.visit(branch.case_var)
            self.visit(branch.expr)

            if not lca:
                lca = branch.expr.static_type

            else:
                lca = self._lca(lca, branch.expr.static_type)

            self.cur_env = old_env

        node.set_static_type(lca)

    def visit_Plus(self, node):
        self.visit(node.left)
        self.visit(node.right)

        if node.left.static_type.type.value != 'Int' or node.right.static_type.type.value != 'Int':
            raise TypeError(
                node.line, node.col,
                f'{node.left} has {node.left.static_type}; {node.right} has {node.right.static_type}, they both should have {self.cls_refs["Int"]}'
            )

        node.set_static_type(self.cls_refs['Int'])

    def visit_Minus(self, node):
        self.visit_Plus(node)

    def visit_Mult(self, node):
        self.visit_Plus(node)

    def visit_Div(self, node):
        self.visit_Plus(node)

    def visit_Eq(self, node):
        self.visit(node.left)
        self.visit(node.right)

        types = ['Int', 'String', 'Bool']

        lft_type = node.left.static_type.type.value
        rgt_type = node.right.static_type.type.value

        if lft_type in types or rgt_type in types:
            if lft_type != rgt_type:
                raise TypeError(
                    node.line, node.col,
                    f'{node.left} with {node.left.static_type} and {node.right} with {node.right.static_type} must both have the same type'
                )

        node.set_static_type(self.cls_refs['Bool'])

    def visit_Less(self, node):
        self.visit(node.left)
        self.visit(node.right)

        if node.left.static_type.type.value != 'Int' or node.right.static_type.type.value != 'Int':
            raise TypeError(
                node.line, node.col,
                f'{node.left} and {node.right} must both have {self.cls_refs["Int"]}'
            )

        node.set_static_type(self.cls_refs['Bool'])

    def visit_LessEq(self, node):
        self.visit_Less(node)

    def visit_IntComp(self, node):
        self.visit(node.expr)

        if node.expr.static_type.type.value != 'Int':
            raise TypeError(node.line, node.col,
                            f'{node.expr} must have {self.cls_refs["Int"]}')

        node.set_static_type(self.cls_refs['Int'])

    def visit_Not(self, node):
        self.visit(node.expr)

        if node.expr.static_type.type.value != 'Bool':
            raise TypeError(node.line, node.col,
                            f'{node.expr} must have {self.cls_refs["Bool"]}')

        node.set_static_type(self.cls_refs['Bool'])

    def visit_IsVoid(self, node):
        node.set_static_type(self.cls_refs['Bool'])

    def visit_Id(self, node):
        ref = self.cur_env.get(node.value)

        if not ref:
            raise NameError(node.line, node.col,
                            f'{node} doesnt exists in this environment')

        self.logger.info(
            f'{node}, asked for reference, got {ref} with declared type : {ref.type}'
        )

        node.set_static_type(
            self._get_correct_type(ref, self.cur_cls.self_type))

    def visit_Int(self, node):
        node.set_static_type(self.cls_refs['Int'])

    def visit_Bool(self, node):
        node.set_static_type(self.cls_refs['Bool'])

    def visit_String(self, node):
        node.set_static_type(self.cls_refs['String'])
class GenCIL:  #in this model Type, Let, LetVar, CaseVar, Class doesnt exists (ie, they are not generated)
    def __init__(self, cls_refs):
        self.logger = init_logger('GenCIL')
        self.cls_refs = cls_refs
        self.attrs = List()  #to hold attributes
        self.cil_code = CILCode(List(), List(), defaultdict(lambda: []), {})
        self.pos = -1
        self.max_idx = -1
        self.cur_env = None  #environment for locals only
        self.cur_cls = None

        # save empty string literal
        self._save_str_literal('')

        # save 0 int literal
        self._save_int_literal(0)

    @staticmethod
    def get_default_value(_type: str):
        expr = Void()

        if _type == 'Bool':
            expr = Bool('false')

        elif _type == 'String':
            expr = String('')

        elif _type == 'Int':
            expr = Int('0')

        return expr

    def _save_str_literal(self, value: str):
        if value not in self.cil_code.str_literals:
            label = f'{LABEL_STR_LITERAL}{len(self.cil_code.str_literals)}'
            self.cil_code.str_literals[value] = label

            # save length int literal
            self._save_int_literal(len(value))

    def _save_int_literal(self, value: int):
        if value not in self.cil_code.int_literals:
            label = f'{LABEL_INT_LITERAL}{len(self.cil_code.int_literals)}'
            self.cil_code.int_literals[value] = label

    def _get_declaration_expr(self, node):
        expr = GenCIL.get_default_value(node.type.value)

        if node.opt_expr_init:
            expr = self.visit(node.opt_expr_init)

        return expr

    def visit(self, node):
        if isinstance(node, Class):
            self.logger.debug('.' * 200)

        self.logger.debug(f'On {node}')

        fn = getattr(self, 'visit_' + node.__class__.__name__)
        res = fn(node)

        if isinstance(node, Class):
            self.logger.debug('.' * 200)

        return res

    def visit_Class(self, node):
        old_attrs = self.attrs
        self.attrs = List(old_attrs[:])

        old_env = self.cur_env
        self.cur_env = Environment(old_env)

        old_cls = self.cur_cls
        self.cur_cls = node

        # let's save each class type as a String object in data segment
        self._save_str_literal(node.type.value)

        #own attrs
        own_attrs = [
            feature for feature in node.feature_list
            if isinstance(feature, Attribute)
        ]

        self.attr_dict = {}

        #filling attr_dict, it is needed so that references know what attr they are refering to
        p = 0

        #I ensure that _type_info attr will always be at position 0 in "attr table"
        assert node.reserved_attrs[0].ref.name == '_type_info'

        for attr in node.reserved_attrs:  #reserved attributes
            self.attr_dict[attr.ref.name] = p
            p += 1

        for decl in self.attrs:  #declarations of attributes from inheritance, these are instance of AttrDecl
            self.attr_dict[decl.ref.name] = p
            p += 1

        for attr in own_attrs:  #own attributes, note that these are instance of Attribute right now
            self.attr_dict[attr.id.value] = p
            p += 1

        for feature in node.feature_list:
            self.visit(feature)

        func_init = FuncInit(node.type.value, self.attrs, self.attr_dict,
                             f'{node.type.value}_Init',
                             List(node.reserved_attrs), node.type_obj)

        #needed for static data segment of the type
        func_init.td = self.cur_cls.td
        func_init.tf = self.cur_cls.tf

        self.cil_code.init_functions.append(func_init)
        self.cil_code.dict_init_func[func_init.name] = func_init

        self.logger.debug(func_init)
        self.logger.debug(f'Attrs: {list(self.attrs)}')

        for cls in node.children:
            self.visit(cls)

        self.pos -= self.cur_env.definitions  #undo
        self.cur_env = old_env
        self.attrs = old_attrs
        self.cur_cls = old_cls

    def visit_Formal(self, node):
        self.pos += 1  #do
        self.max_idx = max(self.max_idx, self.pos)
        self.cur_env.define(node.id.value, self.pos)

        return self.visit(node.id)

    def visit_Method(self, node):
        old_env = self.cur_env
        self.cur_env = Environment(old_env)

        assert self.pos == -1
        assert self.max_idx == -1

        formals = List([self.visit(formal) for formal in node.formal_list])
        body = self.visit(
            node.expr
        ) if node.expr else None  #if method is not native visit body, else None

        new_func = Function(node.id.value, formals, body, self.max_idx + 1)

        #needed for fast dispatch
        new_func.td = self.cur_cls.td
        new_func.tf = self.cur_cls.tf
        new_func.level = self.cur_cls.level
        new_func.label = f'{self.cur_cls.type.value}.{new_func.name}'

        self.logger.debug(
            f'{new_func}, td={new_func.td}, tf={new_func.tf}, level={new_func.level}, locals_size={new_func.locals_size}, label={new_func.label}'
        )

        self.cil_code.functions.append(new_func)
        self.cil_code.dict_func[new_func.name].append(new_func)

        self.max_idx = -1
        self.pos -= self.cur_env.definitions  #undo
        self.cur_env = old_env

    def visit_Attribute(self, node):
        assert self.pos == -1
        assert self.max_idx == -1

        ref = self.visit(node.id)
        expr = self._get_declaration_expr(node)

        dec = AttrDecl(ref, node.type.value, expr, self.max_idx + 1)
        self.attrs.append(dec)

        self.max_idx = -1

    def visit_Dispatch(self, node):
        expr = self.visit(node.expr)
        opt_class = node.opt_type  #can be none or a str

        if opt_class:
            opt_class = opt_class.value

        name = node.id.value
        args = List([self.visit(e) for e in node.expr_list])

        return FunctionCall(expr, opt_class, name, args)

    def visit_SELF_TYPE(self, node):
        pass

    def visit_Assignment(self, node):
        return Binding(self.visit(node.id), self.visit(node.expr))

    def visit_If(self, node):
        pred, if_branch, else_branch = self.visit(node.predicate), self.visit(
            node.if_branch), self.visit(node.else_branch)
        return If(pred, if_branch, else_branch)

    def visit_While(self, node):
        pred, body = self.visit(node.predicate), self.visit(node.body)
        return While(pred, body)

    def visit_Block(self, node):
        return Block(List([self.visit(expr) for expr in node.expr_list
                           ]))  #now block has a list arg

    def visit_LetVar(self, node):
        expr = self._get_declaration_expr(node)

        self.pos += 1  #do
        self.max_idx = max(self.max_idx, self.pos)
        self.cur_env.define(node.id.value, self.pos)

        ref = self.visit(node.id)

        return Binding(ref, expr)

    def visit_Let(self, node):
        old_env = self.cur_env
        self.cur_env = Environment(old_env)

        lets = [self.visit(let_var) for let_var in node.let_list]
        body = self.visit(node.body)

        self.pos -= self.cur_env.definitions  #undo
        self.cur_env = old_env

        return Let(List(lets), body)

    def visit_CaseVar(self, node):
        self.pos += 1  #do
        self.max_idx = max(self.max_idx, self.pos)
        self.cur_env.define(node.id.value, self.pos)

        cls = self.cls_refs[node.type.value]

        return self.visit(node.id), cls.td, cls.tf, cls.level

    def visit_CaseBranch(self, node):
        old_env = self.cur_env
        self.cur_env = Environment(old_env)

        ref, td, tf, level = self.visit(node.case_var)
        expr = self.visit(node.expr)

        branch = CaseBranch(ref, expr)
        branch.set_times(td, tf)
        branch.level = level

        self.pos -= self.cur_env.definitions  #undo
        self.cur_env = old_env

        return branch

    def visit_Case(self, node):
        expr = self.visit(node.expr)
        branches = List([self.visit(branch) for branch in node.case_list])
        branches.sort(key=lambda x: x.level,
                      reverse=True)  #sort by greater level

        case = Case(expr, branches)

        return case

    def visit_New(self, node):
        return New(node.type.value)  #now attr of new is a string

    def visit_IsVoid(self, node):
        return IsVoid(self.visit(node.expr))

    def visit_IntComp(self, node):
        return IntComp(self.visit(node.expr))

    def visit_Not(self, node):
        return Not(self.visit(node.expr))

    def visit_Plus(self, node):
        return Plus(self.visit(node.left), self.visit(node.right))

    def visit_Minus(self, node):
        return Minus(self.visit(node.left), self.visit(node.right))

    def visit_Mult(self, node):
        return Mult(self.visit(node.left), self.visit(node.right))

    def visit_Div(self, node):
        return Div(self.visit(node.left), self.visit(node.right))

    def visit_Less(self, node):
        return Less(self.visit(node.left), self.visit(node.right))

    def visit_LessEq(self, node):
        return LessEq(self.visit(node.left), self.visit(node.right))

    def visit_Eq(self, node):
        return Eq(self.visit(node.left), self.visit(node.right))

    def visit_Id(self, node):
        ref = Reference(node.value)

        if ref.name == 'self':  #ignore self for now, it will be always saved in some fixed register at CG
            return ref

        to = self.cur_env.get(ref.name)

        if to is None:  #must be an attr variable
            assert ref.name in self.attr_dict
            ref.refers_to = ('attr', self.attr_dict[ref.name])

        else:
            ref.refers_to = (
                'local', to)  #local variable (ie. formal, let_var or case_var)

        self.logger.debug(f'{ref} refers to: {ref.refers_to}')

        return ref

    def visit_Int(self, node):
        ref = Int(node.value)
        self._save_int_literal(int(ref.value))
        return ref

    def visit_Bool(self, node):
        return Bool(node.value)

    def visit_String(self, node):
        ref = String(node.value)
        self._save_str_literal(ref.value)
        return ref