class ParameterEstimation:
    def __init__(self, network):
        self._network = network
        self._data = DataExtractor(network.name)

    def _get_probabilities(self, X, S, S_combinations):
        data_vectors = self._data.get_data_vectors()
        N = len(data_vectors[data_vectors.keys()[0]])
        X_values = data_vectors[X]
        observed_prob_dict = {}
        #  Now we look for the  value x of the variable X
        for x in self._values_dict[X]:
            #  finding matches for x
            x_indices = set([element_index for (element_index, element) in enumerate(X_values) if element == x])
            observed_prob_dict['P(' + X + '=' + x + ')'] = (len(x_indices) / N) + 0.001

            for S_combination in S_combinations:
                z_indices = self._get_z_indices(S, S_combination)
                z = z_indices
                x_z = x_indices.intersection(z)
                observed_prob_dict['P(' + X + '=' + x + '|' + ','.join(S) + '=' + ','.join(S_combination) + ')'] = (len(x_z) / float(len(z))) + 0.001
        return observed_prob_dict

    def get_estimated_cpds(self):
        values_dict = self._data.get_variable_values_sets()
        cpds = []
        for node in self._network:
            parents = self._network.predecessors(node)
            value_combinations = PGMUtils.get_combinations(parents, values_dict)
            probability_dict = self._get_probabilities(node, parents, value_combinations)
            cpds.append(probability_dict)
        return cpds
class GreedyHillClimber:
    _max_change_count = 20

    def __init__(self, hyperparameter, initial_bayesian_network, tabu_list_size, max_change_count):
        self._bayesian_network = initial_bayesian_network
        self._best_score = -float('inf')
        self._best_solution = initial_bayesian_network
        self._actions_list = ['add', 'remove', 'reverse']
        self._tabu_list = OrderedDict()
        self._tabulist_size = tabu_list_size
        self._max_change_count = max_change_count
        self._data = DataExtractor(initial_bayesian_network.name)
        self._node_names = self._data.get_variable_values_sets().keys()
        values_sets = self._data.get_variable_values_sets()
        data_vectors = self._data.get_data_vectors()
        self._score_util = BDeuScoreUtil(hyperparameter, self._bayesian_network, data_vectors, values_sets)

    def _get_score(self, action = None, edge = None):
        #  We calculate the score using the BDeu score
        #  calculator
        return self._score_util.get_score(action, edge)

    def _equals(self, bayesian_network_A, bayesian_network_B):
        #  Return true if two bayesian network with identical nodes
        #  also have identical edges.
        signature_A = self._get_bn_signature(bayesian_network_A)
        signature_B = self._get_bn_signature(bayesian_network_B)
        return signature_A == signature_B

    def _tabu_list_contains(self, bayesian_network):
        #  Returns true if the tabu list contains the given
        #  bayesian network
        solution_signature = self._get_bn_signature(bayesian_network)
        has_solution = solution_signature in self._tabu_list
        if has_solution:
            pass    #  print 'solution is  contained in  tabulist(length = ', len(self._tabu_list), ')'
        else:
            pass    #  print  'solution is  not contained in  tabulist'
        return has_solution

    def _get_bn_signature(self, bayesian_network):
        #  Generate a string from the edge set of the given bayesian
        #  network which is unique for a given edge set
        edge_string_list = []
        for edge in bayesian_network.edges():
            edge_string = str(edge[0]) + '-' + str(edge[1])
            edge_string_list.append(edge_string)
        signature = ' '.join(edge_string_list)
        return signature

    def _add_solution_to_tabu_list(self, bayesian_network):
        #  Adds the given bayesian network to the tabu list
        if len(self._tabu_list) == self._tabulist_size:
            first_key = self._tabu_list.keys()[0]
            self._tabu_list.pop(first_key)

        solution_signature = self._get_bn_signature(bayesian_network)
        self._tabu_list[solution_signature] = 'dummy'

    def _get_feasible_local_solutions(self, bayesian_network, undirected_graph, edge):
        local_solutions_action_pairs = []
        #  Calculate all possible local solutions by applying
        #  all the possible actions.
        temp_bn = deepcopy(bayesian_network)
        temp_graph = deepcopy(undirected_graph)
        for action in self._actions_list:
            #  print action + 'ing', edge, ' in ', bayesian_network.edges()
            is_feasible = GraphUtils.apply_action(temp_bn, temp_graph, (edge), action, 2)

            if not is_feasible:
                #  If the action was not feasible  then try again
                #  print 'Infeasible action.. trying with different action'
                continue

            if self._tabu_list_contains(temp_bn):
                #  If generated solution is already in the tabu list then try again
                #  print 'Solution already in tabu list trying again'
                continue
            #  print 'Got ', temp_bn.edges()
            local_solutions_action_pairs.append((temp_bn, action))
            temp_bn = deepcopy(bayesian_network)
            temp_graph = deepcopy(undirected_graph)
        return local_solutions_action_pairs

    def _get_best_local_solution(self, bayesian_network, undirected_graph, edge):
        local_solutions_action_pairs = self._get_feasible_local_solutions(bayesian_network, undirected_graph, edge)
        if len(local_solutions_action_pairs) == 0:
            return self._get_score(bayesian_network), bayesian_network
        scores = [self._get_score(solution_action_pair[1], edge) for solution_action_pair in local_solutions_action_pairs]
        #  The solution with maximum score is the most optimal one
        sorted_scores = sorted(scores, reverse = True)
        #  print 'Scores: ', scores
        best_local_solution_score = sorted_scores[0]
        best_solution_index = scores.index(best_local_solution_score)
        #  print local_solutions_action_pairs[best_solution_index][1], ' action is the best action'
        best_local_solution = local_solutions_action_pairs[best_solution_index][0]
        return best_local_solution_score, best_local_solution

    def perform_GHC(self):
        current_solution = self._bayesian_network
        self._best_score = current_score = self._get_score(current_solution)
        #  draw(self._bayesian_network)
        #  plt.show()
        print 'Initial score :', self._best_score
        undirected_graph = current_solution.to_undirected()
        change_count = 0
        max_count = self._max_change_count
        print max_count
        while True:
            #  Pick a random edge and decide the best action to be
            #  applied on the edge
            random_edge = GraphUtils.get_random_edge(self._node_names)
            #  print random_edge, ' is the edge selected'
            current_score, current_solution = \
            self._get_best_local_solution(current_solution,
                                              undirected_graph, random_edge)
            undirected_graph = current_solution.to_undirected()

            if current_score > self._best_score:
                change_count = 0
                #  Update the new best solution
                self._best_solution = deepcopy(current_solution)
                self._best_score = current_score
                print '-----------', self._best_score , '------------------'

            else:
                change_count += 1

            self._add_solution_to_tabu_list(current_solution)

            if change_count == max_count:
                break

    def get_solution(self):
        return self._best_solution, self._best_score
