def rank_rule_scores(dnf_rule: Rule, X_train, y_train, use_rl: bool):
    """

    Args:
        dnf_rule: dnf rule for a class
        X_train: train data
        y_train: test data
        use_rl: if true takes the length of rules into account too

    Returns:
        Set two scores for each rule, accuracy and rank, where rank is based on
        the following formula: (cc - ic)/(cc + ic) + (cc)/(ic + k) + (cc)/(rl)

    """
    # Each run of rule extraction return a DNF rule for each output class
    rule_output = dnf_rule.get_conclusion()

    # Each clause in the dnf rule is considered a rule for this output class
    for clause in dnf_rule.get_premise():
        cc = ic = 0
        rl = len(clause.get_terms())

        # Iterate over all items in the training data
        for i in range(0, len(X_train)):
            # Map of Neuron objects to values from input data. This is the form of data a rule expects
            neuron_to_value_map = {
                Neuron(layer=0, index=j): X_train[i][j]
                for j in range(len(X_train[i]))
            }

            if clause.evaluate(data=neuron_to_value_map):
                if rule_output.encoding == y_train[i]:
                    cc += 1
                else:
                    ic += 1

        # Compute rule rank_score
        if cc + ic == 0:
            accuracy_score = rank_score = 0
        else:
            accuracy_score = cc / (cc + ic)
            rank_score = ((cc - ic) / (cc + ic)) + cc / (ic + k)

        if use_rl:
            rank_score += cc / rl

        # print('cc: %d, ic: %d, rl: %d  rankscroe: %f' % (cc, ic, rl, rank_score))

        # Save rank score
        clause.set_accuracy_score(accuracy_score)
        clause.set_rank_score(rank_score)
def substitute(total_rule: Rule, intermediate_rules: Ruleset) -> Rule:
    """
    Substitute the intermediate rules from the previous layer into the total rule
    """
    new_premise_clauses = set()

    print('  Rule Premise Length: ', len(total_rule.get_premise()))
    premise_count = 1

    # for each clause in the total rule
    for old_premise_clause in total_rule.get_premise():
        print('    premise: %d' % premise_count)

        # list of sets of conjunctive clauses that are all conjunctive
        conj_new_premise_clauses = []
        for old_premise_term in old_premise_clause.get_terms():
            clauses_to_append = intermediate_rules.get_rule_premises_by_conclusion(
                old_premise_term)
            if clauses_to_append:
                conj_new_premise_clauses.append(clauses_to_append)

        # Print progress bar of all clause combinations need to be iterated over
        n_clause_combs = 1
        for clause_set in conj_new_premise_clauses:
            n_clause_combs = n_clause_combs * len(clause_set)
        if n_clause_combs > 10000:
            for _ in range(0, n_clause_combs // 10000):
                print('.', end='', flush=True)
            print()

        # When combined into a cartesian product, get all possible conjunctive clauses for merged rule
        # Itertools implementation does not build up intermediate results in memory
        conj_new_premise_clauses_combinations = itertools.product(
            *tuple(conj_new_premise_clauses))

        # given tuples of ConjunctiveClauses that are all now conjunctions, union terms into a single clause
        clause_comb_count = 0
        for premise_clause_tuple in conj_new_premise_clauses_combinations:
            new_clause = ConjunctiveClause()
            for premise_clause in premise_clause_tuple:
                new_clause = new_clause.union(premise_clause)
            new_premise_clauses.add(new_clause)

            clause_comb_count += 1
            if clause_comb_count % 10000 == 0:
                print('.', end='', flush=True)
        premise_count += 1

    return Rule(premise=new_premise_clauses,
                conclusion=total_rule.get_conclusion())