@author: himanshu
'''
import json
from networkx import DiGraph, draw
from libpgm.nodedata import NodeData
from libpgm.graphskeleton import GraphSkeleton
from libpgm.discretebayesiannetwork import DiscreteBayesianNetwork
from libpgm.pgmlearner import PGMLearner
import matplotlib.pyplot as plt
from data_extractor import DataExtractor


#  generate some data to use
data_ext = DataExtractor('genome', format = 'json')
data = data_ext.get_data_vectors()
print 'Got data with ', len(data), ' vectors'
#  instantiate my learner
learner = PGMLearner()

print 'learning the structure'
#  estimate structure
result = learner.discrete_constraint_estimatestruct(data, pvalparam = 0.02)

#  output
print json.dumps(result.E, indent = 2)
graph = DiGraph()
graph.add_edges_from(result.E)
draw(graph)
plt.show()
class PC:

    _mutual_info_thresholds = [0.0005, 0.005, 0.025, 0.025]

    def __init__(self, network_name):
        self._network_name = network_name
        self._data = DataExtractor(network_name)
        self._values_dict = self._data.get_variable_values_sets()
        self._node_names = self._values_dict.keys()
        self._graph = None
        self._nmis = {}

    def _get_probabilities(self, X, Y, S, S_combinations):
        data_vectors = self._data.get_data_vectors()
        N = len(data_vectors[data_vectors.keys()[0]])
        X_values = data_vectors[X]
        Y_values = data_vectors[Y]
        observed_prob_dict = {}
        #  Now we look for the  value x of the variable X, and value y of the variable Y
        for x in self._values_dict[X]:    #  finding matches for x
            x_indices = set([element_index for (element_index, element) in enumerate(X_values) if element == x])
            observed_prob_dict['P(' + X + '=' + x + ')'] = len(x_indices) / float(N)
            for y in self._values_dict[Y]:
                #  finding matches for y
                y_indices = set([element_index for (element_index, element) in enumerate(Y_values) if element == y])
                observed_prob_dict['P(' + Y + '=' + y + ')'] = len(y_indices) / float(N)
                xy = x_indices.intersection(y_indices)
                observed_prob_dict['P(' + X + '=' + x + ',' + Y + '=' + y + ')'] = len(xy) / float(N)

                for S_combination in S_combinations:
                    z_indices = PGMUtils.get_z_indices(S, S_combination, data_vectors)
                    z = z_indices
                    y_z = y_indices.intersection(z)
                    x_z = x_indices.intersection(z)
                    xyz = xy.intersection(z)
                    observed_prob_dict['P(' + X + '=' + x + '|' + ','.join(S) + '=' + ','.join(S_combination) + ')'] = len(x_z) / float(len(z))
                    observed_prob_dict['P(' + ','.join(S) + '=' + ','.join(S_combination) + ')'] = len(z) / float(N)
                    observed_prob_dict['P(' + Y + '=' + y + '|' + ','.join(S) + '=' + ','.join(S_combination) + ')'] = len(y_z) / float(len(z))
                    observed_prob_dict['P(' + X + '=' + x + ',' + Y + '=' + y + ',' + ','.join(S) + '=' + ','.join(S_combination) + ')'] = len(xyz) / float(N)
        return observed_prob_dict

    def _are_dseparated(self, X, Y, S, n):
        H_X = H_Y = H_XY = 0
        S_combinations = PGMUtils.get_combinations(S, self._values_dict)
        probability_dict = self._get_probabilities(X, Y, S, S_combinations)
        for x in self._values_dict[X]:
            p_x = probability_dict['P(' + X + '=' + x + ')']
            for y in self._values_dict[Y]:
                p_y = probability_dict['P(' + Y + '=' + y + ')']
                #  in case we are looking for zero order conditional dependency
                if len(S_combinations) == 0:
                    H_Y += -log(p_y + 0.001)
                    H_X += -log(p_x + 0.001)
                    p_xy = probability_dict['P(' + X + '=' + x + ',' + Y + '=' + y + ')']
                    H_XY += -log(p_xy + 0.001)
                else:
                    for S_combination in S_combinations:
                        p_y_z = probability_dict['P(' + Y + '=' + y + '|' + ','.join(S) + '=' + ','.join(S_combination) + ')']
                        p_x_z = probability_dict['P(' + X + '=' + x + '|' + ','.join(S) + '=' + ','.join(S_combination) + ')']
                        p_xyz = probability_dict['P(' + X + '=' + x + ',' + Y + '=' + y + ',' + ','.join(S) + '=' + ','.join(S_combination) + ')']
                        p_z = probability_dict['P(' + ','.join(S) + '=' + ','.join(S_combination) + ')']
                        H_X += -log(p_x_z + 0.001)
                        H_Y += -log(p_y_z + 0.001)
                        H_XY += -log(p_xyz * p_z + 0.001)

        #  If mutual information is greater than certain threshhold
        #  then X and Y are dependent otherwise not
        n_X = 2 * len(self._values_dict[X])
        n_Y = 2 * len(self._values_dict[Y])
        n_XY = 4 * len(S_combinations)
        if n_XY == 0:
            n_XY = 4
        MI = abs((H_X / n_X) + (H_Y / n_Y) - (H_XY / n_XY))
        #  print 'MI(', X + ',' + Y + '|' + ','.join(S), ') = ', MI
        self._nmis[X + ',' + Y + '|' + ','.join(S)] = MI
        if MI < self._mutual_info_thresholds[n]:
            return True
        return False

    def _eliminate_edges(self, Sep):
        num_nodes = len(self._values_dict.keys())
        graph = complete_graph(num_nodes, Graph())
        self._graph = GraphUtils.rename_nodes(graph, self._node_names)
        n = 0
        max_allowed_degree = settings.networks_settings['genome']['max_allowed_degree']
        while n <= 3:
            print '--------------------------------------------------------'
            #  We repeat the iterations unless each node X has
            #  less than or equal to n neighbors
            for X in self._graph:
                for Y in self._graph:
                    if X != Y and GraphUtils.is_degree_greater(self._graph, max_allowed_degree) \
                    and is_connected(self._graph):
                        #  all the neighbors of X excluding Y
                        neighbors = self._graph.neighbors(X)
                        if Y in neighbors:
                            neighbors.remove(Y)
                            #  We only consider X,Y if #neighbors of X excluding Y are more than
                            #  or equal to n
                            if len(neighbors) >= n:
                                #  Combinations of all the adjacent nodes of X excluding Y
                                #  each subset in the observed_sets has cardinality 'n'
                                observed_subsets = combinations(neighbors, n)
                                for S in observed_subsets:
                                    #  We only consider the subsets which have exactly
                                    S = [s for s in sorted(S)]
                                    are_deseparated = self._are_dseparated(X, Y, S, n)
                                    if are_deseparated:
                                        if self._graph.has_edge(X, Y):
                                            self._graph.remove_edge(X, Y)
                                            print 'Removed', X, '-', Y
                                            Sep[X + ',' + Y] = S
                                            Sep[Y + ',' + X] = S
            n += 1

    def _has_directed_path(self, A, B):
        has_directed_path = False
        paths = all_simple_paths(self._graph, A, B)
        for path in paths:
            has_directed_path = has_directed_path or has_directed_path
            if has_directed_path:
                break
            i = 0
            while i < len(path) - 1:
                src_node = path[i]
                next_node = path[i + 1]
                edge = self._graph.edge[src_node] [next_node]
                if 'direction' in edge:
                    if edge['direction'] == src_node + '->' + next_node:
                        has_directed_path = True
                else:
                    has_directed_path = False
                    break
                i += 1
        return has_directed_path

    def _all_edges_oriented(self):
        '''
        for edge in self._graph.edges():
            if 'direction'in self._graph[edge[0]][edge[1]]:
                print self._graph[edge[0]][edge[1]]['direction']
        '''
        for edge in self._graph.edges():
            if 'direction'not in self._graph[edge[0]][edge[1]]:
                return False
        return True

    def _orient_edges(self, Sep):
        triplets = []
        for source in self._graph.nodes():
            for target in self._graph.nodes():
                if source != target:
                    if not self._graph.has_edge(source, target):
                        #  Each element in triplets lists will be a list of three nodes
                        #  [X, Y , Z] such that X and Z are not adjacent in the graph
                        #  while X,Y and Y,Z are adjacent
                        triplets.append(list(all_simple_paths(self._graph, source, target, 2)))

        for triplet in triplets:
            if triplet != []:
                X, Y , Z = triplet[0][0], triplet[0][1], triplet[0][2]
                if Y not in Sep[X + ',' + Z]:
                    #  We dont have partially connected graphs in networkx library
                    #  so we attach a direction attribute to all the edges which we
                    #  want to be directed
                    edgeXY = self._graph.edge[X][Y]
                    edgeXY['direction'] = X + '->' + Y
                    edgeZY = self._graph.edge[Z][Y]
                    edgeZY['direction'] = Z + '->' + Y

        while not self._all_edges_oriented():
            for edge in self._graph.edges():
                A = edge[0]
                B = edge[1]
                edgeAB = self._graph.edge[A][B]
                if 'direction' in edgeAB:
                    if edgeAB['direction'] == A + '->' + 'B':
                        for C in self._graph.neighbors(B):
                            #  A & C are not adjacent
                            if not self._graph.has_edge(A, C):
                                edgeBC = self._graph.edge[B][C]
                                if 'direction' not in edgeBC:
                                    edgeBC['direction'] = B + '->' + C

                elif self._has_directed_path(A, B):
                    edgeAB['direction'] = A + '->' + B

    def perform_PC(self):
        #  Implementation of the PC algorithm given here:
        #  http://www.lowcaliber.org/influence/spirtes-causation-prediction-search.pdf
        Sep = {}
        self._eliminate_edges(Sep)
        pprint(sorted(self._nmis.iteritems(), key = itemgetter(1), reverse = True))
        print self._graph.edges()
        draw(self._graph)
        plt.show()
        if is_connected(self._graph):
            print 'The graph is connected'
        else:
            print 'The graph is not connected'

        self._orient_edges(Sep)
        pprint (self._graph.edges())

    def get_skeleton(self):
        self._graph = GraphUtils.convert_to_directed(self._graph)
        self._graph.name = self._network_name
        return self._graph